conv.py 77 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909
  1. # mypy: allow-untyped-defs
  2. import math
  3. from typing import Literal, Optional
  4. from typing_extensions import deprecated
  5. import torch
  6. from torch import Tensor
  7. from torch._torch_docs import reproducibility_notes
  8. from torch.nn import functional as F, init
  9. from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
  10. from torch.nn.parameter import Parameter, UninitializedParameter
  11. from .lazy import LazyModuleMixin
  12. from .module import Module
  13. from .utils import _pair, _reverse_repeat_tuple, _single, _triple
  14. __all__ = [
  15. "Conv1d",
  16. "Conv2d",
  17. "Conv3d",
  18. "ConvTranspose1d",
  19. "ConvTranspose2d",
  20. "ConvTranspose3d",
  21. "LazyConv1d",
  22. "LazyConv2d",
  23. "LazyConv3d",
  24. "LazyConvTranspose1d",
  25. "LazyConvTranspose2d",
  26. "LazyConvTranspose3d",
  27. ]
  28. convolution_notes = {
  29. "groups_note": r"""* :attr:`groups` controls the connections between inputs and outputs.
  30. :attr:`in_channels` and :attr:`out_channels` must both be divisible by
  31. :attr:`groups`. For example,
  32. * At groups=1, all inputs are convolved to all outputs.
  33. * At groups=2, the operation becomes equivalent to having two conv
  34. layers side by side, each seeing half the input channels
  35. and producing half the output channels, and both subsequently
  36. concatenated.
  37. * At groups= :attr:`in_channels`, each input channel is convolved with
  38. its own set of filters (of size
  39. :math:`\frac{\text{out\_channels}}{\text{in\_channels}}`).""",
  40. "depthwise_separable_note": r"""When `groups == in_channels` and `out_channels == K * in_channels`,
  41. where `K` is a positive integer, this operation is also known as a "depthwise convolution".
  42. In other words, for an input of size :math:`(N, C_{in}, L_{in})`,
  43. a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments
  44. :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`.""",
  45. } # noqa: B950
  46. class _ConvNd(Module):
  47. __constants__ = [
  48. "stride",
  49. "padding",
  50. "dilation",
  51. "groups",
  52. "padding_mode",
  53. "output_padding",
  54. "in_channels",
  55. "out_channels",
  56. "kernel_size",
  57. ]
  58. __annotations__ = {"bias": Optional[torch.Tensor]}
  59. def _conv_forward( # type: ignore[empty-body]
  60. self, input: Tensor, weight: Tensor, bias: Tensor | None
  61. ) -> Tensor: ...
  62. in_channels: int
  63. _reversed_padding_repeated_twice: list[int]
  64. out_channels: int
  65. kernel_size: tuple[int, ...]
  66. stride: tuple[int, ...]
  67. padding: str | tuple[int, ...]
  68. dilation: tuple[int, ...]
  69. transposed: bool
  70. output_padding: tuple[int, ...]
  71. groups: int
  72. padding_mode: Literal["zeros", "reflect", "replicate", "circular"]
  73. weight: Tensor
  74. bias: Tensor | None
  75. def __init__(
  76. self,
  77. in_channels: int,
  78. out_channels: int,
  79. kernel_size: tuple[int, ...],
  80. stride: tuple[int, ...],
  81. padding: str | tuple[int, ...],
  82. dilation: tuple[int, ...],
  83. transposed: bool,
  84. output_padding: tuple[int, ...],
  85. groups: int,
  86. bias: bool,
  87. padding_mode: Literal["zeros", "reflect", "replicate", "circular"],
  88. device=None,
  89. dtype=None,
  90. ) -> None:
  91. factory_kwargs = {"device": device, "dtype": dtype}
  92. super().__init__()
  93. if groups <= 0:
  94. raise ValueError("groups must be a positive integer")
  95. if in_channels % groups != 0:
  96. raise ValueError("in_channels must be divisible by groups")
  97. if out_channels % groups != 0:
  98. raise ValueError("out_channels must be divisible by groups")
  99. valid_padding_strings = {"same", "valid"}
  100. if isinstance(padding, str):
  101. if padding not in valid_padding_strings:
  102. raise ValueError(
  103. f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}"
  104. )
  105. if padding == "same" and any(s != 1 for s in stride):
  106. raise ValueError(
  107. "padding='same' is not supported for strided convolutions"
  108. )
  109. valid_padding_modes = {"zeros", "reflect", "replicate", "circular"}
  110. if padding_mode not in valid_padding_modes:
  111. raise ValueError(
  112. f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'"
  113. )
  114. self.in_channels = in_channels
  115. self.out_channels = out_channels
  116. self.kernel_size = kernel_size
  117. self.stride = stride
  118. self.padding = padding
  119. self.dilation = dilation
  120. self.transposed = transposed
  121. self.output_padding = output_padding
  122. self.groups = groups
  123. self.padding_mode = padding_mode
  124. # `_reversed_padding_repeated_twice` is the padding to be passed to
  125. # `F.pad` if needed (e.g., for non-zero padding types that are
  126. # implemented as two ops: padding + conv). `F.pad` accepts paddings in
  127. # reverse order than the dimension.
  128. if isinstance(self.padding, str):
  129. self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size)
  130. if padding == "same":
  131. for d, k, i in zip(
  132. dilation,
  133. kernel_size,
  134. range(len(kernel_size) - 1, -1, -1),
  135. strict=False,
  136. ):
  137. total_padding = d * (k - 1)
  138. left_pad = total_padding // 2
  139. self._reversed_padding_repeated_twice[2 * i] = left_pad
  140. self._reversed_padding_repeated_twice[2 * i + 1] = (
  141. total_padding - left_pad
  142. )
  143. else:
  144. self._reversed_padding_repeated_twice = _reverse_repeat_tuple(
  145. self.padding, 2
  146. )
  147. if transposed:
  148. self.weight = Parameter(
  149. torch.empty(
  150. (in_channels, out_channels // groups, *kernel_size),
  151. **factory_kwargs,
  152. )
  153. )
  154. else:
  155. self.weight = Parameter(
  156. torch.empty(
  157. (out_channels, in_channels // groups, *kernel_size),
  158. **factory_kwargs,
  159. )
  160. )
  161. if bias:
  162. self.bias = Parameter(torch.empty(out_channels, **factory_kwargs))
  163. else:
  164. self.register_parameter("bias", None)
  165. self.reset_parameters()
  166. def reset_parameters(self) -> None:
  167. # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
  168. # uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size)
  169. # For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573
  170. init.kaiming_uniform_(self.weight, a=math.sqrt(5))
  171. if self.bias is not None:
  172. fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
  173. if fan_in != 0:
  174. bound = 1 / math.sqrt(fan_in)
  175. init.uniform_(self.bias, -bound, bound)
  176. def extra_repr(self):
  177. s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}"
  178. if self.padding != (0,) * len(self.padding):
  179. s += ", padding={padding}"
  180. if self.dilation != (1,) * len(self.dilation):
  181. s += ", dilation={dilation}"
  182. if self.output_padding != (0,) * len(self.output_padding):
  183. s += ", output_padding={output_padding}"
  184. if self.groups != 1:
  185. s += ", groups={groups}"
  186. if self.bias is None:
  187. s += ", bias=False"
  188. if self.padding_mode != "zeros":
  189. s += ", padding_mode={padding_mode}"
  190. return s.format(**self.__dict__)
  191. def __setstate__(self, state):
  192. super().__setstate__(state)
  193. if not hasattr(self, "padding_mode"):
  194. self.padding_mode = "zeros"
  195. class Conv1d(_ConvNd):
  196. __doc__ = (
  197. r"""Applies a 1D convolution over an input signal composed of several input
  198. planes.
  199. In the simplest case, the output value of the layer with input size
  200. :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be
  201. precisely described as:
  202. .. math::
  203. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  204. \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k)
  205. \star \text{input}(N_i, k)
  206. where :math:`\star` is the valid `cross-correlation`_ operator,
  207. :math:`N` is a batch size, :math:`C` denotes a number of channels,
  208. :math:`L` is a length of signal sequence.
  209. """
  210. + r"""
  211. This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
  212. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  213. * :attr:`stride` controls the stride for the cross-correlation, a single
  214. number or a one-element tuple.
  215. * :attr:`padding` controls the amount of padding applied to the input. It
  216. can be either a string {{'valid', 'same'}} or a tuple of ints giving the
  217. amount of implicit padding applied on both sides.
  218. """
  219. """
  220. * :attr:`dilation` controls the spacing between the kernel points; also
  221. known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_
  222. has a nice visualization of what :attr:`dilation` does.
  223. """
  224. r"""
  225. {groups_note}
  226. Note:
  227. {depthwise_separable_note}
  228. Note:
  229. {cudnn_reproducibility_note}
  230. Note:
  231. ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
  232. the input so the output has the shape as the input. However, this mode
  233. doesn't support any stride values other than 1.
  234. Note:
  235. This module supports complex data types i.e. ``complex32, complex64, complex128``.
  236. Args:
  237. in_channels (int): Number of channels in the input image
  238. out_channels (int): Number of channels produced by the convolution
  239. kernel_size (int or tuple): Size of the convolving kernel
  240. stride (int or tuple, optional): Stride of the convolution. Default: 1
  241. padding (int, tuple or str, optional): Padding added to both sides of
  242. the input. Default: 0
  243. dilation (int or tuple, optional): Spacing between kernel
  244. elements. Default: 1
  245. groups (int, optional): Number of blocked connections from input
  246. channels to output channels. Default: 1
  247. bias (bool, optional): If ``True``, adds a learnable bias to the
  248. output. Default: ``True``
  249. padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
  250. ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
  251. """.format(**reproducibility_notes, **convolution_notes)
  252. + r"""
  253. Shape:
  254. - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`
  255. - Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where
  256. .. math::
  257. L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation}
  258. \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
  259. Attributes:
  260. weight (Tensor): the learnable weights of the module of shape
  261. :math:`(\text{out\_channels},
  262. \frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`.
  263. The values of these weights are sampled from
  264. :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  265. :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}`
  266. bias (Tensor): the learnable bias of the module of shape
  267. (out_channels). If :attr:`bias` is ``True``, then the values of these weights are
  268. sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  269. :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}`
  270. Examples::
  271. >>> m = nn.Conv1d(16, 33, 3, stride=2)
  272. >>> input = torch.randn(20, 16, 50)
  273. >>> output = m(input)
  274. .. _cross-correlation:
  275. https://en.wikipedia.org/wiki/Cross-correlation
  276. .. _link:
  277. https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
  278. """
  279. )
  280. def __init__(
  281. self,
  282. in_channels: int,
  283. out_channels: int,
  284. kernel_size: _size_1_t,
  285. stride: _size_1_t = 1,
  286. padding: str | _size_1_t = 0,
  287. dilation: _size_1_t = 1,
  288. groups: int = 1,
  289. bias: bool = True,
  290. padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
  291. device=None,
  292. dtype=None,
  293. ) -> None:
  294. factory_kwargs = {"device": device, "dtype": dtype}
  295. # we create new variables below to make mypy happy since kernel_size has
  296. # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int]
  297. kernel_size_ = _single(kernel_size)
  298. stride_ = _single(stride)
  299. padding_ = padding if isinstance(padding, str) else _single(padding)
  300. dilation_ = _single(dilation)
  301. super().__init__(
  302. in_channels,
  303. out_channels,
  304. kernel_size_,
  305. stride_,
  306. padding_,
  307. dilation_,
  308. False,
  309. _single(0),
  310. groups,
  311. bias,
  312. padding_mode,
  313. **factory_kwargs,
  314. )
  315. def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None):
  316. if self.padding_mode != "zeros":
  317. return F.conv1d(
  318. F.pad(
  319. input, self._reversed_padding_repeated_twice, mode=self.padding_mode
  320. ),
  321. weight,
  322. bias,
  323. self.stride,
  324. _single(0),
  325. self.dilation,
  326. self.groups,
  327. )
  328. return F.conv1d(
  329. input, weight, bias, self.stride, self.padding, self.dilation, self.groups
  330. )
  331. def forward(self, input: Tensor) -> Tensor:
  332. return self._conv_forward(input, self.weight, self.bias)
  333. class Conv2d(_ConvNd):
  334. __doc__ = (
  335. r"""Applies a 2D convolution over an input signal composed of several input
  336. planes.
  337. In the simplest case, the output value of the layer with input size
  338. :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
  339. can be precisely described as:
  340. .. math::
  341. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  342. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  343. where :math:`\star` is the valid 2D `cross-correlation`_ operator,
  344. :math:`N` is a batch size, :math:`C` denotes a number of channels,
  345. :math:`H` is a height of input planes in pixels, and :math:`W` is
  346. width in pixels.
  347. """
  348. + r"""
  349. This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
  350. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  351. * :attr:`stride` controls the stride for the cross-correlation, a single
  352. number or a tuple.
  353. * :attr:`padding` controls the amount of padding applied to the input. It
  354. can be either a string {{'valid', 'same'}} or an int / a tuple of ints giving the
  355. amount of implicit padding applied on both sides.
  356. """
  357. """
  358. * :attr:`dilation` controls the spacing between the kernel points; also
  359. known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_
  360. has a nice visualization of what :attr:`dilation` does.
  361. """
  362. r"""
  363. {groups_note}
  364. The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
  365. - a single ``int`` -- in which case the same value is used for the height and width dimension
  366. - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
  367. and the second `int` for the width dimension
  368. Note:
  369. {depthwise_separable_note}
  370. Note:
  371. {cudnn_reproducibility_note}
  372. Note:
  373. ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
  374. the input so the output has the shape as the input. However, this mode
  375. doesn't support any stride values other than 1.
  376. Note:
  377. This module supports complex data types i.e. ``complex32, complex64, complex128``.
  378. Args:
  379. in_channels (int): Number of channels in the input image
  380. out_channels (int): Number of channels produced by the convolution
  381. kernel_size (int or tuple): Size of the convolving kernel
  382. stride (int or tuple, optional): Stride of the convolution. Default: 1
  383. padding (int, tuple or str, optional): Padding added to all four sides of
  384. the input. Default: 0
  385. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  386. groups (int, optional): Number of blocked connections from input
  387. channels to output channels. Default: 1
  388. bias (bool, optional): If ``True``, adds a learnable bias to the
  389. output. Default: ``True``
  390. padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
  391. ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
  392. """.format(**reproducibility_notes, **convolution_notes)
  393. + r"""
  394. Shape:
  395. - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})`
  396. - Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where
  397. .. math::
  398. H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
  399. \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
  400. .. math::
  401. W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
  402. \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
  403. Attributes:
  404. weight (Tensor): the learnable weights of the module of shape
  405. :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
  406. :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
  407. The values of these weights are sampled from
  408. :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  409. :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
  410. bias (Tensor): the learnable bias of the module of shape
  411. (out_channels). If :attr:`bias` is ``True``,
  412. then the values of these weights are
  413. sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  414. :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
  415. Examples:
  416. >>> # With square kernels and equal stride
  417. >>> m = nn.Conv2d(16, 33, 3, stride=2)
  418. >>> # non-square kernels and unequal stride and with padding
  419. >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
  420. >>> # non-square kernels and unequal stride and with padding and dilation
  421. >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
  422. >>> input = torch.randn(20, 16, 50, 100)
  423. >>> output = m(input)
  424. .. _cross-correlation:
  425. https://en.wikipedia.org/wiki/Cross-correlation
  426. .. _link:
  427. https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
  428. """
  429. )
  430. def __init__(
  431. self,
  432. in_channels: int,
  433. out_channels: int,
  434. kernel_size: _size_2_t,
  435. stride: _size_2_t = 1,
  436. padding: str | _size_2_t = 0,
  437. dilation: _size_2_t = 1,
  438. groups: int = 1,
  439. bias: bool = True,
  440. padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
  441. device=None,
  442. dtype=None,
  443. ) -> None:
  444. factory_kwargs = {"device": device, "dtype": dtype}
  445. kernel_size_ = _pair(kernel_size)
  446. stride_ = _pair(stride)
  447. padding_ = padding if isinstance(padding, str) else _pair(padding)
  448. dilation_ = _pair(dilation)
  449. super().__init__(
  450. in_channels,
  451. out_channels,
  452. kernel_size_,
  453. stride_,
  454. padding_,
  455. dilation_,
  456. False,
  457. _pair(0),
  458. groups,
  459. bias,
  460. padding_mode,
  461. **factory_kwargs,
  462. )
  463. def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None):
  464. if self.padding_mode != "zeros":
  465. return F.conv2d(
  466. F.pad(
  467. input, self._reversed_padding_repeated_twice, mode=self.padding_mode
  468. ),
  469. weight,
  470. bias,
  471. self.stride,
  472. _pair(0),
  473. self.dilation,
  474. self.groups,
  475. )
  476. return F.conv2d(
  477. input, weight, bias, self.stride, self.padding, self.dilation, self.groups
  478. )
  479. def forward(self, input: Tensor) -> Tensor:
  480. return self._conv_forward(input, self.weight, self.bias)
  481. class Conv3d(_ConvNd):
  482. __doc__ = (
  483. r"""Applies a 3D convolution over an input signal composed of several input
  484. planes.
  485. In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)`
  486. and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as:
  487. .. math::
  488. out(N_i, C_{out_j}) = bias(C_{out_j}) +
  489. \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k)
  490. where :math:`\star` is the valid 3D `cross-correlation`_ operator
  491. """
  492. + r"""
  493. This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
  494. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  495. * :attr:`stride` controls the stride for the cross-correlation.
  496. * :attr:`padding` controls the amount of padding applied to the input. It
  497. can be either a string {{'valid', 'same'}} or a tuple of ints giving the
  498. amount of implicit padding applied on both sides.
  499. """
  500. """
  501. * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
  502. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
  503. """
  504. r"""
  505. {groups_note}
  506. The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
  507. - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
  508. - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
  509. the second `int` for the height dimension and the third `int` for the width dimension
  510. Note:
  511. {depthwise_separable_note}
  512. Note:
  513. {cudnn_reproducibility_note}
  514. Note:
  515. ``padding='valid'`` is the same as no padding. ``padding='same'`` pads
  516. the input so the output has the shape as the input. However, this mode
  517. doesn't support any stride values other than 1.
  518. Note:
  519. This module supports complex data types i.e. ``complex32, complex64, complex128``.
  520. Args:
  521. in_channels (int): Number of channels in the input image
  522. out_channels (int): Number of channels produced by the convolution
  523. kernel_size (int or tuple): Size of the convolving kernel
  524. stride (int or tuple, optional): Stride of the convolution. Default: 1
  525. padding (int, tuple or str, optional): Padding added to all six sides of
  526. the input. Default: 0
  527. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  528. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  529. bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
  530. padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
  531. """.format(**reproducibility_notes, **convolution_notes)
  532. + r"""
  533. Shape:
  534. - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})`
  535. - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or :math:`(C_{out}, D_{out}, H_{out}, W_{out})`,
  536. where
  537. .. math::
  538. D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
  539. \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
  540. .. math::
  541. H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
  542. \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
  543. .. math::
  544. W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2]
  545. \times (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
  546. Attributes:
  547. weight (Tensor): the learnable weights of the module of shape
  548. :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
  549. :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
  550. The values of these weights are sampled from
  551. :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  552. :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
  553. bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
  554. then the values of these weights are
  555. sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  556. :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
  557. Examples::
  558. >>> # With square kernels and equal stride
  559. >>> m = nn.Conv3d(16, 33, 3, stride=2)
  560. >>> # non-square kernels and unequal stride and with padding
  561. >>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
  562. >>> input = torch.randn(20, 16, 10, 50, 100)
  563. >>> output = m(input)
  564. .. _cross-correlation:
  565. https://en.wikipedia.org/wiki/Cross-correlation
  566. .. _link:
  567. https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
  568. """
  569. )
  570. def __init__(
  571. self,
  572. in_channels: int,
  573. out_channels: int,
  574. kernel_size: _size_3_t,
  575. stride: _size_3_t = 1,
  576. padding: str | _size_3_t = 0,
  577. dilation: _size_3_t = 1,
  578. groups: int = 1,
  579. bias: bool = True,
  580. padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
  581. device=None,
  582. dtype=None,
  583. ) -> None:
  584. factory_kwargs = {"device": device, "dtype": dtype}
  585. kernel_size_ = _triple(kernel_size)
  586. stride_ = _triple(stride)
  587. padding_ = padding if isinstance(padding, str) else _triple(padding)
  588. dilation_ = _triple(dilation)
  589. super().__init__(
  590. in_channels,
  591. out_channels,
  592. kernel_size_,
  593. stride_,
  594. padding_,
  595. dilation_,
  596. False,
  597. _triple(0),
  598. groups,
  599. bias,
  600. padding_mode,
  601. **factory_kwargs,
  602. )
  603. def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None):
  604. if self.padding_mode != "zeros":
  605. return F.conv3d(
  606. F.pad(
  607. input, self._reversed_padding_repeated_twice, mode=self.padding_mode
  608. ),
  609. weight,
  610. bias,
  611. self.stride,
  612. _triple(0),
  613. self.dilation,
  614. self.groups,
  615. )
  616. return F.conv3d(
  617. input, weight, bias, self.stride, self.padding, self.dilation, self.groups
  618. )
  619. def forward(self, input: Tensor) -> Tensor:
  620. return self._conv_forward(input, self.weight, self.bias)
  621. class _ConvTransposeNd(_ConvNd):
  622. def __init__(
  623. self,
  624. in_channels,
  625. out_channels,
  626. kernel_size,
  627. stride,
  628. padding,
  629. dilation,
  630. transposed,
  631. output_padding,
  632. groups,
  633. bias,
  634. padding_mode,
  635. device=None,
  636. dtype=None,
  637. ) -> None:
  638. if padding_mode != "zeros":
  639. raise ValueError(
  640. f'Only "zeros" padding mode is supported for {self.__class__.__name__}'
  641. )
  642. factory_kwargs = {"device": device, "dtype": dtype}
  643. super().__init__(
  644. in_channels,
  645. out_channels,
  646. kernel_size,
  647. stride,
  648. padding,
  649. dilation,
  650. transposed,
  651. output_padding,
  652. groups,
  653. bias,
  654. padding_mode,
  655. **factory_kwargs,
  656. )
  657. # dilation being an optional parameter is for backwards
  658. # compatibility
  659. def _output_padding(
  660. self,
  661. input: Tensor,
  662. output_size: list[int] | None,
  663. stride: list[int],
  664. padding: list[int],
  665. kernel_size: list[int],
  666. num_spatial_dims: int,
  667. dilation: list[int] | None = None,
  668. ) -> list[int]:
  669. if output_size is None:
  670. ret = _single(self.output_padding) # converting to list if was not already
  671. else:
  672. has_batch_dim = input.dim() == num_spatial_dims + 2
  673. num_non_spatial_dims = 2 if has_batch_dim else 1
  674. if len(output_size) == num_non_spatial_dims + num_spatial_dims:
  675. output_size = output_size[num_non_spatial_dims:]
  676. if len(output_size) != num_spatial_dims:
  677. raise ValueError(
  678. f"ConvTranspose{num_spatial_dims}D: for {input.dim()}D input, output_size must have {num_spatial_dims} "
  679. f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})"
  680. )
  681. min_sizes = torch.jit.annotate(list[int], [])
  682. max_sizes = torch.jit.annotate(list[int], [])
  683. for d in range(num_spatial_dims):
  684. dim_size = (
  685. (input.size(d + num_non_spatial_dims) - 1) * stride[d]
  686. - 2 * padding[d]
  687. + (dilation[d] if dilation is not None else 1)
  688. * (kernel_size[d] - 1)
  689. + 1
  690. )
  691. min_sizes.append(dim_size)
  692. max_sizes.append(min_sizes[d] + stride[d] - 1)
  693. for i in range(len(output_size)):
  694. size = output_size[i]
  695. min_size = min_sizes[i]
  696. max_size = max_sizes[i]
  697. if size < min_size or size > max_size:
  698. raise ValueError(
  699. f"requested an output size of {output_size}, but valid sizes range "
  700. f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})"
  701. )
  702. res = torch.jit.annotate(list[int], [])
  703. for d in range(num_spatial_dims):
  704. res.append(output_size[d] - min_sizes[d])
  705. ret = res
  706. return ret
  707. class ConvTranspose1d(_ConvTransposeNd):
  708. __doc__ = (
  709. r"""Applies a 1D transposed convolution operator over an input image
  710. composed of several input planes.
  711. This module can be seen as the gradient of Conv1d with respect to its input.
  712. It is also known as a fractionally-strided convolution or
  713. a deconvolution (although it is not an actual deconvolution operation as it does
  714. not compute a true inverse of convolution). For more information, see the visualizations
  715. `here`_ and the `Deconvolutional Networks`_ paper.
  716. This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
  717. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  718. * :attr:`stride` controls the stride for the cross-correlation.
  719. * :attr:`padding` controls the amount of implicit zero padding on both
  720. sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
  721. below for details.
  722. * :attr:`output_padding` controls the additional size added to one side
  723. of the output shape. See note below for details.
  724. """
  725. """
  726. * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
  727. It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
  728. """
  729. r"""
  730. {groups_note}
  731. Note:
  732. The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
  733. amount of zero padding to both sizes of the input. This is set so that
  734. when a :class:`~torch.nn.Conv1d` and a :class:`~torch.nn.ConvTranspose1d`
  735. are initialized with same parameters, they are inverses of each other in
  736. regard to the input and output shapes. However, when ``stride > 1``,
  737. :class:`~torch.nn.Conv1d` maps multiple input shapes to the same output
  738. shape. :attr:`output_padding` is provided to resolve this ambiguity by
  739. effectively increasing the calculated output shape on one side. Note
  740. that :attr:`output_padding` is only used to find output shape, but does
  741. not actually add zero-padding to output.
  742. Note:
  743. In some circumstances when using the CUDA backend with CuDNN, this operator
  744. may select a nondeterministic algorithm to increase performance. If this is
  745. undesirable, you can try to make the operation deterministic (potentially at
  746. a performance cost) by setting ``torch.backends.cudnn.deterministic =
  747. True``.
  748. Please see the notes on :doc:`/notes/randomness` for background.
  749. Args:
  750. in_channels (int): Number of channels in the input image
  751. out_channels (int): Number of channels produced by the convolution
  752. kernel_size (int or tuple): Size of the convolving kernel
  753. stride (int or tuple, optional): Stride of the convolution. Default: 1
  754. padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
  755. will be added to both sides of the input. Default: 0
  756. output_padding (int or tuple, optional): Additional size added to one side
  757. of the output shape. Default: 0
  758. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  759. bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
  760. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  761. """.format(**reproducibility_notes, **convolution_notes)
  762. + r"""
  763. Shape:
  764. - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})`
  765. - Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where
  766. .. math::
  767. L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{dilation}
  768. \times (\text{kernel\_size} - 1) + \text{output\_padding} + 1
  769. Attributes:
  770. weight (Tensor): the learnable weights of the module of shape
  771. :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
  772. :math:`\text{kernel\_size})`.
  773. The values of these weights are sampled from
  774. :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  775. :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}`
  776. bias (Tensor): the learnable bias of the module of shape (out_channels).
  777. If :attr:`bias` is ``True``, then the values of these weights are
  778. sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  779. :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}`
  780. Examples::
  781. >>> # With square kernels and equal stride
  782. >>> m = nn.ConvTranspose1d(16, 33, 3, stride=2)
  783. >>> input = torch.randn(20, 16, 50)
  784. >>> output = m(input)
  785. >>> # exact output size can be also specified as an argument
  786. >>> input = torch.randn(1, 16, 12)
  787. >>> downsample = nn.Conv1d(16, 16, 3, stride=2, padding=1)
  788. >>> upsample = nn.ConvTranspose1d(16, 16, 3, stride=2, padding=1)
  789. >>> h = downsample(input)
  790. >>> h.size()
  791. torch.Size([1, 16, 6])
  792. >>> output = upsample(h, output_size=input.size())
  793. >>> output.size()
  794. torch.Size([1, 16, 12])
  795. .. _`here`:
  796. https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
  797. .. _`Deconvolutional Networks`:
  798. https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf
  799. """
  800. )
  801. def __init__(
  802. self,
  803. in_channels: int,
  804. out_channels: int,
  805. kernel_size: _size_1_t,
  806. stride: _size_1_t = 1,
  807. padding: _size_1_t = 0,
  808. output_padding: _size_1_t = 0,
  809. groups: int = 1,
  810. bias: bool = True,
  811. dilation: _size_1_t = 1,
  812. padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
  813. device=None,
  814. dtype=None,
  815. ) -> None:
  816. factory_kwargs = {"device": device, "dtype": dtype}
  817. kernel_size = _single(kernel_size)
  818. stride = _single(stride)
  819. padding = _single(padding)
  820. dilation = _single(dilation)
  821. output_padding = _single(output_padding)
  822. super().__init__(
  823. in_channels,
  824. out_channels,
  825. kernel_size,
  826. stride,
  827. padding,
  828. dilation,
  829. True,
  830. output_padding,
  831. groups,
  832. bias,
  833. padding_mode,
  834. **factory_kwargs,
  835. )
  836. def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor:
  837. if self.padding_mode != "zeros":
  838. raise ValueError(
  839. "Only `zeros` padding mode is supported for ConvTranspose1d"
  840. )
  841. if not isinstance(self.padding, tuple):
  842. raise AssertionError("self.padding must be a tuple")
  843. # One cannot replace List by Tuple or Sequence in "_output_padding" because
  844. # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
  845. num_spatial_dims = 1
  846. output_padding = self._output_padding(
  847. input,
  848. output_size,
  849. self.stride, # type: ignore[arg-type]
  850. self.padding, # type: ignore[arg-type]
  851. self.kernel_size, # type: ignore[arg-type]
  852. num_spatial_dims,
  853. self.dilation, # type: ignore[arg-type]
  854. )
  855. return F.conv_transpose1d(
  856. input,
  857. self.weight,
  858. self.bias,
  859. self.stride,
  860. self.padding,
  861. output_padding,
  862. self.groups,
  863. self.dilation,
  864. )
  865. class ConvTranspose2d(_ConvTransposeNd):
  866. __doc__ = (
  867. r"""Applies a 2D transposed convolution operator over an input image
  868. composed of several input planes.
  869. This module can be seen as the gradient of Conv2d with respect to its input.
  870. It is also known as a fractionally-strided convolution or
  871. a deconvolution (although it is not an actual deconvolution operation as it does
  872. not compute a true inverse of convolution). For more information, see the visualizations
  873. `here`_ and the `Deconvolutional Networks`_ paper.
  874. This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
  875. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  876. * :attr:`stride` controls the stride for the cross-correlation. When stride > 1, ConvTranspose2d inserts zeros between input
  877. elements along the spatial dimensions before applying the convolution kernel. This zero-insertion operation is the standard
  878. behavior of transposed convolutions, which can increase the spatial resolution and is equivalent to a learnable
  879. upsampling operation.
  880. * :attr:`padding` controls the amount of implicit zero padding on both
  881. sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
  882. below for details.
  883. * :attr:`output_padding` controls the additional size added to one side
  884. of the output shape. See note below for details.
  885. """
  886. """
  887. * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
  888. It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
  889. """
  890. r"""
  891. {groups_note}
  892. The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`
  893. can either be:
  894. - a single ``int`` -- in which case the same value is used for the height and width dimensions
  895. - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
  896. and the second `int` for the width dimension
  897. Note:
  898. The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
  899. amount of zero padding to both sizes of the input. This is set so that
  900. when a :class:`~torch.nn.Conv2d` and a :class:`~torch.nn.ConvTranspose2d`
  901. are initialized with same parameters, they are inverses of each other in
  902. regard to the input and output shapes. However, when ``stride > 1``,
  903. :class:`~torch.nn.Conv2d` maps multiple input shapes to the same output
  904. shape. :attr:`output_padding` is provided to resolve this ambiguity by
  905. effectively increasing the calculated output shape on one side. Note
  906. that :attr:`output_padding` is only used to find output shape, but does
  907. not actually add zero-padding to output.
  908. Note:
  909. {cudnn_reproducibility_note}
  910. Args:
  911. in_channels (int): Number of channels in the input image
  912. out_channels (int): Number of channels produced by the convolution
  913. kernel_size (int or tuple): Size of the convolving kernel
  914. stride (int or tuple, optional): Stride of the convolution. Default: 1
  915. padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
  916. will be added to both sides of each dimension in the input. Default: 0
  917. output_padding (int or tuple, optional): Additional size added to one side
  918. of each dimension in the output shape. Default: 0
  919. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  920. bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
  921. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  922. """.format(**reproducibility_notes, **convolution_notes)
  923. + r"""
  924. Shape:
  925. - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})`
  926. - Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where
  927. .. math::
  928. H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
  929. \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
  930. .. math::
  931. W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1]
  932. \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1
  933. Attributes:
  934. weight (Tensor): the learnable weights of the module of shape
  935. :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
  936. :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
  937. The values of these weights are sampled from
  938. :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  939. :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
  940. bias (Tensor): the learnable bias of the module of shape (out_channels)
  941. If :attr:`bias` is ``True``, then the values of these weights are
  942. sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  943. :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
  944. Examples::
  945. >>> # With square kernels and equal stride
  946. >>> m = nn.ConvTranspose2d(16, 33, 3, stride=2)
  947. >>> # non-square kernels and unequal stride and with padding
  948. >>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
  949. >>> input = torch.randn(20, 16, 50, 100)
  950. >>> output = m(input)
  951. >>> # exact output size can be also specified as an argument
  952. >>> input = torch.randn(1, 16, 12, 12)
  953. >>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
  954. >>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
  955. >>> h = downsample(input)
  956. >>> h.size()
  957. torch.Size([1, 16, 6, 6])
  958. >>> output = upsample(h, output_size=input.size())
  959. >>> output.size()
  960. torch.Size([1, 16, 12, 12])
  961. .. _`here`:
  962. https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
  963. .. _`Deconvolutional Networks`:
  964. https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf
  965. """
  966. )
  967. def __init__(
  968. self,
  969. in_channels: int,
  970. out_channels: int,
  971. kernel_size: _size_2_t,
  972. stride: _size_2_t = 1,
  973. padding: _size_2_t = 0,
  974. output_padding: _size_2_t = 0,
  975. groups: int = 1,
  976. bias: bool = True,
  977. dilation: _size_2_t = 1,
  978. padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
  979. device=None,
  980. dtype=None,
  981. ) -> None:
  982. factory_kwargs = {"device": device, "dtype": dtype}
  983. kernel_size = _pair(kernel_size)
  984. stride = _pair(stride)
  985. padding = _pair(padding)
  986. dilation = _pair(dilation)
  987. output_padding = _pair(output_padding)
  988. super().__init__(
  989. in_channels,
  990. out_channels,
  991. kernel_size,
  992. stride,
  993. padding,
  994. dilation,
  995. True,
  996. output_padding,
  997. groups,
  998. bias,
  999. padding_mode,
  1000. **factory_kwargs,
  1001. )
  1002. def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor:
  1003. """
  1004. Performs the forward pass.
  1005. Attributes:
  1006. input (Tensor): The input tensor.
  1007. output_size (list[int], optional): A list of integers representing
  1008. the size of the output tensor. Default is None.
  1009. """
  1010. if self.padding_mode != "zeros":
  1011. raise ValueError(
  1012. "Only `zeros` padding mode is supported for ConvTranspose2d"
  1013. )
  1014. if not isinstance(self.padding, tuple):
  1015. raise AssertionError("self.padding must be a tuple")
  1016. # One cannot replace List by Tuple or Sequence in "_output_padding" because
  1017. # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
  1018. num_spatial_dims = 2
  1019. output_padding = self._output_padding(
  1020. input,
  1021. output_size,
  1022. self.stride, # type: ignore[arg-type]
  1023. self.padding, # type: ignore[arg-type]
  1024. self.kernel_size, # type: ignore[arg-type]
  1025. num_spatial_dims,
  1026. self.dilation, # type: ignore[arg-type]
  1027. )
  1028. return F.conv_transpose2d(
  1029. input,
  1030. self.weight,
  1031. self.bias,
  1032. self.stride,
  1033. self.padding,
  1034. output_padding,
  1035. self.groups,
  1036. self.dilation,
  1037. )
  1038. class ConvTranspose3d(_ConvTransposeNd):
  1039. __doc__ = (
  1040. r"""Applies a 3D transposed convolution operator over an input image composed of several input
  1041. planes.
  1042. The transposed convolution operator multiplies each input value element-wise by a learnable kernel,
  1043. and sums over the outputs from all input feature planes.
  1044. This module can be seen as the gradient of Conv3d with respect to its input.
  1045. It is also known as a fractionally-strided convolution or
  1046. a deconvolution (although it is not an actual deconvolution operation as it does
  1047. not compute a true inverse of convolution). For more information, see the visualizations
  1048. `here`_ and the `Deconvolutional Networks`_ paper.
  1049. This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
  1050. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  1051. * :attr:`stride` controls the stride for the cross-correlation.
  1052. * :attr:`padding` controls the amount of implicit zero padding on both
  1053. sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
  1054. below for details.
  1055. * :attr:`output_padding` controls the additional size added to one side
  1056. of the output shape. See note below for details.
  1057. """
  1058. """
  1059. * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
  1060. It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does.
  1061. """
  1062. r"""
  1063. {groups_note}
  1064. The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`
  1065. can either be:
  1066. - a single ``int`` -- in which case the same value is used for the depth, height and width dimensions
  1067. - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
  1068. the second `int` for the height dimension and the third `int` for the width dimension
  1069. Note:
  1070. The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
  1071. amount of zero padding to both sizes of the input. This is set so that
  1072. when a :class:`~torch.nn.Conv3d` and a :class:`~torch.nn.ConvTranspose3d`
  1073. are initialized with same parameters, they are inverses of each other in
  1074. regard to the input and output shapes. However, when ``stride > 1``,
  1075. :class:`~torch.nn.Conv3d` maps multiple input shapes to the same output
  1076. shape. :attr:`output_padding` is provided to resolve this ambiguity by
  1077. effectively increasing the calculated output shape on one side. Note
  1078. that :attr:`output_padding` is only used to find output shape, but does
  1079. not actually add zero-padding to output.
  1080. Note:
  1081. {cudnn_reproducibility_note}
  1082. Args:
  1083. in_channels (int): Number of channels in the input image
  1084. out_channels (int): Number of channels produced by the convolution
  1085. kernel_size (int or tuple): Size of the convolving kernel
  1086. stride (int or tuple, optional): Stride of the convolution. Default: 1
  1087. padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
  1088. will be added to both sides of each dimension in the input. Default: 0
  1089. output_padding (int or tuple, optional): Additional size added to one side
  1090. of each dimension in the output shape. Default: 0
  1091. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  1092. bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
  1093. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  1094. """.format(**reproducibility_notes, **convolution_notes)
  1095. + r"""
  1096. Shape:
  1097. - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})`
  1098. - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or
  1099. :math:`(C_{out}, D_{out}, H_{out}, W_{out})`, where
  1100. .. math::
  1101. D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
  1102. \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
  1103. .. math::
  1104. H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1]
  1105. \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1
  1106. .. math::
  1107. W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{dilation}[2]
  1108. \times (\text{kernel\_size}[2] - 1) + \text{output\_padding}[2] + 1
  1109. Attributes:
  1110. weight (Tensor): the learnable weights of the module of shape
  1111. :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
  1112. :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
  1113. The values of these weights are sampled from
  1114. :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  1115. :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
  1116. bias (Tensor): the learnable bias of the module of shape (out_channels)
  1117. If :attr:`bias` is ``True``, then the values of these weights are
  1118. sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  1119. :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
  1120. Examples::
  1121. >>> # With square kernels and equal stride
  1122. >>> m = nn.ConvTranspose3d(16, 33, 3, stride=2)
  1123. >>> # non-square kernels and unequal stride and with padding
  1124. >>> m = nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2))
  1125. >>> input = torch.randn(20, 16, 10, 50, 100)
  1126. >>> output = m(input)
  1127. .. _`here`:
  1128. https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
  1129. .. _`Deconvolutional Networks`:
  1130. https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf
  1131. """
  1132. )
  1133. def __init__(
  1134. self,
  1135. in_channels: int,
  1136. out_channels: int,
  1137. kernel_size: _size_3_t,
  1138. stride: _size_3_t = 1,
  1139. padding: _size_3_t = 0,
  1140. output_padding: _size_3_t = 0,
  1141. groups: int = 1,
  1142. bias: bool = True,
  1143. dilation: _size_3_t = 1,
  1144. padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
  1145. device=None,
  1146. dtype=None,
  1147. ) -> None:
  1148. factory_kwargs = {"device": device, "dtype": dtype}
  1149. kernel_size = _triple(kernel_size)
  1150. stride = _triple(stride)
  1151. padding = _triple(padding)
  1152. dilation = _triple(dilation)
  1153. output_padding = _triple(output_padding)
  1154. super().__init__(
  1155. in_channels,
  1156. out_channels,
  1157. kernel_size,
  1158. stride,
  1159. padding,
  1160. dilation,
  1161. True,
  1162. output_padding,
  1163. groups,
  1164. bias,
  1165. padding_mode,
  1166. **factory_kwargs,
  1167. )
  1168. def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor:
  1169. if self.padding_mode != "zeros":
  1170. raise ValueError(
  1171. "Only `zeros` padding mode is supported for ConvTranspose3d"
  1172. )
  1173. if not isinstance(self.padding, tuple):
  1174. raise AssertionError("self.padding must be a tuple")
  1175. # One cannot replace List by Tuple or Sequence in "_output_padding" because
  1176. # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
  1177. num_spatial_dims = 3
  1178. output_padding = self._output_padding(
  1179. input,
  1180. output_size,
  1181. self.stride, # type: ignore[arg-type]
  1182. self.padding, # type: ignore[arg-type]
  1183. self.kernel_size, # type: ignore[arg-type]
  1184. num_spatial_dims,
  1185. self.dilation, # type: ignore[arg-type]
  1186. )
  1187. return F.conv_transpose3d(
  1188. input,
  1189. self.weight,
  1190. self.bias,
  1191. self.stride,
  1192. self.padding,
  1193. output_padding,
  1194. self.groups,
  1195. self.dilation,
  1196. )
  1197. # TODO: Deprecate and remove the following alias `_ConvTransposeMixin`.
  1198. #
  1199. # `_ConvTransposeMixin` was a mixin that was removed. It is meant to be used
  1200. # with `_ConvNd` to construct actual module classes that implements conv
  1201. # transpose ops:
  1202. #
  1203. # class MyConvTranspose(_ConvNd, _ConvTransposeMixin):
  1204. # ...
  1205. #
  1206. # In PyTorch, it has been replaced by `_ConvTransposeNd`, which is a proper
  1207. # subclass of `_ConvNd`. However, some user code in the wild still (incorrectly)
  1208. # use the internal class `_ConvTransposeMixin`. Hence, we provide this alias
  1209. # for BC, because it is cheap and easy for us to do so, even though that
  1210. # `_ConvTransposeNd` is really not a mixin anymore (but multiple inheritance as
  1211. # above would still work).
  1212. class _ConvTransposeMixin(_ConvTransposeNd):
  1213. @deprecated(
  1214. "`_ConvTransposeMixin` is a deprecated internal class. "
  1215. "Please consider using public APIs.",
  1216. category=FutureWarning,
  1217. )
  1218. def __init__(self, *args, **kwargs) -> None:
  1219. super().__init__(*args, **kwargs)
  1220. # TODO: Conv2dLocal
  1221. # TODO: Conv2dMap
  1222. # TODO: ConvTranspose2dMap
  1223. class _LazyConvXdMixin(LazyModuleMixin):
  1224. groups: int
  1225. transposed: bool
  1226. in_channels: int
  1227. out_channels: int
  1228. kernel_size: tuple[int, ...]
  1229. weight: UninitializedParameter
  1230. bias: UninitializedParameter
  1231. def reset_parameters(self) -> None:
  1232. # has_uninitialized_params is defined in parent class and it is using a protocol on self
  1233. if not self.has_uninitialized_params() and self.in_channels != 0: # type: ignore[misc]
  1234. # "type:ignore[..]" is required because mypy thinks that "reset_parameters" is undefined
  1235. # in super class. Turns out that it is defined in _ConvND which is inherited by any class
  1236. # that also inherits _LazyConvXdMixin
  1237. super().reset_parameters() # type: ignore[misc]
  1238. # Signature of "initialize_parameters" is incompatible with the definition in supertype LazyModuleMixin
  1239. def initialize_parameters(self, input: Tensor, *args, **kwargs) -> None: # type: ignore[override]
  1240. # defined by parent class but using a protocol
  1241. if self.has_uninitialized_params(): # type: ignore[misc]
  1242. self.in_channels = self._get_in_channels(input)
  1243. if self.in_channels % self.groups != 0:
  1244. raise ValueError("in_channels must be divisible by groups")
  1245. if not isinstance(self.weight, UninitializedParameter):
  1246. raise AssertionError("self.weight must be an UninitializedParameter")
  1247. if self.transposed:
  1248. self.weight.materialize(
  1249. (
  1250. self.in_channels,
  1251. self.out_channels // self.groups,
  1252. *self.kernel_size,
  1253. )
  1254. )
  1255. else:
  1256. self.weight.materialize(
  1257. (
  1258. self.out_channels,
  1259. self.in_channels // self.groups,
  1260. *self.kernel_size,
  1261. )
  1262. )
  1263. if self.bias is not None:
  1264. if not isinstance(self.bias, UninitializedParameter):
  1265. raise AssertionError("self.bias must be an UninitializedParameter")
  1266. self.bias.materialize((self.out_channels,))
  1267. self.reset_parameters()
  1268. # Function to extract in_channels from first input.
  1269. def _get_in_channels(self, input: Tensor) -> int:
  1270. num_spatial_dims = self._get_num_spatial_dims()
  1271. num_dims_no_batch = num_spatial_dims + 1 # +1 for channels dim
  1272. num_dims_batch = num_dims_no_batch + 1
  1273. if input.dim() not in (num_dims_no_batch, num_dims_batch):
  1274. raise RuntimeError(
  1275. f"Expected {num_dims_no_batch}D (unbatched) or {num_dims_batch}D (batched) input "
  1276. f"to {self.__class__.__name__}, but "
  1277. f"got input of size: {input.shape}"
  1278. )
  1279. return input.shape[1] if input.dim() == num_dims_batch else input.shape[0]
  1280. # Function to return the number of spatial dims expected for inputs to the module.
  1281. # This is expected to be implemented by subclasses.
  1282. def _get_num_spatial_dims(self) -> int:
  1283. raise NotImplementedError
  1284. # LazyConv1d defines weight as a Tensor but derived class defines it as UninitializeParameter
  1285. class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
  1286. r"""A :class:`torch.nn.Conv1d` module with lazy initialization of the ``in_channels`` argument.
  1287. The ``in_channels`` argument of the :class:`Conv1d` is inferred from the ``input.size(1)``.
  1288. The attributes that will be lazily initialized are `weight` and `bias`.
  1289. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  1290. on lazy modules and their limitations.
  1291. Args:
  1292. out_channels (int): Number of channels produced by the convolution
  1293. kernel_size (int or tuple): Size of the convolving kernel
  1294. stride (int or tuple, optional): Stride of the convolution. Default: 1
  1295. padding (int or tuple, optional): Zero-padding added to both sides of
  1296. the input. Default: 0
  1297. dilation (int or tuple, optional): Spacing between kernel
  1298. elements. Default: 1
  1299. groups (int, optional): Number of blocked connections from input
  1300. channels to output channels. Default: 1
  1301. bias (bool, optional): If ``True``, adds a learnable bias to the
  1302. output. Default: ``True``
  1303. padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
  1304. ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
  1305. .. seealso:: :class:`torch.nn.Conv1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
  1306. """
  1307. # super class define this variable as None. "type: ignore[..] is required
  1308. # since we are redefining the variable.
  1309. cls_to_become = Conv1d # type: ignore[assignment]
  1310. def __init__(
  1311. self,
  1312. out_channels: int,
  1313. kernel_size: _size_1_t,
  1314. stride: _size_1_t = 1,
  1315. padding: _size_1_t = 0,
  1316. dilation: _size_1_t = 1,
  1317. groups: int = 1,
  1318. bias: bool = True,
  1319. padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
  1320. device=None,
  1321. dtype=None,
  1322. ) -> None:
  1323. factory_kwargs = {"device": device, "dtype": dtype}
  1324. # pyrefly: ignore [bad-argument-type]
  1325. super().__init__(
  1326. 0,
  1327. 0,
  1328. kernel_size,
  1329. stride,
  1330. padding,
  1331. dilation,
  1332. groups,
  1333. # bias is hardcoded to False to avoid creating tensor
  1334. # that will soon be overwritten.
  1335. False,
  1336. padding_mode,
  1337. **factory_kwargs,
  1338. )
  1339. # pyrefly: ignore [bad-argument-type, bad-override, unexpected-keyword]
  1340. self.weight = UninitializedParameter(**factory_kwargs)
  1341. self.out_channels = out_channels
  1342. if bias:
  1343. # pyrefly: ignore [bad-argument-type, bad-override, unexpected-keyword]
  1344. self.bias = UninitializedParameter(**factory_kwargs)
  1345. def _get_num_spatial_dims(self) -> int:
  1346. return 1
  1347. # LazyConv2d defines weight as a Tensor but derived class defines it as UninitializeParameter
  1348. class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
  1349. r"""A :class:`torch.nn.Conv2d` module with lazy initialization of the ``in_channels`` argument.
  1350. The ``in_channels`` argument of the :class:`Conv2d` that is inferred from the ``input.size(1)``.
  1351. The attributes that will be lazily initialized are `weight` and `bias`.
  1352. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  1353. on lazy modules and their limitations.
  1354. Args:
  1355. out_channels (int): Number of channels produced by the convolution
  1356. kernel_size (int or tuple): Size of the convolving kernel
  1357. stride (int or tuple, optional): Stride of the convolution. Default: 1
  1358. padding (int or tuple, optional): Zero-padding added to both sides of
  1359. the input. Default: 0
  1360. dilation (int or tuple, optional): Spacing between kernel
  1361. elements. Default: 1
  1362. groups (int, optional): Number of blocked connections from input
  1363. channels to output channels. Default: 1
  1364. bias (bool, optional): If ``True``, adds a learnable bias to the
  1365. output. Default: ``True``
  1366. padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
  1367. ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
  1368. .. seealso:: :class:`torch.nn.Conv2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
  1369. """
  1370. # super class define this variable as None. "type: ignore[..] is required
  1371. # since we are redefining the variable.
  1372. cls_to_become = Conv2d # type: ignore[assignment]
  1373. def __init__(
  1374. self,
  1375. out_channels: int,
  1376. kernel_size: _size_2_t,
  1377. stride: _size_2_t = 1,
  1378. padding: _size_2_t = 0,
  1379. dilation: _size_2_t = 1,
  1380. groups: int = 1,
  1381. bias: bool = True,
  1382. padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
  1383. device=None,
  1384. dtype=None,
  1385. ) -> None:
  1386. factory_kwargs = {"device": device, "dtype": dtype}
  1387. # pyrefly: ignore [bad-argument-type]
  1388. super().__init__(
  1389. 0,
  1390. 0,
  1391. kernel_size,
  1392. stride,
  1393. padding,
  1394. dilation,
  1395. groups,
  1396. # bias is hardcoded to False to avoid creating tensor
  1397. # that will soon be overwritten.
  1398. False,
  1399. padding_mode,
  1400. **factory_kwargs,
  1401. )
  1402. # pyrefly: ignore [bad-argument-type, bad-override, unexpected-keyword]
  1403. self.weight = UninitializedParameter(**factory_kwargs)
  1404. self.out_channels = out_channels
  1405. if bias:
  1406. # pyrefly: ignore [bad-argument-type, bad-override, unexpected-keyword]
  1407. self.bias = UninitializedParameter(**factory_kwargs)
  1408. def _get_num_spatial_dims(self) -> int:
  1409. return 2
  1410. # LazyConv3d defines weight as a Tensor but derived class defines it as UninitializeParameter
  1411. class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
  1412. r"""A :class:`torch.nn.Conv3d` module with lazy initialization of the ``in_channels`` argument.
  1413. The ``in_channels`` argument of the :class:`Conv3d` that is inferred from
  1414. the ``input.size(1)``.
  1415. The attributes that will be lazily initialized are `weight` and `bias`.
  1416. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  1417. on lazy modules and their limitations.
  1418. Args:
  1419. out_channels (int): Number of channels produced by the convolution
  1420. kernel_size (int or tuple): Size of the convolving kernel
  1421. stride (int or tuple, optional): Stride of the convolution. Default: 1
  1422. padding (int or tuple, optional): Zero-padding added to both sides of
  1423. the input. Default: 0
  1424. dilation (int or tuple, optional): Spacing between kernel
  1425. elements. Default: 1
  1426. groups (int, optional): Number of blocked connections from input
  1427. channels to output channels. Default: 1
  1428. bias (bool, optional): If ``True``, adds a learnable bias to the
  1429. output. Default: ``True``
  1430. padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
  1431. ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
  1432. .. seealso:: :class:`torch.nn.Conv3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
  1433. """
  1434. # super class define this variable as None. "type: ignore[..] is required
  1435. # since we are redefining the variable.
  1436. cls_to_become = Conv3d # type: ignore[assignment]
  1437. def __init__(
  1438. self,
  1439. out_channels: int,
  1440. kernel_size: _size_3_t,
  1441. stride: _size_3_t = 1,
  1442. padding: _size_3_t = 0,
  1443. dilation: _size_3_t = 1,
  1444. groups: int = 1,
  1445. bias: bool = True,
  1446. padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
  1447. device=None,
  1448. dtype=None,
  1449. ) -> None:
  1450. factory_kwargs = {"device": device, "dtype": dtype}
  1451. # pyrefly: ignore [bad-argument-type]
  1452. super().__init__(
  1453. 0,
  1454. 0,
  1455. kernel_size,
  1456. stride,
  1457. padding,
  1458. dilation,
  1459. groups,
  1460. # bias is hardcoded to False to avoid creating tensor
  1461. # that will soon be overwritten.
  1462. False,
  1463. padding_mode,
  1464. **factory_kwargs,
  1465. )
  1466. # pyrefly: ignore [bad-argument-type, bad-override, unexpected-keyword]
  1467. self.weight = UninitializedParameter(**factory_kwargs)
  1468. self.out_channels = out_channels
  1469. if bias:
  1470. # pyrefly: ignore [bad-argument-type, bad-override, unexpected-keyword]
  1471. self.bias = UninitializedParameter(**factory_kwargs)
  1472. def _get_num_spatial_dims(self) -> int:
  1473. return 3
  1474. # LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UninitializeParameter
  1475. class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[misc]
  1476. r"""A :class:`torch.nn.ConvTranspose1d` module with lazy initialization of the ``in_channels`` argument.
  1477. The ``in_channels`` argument of the :class:`ConvTranspose1d` that is inferred from
  1478. the ``input.size(1)``.
  1479. The attributes that will be lazily initialized are `weight` and `bias`.
  1480. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  1481. on lazy modules and their limitations.
  1482. Args:
  1483. out_channels (int): Number of channels produced by the convolution
  1484. kernel_size (int or tuple): Size of the convolving kernel
  1485. stride (int or tuple, optional): Stride of the convolution. Default: 1
  1486. padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
  1487. will be added to both sides of the input. Default: 0
  1488. output_padding (int or tuple, optional): Additional size added to one side
  1489. of the output shape. Default: 0
  1490. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  1491. bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
  1492. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  1493. .. seealso:: :class:`torch.nn.ConvTranspose1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
  1494. """
  1495. # super class define this variable as None. "type: ignore[..] is required
  1496. # since we are redefining the variable.
  1497. cls_to_become = ConvTranspose1d # type: ignore[assignment]
  1498. def __init__(
  1499. self,
  1500. out_channels: int,
  1501. kernel_size: _size_1_t,
  1502. stride: _size_1_t = 1,
  1503. padding: _size_1_t = 0,
  1504. output_padding: _size_1_t = 0,
  1505. groups: int = 1,
  1506. bias: bool = True,
  1507. dilation: _size_1_t = 1,
  1508. padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
  1509. device=None,
  1510. dtype=None,
  1511. ) -> None:
  1512. factory_kwargs = {"device": device, "dtype": dtype}
  1513. # pyrefly: ignore [bad-argument-type]
  1514. super().__init__(
  1515. 0,
  1516. 0,
  1517. kernel_size,
  1518. stride,
  1519. padding,
  1520. output_padding,
  1521. groups,
  1522. # bias is hardcoded to False to avoid creating tensor
  1523. # that will soon be overwritten.
  1524. False,
  1525. dilation,
  1526. padding_mode,
  1527. **factory_kwargs,
  1528. )
  1529. # pyrefly: ignore [bad-argument-type, bad-override, unexpected-keyword]
  1530. self.weight = UninitializedParameter(**factory_kwargs)
  1531. self.out_channels = out_channels
  1532. if bias:
  1533. # pyrefly: ignore [bad-argument-type, bad-override, unexpected-keyword]
  1534. self.bias = UninitializedParameter(**factory_kwargs)
  1535. def _get_num_spatial_dims(self) -> int:
  1536. return 1
  1537. # LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UninitializeParameter
  1538. class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[misc]
  1539. r"""A :class:`torch.nn.ConvTranspose2d` module with lazy initialization of the ``in_channels`` argument.
  1540. The ``in_channels`` argument of the :class:`ConvTranspose2d` is inferred from
  1541. the ``input.size(1)``.
  1542. The attributes that will be lazily initialized are `weight` and `bias`.
  1543. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  1544. on lazy modules and their limitations.
  1545. Args:
  1546. out_channels (int): Number of channels produced by the convolution
  1547. kernel_size (int or tuple): Size of the convolving kernel
  1548. stride (int or tuple, optional): Stride of the convolution. Default: 1
  1549. padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
  1550. will be added to both sides of each dimension in the input. Default: 0
  1551. output_padding (int or tuple, optional): Additional size added to one side
  1552. of each dimension in the output shape. Default: 0
  1553. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  1554. bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
  1555. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  1556. .. seealso:: :class:`torch.nn.ConvTranspose2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
  1557. """
  1558. # super class define this variable as None. "type: ignore[..] is required
  1559. # since we are redefining the variable.
  1560. cls_to_become = ConvTranspose2d # type: ignore[assignment]
  1561. def __init__(
  1562. self,
  1563. out_channels: int,
  1564. kernel_size: _size_2_t,
  1565. stride: _size_2_t = 1,
  1566. padding: _size_2_t = 0,
  1567. output_padding: _size_2_t = 0,
  1568. groups: int = 1,
  1569. bias: bool = True,
  1570. dilation: int = 1,
  1571. padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
  1572. device=None,
  1573. dtype=None,
  1574. ) -> None:
  1575. factory_kwargs = {"device": device, "dtype": dtype}
  1576. # pyrefly: ignore [bad-argument-type]
  1577. super().__init__(
  1578. 0,
  1579. 0,
  1580. kernel_size,
  1581. stride,
  1582. padding,
  1583. output_padding,
  1584. groups,
  1585. # bias is hardcoded to False to avoid creating tensor
  1586. # that will soon be overwritten.
  1587. False,
  1588. dilation,
  1589. padding_mode,
  1590. **factory_kwargs,
  1591. )
  1592. # pyrefly: ignore [bad-argument-type, bad-override, unexpected-keyword]
  1593. self.weight = UninitializedParameter(**factory_kwargs)
  1594. self.out_channels = out_channels
  1595. if bias:
  1596. # pyrefly: ignore [bad-argument-type, bad-override, unexpected-keyword]
  1597. self.bias = UninitializedParameter(**factory_kwargs)
  1598. def _get_num_spatial_dims(self) -> int:
  1599. return 2
  1600. # LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UninitializeParameter
  1601. class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[misc]
  1602. r"""A :class:`torch.nn.ConvTranspose3d` module with lazy initialization of the ``in_channels`` argument.
  1603. The ``in_channels`` argument of the :class:`ConvTranspose3d` is inferred from
  1604. the ``input.size(1)``.
  1605. The attributes that will be lazily initialized are `weight` and `bias`.
  1606. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  1607. on lazy modules and their limitations.
  1608. Args:
  1609. out_channels (int): Number of channels produced by the convolution
  1610. kernel_size (int or tuple): Size of the convolving kernel
  1611. stride (int or tuple, optional): Stride of the convolution. Default: 1
  1612. padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
  1613. will be added to both sides of each dimension in the input. Default: 0
  1614. output_padding (int or tuple, optional): Additional size added to one side
  1615. of each dimension in the output shape. Default: 0
  1616. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  1617. bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
  1618. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  1619. .. seealso:: :class:`torch.nn.ConvTranspose3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin`
  1620. """
  1621. # super class define this variable as None. "type: ignore[..] is required
  1622. # since we are redefining the variable.
  1623. cls_to_become = ConvTranspose3d # type: ignore[assignment]
  1624. def __init__(
  1625. self,
  1626. out_channels: int,
  1627. kernel_size: _size_3_t,
  1628. stride: _size_3_t = 1,
  1629. padding: _size_3_t = 0,
  1630. output_padding: _size_3_t = 0,
  1631. groups: int = 1,
  1632. bias: bool = True,
  1633. dilation: _size_3_t = 1,
  1634. padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
  1635. device=None,
  1636. dtype=None,
  1637. ) -> None:
  1638. factory_kwargs = {"device": device, "dtype": dtype}
  1639. # pyrefly: ignore [bad-argument-type]
  1640. super().__init__(
  1641. 0,
  1642. 0,
  1643. kernel_size,
  1644. stride,
  1645. padding,
  1646. output_padding,
  1647. groups,
  1648. # bias is hardcoded to False to avoid creating tensor
  1649. # that will soon be overwritten.
  1650. False,
  1651. dilation,
  1652. padding_mode,
  1653. **factory_kwargs,
  1654. )
  1655. # pyrefly: ignore [bad-argument-type, bad-override, unexpected-keyword]
  1656. self.weight = UninitializedParameter(**factory_kwargs)
  1657. self.out_channels = out_channels
  1658. if bias:
  1659. # pyrefly: ignore [bad-argument-type, bad-override, unexpected-keyword]
  1660. self.bias = UninitializedParameter(**factory_kwargs)
  1661. def _get_num_spatial_dims(self) -> int:
  1662. return 3