adjust.py 52 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from math import pi
  19. from typing import ClassVar, Optional, Union
  20. import torch
  21. from kornia.color import hsv_to_rgb, rgb_to_grayscale, rgb_to_hsv
  22. from kornia.core import ImageModule as Module
  23. from kornia.core import Parameter, Tensor, tensor
  24. from kornia.core.check import (
  25. KORNIA_CHECK,
  26. KORNIA_CHECK_IS_COLOR_OR_GRAY,
  27. KORNIA_CHECK_IS_TENSOR,
  28. )
  29. from kornia.utils.helpers import _torch_histc_cast
  30. from kornia.utils.image import perform_keep_shape_image, perform_keep_shape_video
  31. def adjust_saturation_raw(image: Tensor, factor: Union[float, Tensor]) -> Tensor:
  32. r"""Adjust color saturation of an image.
  33. Expecting image to be in hsv format already.
  34. """
  35. KORNIA_CHECK_IS_TENSOR(image, "Expected shape (*, H, W)")
  36. KORNIA_CHECK(isinstance(factor, (float, Tensor)), "Factor should be float or Tensor.")
  37. if isinstance(factor, float):
  38. # TODO: figure out how to create later a tensor without importing torch
  39. factor = torch.as_tensor(factor, device=image.device, dtype=image.dtype)
  40. elif isinstance(factor, Tensor):
  41. factor = factor.to(image.device, image.dtype)
  42. # make factor broadcastable
  43. while len(factor.shape) != len(image.shape):
  44. factor = factor[..., None]
  45. # unpack the hsv values
  46. h, s, v = torch.chunk(image, chunks=3, dim=-3)
  47. # transform the hue value and appl module
  48. s_out: Tensor = torch.clamp(s * factor, min=0, max=1)
  49. # pack back back the corrected hue
  50. out: Tensor = torch.cat([h, s_out, v], dim=-3)
  51. return out
  52. def adjust_saturation_with_gray_subtraction(image: Tensor, factor: Union[float, Tensor]) -> Tensor:
  53. r"""Adjust color saturation of an image by blending the image with its grayscaled version.
  54. The image is expected to be an RGB image or a gray image in the range of [0, 1].
  55. If it is an RGB image, returns blending of the image with its grayscaled version.
  56. If it is a gray image, returns the image.
  57. .. note::
  58. this is just a convenience function to have compatibility with Pil
  59. Args:
  60. image: Image/Tensor to be adjusted in the shape of :math:`(*, 3, H, W)`.
  61. factor: How much to adjust the saturation. 0 will give a black
  62. and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2.
  63. Return:
  64. Adjusted image in the shape of :math:`(*, 3, H, W)`.
  65. Example:
  66. >>> x = torch.ones(1, 3, 3, 3)
  67. >>> adjust_saturation_with_gray_subtraction(x, 2.).shape
  68. torch.Size([1, 3, 3, 3])
  69. >>> x = torch.ones(2, 3, 3, 3)
  70. >>> y = torch.tensor([1., 2.])
  71. >>> adjust_saturation_with_gray_subtraction(x, y).shape
  72. torch.Size([2, 3, 3, 3])
  73. """
  74. KORNIA_CHECK_IS_TENSOR(image, "Expected shape (*, H, W)")
  75. KORNIA_CHECK(isinstance(factor, (float, Tensor)), "Factor should be float or Tensor.")
  76. KORNIA_CHECK_IS_COLOR_OR_GRAY(image, "Image should be an RGB or gray image")
  77. if image.shape[-3] == 1:
  78. return image
  79. if isinstance(factor, float):
  80. # TODO: figure out how to create later a tensor without importing torch
  81. factor = torch.as_tensor(factor, device=image.device, dtype=image.dtype)
  82. elif isinstance(factor, Tensor):
  83. factor = factor.to(image.device, image.dtype)
  84. # make factor broadcastable
  85. while len(factor.shape) != len(image.shape):
  86. factor = factor[..., None]
  87. x_other: Tensor = rgb_to_grayscale(image)
  88. # blend the image with the grayscaled image
  89. x_adjusted: Tensor = (1 - factor) * x_other + factor * image
  90. # clamp the output
  91. out: Tensor = torch.clamp(x_adjusted, 0.0, 1.0)
  92. return out
  93. def adjust_saturation(image: Tensor, factor: Union[float, Tensor]) -> Tensor:
  94. r"""Adjust color saturation of an image.
  95. .. image:: _static/img/adjust_saturation.png
  96. The image is expected to be an RGB image in the range of [0, 1].
  97. Args:
  98. image: Image/Tensor to be adjusted in the shape of :math:`(*, 3, H, W)`.
  99. factor: How much to adjust the saturation. 0 will give a black
  100. and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2.
  101. saturation_mode: The mode to adjust saturation.
  102. Return:
  103. Adjusted image in the shape of :math:`(*, 3, H, W)`.
  104. .. note::
  105. See a working example `here <https://kornia.github.io/tutorials/nbs/image_enhancement.html>`__.
  106. Example:
  107. >>> x = torch.ones(1, 3, 3, 3)
  108. >>> adjust_saturation(x, 2.).shape
  109. torch.Size([1, 3, 3, 3])
  110. >>> x = torch.ones(2, 3, 3, 3)
  111. >>> y = torch.tensor([1., 2.])
  112. >>> adjust_saturation(x, y).shape
  113. torch.Size([2, 3, 3, 3])
  114. """
  115. # convert the rgb image to hsv
  116. x_hsv: Tensor = rgb_to_hsv(image)
  117. # perform the conversion
  118. x_adjusted: Tensor = adjust_saturation_raw(x_hsv, factor)
  119. # convert back to rgb
  120. out: Tensor = hsv_to_rgb(x_adjusted)
  121. return out
  122. def adjust_hue_raw(image: Tensor, factor: Union[float, Tensor]) -> Tensor:
  123. r"""Adjust hue of an image.
  124. Expecting image to be in hsv format already.
  125. """
  126. KORNIA_CHECK_IS_TENSOR(image, "Expected shape (*, H, W)")
  127. KORNIA_CHECK(
  128. isinstance(factor, (float, Tensor)),
  129. f"The factor should be a float number or Tensor in the range between [-PI, PI]. Got {type(factor)}",
  130. )
  131. if isinstance(factor, float):
  132. factor = torch.as_tensor(factor)
  133. factor = factor.to(image.device, image.dtype)
  134. # make factor broadcastable
  135. while len(factor.shape) != len(image.shape):
  136. factor = factor[..., None]
  137. # unpack the hsv values
  138. h, s, v = torch.chunk(image, chunks=3, dim=-3)
  139. # transform the hue value and appl module
  140. divisor: float = 2 * pi
  141. h_out: Tensor = torch.fmod(h + factor, divisor)
  142. # pack back back the corrected hue
  143. out: Tensor = torch.cat([h_out, s, v], dim=-3)
  144. return out
  145. def adjust_hue(image: Tensor, factor: Union[float, Tensor]) -> Tensor:
  146. r"""Adjust hue of an image.
  147. .. image:: _static/img/adjust_hue.png
  148. The image is expected to be an RGB image in the range of [0, 1].
  149. Args:
  150. image: Image to be adjusted in the shape of :math:`(*, 3, H, W)`.
  151. factor: How much to shift the hue channel. Should be in [-PI, PI]. PI
  152. and -PI give complete reversal of hue channel in HSV space in positive and negative
  153. direction respectively. 0 means no shift. Therefore, both -PI and PI will give an
  154. image with complementary colors while 0 gives the original image.
  155. Return:
  156. Adjusted image in the shape of :math:`(*, 3, H, W)`.
  157. .. note::
  158. See a working example `here <https://kornia.github.io/tutorials/nbs/image_enhancement.html>`__.
  159. Example:
  160. >>> x = torch.ones(1, 3, 2, 2)
  161. >>> adjust_hue(x, 3.141516).shape
  162. torch.Size([1, 3, 2, 2])
  163. >>> x = torch.ones(2, 3, 3, 3)
  164. >>> y = torch.ones(2) * 3.141516
  165. >>> adjust_hue(x, y).shape
  166. torch.Size([2, 3, 3, 3])
  167. """
  168. # convert the rgb image to hsv
  169. x_hsv: Tensor = rgb_to_hsv(image)
  170. # perform the conversion
  171. x_adjusted: Tensor = adjust_hue_raw(x_hsv, factor)
  172. # convert back to rgb
  173. out: Tensor = hsv_to_rgb(x_adjusted)
  174. return out
  175. def adjust_gamma(input: Tensor, gamma: Union[float, Tensor], gain: Union[float, Tensor] = 1.0) -> Tensor:
  176. r"""Perform gamma correction on an image.
  177. .. image:: _static/img/adjust_contrast.png
  178. The input image is expected to be in the range of [0, 1].
  179. Args:
  180. input: Image to be adjusted in the shape of :math:`(*, H, W)`.
  181. gamma: Non negative real number, same as y\gammay in the equation.
  182. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make
  183. dark regions lighter.
  184. gain: The constant multiplier.
  185. Return:
  186. Adjusted image in the shape of :math:`(*, H, W)`.
  187. .. note::
  188. See a working example `here <https://kornia.github.io/tutorials/nbs/image_enhancement.html>`__.
  189. Example:
  190. >>> x = torch.ones(1, 1, 2, 2)
  191. >>> adjust_gamma(x, 1.0, 2.0)
  192. tensor([[[[1., 1.],
  193. [1., 1.]]]])
  194. >>> x = torch.ones(2, 5, 3, 3)
  195. >>> y1 = torch.ones(2) * 1.0
  196. >>> y2 = torch.ones(2) * 2.0
  197. >>> adjust_gamma(x, y1, y2).shape
  198. torch.Size([2, 5, 3, 3])
  199. """
  200. if not isinstance(input, Tensor):
  201. raise TypeError(f"Input type is not a Tensor. Got {type(input)}")
  202. if not isinstance(gamma, (float, Tensor)):
  203. raise TypeError(f"The gamma should be a positive float or Tensor. Got {type(gamma)}")
  204. if not isinstance(gain, (float, Tensor)):
  205. raise TypeError(f"The gain should be a positive float or Tensor. Got {type(gain)}")
  206. if isinstance(gamma, float):
  207. gamma = Tensor([gamma])
  208. if isinstance(gain, float):
  209. gain = Tensor([gain])
  210. gamma = gamma.to(input.device).to(input.dtype)
  211. gain = gain.to(input.device).to(input.dtype)
  212. if (gamma < 0.0).any():
  213. raise ValueError(f"Gamma must be non-negative. Got {gamma}")
  214. if (gain < 0.0).any():
  215. raise ValueError(f"Gain must be non-negative. Got {gain}")
  216. for _ in range(len(input.shape) - len(gamma.shape)):
  217. gamma = torch.unsqueeze(gamma, dim=-1)
  218. for _ in range(len(input.shape) - len(gain.shape)):
  219. gain = torch.unsqueeze(gain, dim=-1)
  220. # Apply the gamma correction
  221. x_adjust: Tensor = gain * torch.pow(input, gamma)
  222. # Truncate between pixel values
  223. out: Tensor = torch.clamp(x_adjust, 0.0, 1.0)
  224. return out
  225. def adjust_contrast(image: Tensor, factor: Union[float, Tensor], clip_output: bool = True) -> Tensor:
  226. r"""Adjust the contrast of an image tensor.
  227. .. image:: _static/img/adjust_contrast.png
  228. This implementation follows Szeliski's book convention, where contrast is defined as
  229. a `multiplicative` operation directly to raw pixel values. Beware that other frameworks
  230. might use different conventions which can be difficult to reproduce exact results.
  231. The input image and factor is expected to be in the range of [0, 1].
  232. .. tip::
  233. This is not the preferred way to adjust the contrast of an image. Ideally one must
  234. implement :func:`kornia.enhance.adjust_gamma`. More details in the following link:
  235. https://scikit-image.org/docs/dev/auto_examples/color_exposure/plot_log_gamma.html#sphx-glr-auto-examples-color-exposure-plot-log-gamma-py
  236. Args:
  237. image: Image to be adjusted in the shape of :math:`(*, H, W)`.
  238. factor: Contrast adjust factor per element
  239. in the batch. 0 generates a completely black image, 1 does not modify
  240. the input image while any other non-negative number modify the
  241. brightness by this factor.
  242. clip_output: whether to clip the output image with range of [0, 1].
  243. Return:
  244. Adjusted image in the shape of :math:`(*, H, W)`.
  245. .. note::
  246. See a working example `here <https://kornia.github.io/tutorials/nbs/image_enhancement.html>`__.
  247. Example:
  248. >>> import torch
  249. >>> x = torch.ones(1, 1, 2, 2)
  250. >>> adjust_contrast(x, 0.5)
  251. tensor([[[[0.5000, 0.5000],
  252. [0.5000, 0.5000]]]])
  253. >>> x = torch.ones(2, 5, 3, 3)
  254. >>> y = torch.tensor([0.65, 0.50])
  255. >>> adjust_contrast(x, y).shape
  256. torch.Size([2, 5, 3, 3])
  257. """
  258. KORNIA_CHECK_IS_TENSOR(image, "Expected shape (*, H, W)")
  259. KORNIA_CHECK(isinstance(factor, (float, Tensor)), "Factor should be float or Tensor.")
  260. if isinstance(factor, float):
  261. # TODO: figure out how to create later a tensor without importing torch
  262. factor = torch.as_tensor(factor, device=image.device, dtype=image.dtype)
  263. elif isinstance(factor, Tensor):
  264. factor = factor.to(image.device, image.dtype)
  265. # make factor broadcastable
  266. while len(factor.shape) != len(image.shape):
  267. factor = factor[..., None]
  268. KORNIA_CHECK(any(factor >= 0), "Contrast factor must be positive.")
  269. # Apply contrast factor to each channel
  270. img_adjust: Tensor = image * factor
  271. # Truncate between pixel values
  272. if clip_output:
  273. img_adjust = img_adjust.clamp(min=0.0, max=1.0)
  274. return img_adjust
  275. def adjust_contrast_with_mean_subtraction(image: Tensor, factor: Union[float, Tensor]) -> Tensor:
  276. r"""Adjust the contrast of an image tensor by subtracting the mean over channels.
  277. .. note::
  278. this is just a convenience function to have compatibility with Pil. For exact
  279. definition of image contrast adjustment consider using :func:`kornia.enhance.adjust_gamma`.
  280. Args:
  281. image: Image to be adjusted in the shape of :math:`(*, H, W)`.
  282. factor: Contrast adjust factor per element
  283. in the batch. 0 generates a completely black image, 1 does not modify
  284. the input image while any other non-negative number modify the
  285. brightness by this factor.
  286. Return:
  287. Adjusted image in the shape of :math:`(*, H, W)`.
  288. Example:
  289. >>> import torch
  290. >>> x = torch.ones(1, 1, 2, 2)
  291. >>> adjust_contrast_with_mean_subtraction(x, 0.5)
  292. tensor([[[[1., 1.],
  293. [1., 1.]]]])
  294. >>> x = torch.ones(2, 5, 3, 3)
  295. >>> y = torch.tensor([0.65, 0.50])
  296. >>> adjust_contrast_with_mean_subtraction(x, y).shape
  297. torch.Size([2, 5, 3, 3])
  298. """
  299. KORNIA_CHECK_IS_TENSOR(image, "Expected shape (*, H, W)")
  300. KORNIA_CHECK(isinstance(factor, (float, Tensor)), "Factor should be float or Tensor.")
  301. if isinstance(factor, float):
  302. # TODO: figure out how to create later a tensor without importing torch
  303. factor = torch.as_tensor(factor, device=image.device, dtype=image.dtype)
  304. elif isinstance(factor, Tensor):
  305. factor = factor.to(image.device, image.dtype)
  306. # make factor broadcastable
  307. while len(factor.shape) != len(image.shape):
  308. factor = factor[..., None]
  309. # KORNIA_CHECK(any(factor >= 0), "Contrast factor must be positive.")
  310. if image.shape[-3] == 3:
  311. img_mean = rgb_to_grayscale(image).mean((-2, -1), True)
  312. else:
  313. img_mean = image.mean()
  314. # Apply contrast factor subtracting the mean
  315. img_adjust: Tensor = image * factor + img_mean * (1 - factor)
  316. img_adjust = img_adjust.clamp(min=0.0, max=1.0)
  317. return img_adjust
  318. def adjust_brightness(image: Tensor, factor: Union[float, Tensor], clip_output: bool = True) -> Tensor:
  319. r"""Adjust the brightness of an image tensor.
  320. .. image:: _static/img/adjust_brightness.png
  321. This implementation follows Szeliski's book convention, where brightness is defined as
  322. an `additive` operation directly to raw pixel and shift its values according the applied
  323. factor and range of the image values. Beware that other framework might use different
  324. conventions which can be difficult to reproduce exact results.
  325. The input image and factor is expected to be in the range of [0, 1].
  326. .. tip::
  327. By applying a large factor might prouce clipping or loss of image detail. We recommenda to
  328. apply small factors to avoid the mentioned issues. Ideally one must implement the adjustment
  329. of image intensity with other techniques suchs as :func:`kornia.enhance.adjust_gamma`. More
  330. details in the following link:
  331. https://scikit-image.org/docs/dev/auto_examples/color_exposure/plot_log_gamma.html#sphx-glr-auto-examples-color-exposure-plot-log-gamma-py
  332. Args:
  333. image: Image to be adjusted in the shape of :math:`(*, H, W)`.
  334. factor: Brightness adjust factor per element in the batch. It's recommended to
  335. bound the factor by [0, 1]. 0 does not modify the input image while any other
  336. number modify the brightness.
  337. clip_output: Whether to clip output to be in [0,1].
  338. Return:
  339. Adjusted tensor in the shape of :math:`(*, H, W)`.
  340. .. note::
  341. See a working example `here <https://kornia.github.io/tutorials/nbs/image_enhancement.html>`__.
  342. Example:
  343. >>> x = torch.ones(1, 1, 2, 2)
  344. >>> adjust_brightness(x, 1.)
  345. tensor([[[[1., 1.],
  346. [1., 1.]]]])
  347. >>> x = torch.ones(2, 5, 3, 3)
  348. >>> y = torch.tensor([0.25, 0.50])
  349. >>> adjust_brightness(x, y).shape
  350. torch.Size([2, 5, 3, 3])
  351. """
  352. KORNIA_CHECK_IS_TENSOR(image, "Expected shape (*, H, W)")
  353. KORNIA_CHECK(isinstance(factor, (float, Tensor)), "Factor should be float or Tensor.")
  354. # convert factor to a tensor
  355. if isinstance(factor, float):
  356. # TODO: figure out how to create later a tensor without importing torch
  357. factor = torch.as_tensor(factor, device=image.device, dtype=image.dtype)
  358. elif isinstance(factor, Tensor):
  359. factor = factor.to(image.device, image.dtype)
  360. # make factor broadcastable
  361. while len(factor.shape) != len(image.shape):
  362. factor = factor[..., None]
  363. # shift pixel values
  364. img_adjust: Tensor = image + factor
  365. # truncate between pixel values
  366. if clip_output:
  367. img_adjust = img_adjust.clamp(min=0.0, max=1.0)
  368. return img_adjust
  369. def adjust_brightness_accumulative(image: Tensor, factor: Union[float, Tensor], clip_output: bool = True) -> Tensor:
  370. r"""Adjust the brightness accumulatively of an image tensor.
  371. This implementation follows PIL convention.
  372. The input image and factor is expected to be in the range of [0, 1].
  373. Args:
  374. image: Image to be adjusted in the shape of :math:`(*, H, W)`.
  375. factor: Brightness adjust factor per element in the batch. It's recommended to
  376. bound the factor by [0, 1]. 0 does not modify the input image while any other
  377. number modify the brightness.
  378. clip_output: Whether to clip output to be in [0,1].
  379. Return:
  380. Adjusted tensor in the shape of :math:`(*, H, W)`.
  381. Example:
  382. >>> x = torch.ones(1, 1, 2, 2)
  383. >>> adjust_brightness_accumulative(x, 1.)
  384. tensor([[[[1., 1.],
  385. [1., 1.]]]])
  386. >>> x = torch.ones(2, 5, 3, 3)
  387. >>> y = torch.tensor([0.25, 0.50])
  388. >>> adjust_brightness_accumulative(x, y).shape
  389. torch.Size([2, 5, 3, 3])
  390. """
  391. KORNIA_CHECK_IS_TENSOR(image, "Expected shape (*, H, W)")
  392. KORNIA_CHECK(isinstance(factor, (float, Tensor)), "Factor should be float or Tensor.")
  393. # convert factor to a tensor
  394. if isinstance(factor, float):
  395. # TODO: figure out how to create later a tensor without importing torch
  396. factor = torch.as_tensor(factor, device=image.device, dtype=image.dtype)
  397. elif isinstance(factor, Tensor):
  398. factor = factor.to(image.device, image.dtype)
  399. # make factor broadcastable
  400. while len(factor.shape) != len(image.shape):
  401. factor = factor[..., None]
  402. # shift pixel values
  403. img_adjust: Tensor = image * factor
  404. # truncate between pixel values
  405. if clip_output:
  406. img_adjust = img_adjust.clamp(min=0.0, max=1.0)
  407. return img_adjust
  408. def adjust_sigmoid(image: Tensor, cutoff: float = 0.5, gain: float = 10, inv: bool = False) -> Tensor:
  409. """Adjust sigmoid correction on the input image tensor.
  410. The input image is expected to be in the range of [0, 1].
  411. Reference:
  412. [1]: Gustav J. Braun, "Image Lightness Rescaling Using Sigmoidal Contrast Enhancement Functions",
  413. http://markfairchild.org/PDFs/PAP07.pdf
  414. Args:
  415. image: Image to be adjusted in the shape of :math:`(*, H, W)`.
  416. cutoff: The cutoff of sigmoid function.
  417. gain: The multiplier of sigmoid function.
  418. inv: If is set to True the function will return the inverse sigmoid correction.
  419. Returns:
  420. Adjusted tensor in the shape of :math:`(*, H, W)`.
  421. Example:
  422. >>> x = torch.ones(1, 1, 2, 2)
  423. >>> adjust_sigmoid(x, gain=0)
  424. tensor([[[[0.5000, 0.5000],
  425. [0.5000, 0.5000]]]])
  426. """
  427. KORNIA_CHECK_IS_TENSOR(image, "Expected shape (*, H, W)")
  428. if inv:
  429. img_adjust = 1 - 1 / (1 + (gain * (cutoff - image)).exp())
  430. else:
  431. img_adjust = 1 / (1 + (gain * (cutoff - image)).exp())
  432. return img_adjust
  433. def adjust_log(image: Tensor, gain: float = 1, inv: bool = False, clip_output: bool = True) -> Tensor:
  434. """Adjust log correction on the input image tensor.
  435. The input image is expected to be in the range of [0, 1].
  436. Reference:
  437. [1]: http://www.ece.ucsb.edu/Faculty/Manjunath/courses/ece178W03/EnhancePart1.pdf
  438. Args:
  439. image: Image to be adjusted in the shape of :math:`(*, H, W)`.
  440. gain: The multiplier of logarithmic function.
  441. inv: If is set to True the function will return the inverse logarithmic correction.
  442. clip_output: Whether to clip the output image with range of [0, 1].
  443. Returns:
  444. Adjusted tensor in the shape of :math:`(*, H, W)`.
  445. Example:
  446. >>> x = torch.zeros(1, 1, 2, 2)
  447. >>> adjust_log(x, inv=True)
  448. tensor([[[[0., 0.],
  449. [0., 0.]]]])
  450. """
  451. KORNIA_CHECK_IS_TENSOR(image, "Expected shape (*, H, W)")
  452. if inv:
  453. img_adjust = (2**image - 1) * gain
  454. else:
  455. img_adjust = (1 + image).log2() * gain
  456. # truncate between pixel values
  457. if clip_output:
  458. img_adjust = img_adjust.clamp(min=0.0, max=1.0)
  459. return img_adjust
  460. def _solarize(input: Tensor, thresholds: Union[float, Tensor] = 0.5) -> Tensor:
  461. r"""For each pixel in the image, select the pixel if the value is less than the threshold.
  462. Otherwise, subtract 1.0 from the pixel.
  463. Args:
  464. input: image or batched images to solarize.
  465. thresholds: solarize thresholds.
  466. If int or one element tensor, input will be solarized across the whole batch.
  467. If 1-d tensor, input will be solarized element-wise, len(thresholds) == len(input).
  468. Returns:
  469. Solarized images.
  470. """
  471. if not isinstance(input, Tensor):
  472. raise TypeError(f"Input type is not a Tensor. Got {type(input)}")
  473. if not isinstance(thresholds, (float, Tensor)):
  474. raise TypeError(f"The factor should be either a float or Tensor. Got {type(thresholds)}")
  475. if isinstance(thresholds, Tensor) and len(thresholds.shape) != 0:
  476. if not (input.size(0) == len(thresholds) and len(thresholds.shape) == 1):
  477. raise AssertionError(f"thresholds must be a 1-d vector of shape ({input.size(0)},). Got {thresholds}")
  478. # TODO: I am not happy about this line, but no easy to do batch-wise operation
  479. thresholds = thresholds.to(input.device).to(input.dtype)
  480. thresholds = torch.stack([x.expand(*input.shape[-3:]) for x in thresholds])
  481. return torch.where(input < thresholds, input, 1.0 - input)
  482. def solarize(
  483. input: Tensor,
  484. thresholds: Union[float, Tensor] = 0.5,
  485. additions: Optional[Union[float, Tensor]] = None,
  486. ) -> Tensor:
  487. r"""For each pixel in the image less than threshold.
  488. .. image:: _static/img/solarize.png
  489. We add 'addition' amount to it and then clip the pixel value to be between 0 and 1.0.
  490. The value of 'addition' is between -0.5 and 0.5.
  491. Args:
  492. input: image tensor with shapes like :math:`(*, C, H, W)` to solarize.
  493. thresholds: solarize thresholds.
  494. If int or one element tensor, input will be solarized across the whole batch.
  495. If 1-d tensor, input will be solarized element-wise, len(thresholds) == len(input).
  496. additions: between -0.5 and 0.5.
  497. If None, no addition will be performed.
  498. If int or one element tensor, same addition will be added across the whole batch.
  499. If 1-d tensor, additions will be added element-wisely, len(additions) == len(input).
  500. Returns:
  501. The solarized images with shape :math:`(*, C, H, W)`.
  502. Example:
  503. >>> x = torch.rand(1, 4, 3, 3)
  504. >>> out = solarize(x, thresholds=0.5, additions=0.)
  505. >>> out.shape
  506. torch.Size([1, 4, 3, 3])
  507. >>> x = torch.rand(2, 4, 3, 3)
  508. >>> thresholds = torch.tensor([0.8, 0.5])
  509. >>> additions = torch.tensor([-0.25, 0.25])
  510. >>> solarize(x, thresholds, additions).shape
  511. torch.Size([2, 4, 3, 3])
  512. """
  513. if not isinstance(input, Tensor):
  514. raise TypeError(f"Input type is not a Tensor. Got {type(input)}")
  515. if not isinstance(thresholds, (float, Tensor)):
  516. raise TypeError(f"The factor should be either a float or Tensor. Got {type(thresholds)}")
  517. if isinstance(thresholds, float):
  518. thresholds = torch.as_tensor(thresholds)
  519. if additions is not None:
  520. if not isinstance(additions, (float, Tensor)):
  521. raise TypeError(f"The factor should be either a float or Tensor. Got {type(additions)}")
  522. if isinstance(additions, float):
  523. additions = torch.as_tensor(additions)
  524. if not torch.all((additions < 0.5) * (additions > -0.5)):
  525. raise AssertionError(f"The value of 'addition' is between -0.5 and 0.5. Got {additions}.")
  526. if isinstance(additions, Tensor) and len(additions.shape) != 0:
  527. if not (input.size(0) == len(additions) and len(additions.shape) == 1):
  528. raise AssertionError(f"additions must be a 1-d vector of shape ({input.size(0)},). Got {additions}")
  529. # TODO: I am not happy about this line, but no easy to do batch-wise operation
  530. additions = additions.to(input.device).to(input.dtype)
  531. additions = torch.stack([x.expand(*input.shape[-3:]) for x in additions])
  532. input = input + additions
  533. input = input.clamp(0.0, 1.0)
  534. return _solarize(input, thresholds)
  535. @perform_keep_shape_image
  536. def posterize(input: Tensor, bits: Union[int, Tensor]) -> Tensor:
  537. r"""Reduce the number of bits for each color channel.
  538. .. image:: _static/img/posterize.png
  539. Non-differentiable function, ``torch.uint8`` involved.
  540. Args:
  541. input: image tensor with shape :math:`(*, C, H, W)` to posterize.
  542. bits: number of high bits. Must be in range [0, 8].
  543. If int or one element tensor, input will be posterized by this bits.
  544. If 1-d tensor, input will be posterized element-wisely, len(bits) == input.shape[-3].
  545. If n-d tensor, input will be posterized element-channel-wisely, bits.shape == input.shape[:len(bits.shape)]
  546. Returns:
  547. Image with reduced color channels with shape :math:`(*, C, H, W)`.
  548. Example:
  549. >>> x = torch.rand(1, 6, 3, 3)
  550. >>> out = posterize(x, bits=8)
  551. >>> torch.testing.assert_close(x, out)
  552. >>> x = torch.rand(2, 6, 3, 3)
  553. >>> bits = torch.tensor([4, 2])
  554. >>> posterize(x, bits).shape
  555. torch.Size([2, 6, 3, 3])
  556. """
  557. if not isinstance(input, Tensor):
  558. raise TypeError(f"Input type is not a Tensor. Got {type(input)}")
  559. if not isinstance(bits, (int, Tensor)):
  560. raise TypeError(f"bits type is not an int or Tensor. Got {type(bits)}")
  561. if isinstance(bits, int):
  562. bits = torch.as_tensor(bits)
  563. # TODO: find a better way to check boundaries on tensors
  564. # if not torch.all((bits >= 0) * (bits <= 8)) and bits.dtype == torch.int:
  565. # raise ValueError(f"bits must be integers within range [0, 8]. Got {bits}.")
  566. # TODO: Make a differentiable version
  567. # Current version:
  568. # Ref: https://github.com/open-mmlab/mmcv/pull/132/files#diff-309c9320c7f71bedffe89a70ccff7f3bR19
  569. # Ref: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L222
  570. # Potential approach: implementing kornia.LUT with floating points
  571. # https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/functional.py#L472
  572. def _left_shift(input: Tensor, shift: Tensor) -> Tensor:
  573. return ((input * 255).to(torch.uint8) * (2**shift)).to(input.dtype) / 255.0
  574. def _right_shift(input: Tensor, shift: Tensor) -> Tensor:
  575. return (input * 255).to(torch.uint8) / (2**shift).to(input.dtype) / 255.0
  576. def _posterize_one(input: Tensor, bits: Tensor) -> Tensor:
  577. # Single bits value condition
  578. if bits == 0:
  579. return torch.zeros_like(input)
  580. if bits == 8:
  581. return input.clone()
  582. bits = 8 - bits
  583. return _left_shift(_right_shift(input, bits), bits)
  584. if len(bits.shape) == 0 or (len(bits.shape) == 1 and len(bits) == 1):
  585. return _posterize_one(input, bits)
  586. res = []
  587. if len(bits.shape) == 1:
  588. if bits.shape[0] != input.shape[0]:
  589. raise AssertionError(
  590. f"Batch size must be equal between bits and input. Got {bits.shape[0]}, {input.shape[0]}."
  591. )
  592. for i in range(input.shape[0]):
  593. res.append(_posterize_one(input[i], bits[i]))
  594. return torch.stack(res, dim=0)
  595. if bits.shape != input.shape[: len(bits.shape)]:
  596. raise AssertionError(
  597. "Batch and channel must be equal between bits and input. "
  598. f"Got {bits.shape}, {input.shape[: len(bits.shape)]}."
  599. )
  600. _input = input.view(-1, *input.shape[len(bits.shape) :])
  601. _bits = bits.flatten()
  602. for i in range(input.shape[0]):
  603. res.append(_posterize_one(_input[i], _bits[i]))
  604. return torch.stack(res, dim=0).reshape(*input.shape)
  605. @perform_keep_shape_image
  606. def sharpness(input: Tensor, factor: Union[float, Tensor]) -> Tensor:
  607. r"""Apply sharpness to the input tensor.
  608. .. image:: _static/img/sharpness.png
  609. Implemented Sharpness function from PIL using torch ops. This implementation refers to:
  610. https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L326
  611. Args:
  612. input: image tensor with shape :math:`(*, C, H, W)` to sharpen.
  613. factor: factor of sharpness strength. Must be above 0.
  614. If float or one element tensor, input will be sharpened by the same factor across the whole batch.
  615. If 1-d tensor, input will be sharpened element-wisely, len(factor) == len(input).
  616. Returns:
  617. Sharpened image or images with shape :math:`(*, C, H, W)`.
  618. Example:
  619. >>> x = torch.rand(1, 1, 5, 5)
  620. >>> sharpness(x, 0.5).shape
  621. torch.Size([1, 1, 5, 5])
  622. """
  623. if not isinstance(factor, Tensor):
  624. factor = torch.as_tensor(factor, device=input.device, dtype=input.dtype)
  625. if len(factor.size()) != 0 and factor.shape != torch.Size([input.size(0)]):
  626. raise AssertionError(
  627. "Input batch size shall match with factor size if factor is not a 0-dim tensor. "
  628. f"Got {input.size(0)} and {factor.shape}"
  629. )
  630. kernel = (
  631. torch.as_tensor([[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=input.dtype, device=input.device)
  632. .view(1, 1, 3, 3)
  633. .repeat(input.size(1), 1, 1, 1)
  634. / 13
  635. )
  636. # This shall be equivalent to depthwise conv2d:
  637. # Ref: https://discuss.pytorch.org/t/depthwise-and-separable-convolutions-in-pytorch/7315/2
  638. degenerate = torch.nn.functional.conv2d(input, kernel, bias=None, stride=1, groups=input.size(1))
  639. degenerate = torch.clamp(degenerate, 0.0, 1.0)
  640. # For the borders of the resulting image, fill in the values of the original image.
  641. mask = torch.ones_like(degenerate)
  642. padded_mask = torch.nn.functional.pad(mask, [1, 1, 1, 1])
  643. padded_degenerate = torch.nn.functional.pad(degenerate, [1, 1, 1, 1])
  644. result = torch.where(padded_mask == 1, padded_degenerate, input)
  645. if len(factor.size()) == 0:
  646. return _blend_one(result, input, factor)
  647. return torch.stack([_blend_one(result[i], input[i], factor[i]) for i in range(len(factor))])
  648. def _blend_one(input1: Tensor, input2: Tensor, factor: Tensor) -> Tensor:
  649. r"""Blend two images into one.
  650. Args:
  651. input1: image tensor with shapes like :math:`(H, W)` or :math:`(D, H, W)`.
  652. input2: image tensor with shapes like :math:`(H, W)` or :math:`(D, H, W)`.
  653. factor: factor 0-dim tensor.
  654. Returns:
  655. : image tensor with the batch in the zero position.
  656. """
  657. if not isinstance(input1, Tensor):
  658. raise AssertionError(f"`input1` must be a tensor. Got {input1}.")
  659. if not isinstance(input2, Tensor):
  660. raise AssertionError(f"`input1` must be a tensor. Got {input2}.")
  661. if isinstance(factor, Tensor) and len(factor.size()) != 0:
  662. raise AssertionError(f"Factor shall be a float or single element tensor. Got {factor}.")
  663. if factor == 0.0:
  664. return input1
  665. if factor == 1.0:
  666. return input2
  667. diff = (input2 - input1) * factor
  668. res = input1 + diff
  669. if factor > 0.0 and factor < 1.0:
  670. return res
  671. return torch.clamp(res, 0, 1)
  672. def _build_lut(histo: Tensor, step: Tensor) -> Tensor:
  673. # Compute the cumulative sum, shifting by step // 2
  674. # and then normalization by step.
  675. step_trunc = torch.div(step, 2, rounding_mode="trunc")
  676. lut = torch.div(torch.cumsum(histo, 0) + step_trunc, step, rounding_mode="trunc")
  677. # Shift lut, prepending with 0.
  678. lut = torch.cat([torch.zeros(1, device=lut.device, dtype=lut.dtype), lut[:-1]])
  679. # Clip the counts to be in range. This is done
  680. # in the C code for image.point.
  681. return torch.clamp(lut, 0, 255)
  682. # Code taken from: https://github.com/pytorch/vision/pull/796
  683. def _scale_channel(im: Tensor) -> Tensor:
  684. r"""Scale the data in the channel to implement equalize.
  685. Args:
  686. im: image tensor with shapes like :math:`(H, W)` or :math:`(D, H, W)`.
  687. Returns:
  688. image tensor with the batch in the zero position.
  689. """
  690. min_ = im.min()
  691. max_ = im.max()
  692. if min_.item() < 0.0 and not torch.isclose(min_, torch.as_tensor(0.0, dtype=min_.dtype)):
  693. raise ValueError(f"Values in the input tensor must greater or equal to 0.0. Found {min_.item()}.")
  694. if max_.item() > 1.0 and not torch.isclose(max_, torch.as_tensor(1.0, dtype=max_.dtype)):
  695. raise ValueError(f"Values in the input tensor must lower or equal to 1.0. Found {max_.item()}.")
  696. ndims = len(im.shape)
  697. if ndims not in (2, 3):
  698. raise TypeError(f"Input tensor must have 2 or 3 dimensions. Found {ndims}.")
  699. im = im * 255.0
  700. # Compute the histogram of the image channel.
  701. histo = _torch_histc_cast(im, bins=256, min=0, max=255)
  702. # For the purposes of computing the step, filter out the nonzeros.
  703. nonzero_histo = torch.reshape(histo[histo != 0], [-1])
  704. step = torch.div(torch.sum(nonzero_histo) - nonzero_histo[-1], 255, rounding_mode="trunc")
  705. # If step is zero, return the original image. Otherwise, build
  706. # lut from the full histogram and step and then index from it.
  707. if step == 0:
  708. result = im
  709. else:
  710. # can't index using 2d index. Have to flatten and then reshape
  711. result = torch.gather(_build_lut(histo, step), 0, im.flatten().long())
  712. result = result.reshape_as(im)
  713. return result / 255.0
  714. @perform_keep_shape_image
  715. def equalize(input: Tensor) -> Tensor:
  716. r"""Apply equalize on the input tensor.
  717. .. image:: _static/img/equalize.png
  718. Implements Equalize function from PIL using PyTorch ops based on uint8 format:
  719. https://github.com/tensorflow/tpu/blob/5f71c12a020403f863434e96982a840578fdd127/models/official/efficientnet/autoaugment.py#L355
  720. Args:
  721. input: image tensor to equalize with shape :math:`(*, C, H, W)`.
  722. Returns:
  723. Equalized image tensor with shape :math:`(*, C, H, W)`.
  724. Example:
  725. >>> x = torch.rand(1, 2, 3, 3)
  726. >>> equalize(x).shape
  727. torch.Size([1, 2, 3, 3])
  728. """
  729. res = []
  730. for image in input:
  731. # Assumes RGB for now. Scales each channel independently
  732. # and then stacks the result.
  733. scaled_image = torch.stack([_scale_channel(image[i, :, :]) for i in range(len(image))])
  734. res.append(scaled_image)
  735. return torch.stack(res)
  736. @perform_keep_shape_video
  737. def equalize3d(input: Tensor) -> Tensor:
  738. r"""Equalize the values for a 3D volumetric tensor.
  739. Implements Equalize function for a sequence of images using PyTorch ops based on uint8 format:
  740. https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L352
  741. Args:
  742. input: image tensor with shape :math:`(*, C, D, H, W)` to equalize.
  743. Returns:
  744. Equalized volume with shape :math:`(B, C, D, H, W)`.
  745. """
  746. res = []
  747. for volume in input:
  748. # Assumes RGB for now. Scales each channel independently
  749. # and then stacks the result.
  750. scaled_input = torch.stack([_scale_channel(volume[i, :, :, :]) for i in range(len(volume))])
  751. res.append(scaled_input)
  752. return torch.stack(res)
  753. def invert(image: Tensor, max_val: Optional[Tensor] = None) -> Tensor:
  754. r"""Invert the values of an input image tensor by its maximum value.
  755. .. image:: _static/img/invert.png
  756. Args:
  757. image: The input tensor to invert with an arbitatry shape.
  758. max_val: The expected maximum value in the input tensor. The shape has to
  759. according to the input tensor shape, or at least has to work with broadcasting.
  760. Example:
  761. >>> img = torch.rand(1, 2, 4, 4)
  762. >>> invert(img).shape
  763. torch.Size([1, 2, 4, 4])
  764. >>> img = 255. * torch.rand(1, 2, 3, 4, 4)
  765. >>> invert(img, torch.as_tensor(255.)).shape
  766. torch.Size([1, 2, 3, 4, 4])
  767. >>> img = torch.rand(1, 3, 4, 4)
  768. >>> invert(img, torch.as_tensor([[[[1.]]]])).shape
  769. torch.Size([1, 3, 4, 4])
  770. """
  771. if not isinstance(image, Tensor):
  772. raise AssertionError(f"Input is not a Tensor. Got: {type(input)}")
  773. if max_val is None:
  774. _max_val = tensor([1.0])
  775. else:
  776. _max_val = max_val
  777. if not isinstance(_max_val, Tensor):
  778. raise AssertionError(f"max_val is not a Tensor. Got: {type(_max_val)}")
  779. return _max_val.to(image) - image
  780. class AdjustSaturation(Module):
  781. r"""Adjust color saturation of an image.
  782. The input image is expected to be an RGB image in the range of [0, 1].
  783. Args:
  784. saturation_factor: How much to adjust the saturation. 0 will give a black
  785. and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2.
  786. saturation_mode: The mode to adjust saturation.
  787. Shape:
  788. - Input: Image/Tensor to be adjusted in the shape of :math:`(*, 3, H, W)`.
  789. - Output: Adjusted image in the shape of :math:`(*, 3, H, W)`.
  790. Example:
  791. >>> x = torch.ones(1, 3, 3, 3)
  792. >>> AdjustSaturation(2.)(x)
  793. tensor([[[[1., 1., 1.],
  794. [1., 1., 1.],
  795. [1., 1., 1.]],
  796. <BLANKLINE>
  797. [[1., 1., 1.],
  798. [1., 1., 1.],
  799. [1., 1., 1.]],
  800. <BLANKLINE>
  801. [[1., 1., 1.],
  802. [1., 1., 1.],
  803. [1., 1., 1.]]]])
  804. >>> x = torch.ones(2, 3, 3, 3)
  805. >>> y = torch.ones(2)
  806. >>> out = AdjustSaturation(y)(x)
  807. >>> torch.nn.functional.mse_loss(x, out)
  808. tensor(0.)
  809. """
  810. ONNX_DEFAULT_INPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]
  811. ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]
  812. def __init__(self, saturation_factor: Union[float, Tensor]) -> None:
  813. super().__init__()
  814. self.saturation_factor: Union[float, Tensor] = saturation_factor
  815. def forward(self, input: Tensor) -> Tensor:
  816. return adjust_saturation(input, self.saturation_factor)
  817. class AdjustSaturationWithGraySubtraction(Module):
  818. r"""Adjust color saturation of an image.
  819. This implementation aligns PIL. Hence, the output is close to TorchVision.
  820. The input image is expected to be in the range of [0, 1].
  821. The input image is expected to be an RGB or gray image in the range of [0, 1].
  822. Args:
  823. saturation_factor: How much to adjust the saturation. 0 will give a black
  824. and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2.
  825. saturation_mode: The mode to adjust saturation.
  826. Shape:
  827. - Input: Image/Tensor to be adjusted in the shape of :math:`(*, 3, H, W)`.
  828. - Output: Adjusted image in the shape of :math:`(*, 3, H, W)`.
  829. Example:
  830. >>> x = torch.ones(1, 3, 3, 3)
  831. >>> AdjustSaturationWithGraySubtraction(2.)(x)
  832. tensor([[[[1., 1., 1.],
  833. [1., 1., 1.],
  834. [1., 1., 1.]],
  835. <BLANKLINE>
  836. [[1., 1., 1.],
  837. [1., 1., 1.],
  838. [1., 1., 1.]],
  839. <BLANKLINE>
  840. [[1., 1., 1.],
  841. [1., 1., 1.],
  842. [1., 1., 1.]]]])
  843. >>> x = torch.ones(2, 3, 3, 3)
  844. >>> y = torch.ones(2)
  845. >>> out = AdjustSaturationWithGraySubtraction(y)(x)
  846. >>> torch.nn.functional.mse_loss(x, out)
  847. tensor(0.)
  848. """
  849. ONNX_DEFAULT_INPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]
  850. ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]
  851. def __init__(self, saturation_factor: Union[float, Tensor]) -> None:
  852. super().__init__()
  853. self.saturation_factor: Union[float, Tensor] = saturation_factor
  854. def forward(self, input: Tensor) -> Tensor:
  855. return adjust_saturation_with_gray_subtraction(input, self.saturation_factor)
  856. class AdjustHue(Module):
  857. r"""Adjust hue of an image.
  858. This implementation aligns PIL. Hence, the output is close to TorchVision.
  859. The input image is expected to be in the range of [0, 1].
  860. The input image is expected to be an RGB image in the range of [0, 1].
  861. Args:
  862. hue_factor: How much to shift the hue channel. Should be in [-PI, PI]. PI
  863. and -PI give complete reversal of hue channel in HSV space in positive and negative
  864. direction respectively. 0 means no shift. Therefore, both -PI and PI will give an
  865. image with complementary colors while 0 gives the original image.
  866. Shape:
  867. - Input: Image/Tensor to be adjusted in the shape of :math:`(*, 3, H, W)`.
  868. - Output: Adjusted image in the shape of :math:`(*, 3, H, W)`.
  869. Example:
  870. >>> x = torch.ones(1, 3, 3, 3)
  871. >>> AdjustHue(3.141516)(x)
  872. tensor([[[[1., 1., 1.],
  873. [1., 1., 1.],
  874. [1., 1., 1.]],
  875. <BLANKLINE>
  876. [[1., 1., 1.],
  877. [1., 1., 1.],
  878. [1., 1., 1.]],
  879. <BLANKLINE>
  880. [[1., 1., 1.],
  881. [1., 1., 1.],
  882. [1., 1., 1.]]]])
  883. >>> x = torch.ones(2, 3, 3, 3)
  884. >>> y = torch.ones(2) * 3.141516
  885. >>> AdjustHue(y)(x).shape
  886. torch.Size([2, 3, 3, 3])
  887. """
  888. ONNX_DEFAULT_INPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]
  889. ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]
  890. def __init__(self, hue_factor: Union[float, Tensor]) -> None:
  891. super().__init__()
  892. self.hue_factor: Union[float, Tensor] = hue_factor
  893. def forward(self, input: Tensor) -> Tensor:
  894. return adjust_hue(input, self.hue_factor)
  895. class AdjustGamma(Module):
  896. r"""Perform gamma correction on an image.
  897. The input image is expected to be in the range of [0, 1].
  898. Args:
  899. gamma: Non negative real number, same as y\gammay in the equation.
  900. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make
  901. dark regions lighter.
  902. gain: The constant multiplier.
  903. Shape:
  904. - Input: Image to be adjusted in the shape of :math:`(*, N)`.
  905. - Output: Adjusted image in the shape of :math:`(*, N)`.
  906. Example:
  907. >>> x = torch.ones(1, 1, 3, 3)
  908. >>> AdjustGamma(1.0, 2.0)(x)
  909. tensor([[[[1., 1., 1.],
  910. [1., 1., 1.],
  911. [1., 1., 1.]]]])
  912. >>> x = torch.ones(2, 5, 3, 3)
  913. >>> y1 = torch.ones(2) * 1.0
  914. >>> y2 = torch.ones(2) * 2.0
  915. >>> AdjustGamma(y1, y2)(x).shape
  916. torch.Size([2, 5, 3, 3])
  917. """
  918. def __init__(self, gamma: Union[float, Tensor], gain: Union[float, Tensor] = 1.0) -> None:
  919. super().__init__()
  920. self.gamma: Union[float, Tensor] = gamma
  921. self.gain: Union[float, Tensor] = gain
  922. def forward(self, input: Tensor) -> Tensor:
  923. return adjust_gamma(input, self.gamma, self.gain)
  924. class AdjustContrast(Module):
  925. r"""Adjust Contrast of an image.
  926. This implementation aligns OpenCV, not PIL. Hence, the output differs from TorchVision.
  927. The input image is expected to be in the range of [0, 1].
  928. Args:
  929. contrast_factor: Contrast adjust factor per element
  930. in the batch. 0 generates a completely black image, 1 does not modify
  931. the input image while any other non-negative number modify the
  932. brightness by this factor.
  933. Shape:
  934. - Input: Image/Input to be adjusted in the shape of :math:`(*, N)`.
  935. - Output: Adjusted image in the shape of :math:`(*, N)`.
  936. Example:
  937. >>> x = torch.ones(1, 1, 3, 3)
  938. >>> AdjustContrast(0.5)(x)
  939. tensor([[[[0.5000, 0.5000, 0.5000],
  940. [0.5000, 0.5000, 0.5000],
  941. [0.5000, 0.5000, 0.5000]]]])
  942. >>> x = torch.ones(2, 5, 3, 3)
  943. >>> y = torch.ones(2)
  944. >>> AdjustContrast(y)(x).shape
  945. torch.Size([2, 5, 3, 3])
  946. """
  947. def __init__(self, contrast_factor: Union[float, Tensor]) -> None:
  948. super().__init__()
  949. self.contrast_factor: Union[float, Tensor] = contrast_factor
  950. def forward(self, input: Tensor) -> Tensor:
  951. return adjust_contrast(input, self.contrast_factor)
  952. class AdjustContrastWithMeanSubtraction(Module):
  953. r"""Adjust Contrast of an image.
  954. This implementation aligns PIL. Hence, the output is close to TorchVision.
  955. The input image is expected to be in the range of [0, 1].
  956. Args:
  957. contrast_factor: Contrast adjust factor per element
  958. in the batch by subtracting its mean grayscaled version.
  959. 0 generates a completely black image, 1 does not modify
  960. the input image while any other non-negative number modify the
  961. brightness by this factor.
  962. Shape:
  963. - Input: Image/Input to be adjusted in the shape of :math:`(*, N)`.
  964. - Output: Adjusted image in the shape of :math:`(*, N)`.
  965. Example:
  966. >>> x = torch.ones(1, 1, 3, 3)
  967. >>> AdjustContrastWithMeanSubtraction(0.5)(x)
  968. tensor([[[[1., 1., 1.],
  969. [1., 1., 1.],
  970. [1., 1., 1.]]]])
  971. >>> x = torch.ones(2, 5, 3, 3)
  972. >>> y = torch.ones(2)
  973. >>> AdjustContrastWithMeanSubtraction(y)(x).shape
  974. torch.Size([2, 5, 3, 3])
  975. """
  976. def __init__(self, contrast_factor: Union[float, Tensor]) -> None:
  977. super().__init__()
  978. self.contrast_factor: Union[float, Tensor] = contrast_factor
  979. def forward(self, input: Tensor) -> Tensor:
  980. return adjust_contrast_with_mean_subtraction(input, self.contrast_factor)
  981. class AdjustBrightness(Module):
  982. r"""Adjust Brightness of an image.
  983. This implementation aligns OpenCV, not PIL. Hence, the output differs from TorchVision.
  984. The input image is expected to be in the range of [0, 1].
  985. Args:
  986. brightness_factor: Brightness adjust factor per element
  987. in the batch. 0 does not modify the input image while any other number modify the
  988. brightness.
  989. Shape:
  990. - Input: Image/Input to be adjusted in the shape of :math:`(*, N)`.
  991. - Output: Adjusted image in the shape of :math:`(*, N)`.
  992. Example:
  993. >>> x = torch.ones(1, 1, 3, 3)
  994. >>> AdjustBrightness(1.)(x)
  995. tensor([[[[1., 1., 1.],
  996. [1., 1., 1.],
  997. [1., 1., 1.]]]])
  998. >>> x = torch.ones(2, 5, 3, 3)
  999. >>> y = torch.ones(2)
  1000. >>> AdjustBrightness(y)(x).shape
  1001. torch.Size([2, 5, 3, 3])
  1002. """
  1003. def __init__(self, brightness_factor: Union[float, Tensor]) -> None:
  1004. super().__init__()
  1005. self.brightness_factor: Union[float, Tensor] = brightness_factor
  1006. def forward(self, input: Tensor) -> Tensor:
  1007. return adjust_brightness(input, self.brightness_factor)
  1008. class AdjustSigmoid(Module):
  1009. r"""Adjust the contrast of an image tensor or performs sigmoid correction on the input image tensor.
  1010. The input image is expected to be in the range of [0, 1].
  1011. Reference:
  1012. [1]: Gustav J. Braun, "Image Lightness Rescaling Using Sigmoidal Contrast Enhancement Functions",
  1013. http://markfairchild.org/PDFs/PAP07.pdf
  1014. Args:
  1015. image: Image to be adjusted in the shape of :math:`(*, H, W)`.
  1016. cutoff: The cutoff of sigmoid function.
  1017. gain: The multiplier of sigmoid function.
  1018. inv: If is set to True the function will return the negative sigmoid correction.
  1019. Example:
  1020. >>> x = torch.ones(1, 1, 2, 2)
  1021. >>> AdjustSigmoid(gain=0)(x)
  1022. tensor([[[[0.5000, 0.5000],
  1023. [0.5000, 0.5000]]]])
  1024. """
  1025. def __init__(self, cutoff: float = 0.5, gain: float = 10, inv: bool = False) -> None:
  1026. super().__init__()
  1027. self.cutoff: float = cutoff
  1028. self.gain: float = gain
  1029. self.inv: bool = inv
  1030. def forward(self, image: Tensor) -> Tensor:
  1031. return adjust_sigmoid(image, cutoff=self.cutoff, gain=self.gain, inv=self.inv)
  1032. class AdjustLog(Module):
  1033. """Adjust log correction on the input image tensor.
  1034. The input image is expected to be in the range of [0, 1].
  1035. Reference:
  1036. [1]: http://www.ece.ucsb.edu/Faculty/Manjunath/courses/ece178W03/EnhancePart1.pdf
  1037. Args:
  1038. image: Image to be adjusted in the shape of :math:`(*, H, W)`.
  1039. gain: The multiplier of logarithmic function.
  1040. inv: If is set to True the function will return the inverse logarithmic correction.
  1041. clip_output: Whether to clip the output image with range of [0, 1].
  1042. Example:
  1043. >>> x = torch.zeros(1, 1, 2, 2)
  1044. >>> AdjustLog(inv=True)(x)
  1045. tensor([[[[0., 0.],
  1046. [0., 0.]]]])
  1047. """
  1048. def __init__(self, gain: float = 1, inv: bool = False, clip_output: bool = True) -> None:
  1049. super().__init__()
  1050. self.gain: float = gain
  1051. self.inv: bool = inv
  1052. self.clip_output: bool = clip_output
  1053. def forward(self, image: Tensor) -> Tensor:
  1054. return adjust_log(image, gain=self.gain, inv=self.inv, clip_output=self.clip_output)
  1055. class AdjustBrightnessAccumulative(Module):
  1056. r"""Adjust Brightness of an image accumulatively.
  1057. This implementation aligns PIL. Hence, the output is close to TorchVision.
  1058. The input image is expected to be in the range of [0, 1].
  1059. Args:
  1060. brightness_factor: Brightness adjust factor per element
  1061. in the batch. 0 does not modify the input image while any other number modify the
  1062. brightness.
  1063. Shape:
  1064. - Input: Image/Input to be adjusted in the shape of :math:`(*, N)`.
  1065. - Output: Adjusted image in the shape of :math:`(*, N)`.
  1066. Example:
  1067. >>> x = torch.ones(1, 1, 3, 3)
  1068. >>> AdjustBrightnessAccumulative(1.)(x)
  1069. tensor([[[[1., 1., 1.],
  1070. [1., 1., 1.],
  1071. [1., 1., 1.]]]])
  1072. >>> x = torch.ones(2, 5, 3, 3)
  1073. >>> y = torch.ones(2)
  1074. >>> AdjustBrightnessAccumulative(y)(x).shape
  1075. torch.Size([2, 5, 3, 3])
  1076. """
  1077. def __init__(self, brightness_factor: Union[float, Tensor]) -> None:
  1078. super().__init__()
  1079. self.brightness_factor: Union[float, Tensor] = brightness_factor
  1080. def forward(self, input: Tensor) -> Tensor:
  1081. return adjust_brightness_accumulative(input, self.brightness_factor)
  1082. class Invert(Module):
  1083. r"""Invert the values of an input tensor by its maximum value.
  1084. Args:
  1085. input: The input tensor to invert with an arbitatry shape.
  1086. max_val: The expected maximum value in the input tensor. The shape has to
  1087. according to the input tensor shape, or at least has to work with broadcasting. Default: 1.0.
  1088. Example:
  1089. >>> img = torch.rand(1, 2, 4, 4)
  1090. >>> Invert()(img).shape
  1091. torch.Size([1, 2, 4, 4])
  1092. >>> img = 255. * torch.rand(1, 2, 3, 4, 4)
  1093. >>> Invert(torch.as_tensor(255.))(img).shape
  1094. torch.Size([1, 2, 3, 4, 4])
  1095. >>> img = torch.rand(1, 3, 4, 4)
  1096. >>> Invert(torch.as_tensor([[[[1.]]]]))(img).shape
  1097. torch.Size([1, 3, 4, 4])
  1098. """
  1099. def __init__(self, max_val: Optional[Tensor] = None) -> None:
  1100. super().__init__()
  1101. if max_val is None:
  1102. max_val = torch.tensor(1.0)
  1103. if not isinstance(max_val, Parameter):
  1104. self.register_buffer("max_val", max_val)
  1105. else:
  1106. self.max_val = max_val
  1107. def forward(self, input: Tensor) -> Tensor:
  1108. return invert(input, self.max_val)