functional.py 66 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586
  1. import math
  2. import numbers
  3. import sys
  4. import warnings
  5. from enum import Enum
  6. from typing import Any, Optional, Union
  7. import numpy as np
  8. import torch
  9. from PIL import Image
  10. from PIL.Image import Image as PILImage
  11. from torch import Tensor
  12. try:
  13. import accimage
  14. except ImportError:
  15. accimage = None
  16. from ..utils import _Image_fromarray, _log_api_usage_once
  17. from . import _functional_pil as F_pil, _functional_tensor as F_t
  18. class InterpolationMode(Enum):
  19. """Interpolation modes
  20. Available interpolation methods are ``nearest``, ``nearest-exact``, ``bilinear``, ``bicubic``, ``box``, ``hamming``,
  21. and ``lanczos``.
  22. """
  23. NEAREST = "nearest"
  24. NEAREST_EXACT = "nearest-exact"
  25. BILINEAR = "bilinear"
  26. BICUBIC = "bicubic"
  27. # For PIL compatibility
  28. BOX = "box"
  29. HAMMING = "hamming"
  30. LANCZOS = "lanczos"
  31. # TODO: Once torchscript supports Enums with staticmethod
  32. # this can be put into InterpolationMode as staticmethod
  33. def _interpolation_modes_from_int(i: int) -> InterpolationMode:
  34. inverse_modes_mapping = {
  35. 0: InterpolationMode.NEAREST,
  36. 2: InterpolationMode.BILINEAR,
  37. 3: InterpolationMode.BICUBIC,
  38. 4: InterpolationMode.BOX,
  39. 5: InterpolationMode.HAMMING,
  40. 1: InterpolationMode.LANCZOS,
  41. }
  42. return inverse_modes_mapping[i]
  43. pil_modes_mapping = {
  44. InterpolationMode.NEAREST: 0,
  45. InterpolationMode.BILINEAR: 2,
  46. InterpolationMode.BICUBIC: 3,
  47. InterpolationMode.NEAREST_EXACT: 0,
  48. InterpolationMode.BOX: 4,
  49. InterpolationMode.HAMMING: 5,
  50. InterpolationMode.LANCZOS: 1,
  51. }
  52. _is_pil_image = F_pil._is_pil_image
  53. def get_dimensions(img: Tensor) -> list[int]:
  54. """Returns the dimensions of an image as [channels, height, width].
  55. Args:
  56. img (PIL Image or Tensor): The image to be checked.
  57. Returns:
  58. List[int]: The image dimensions.
  59. """
  60. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  61. _log_api_usage_once(get_dimensions)
  62. if isinstance(img, torch.Tensor):
  63. return F_t.get_dimensions(img)
  64. return F_pil.get_dimensions(img)
  65. def get_image_size(img: Tensor) -> list[int]:
  66. """Returns the size of an image as [width, height].
  67. Args:
  68. img (PIL Image or Tensor): The image to be checked.
  69. Returns:
  70. List[int]: The image size.
  71. """
  72. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  73. _log_api_usage_once(get_image_size)
  74. if isinstance(img, torch.Tensor):
  75. return F_t.get_image_size(img)
  76. return F_pil.get_image_size(img)
  77. def get_image_num_channels(img: Tensor) -> int:
  78. """Returns the number of channels of an image.
  79. Args:
  80. img (PIL Image or Tensor): The image to be checked.
  81. Returns:
  82. int: The number of channels.
  83. """
  84. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  85. _log_api_usage_once(get_image_num_channels)
  86. if isinstance(img, torch.Tensor):
  87. return F_t.get_image_num_channels(img)
  88. return F_pil.get_image_num_channels(img)
  89. @torch.jit.unused
  90. def _is_numpy(img: Any) -> bool:
  91. return isinstance(img, np.ndarray)
  92. @torch.jit.unused
  93. def _is_numpy_image(img: Any) -> bool:
  94. return img.ndim in {2, 3}
  95. def to_tensor(pic: Union[PILImage, np.ndarray]) -> Tensor:
  96. """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
  97. This function does not support torchscript.
  98. See :class:`~torchvision.transforms.ToTensor` for more details.
  99. Args:
  100. pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
  101. Returns:
  102. Tensor: Converted image.
  103. """
  104. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  105. _log_api_usage_once(to_tensor)
  106. if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
  107. raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
  108. if _is_numpy(pic) and not _is_numpy_image(pic):
  109. raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
  110. default_float_dtype = torch.get_default_dtype()
  111. if isinstance(pic, np.ndarray):
  112. # handle numpy array
  113. if pic.ndim == 2:
  114. pic = pic[:, :, None]
  115. img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
  116. # backward compatibility
  117. if isinstance(img, torch.ByteTensor):
  118. return img.to(dtype=default_float_dtype).div(255)
  119. else:
  120. return img
  121. if accimage is not None and isinstance(pic, accimage.Image):
  122. nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
  123. pic.copyto(nppic)
  124. return torch.from_numpy(nppic).to(dtype=default_float_dtype)
  125. # handle PIL Image
  126. mode_to_nptype = {"I": np.int32, "I;16" if sys.byteorder == "little" else "I;16B": np.int16, "F": np.float32}
  127. img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
  128. if pic.mode == "1":
  129. img = 255 * img
  130. img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
  131. # put it from HWC to CHW format
  132. img = img.permute((2, 0, 1)).contiguous()
  133. if isinstance(img, torch.ByteTensor):
  134. return img.to(dtype=default_float_dtype).div(255)
  135. else:
  136. return img
  137. def pil_to_tensor(pic: Any) -> Tensor:
  138. """Convert a ``PIL Image`` to a tensor of the same type.
  139. This function does not support torchscript.
  140. See :class:`~torchvision.transforms.PILToTensor` for more details.
  141. .. note::
  142. A deep copy of the underlying array is performed.
  143. Args:
  144. pic (PIL Image): Image to be converted to tensor.
  145. Returns:
  146. Tensor: Converted image.
  147. """
  148. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  149. _log_api_usage_once(pil_to_tensor)
  150. if not F_pil._is_pil_image(pic):
  151. raise TypeError(f"pic should be PIL Image. Got {type(pic)}")
  152. if accimage is not None and isinstance(pic, accimage.Image):
  153. # accimage format is always uint8 internally, so always return uint8 here
  154. nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8)
  155. pic.copyto(nppic)
  156. return torch.as_tensor(nppic)
  157. # handle PIL Image
  158. img = torch.as_tensor(np.array(pic, copy=True))
  159. img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
  160. # put it from HWC to CHW format
  161. img = img.permute((2, 0, 1))
  162. return img
  163. def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
  164. """Convert a tensor image to the given ``dtype`` and scale the values accordingly
  165. This function does not support PIL Image.
  166. Args:
  167. image (torch.Tensor): Image to be converted
  168. dtype (torch.dtype): Desired data type of the output
  169. Returns:
  170. Tensor: Converted image
  171. .. note::
  172. When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
  173. If converted back and forth, this mismatch has no effect.
  174. Raises:
  175. RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
  176. well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
  177. overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
  178. of the integer ``dtype``.
  179. """
  180. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  181. _log_api_usage_once(convert_image_dtype)
  182. if not isinstance(image, torch.Tensor):
  183. raise TypeError("Input img should be Tensor Image")
  184. return F_t.convert_image_dtype(image, dtype)
  185. def to_pil_image(pic, mode=None):
  186. """Convert a tensor or an ndarray to PIL Image. This function does not support torchscript.
  187. See :class:`~torchvision.transforms.ToPILImage` for more details.
  188. Args:
  189. pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
  190. mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
  191. .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
  192. Returns:
  193. PIL Image: Image converted to PIL Image.
  194. """
  195. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  196. _log_api_usage_once(to_pil_image)
  197. if isinstance(pic, torch.Tensor):
  198. if pic.ndim == 3:
  199. pic = pic.permute((1, 2, 0))
  200. pic = pic.numpy(force=True)
  201. elif not isinstance(pic, np.ndarray):
  202. raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
  203. if pic.ndim == 2:
  204. # if 2D image, add channel dimension (HWC)
  205. pic = np.expand_dims(pic, 2)
  206. if pic.ndim != 3:
  207. raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
  208. if pic.shape[-1] > 4:
  209. raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
  210. npimg = pic
  211. if np.issubdtype(npimg.dtype, np.floating) and mode != "F":
  212. npimg = (npimg * 255).astype(np.uint8)
  213. if npimg.shape[2] == 1:
  214. expected_mode = None
  215. npimg = npimg[:, :, 0]
  216. if npimg.dtype == np.uint8:
  217. expected_mode = "L"
  218. elif npimg.dtype == np.int16:
  219. expected_mode = "I;16" if sys.byteorder == "little" else "I;16B"
  220. elif npimg.dtype == np.int32:
  221. expected_mode = "I"
  222. elif npimg.dtype == np.float32:
  223. expected_mode = "F"
  224. if mode is not None and mode != expected_mode:
  225. raise ValueError(f"Incorrect mode ({mode}) supplied for input type {np.dtype}. Should be {expected_mode}")
  226. mode = expected_mode
  227. elif npimg.shape[2] == 2:
  228. permitted_2_channel_modes = ["LA"]
  229. if mode is not None and mode not in permitted_2_channel_modes:
  230. raise ValueError(f"Only modes {permitted_2_channel_modes} are supported for 2D inputs")
  231. if mode is None and npimg.dtype == np.uint8:
  232. mode = "LA"
  233. elif npimg.shape[2] == 4:
  234. permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"]
  235. if mode is not None and mode not in permitted_4_channel_modes:
  236. raise ValueError(f"Only modes {permitted_4_channel_modes} are supported for 4D inputs")
  237. if mode is None and npimg.dtype == np.uint8:
  238. mode = "RGBA"
  239. else:
  240. permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"]
  241. if mode is not None and mode not in permitted_3_channel_modes:
  242. raise ValueError(f"Only modes {permitted_3_channel_modes} are supported for 3D inputs")
  243. if mode is None and npimg.dtype == np.uint8:
  244. mode = "RGB"
  245. if mode is None:
  246. raise TypeError(f"Input type {npimg.dtype} is not supported")
  247. return _Image_fromarray(npimg, mode=mode)
  248. def normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool = False) -> Tensor:
  249. """Normalize a float tensor image with mean and standard deviation.
  250. This transform does not support PIL Image.
  251. .. note::
  252. This transform acts out of place by default, i.e., it does not mutates the input tensor.
  253. See :class:`~torchvision.transforms.Normalize` for more details.
  254. Args:
  255. tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
  256. mean (sequence): Sequence of means for each channel.
  257. std (sequence): Sequence of standard deviations for each channel.
  258. inplace(bool,optional): Bool to make this operation inplace.
  259. Returns:
  260. Tensor: Normalized Tensor image.
  261. """
  262. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  263. _log_api_usage_once(normalize)
  264. if not isinstance(tensor, torch.Tensor):
  265. raise TypeError(f"img should be Tensor Image. Got {type(tensor)}")
  266. return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
  267. def _compute_resized_output_size(
  268. image_size: tuple[int, int],
  269. size: Optional[list[int]],
  270. max_size: Optional[int] = None,
  271. allow_size_none: bool = False, # only True in v2
  272. ) -> list[int]:
  273. h, w = image_size
  274. short, long = (w, h) if w <= h else (h, w)
  275. if size is None:
  276. if not allow_size_none:
  277. raise ValueError("This should never happen!!")
  278. if not isinstance(max_size, int):
  279. raise ValueError(f"max_size must be an integer when size is None, but got {max_size} instead.")
  280. new_short, new_long = int(max_size * short / long), max_size
  281. new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
  282. elif len(size) == 1: # specified size only for the smallest edge
  283. requested_new_short = size if isinstance(size, int) else size[0]
  284. new_short, new_long = requested_new_short, int(requested_new_short * long / short)
  285. if max_size is not None:
  286. if max_size <= requested_new_short:
  287. raise ValueError(
  288. f"max_size = {max_size} must be strictly greater than the requested "
  289. f"size for the smaller edge size = {size}"
  290. )
  291. if new_long > max_size:
  292. new_short, new_long = int(max_size * new_short / new_long), max_size
  293. new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
  294. else: # specified both h and w
  295. new_w, new_h = size[1], size[0]
  296. return [new_h, new_w]
  297. def resize(
  298. img: Tensor,
  299. size: list[int],
  300. interpolation: InterpolationMode = InterpolationMode.BILINEAR,
  301. max_size: Optional[int] = None,
  302. antialias: Optional[bool] = True,
  303. ) -> Tensor:
  304. r"""Resize the input image to the given size.
  305. If the image is torch Tensor, it is expected
  306. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
  307. Args:
  308. img (PIL Image or Tensor): Image to be resized.
  309. size (sequence or int): Desired output size. If size is a sequence like
  310. (h, w), the output size will be matched to this. If size is an int,
  311. the smaller edge of the image will be matched to this number maintaining
  312. the aspect ratio. i.e, if height > width, then image will be rescaled to
  313. :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
  314. .. note::
  315. In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
  316. interpolation (InterpolationMode): Desired interpolation enum defined by
  317. :class:`torchvision.transforms.InterpolationMode`.
  318. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
  319. ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
  320. supported.
  321. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
  322. max_size (int, optional): The maximum allowed for the longer edge of
  323. the resized image. If the longer edge of the image is greater
  324. than ``max_size`` after being resized according to ``size``,
  325. ``size`` will be overruled so that the longer edge is equal to
  326. ``max_size``.
  327. As a result, the smaller edge may be shorter than ``size``. This
  328. is only supported if ``size`` is an int (or a sequence of length
  329. 1 in torchscript mode).
  330. antialias (bool, optional): Whether to apply antialiasing.
  331. It only affects **tensors** with bilinear or bicubic modes and it is
  332. ignored otherwise: on PIL images, antialiasing is always applied on
  333. bilinear or bicubic modes; on other modes (for PIL images and
  334. tensors), antialiasing makes no sense and this parameter is ignored.
  335. Possible values are:
  336. - ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
  337. Other mode aren't affected. This is probably what you want to use.
  338. - ``False``: will not apply antialiasing for tensors on any mode. PIL
  339. images are still antialiased on bilinear or bicubic modes, because
  340. PIL doesn't support no antialias.
  341. - ``None``: equivalent to ``False`` for tensors and ``True`` for
  342. PIL images. This value exists for legacy reasons and you probably
  343. don't want to use it unless you really know what you are doing.
  344. The default value changed from ``None`` to ``True`` in
  345. v0.17, for the PIL and Tensor backends to be consistent.
  346. Returns:
  347. PIL Image or Tensor: Resized image.
  348. """
  349. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  350. _log_api_usage_once(resize)
  351. if isinstance(interpolation, int):
  352. interpolation = _interpolation_modes_from_int(interpolation)
  353. elif not isinstance(interpolation, InterpolationMode):
  354. raise TypeError(
  355. "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
  356. )
  357. if isinstance(size, (list, tuple)):
  358. if len(size) not in [1, 2]:
  359. raise ValueError(
  360. f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
  361. )
  362. if max_size is not None and len(size) != 1:
  363. raise ValueError(
  364. "max_size should only be passed if size specifies the length of the smaller edge, "
  365. "i.e. size should be an int or a sequence of length 1 in torchscript mode."
  366. )
  367. _, image_height, image_width = get_dimensions(img)
  368. if isinstance(size, int):
  369. size = [size]
  370. output_size = _compute_resized_output_size((image_height, image_width), size, max_size)
  371. if [image_height, image_width] == output_size:
  372. return img
  373. if not isinstance(img, torch.Tensor):
  374. if antialias is False:
  375. warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
  376. pil_interpolation = pil_modes_mapping[interpolation]
  377. return F_pil.resize(img, size=output_size, interpolation=pil_interpolation)
  378. return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias)
  379. def pad(img: Tensor, padding: list[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor:
  380. r"""Pad the given image on all sides with the given "pad" value.
  381. If the image is torch Tensor, it is expected
  382. to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
  383. at most 3 leading dimensions for mode edge,
  384. and an arbitrary number of leading dimensions for mode constant
  385. Args:
  386. img (PIL Image or Tensor): Image to be padded.
  387. padding (int or sequence): Padding on each border. If a single int is provided this
  388. is used to pad all borders. If sequence of length 2 is provided this is the padding
  389. on left/right and top/bottom respectively. If a sequence of length 4 is provided
  390. this is the padding for the left, top, right and bottom borders respectively.
  391. .. note::
  392. In torchscript mode padding as single int is not supported, use a sequence of
  393. length 1: ``[padding, ]``.
  394. fill (number or tuple): Pixel fill value for constant fill. Default is 0.
  395. If a tuple of length 3, it is used to fill R, G, B channels respectively.
  396. This value is only used when the padding_mode is constant.
  397. Only number is supported for torch Tensor.
  398. Only int or tuple value is supported for PIL Image.
  399. padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
  400. Default is constant.
  401. - constant: pads with a constant value, this value is specified with fill
  402. - edge: pads with the last value at the edge of the image.
  403. If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
  404. - reflect: pads with reflection of image without repeating the last value on the edge.
  405. For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
  406. will result in [3, 2, 1, 2, 3, 4, 3, 2]
  407. - symmetric: pads with reflection of image repeating the last value on the edge.
  408. For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
  409. will result in [2, 1, 1, 2, 3, 4, 4, 3]
  410. Returns:
  411. PIL Image or Tensor: Padded image.
  412. """
  413. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  414. _log_api_usage_once(pad)
  415. if not isinstance(img, torch.Tensor):
  416. return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
  417. return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
  418. def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
  419. """Crop the given image at specified location and output size.
  420. If the image is torch Tensor, it is expected
  421. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
  422. If image size is smaller than output size along any edge, image is padded with 0 and then cropped.
  423. Args:
  424. img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
  425. top (int): Vertical component of the top left corner of the crop box.
  426. left (int): Horizontal component of the top left corner of the crop box.
  427. height (int): Height of the crop box.
  428. width (int): Width of the crop box.
  429. Returns:
  430. PIL Image or Tensor: Cropped image.
  431. """
  432. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  433. _log_api_usage_once(crop)
  434. if not isinstance(img, torch.Tensor):
  435. return F_pil.crop(img, top, left, height, width)
  436. return F_t.crop(img, top, left, height, width)
  437. def center_crop(img: Tensor, output_size: list[int]) -> Tensor:
  438. """Crops the given image at the center.
  439. If the image is torch Tensor, it is expected
  440. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
  441. If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
  442. Args:
  443. img (PIL Image or Tensor): Image to be cropped.
  444. output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
  445. it is used for both directions.
  446. Returns:
  447. PIL Image or Tensor: Cropped image.
  448. """
  449. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  450. _log_api_usage_once(center_crop)
  451. if isinstance(output_size, numbers.Number):
  452. output_size = (int(output_size), int(output_size))
  453. elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
  454. output_size = (output_size[0], output_size[0])
  455. _, image_height, image_width = get_dimensions(img)
  456. crop_height, crop_width = output_size
  457. if crop_width > image_width or crop_height > image_height:
  458. padding_ltrb = [
  459. (crop_width - image_width) // 2 if crop_width > image_width else 0,
  460. (crop_height - image_height) // 2 if crop_height > image_height else 0,
  461. (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
  462. (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
  463. ]
  464. img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
  465. _, image_height, image_width = get_dimensions(img)
  466. if crop_width == image_width and crop_height == image_height:
  467. return img
  468. crop_top = int(round((image_height - crop_height) / 2.0))
  469. crop_left = int(round((image_width - crop_width) / 2.0))
  470. return crop(img, crop_top, crop_left, crop_height, crop_width)
  471. def resized_crop(
  472. img: Tensor,
  473. top: int,
  474. left: int,
  475. height: int,
  476. width: int,
  477. size: list[int],
  478. interpolation: InterpolationMode = InterpolationMode.BILINEAR,
  479. antialias: Optional[bool] = True,
  480. ) -> Tensor:
  481. """Crop the given image and resize it to desired size.
  482. If the image is torch Tensor, it is expected
  483. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
  484. Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
  485. Args:
  486. img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
  487. top (int): Vertical component of the top left corner of the crop box.
  488. left (int): Horizontal component of the top left corner of the crop box.
  489. height (int): Height of the crop box.
  490. width (int): Width of the crop box.
  491. size (sequence or int): Desired output size. Same semantics as ``resize``.
  492. interpolation (InterpolationMode): Desired interpolation enum defined by
  493. :class:`torchvision.transforms.InterpolationMode`.
  494. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
  495. ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
  496. supported.
  497. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
  498. antialias (bool, optional): Whether to apply antialiasing.
  499. It only affects **tensors** with bilinear or bicubic modes and it is
  500. ignored otherwise: on PIL images, antialiasing is always applied on
  501. bilinear or bicubic modes; on other modes (for PIL images and
  502. tensors), antialiasing makes no sense and this parameter is ignored.
  503. Possible values are:
  504. - ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
  505. Other mode aren't affected. This is probably what you want to use.
  506. - ``False``: will not apply antialiasing for tensors on any mode. PIL
  507. images are still antialiased on bilinear or bicubic modes, because
  508. PIL doesn't support no antialias.
  509. - ``None``: equivalent to ``False`` for tensors and ``True`` for
  510. PIL images. This value exists for legacy reasons and you probably
  511. don't want to use it unless you really know what you are doing.
  512. The default value changed from ``None`` to ``True`` in
  513. v0.17, for the PIL and Tensor backends to be consistent.
  514. Returns:
  515. PIL Image or Tensor: Cropped image.
  516. """
  517. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  518. _log_api_usage_once(resized_crop)
  519. img = crop(img, top, left, height, width)
  520. img = resize(img, size, interpolation, antialias=antialias)
  521. return img
  522. def hflip(img: Tensor) -> Tensor:
  523. """Horizontally flip the given image.
  524. Args:
  525. img (PIL Image or Tensor): Image to be flipped. If img
  526. is a Tensor, it is expected to be in [..., H, W] format,
  527. where ... means it can have an arbitrary number of leading
  528. dimensions.
  529. Returns:
  530. PIL Image or Tensor: Horizontally flipped image.
  531. """
  532. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  533. _log_api_usage_once(hflip)
  534. if not isinstance(img, torch.Tensor):
  535. return F_pil.hflip(img)
  536. return F_t.hflip(img)
  537. def _get_perspective_coeffs(startpoints: list[list[int]], endpoints: list[list[int]]) -> list[float]:
  538. """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
  539. In Perspective Transform each pixel (x, y) in the original image gets transformed as,
  540. (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
  541. Args:
  542. startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
  543. ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
  544. endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
  545. ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
  546. Returns:
  547. octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
  548. """
  549. if len(startpoints) != 4 or len(endpoints) != 4:
  550. raise ValueError(
  551. f"Please provide exactly four corners, got {len(startpoints)} startpoints and {len(endpoints)} endpoints."
  552. )
  553. a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float64)
  554. for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
  555. a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
  556. a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
  557. b_matrix = torch.tensor(startpoints, dtype=torch.float64).view(8)
  558. # do least squares in double precision to prevent numerical issues
  559. res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution.to(torch.float32)
  560. output: list[float] = res.tolist()
  561. return output
  562. def perspective(
  563. img: Tensor,
  564. startpoints: list[list[int]],
  565. endpoints: list[list[int]],
  566. interpolation: InterpolationMode = InterpolationMode.BILINEAR,
  567. fill: Optional[list[float]] = None,
  568. ) -> Tensor:
  569. """Perform perspective transform of the given image.
  570. If the image is torch Tensor, it is expected
  571. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
  572. Args:
  573. img (PIL Image or Tensor): Image to be transformed.
  574. startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
  575. ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
  576. endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
  577. ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
  578. interpolation (InterpolationMode): Desired interpolation enum defined by
  579. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
  580. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  581. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
  582. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  583. image. If given a number, the value is used for all bands respectively.
  584. .. note::
  585. In torchscript mode single int/float value is not supported, please use a sequence
  586. of length 1: ``[value, ]``.
  587. Returns:
  588. PIL Image or Tensor: transformed Image.
  589. """
  590. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  591. _log_api_usage_once(perspective)
  592. coeffs = _get_perspective_coeffs(startpoints, endpoints)
  593. if isinstance(interpolation, int):
  594. interpolation = _interpolation_modes_from_int(interpolation)
  595. elif not isinstance(interpolation, InterpolationMode):
  596. raise TypeError(
  597. "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
  598. )
  599. if not isinstance(img, torch.Tensor):
  600. pil_interpolation = pil_modes_mapping[interpolation]
  601. return F_pil.perspective(img, coeffs, interpolation=pil_interpolation, fill=fill)
  602. return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill)
  603. def vflip(img: Tensor) -> Tensor:
  604. """Vertically flip the given image.
  605. Args:
  606. img (PIL Image or Tensor): Image to be flipped. If img
  607. is a Tensor, it is expected to be in [..., H, W] format,
  608. where ... means it can have an arbitrary number of leading
  609. dimensions.
  610. Returns:
  611. PIL Image or Tensor: Vertically flipped image.
  612. """
  613. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  614. _log_api_usage_once(vflip)
  615. if not isinstance(img, torch.Tensor):
  616. return F_pil.vflip(img)
  617. return F_t.vflip(img)
  618. def five_crop(img: Tensor, size: list[int]) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
  619. """Crop the given image into four corners and the central crop.
  620. If the image is torch Tensor, it is expected
  621. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
  622. .. Note::
  623. This transform returns a tuple of images and there may be a
  624. mismatch in the number of inputs and targets your ``Dataset`` returns.
  625. Args:
  626. img (PIL Image or Tensor): Image to be cropped.
  627. size (sequence or int): Desired output size of the crop. If size is an
  628. int instead of sequence like (h, w), a square crop (size, size) is
  629. made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
  630. Returns:
  631. tuple: tuple (tl, tr, bl, br, center)
  632. Corresponding top left, top right, bottom left, bottom right and center crop.
  633. """
  634. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  635. _log_api_usage_once(five_crop)
  636. if isinstance(size, numbers.Number):
  637. size = (int(size), int(size))
  638. elif isinstance(size, (tuple, list)) and len(size) == 1:
  639. size = (size[0], size[0])
  640. if len(size) != 2:
  641. raise ValueError("Please provide only two dimensions (h, w) for size.")
  642. _, image_height, image_width = get_dimensions(img)
  643. crop_height, crop_width = size
  644. if crop_width > image_width or crop_height > image_height:
  645. msg = "Requested crop size {} is bigger than input size {}"
  646. raise ValueError(msg.format(size, (image_height, image_width)))
  647. tl = crop(img, 0, 0, crop_height, crop_width)
  648. tr = crop(img, 0, image_width - crop_width, crop_height, crop_width)
  649. bl = crop(img, image_height - crop_height, 0, crop_height, crop_width)
  650. br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
  651. center = center_crop(img, [crop_height, crop_width])
  652. return tl, tr, bl, br, center
  653. def ten_crop(
  654. img: Tensor, size: list[int], vertical_flip: bool = False
  655. ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
  656. """Generate ten cropped images from the given image.
  657. Crop the given image into four corners and the central crop plus the
  658. flipped version of these (horizontal flipping is used by default).
  659. If the image is torch Tensor, it is expected
  660. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
  661. .. Note::
  662. This transform returns a tuple of images and there may be a
  663. mismatch in the number of inputs and targets your ``Dataset`` returns.
  664. Args:
  665. img (PIL Image or Tensor): Image to be cropped.
  666. size (sequence or int): Desired output size of the crop. If size is an
  667. int instead of sequence like (h, w), a square crop (size, size) is
  668. made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
  669. vertical_flip (bool): Use vertical flipping instead of horizontal
  670. Returns:
  671. tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
  672. Corresponding top left, top right, bottom left, bottom right and
  673. center crop and same for the flipped image.
  674. """
  675. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  676. _log_api_usage_once(ten_crop)
  677. if isinstance(size, numbers.Number):
  678. size = (int(size), int(size))
  679. elif isinstance(size, (tuple, list)) and len(size) == 1:
  680. size = (size[0], size[0])
  681. if len(size) != 2:
  682. raise ValueError("Please provide only two dimensions (h, w) for size.")
  683. first_five = five_crop(img, size)
  684. if vertical_flip:
  685. img = vflip(img)
  686. else:
  687. img = hflip(img)
  688. second_five = five_crop(img, size)
  689. return first_five + second_five
  690. def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
  691. """Adjust brightness of an image.
  692. Args:
  693. img (PIL Image or Tensor): Image to be adjusted.
  694. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  695. where ... means it can have an arbitrary number of leading dimensions.
  696. brightness_factor (float): How much to adjust the brightness. Can be
  697. any non-negative number. 0 gives a black image, 1 gives the
  698. original image while 2 increases the brightness by a factor of 2.
  699. Returns:
  700. PIL Image or Tensor: Brightness adjusted image.
  701. """
  702. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  703. _log_api_usage_once(adjust_brightness)
  704. if not isinstance(img, torch.Tensor):
  705. return F_pil.adjust_brightness(img, brightness_factor)
  706. return F_t.adjust_brightness(img, brightness_factor)
  707. def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
  708. """Adjust contrast of an image.
  709. Args:
  710. img (PIL Image or Tensor): Image to be adjusted.
  711. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  712. where ... means it can have an arbitrary number of leading dimensions.
  713. contrast_factor (float): How much to adjust the contrast. Can be any
  714. non-negative number. 0 gives a solid gray image, 1 gives the
  715. original image while 2 increases the contrast by a factor of 2.
  716. Returns:
  717. PIL Image or Tensor: Contrast adjusted image.
  718. """
  719. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  720. _log_api_usage_once(adjust_contrast)
  721. if not isinstance(img, torch.Tensor):
  722. return F_pil.adjust_contrast(img, contrast_factor)
  723. return F_t.adjust_contrast(img, contrast_factor)
  724. def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
  725. """Adjust color saturation of an image.
  726. Args:
  727. img (PIL Image or Tensor): Image to be adjusted.
  728. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  729. where ... means it can have an arbitrary number of leading dimensions.
  730. saturation_factor (float): How much to adjust the saturation. 0 will
  731. give a black and white image, 1 will give the original image while
  732. 2 will enhance the saturation by a factor of 2.
  733. Returns:
  734. PIL Image or Tensor: Saturation adjusted image.
  735. """
  736. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  737. _log_api_usage_once(adjust_saturation)
  738. if not isinstance(img, torch.Tensor):
  739. return F_pil.adjust_saturation(img, saturation_factor)
  740. return F_t.adjust_saturation(img, saturation_factor)
  741. def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
  742. """Adjust hue of an image.
  743. The image hue is adjusted by converting the image to HSV and
  744. cyclically shifting the intensities in the hue channel (H).
  745. The image is then converted back to original image mode.
  746. `hue_factor` is the amount of shift in H channel and must be in the
  747. interval `[-0.5, 0.5]`.
  748. See `Hue`_ for more details.
  749. .. _Hue: https://en.wikipedia.org/wiki/Hue
  750. Args:
  751. img (PIL Image or Tensor): Image to be adjusted.
  752. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  753. where ... means it can have an arbitrary number of leading dimensions.
  754. If img is PIL Image mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
  755. Note: the pixel values of the input image has to be non-negative for conversion to HSV space;
  756. thus it does not work if you normalize your image to an interval with negative values,
  757. or use an interpolation that generates negative values before using this function.
  758. hue_factor (float): How much to shift the hue channel. Should be in
  759. [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
  760. HSV space in positive and negative direction respectively.
  761. 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
  762. with complementary colors while 0 gives the original image.
  763. Returns:
  764. PIL Image or Tensor: Hue adjusted image.
  765. """
  766. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  767. _log_api_usage_once(adjust_hue)
  768. if not isinstance(img, torch.Tensor):
  769. return F_pil.adjust_hue(img, hue_factor)
  770. return F_t.adjust_hue(img, hue_factor)
  771. def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
  772. r"""Perform gamma correction on an image.
  773. Also known as Power Law Transform. Intensities in RGB mode are adjusted
  774. based on the following equation:
  775. .. math::
  776. I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
  777. See `Gamma Correction`_ for more details.
  778. .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
  779. Args:
  780. img (PIL Image or Tensor): PIL Image to be adjusted.
  781. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  782. where ... means it can have an arbitrary number of leading dimensions.
  783. If img is PIL Image, modes with transparency (alpha channel) are not supported.
  784. gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
  785. gamma larger than 1 make the shadows darker,
  786. while gamma smaller than 1 make dark regions lighter.
  787. gain (float): The constant multiplier.
  788. Returns:
  789. PIL Image or Tensor: Gamma correction adjusted image.
  790. """
  791. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  792. _log_api_usage_once(adjust_gamma)
  793. if not isinstance(img, torch.Tensor):
  794. return F_pil.adjust_gamma(img, gamma, gain)
  795. return F_t.adjust_gamma(img, gamma, gain)
  796. def _get_inverse_affine_matrix(
  797. center: list[float], angle: float, translate: list[float], scale: float, shear: list[float], inverted: bool = True
  798. ) -> list[float]:
  799. # Helper method to compute inverse matrix for affine transformation
  800. # Pillow requires inverse affine transformation matrix:
  801. # Affine matrix is : M = T * C * RotateScaleShear * C^-1
  802. #
  803. # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
  804. # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
  805. # RotateScaleShear is rotation with scale and shear matrix
  806. #
  807. # RotateScaleShear(a, s, (sx, sy)) =
  808. # = R(a) * S(s) * SHy(sy) * SHx(sx)
  809. # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
  810. # [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
  811. # [ 0 , 0 , 1 ]
  812. # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
  813. # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
  814. # [0, 1 ] [-tan(s), 1]
  815. #
  816. # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
  817. rot = math.radians(angle)
  818. sx = math.radians(shear[0])
  819. sy = math.radians(shear[1])
  820. cx, cy = center
  821. tx, ty = translate
  822. # RSS without scaling
  823. a = math.cos(rot - sy) / math.cos(sy)
  824. b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
  825. c = math.sin(rot - sy) / math.cos(sy)
  826. d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
  827. if inverted:
  828. # Inverted rotation matrix with scale and shear
  829. # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
  830. matrix = [d, -b, 0.0, -c, a, 0.0]
  831. matrix = [x / scale for x in matrix]
  832. # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
  833. matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
  834. matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
  835. # Apply center translation: C * RSS^-1 * C^-1 * T^-1
  836. matrix[2] += cx
  837. matrix[5] += cy
  838. else:
  839. matrix = [a, b, 0.0, c, d, 0.0]
  840. matrix = [x * scale for x in matrix]
  841. # Apply inverse of center translation: RSS * C^-1
  842. matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
  843. matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
  844. # Apply translation and center : T * C * RSS * C^-1
  845. matrix[2] += cx + tx
  846. matrix[5] += cy + ty
  847. return matrix
  848. def rotate(
  849. img: Tensor,
  850. angle: float,
  851. interpolation: InterpolationMode = InterpolationMode.NEAREST,
  852. expand: bool = False,
  853. center: Optional[list[int]] = None,
  854. fill: Optional[list[float]] = None,
  855. ) -> Tensor:
  856. """Rotate the image by angle.
  857. If the image is torch Tensor, it is expected
  858. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
  859. Args:
  860. img (PIL Image or Tensor): image to be rotated.
  861. angle (number): rotation angle value in degrees, counter-clockwise.
  862. interpolation (InterpolationMode): Desired interpolation enum defined by
  863. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  864. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  865. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
  866. expand (bool, optional): Optional expansion flag.
  867. If true, expands the output image to make it large enough to hold the entire rotated image.
  868. If false or omitted, make the output image the same size as the input image.
  869. Note that the expand flag assumes rotation around the center and no translation.
  870. center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
  871. Default is the center of the image.
  872. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  873. image. If given a number, the value is used for all bands respectively.
  874. .. note::
  875. In torchscript mode single int/float value is not supported, please use a sequence
  876. of length 1: ``[value, ]``.
  877. Returns:
  878. PIL Image or Tensor: Rotated image.
  879. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
  880. """
  881. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  882. _log_api_usage_once(rotate)
  883. if isinstance(interpolation, int):
  884. interpolation = _interpolation_modes_from_int(interpolation)
  885. elif not isinstance(interpolation, InterpolationMode):
  886. raise TypeError(
  887. "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
  888. )
  889. if not isinstance(angle, (int, float)):
  890. raise TypeError("Argument angle should be int or float")
  891. if center is not None and not isinstance(center, (list, tuple)):
  892. raise TypeError("Argument center should be a sequence")
  893. if not isinstance(img, torch.Tensor):
  894. pil_interpolation = pil_modes_mapping[interpolation]
  895. return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
  896. center_f = [0.0, 0.0]
  897. if center is not None:
  898. _, height, width = get_dimensions(img)
  899. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
  900. center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
  901. # due to current incoherence of rotation angle direction between affine and rotate implementations
  902. # we need to set -angle.
  903. matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
  904. return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill)
  905. def affine(
  906. img: Tensor,
  907. angle: float,
  908. translate: list[int],
  909. scale: float,
  910. shear: list[float],
  911. interpolation: InterpolationMode = InterpolationMode.NEAREST,
  912. fill: Optional[list[float]] = None,
  913. center: Optional[list[int]] = None,
  914. ) -> Tensor:
  915. """Apply affine transformation on the image keeping image center invariant.
  916. If the image is torch Tensor, it is expected
  917. to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
  918. Args:
  919. img (PIL Image or Tensor): image to transform.
  920. angle (number): rotation angle in degrees between -180 and 180, clockwise direction.
  921. translate (sequence of integers): horizontal and vertical translations (post-rotation translation)
  922. scale (float): overall scale
  923. shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction.
  924. If a sequence is specified, the first value corresponds to a shear parallel to the x-axis, while
  925. the second value corresponds to a shear parallel to the y-axis.
  926. interpolation (InterpolationMode): Desired interpolation enum defined by
  927. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  928. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  929. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
  930. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  931. image. If given a number, the value is used for all bands respectively.
  932. .. note::
  933. In torchscript mode single int/float value is not supported, please use a sequence
  934. of length 1: ``[value, ]``.
  935. center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
  936. Default is the center of the image.
  937. Returns:
  938. PIL Image or Tensor: Transformed image.
  939. """
  940. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  941. _log_api_usage_once(affine)
  942. if isinstance(interpolation, int):
  943. interpolation = _interpolation_modes_from_int(interpolation)
  944. elif not isinstance(interpolation, InterpolationMode):
  945. raise TypeError(
  946. "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
  947. )
  948. if not isinstance(angle, (int, float)):
  949. raise TypeError("Argument angle should be int or float")
  950. if not isinstance(translate, (list, tuple)):
  951. raise TypeError("Argument translate should be a sequence")
  952. if len(translate) != 2:
  953. raise ValueError("Argument translate should be a sequence of length 2")
  954. if scale <= 0.0:
  955. raise ValueError("Argument scale should be positive")
  956. if not isinstance(shear, (numbers.Number, (list, tuple))):
  957. raise TypeError("Shear should be either a single value or a sequence of two values")
  958. if isinstance(angle, int):
  959. angle = float(angle)
  960. if isinstance(translate, tuple):
  961. translate = list(translate)
  962. if isinstance(shear, numbers.Number):
  963. shear = [shear, 0.0]
  964. if isinstance(shear, tuple):
  965. shear = list(shear)
  966. if len(shear) == 1:
  967. shear = [shear[0], shear[0]]
  968. if len(shear) != 2:
  969. raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
  970. if center is not None and not isinstance(center, (list, tuple)):
  971. raise TypeError("Argument center should be a sequence")
  972. _, height, width = get_dimensions(img)
  973. if not isinstance(img, torch.Tensor):
  974. # center = (width * 0.5 + 0.5, height * 0.5 + 0.5)
  975. # it is visually better to estimate the center without 0.5 offset
  976. # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
  977. if center is None:
  978. center = [width * 0.5, height * 0.5]
  979. matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
  980. pil_interpolation = pil_modes_mapping[interpolation]
  981. return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)
  982. center_f = [0.0, 0.0]
  983. if center is not None:
  984. _, height, width = get_dimensions(img)
  985. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
  986. center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
  987. translate_f = [1.0 * t for t in translate]
  988. matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
  989. return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)
  990. # Looks like to_grayscale() is a stand-alone functional that is never called
  991. # from the transform classes. Perhaps it's still here for BC? I can't be
  992. # bothered to dig.
  993. @torch.jit.unused
  994. def to_grayscale(img, num_output_channels=1):
  995. """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
  996. This transform does not support torch Tensor.
  997. Args:
  998. img (PIL Image): PIL Image to be converted to grayscale.
  999. num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default is 1.
  1000. Returns:
  1001. PIL Image: Grayscale version of the image.
  1002. - if num_output_channels = 1 : returned image is single channel
  1003. - if num_output_channels = 3 : returned image is 3 channel with r = g = b
  1004. """
  1005. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1006. _log_api_usage_once(to_grayscale)
  1007. if isinstance(img, Image.Image):
  1008. return F_pil.to_grayscale(img, num_output_channels)
  1009. raise TypeError("Input should be PIL Image")
  1010. def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
  1011. """Convert RGB image to grayscale version of image.
  1012. If the image is torch Tensor, it is expected
  1013. to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
  1014. Note:
  1015. Please, note that this method supports only RGB images as input. For inputs in other color spaces,
  1016. please, consider using :meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.
  1017. Args:
  1018. img (PIL Image or Tensor): RGB Image to be converted to grayscale.
  1019. num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
  1020. Returns:
  1021. PIL Image or Tensor: Grayscale version of the image.
  1022. - if num_output_channels = 1 : returned image is single channel
  1023. - if num_output_channels = 3 : returned image is 3 channel with r = g = b
  1024. """
  1025. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1026. _log_api_usage_once(rgb_to_grayscale)
  1027. if not isinstance(img, torch.Tensor):
  1028. return F_pil.to_grayscale(img, num_output_channels)
  1029. return F_t.rgb_to_grayscale(img, num_output_channels)
  1030. def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
  1031. """Erase the input Tensor Image with given value.
  1032. This transform does not support PIL Image.
  1033. Args:
  1034. img (Tensor Image): Tensor image of size (C, H, W) to be erased
  1035. i (int): i in (i,j) i.e coordinates of the upper left corner.
  1036. j (int): j in (i,j) i.e coordinates of the upper left corner.
  1037. h (int): Height of the erased region.
  1038. w (int): Width of the erased region.
  1039. v: Erasing value.
  1040. inplace(bool, optional): For in-place operations. By default, is set False.
  1041. Returns:
  1042. Tensor Image: Erased image.
  1043. """
  1044. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1045. _log_api_usage_once(erase)
  1046. if not isinstance(img, torch.Tensor):
  1047. raise TypeError(f"img should be Tensor Image. Got {type(img)}")
  1048. return F_t.erase(img, i, j, h, w, v, inplace=inplace)
  1049. def gaussian_blur(img: Tensor, kernel_size: list[int], sigma: Optional[list[float]] = None) -> Tensor:
  1050. """Performs Gaussian blurring on the image by given kernel
  1051. The convolution will be using reflection padding corresponding to the kernel size, to maintain the input shape.
  1052. If the image is torch Tensor, it is expected
  1053. to have [..., H, W] shape, where ... means at most one leading dimension.
  1054. Args:
  1055. img (PIL Image or Tensor): Image to be blurred
  1056. kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
  1057. like ``(kx, ky)`` or a single integer for square kernels.
  1058. .. note::
  1059. In torchscript mode kernel_size as single int is not supported, use a sequence of
  1060. length 1: ``[ksize, ]``.
  1061. sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
  1062. sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
  1063. same sigma in both X/Y directions. If None, then it is computed using
  1064. ``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
  1065. Default, None.
  1066. .. note::
  1067. In torchscript mode sigma as single float is
  1068. not supported, use a sequence of length 1: ``[sigma, ]``.
  1069. Returns:
  1070. PIL Image or Tensor: Gaussian Blurred version of the image.
  1071. """
  1072. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1073. _log_api_usage_once(gaussian_blur)
  1074. if not isinstance(kernel_size, (int, list, tuple)):
  1075. raise TypeError(f"kernel_size should be int or a sequence of integers. Got {type(kernel_size)}")
  1076. if isinstance(kernel_size, int):
  1077. kernel_size = [kernel_size, kernel_size]
  1078. if len(kernel_size) != 2:
  1079. raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
  1080. for ksize in kernel_size:
  1081. if ksize % 2 == 0 or ksize < 0:
  1082. raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
  1083. if sigma is None:
  1084. sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
  1085. if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
  1086. raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
  1087. if isinstance(sigma, (int, float)):
  1088. sigma = [float(sigma), float(sigma)]
  1089. if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
  1090. sigma = [sigma[0], sigma[0]]
  1091. if len(sigma) != 2:
  1092. raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
  1093. for s in sigma:
  1094. if s <= 0.0:
  1095. raise ValueError(f"sigma should have positive values. Got {sigma}")
  1096. t_img = img
  1097. if not isinstance(img, torch.Tensor):
  1098. if not F_pil._is_pil_image(img):
  1099. raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
  1100. t_img = pil_to_tensor(img)
  1101. output = F_t.gaussian_blur(t_img, kernel_size, sigma)
  1102. if not isinstance(img, torch.Tensor):
  1103. output = to_pil_image(output, mode=img.mode)
  1104. return output
  1105. def invert(img: Tensor) -> Tensor:
  1106. """Invert the colors of an RGB/grayscale image.
  1107. Args:
  1108. img (PIL Image or Tensor): Image to have its colors inverted.
  1109. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  1110. where ... means it can have an arbitrary number of leading dimensions.
  1111. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  1112. Returns:
  1113. PIL Image or Tensor: Color inverted image.
  1114. """
  1115. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1116. _log_api_usage_once(invert)
  1117. if not isinstance(img, torch.Tensor):
  1118. return F_pil.invert(img)
  1119. return F_t.invert(img)
  1120. def posterize(img: Tensor, bits: int) -> Tensor:
  1121. """Posterize an image by reducing the number of bits for each color channel.
  1122. Args:
  1123. img (PIL Image or Tensor): Image to have its colors posterized.
  1124. If img is torch Tensor, it should be of type torch.uint8, and
  1125. it is expected to be in [..., 1 or 3, H, W] format, where ... means
  1126. it can have an arbitrary number of leading dimensions.
  1127. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  1128. bits (int): The number of bits to keep for each channel (0-8).
  1129. Returns:
  1130. PIL Image or Tensor: Posterized image.
  1131. """
  1132. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1133. _log_api_usage_once(posterize)
  1134. if not (0 <= bits <= 8):
  1135. raise ValueError(f"The number if bits should be between 0 and 8. Got {bits}")
  1136. if not isinstance(img, torch.Tensor):
  1137. return F_pil.posterize(img, bits)
  1138. return F_t.posterize(img, bits)
  1139. def solarize(img: Tensor, threshold: float) -> Tensor:
  1140. """Solarize an RGB/grayscale image by inverting all pixel values above a threshold.
  1141. Args:
  1142. img (PIL Image or Tensor): Image to have its colors inverted.
  1143. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  1144. where ... means it can have an arbitrary number of leading dimensions.
  1145. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  1146. threshold (float): All pixels equal or above this value are inverted.
  1147. Returns:
  1148. PIL Image or Tensor: Solarized image.
  1149. """
  1150. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1151. _log_api_usage_once(solarize)
  1152. if not isinstance(img, torch.Tensor):
  1153. return F_pil.solarize(img, threshold)
  1154. return F_t.solarize(img, threshold)
  1155. def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
  1156. """Adjust the sharpness of an image.
  1157. Args:
  1158. img (PIL Image or Tensor): Image to be adjusted.
  1159. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  1160. where ... means it can have an arbitrary number of leading dimensions.
  1161. sharpness_factor (float): How much to adjust the sharpness. Can be
  1162. any non-negative number. 0 gives a blurred image, 1 gives the
  1163. original image while 2 increases the sharpness by a factor of 2.
  1164. Returns:
  1165. PIL Image or Tensor: Sharpness adjusted image.
  1166. """
  1167. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1168. _log_api_usage_once(adjust_sharpness)
  1169. if not isinstance(img, torch.Tensor):
  1170. return F_pil.adjust_sharpness(img, sharpness_factor)
  1171. return F_t.adjust_sharpness(img, sharpness_factor)
  1172. def autocontrast(img: Tensor) -> Tensor:
  1173. """Maximize contrast of an image by remapping its
  1174. pixels per channel so that the lowest becomes black and the lightest
  1175. becomes white.
  1176. Args:
  1177. img (PIL Image or Tensor): Image on which autocontrast is applied.
  1178. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  1179. where ... means it can have an arbitrary number of leading dimensions.
  1180. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  1181. Returns:
  1182. PIL Image or Tensor: An image that was autocontrasted.
  1183. """
  1184. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1185. _log_api_usage_once(autocontrast)
  1186. if not isinstance(img, torch.Tensor):
  1187. return F_pil.autocontrast(img)
  1188. return F_t.autocontrast(img)
  1189. def equalize(img: Tensor) -> Tensor:
  1190. """Equalize the histogram of an image by applying
  1191. a non-linear mapping to the input in order to create a uniform
  1192. distribution of grayscale values in the output.
  1193. Args:
  1194. img (PIL Image or Tensor): Image on which equalize is applied.
  1195. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  1196. where ... means it can have an arbitrary number of leading dimensions.
  1197. The tensor dtype must be ``torch.uint8`` and values are expected to be in ``[0, 255]``.
  1198. If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
  1199. Returns:
  1200. PIL Image or Tensor: An image that was equalized.
  1201. """
  1202. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1203. _log_api_usage_once(equalize)
  1204. if not isinstance(img, torch.Tensor):
  1205. return F_pil.equalize(img)
  1206. return F_t.equalize(img)
  1207. def elastic_transform(
  1208. img: Tensor,
  1209. displacement: Tensor,
  1210. interpolation: InterpolationMode = InterpolationMode.BILINEAR,
  1211. fill: Optional[list[float]] = None,
  1212. ) -> Tensor:
  1213. """Transform a tensor image with elastic transformations.
  1214. Given alpha and sigma, it will generate displacement
  1215. vectors for all pixels based on random offsets. Alpha controls the strength
  1216. and sigma controls the smoothness of the displacements.
  1217. The displacements are added to an identity grid and the resulting grid is
  1218. used to grid_sample from the image.
  1219. Applications:
  1220. Randomly transforms the morphology of objects in images and produces a
  1221. see-through-water-like effect.
  1222. Args:
  1223. img (PIL Image or Tensor): Image on which elastic_transform is applied.
  1224. If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
  1225. where ... means it can have an arbitrary number of leading dimensions.
  1226. If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
  1227. displacement (Tensor): The displacement field. Expected shape is [1, H, W, 2].
  1228. interpolation (InterpolationMode): Desired interpolation enum defined by
  1229. :class:`torchvision.transforms.InterpolationMode`.
  1230. Default is ``InterpolationMode.BILINEAR``.
  1231. The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
  1232. fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
  1233. If a tuple of length 3, it is used to fill R, G, B channels respectively.
  1234. This value is only used when the padding_mode is constant.
  1235. """
  1236. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  1237. _log_api_usage_once(elastic_transform)
  1238. # Backward compatibility with integer value
  1239. if isinstance(interpolation, int):
  1240. warnings.warn(
  1241. "Argument interpolation should be of type InterpolationMode instead of int. "
  1242. "Please, use InterpolationMode enum."
  1243. )
  1244. interpolation = _interpolation_modes_from_int(interpolation)
  1245. if not isinstance(displacement, torch.Tensor):
  1246. raise TypeError("Argument displacement should be a Tensor")
  1247. t_img = img
  1248. if not isinstance(img, torch.Tensor):
  1249. if not F_pil._is_pil_image(img):
  1250. raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
  1251. t_img = pil_to_tensor(img)
  1252. shape = t_img.shape
  1253. shape = (1,) + shape[-2:] + (2,)
  1254. if shape != displacement.shape:
  1255. raise ValueError(f"Argument displacement shape should be {shape}, but given {displacement.shape}")
  1256. # TODO: if image shape is [N1, N2, ..., C, H, W] and
  1257. # displacement is [1, H, W, 2] we need to reshape input image
  1258. # such grid_sampler takes internal code for 4D input
  1259. output = F_t.elastic_transform(
  1260. t_img,
  1261. displacement,
  1262. interpolation=interpolation.value,
  1263. fill=fill,
  1264. )
  1265. if not isinstance(img, torch.Tensor):
  1266. output = to_pil_image(output, mode=img.mode)
  1267. return output