grad.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # mypy: allow-untyped-defs
  2. """Gradient interface."""
  3. import torch
  4. from torch.nn.modules.utils import _pair, _single, _triple
  5. def conv1d_input(
  6. input_size,
  7. weight,
  8. grad_output,
  9. stride=1,
  10. padding=0,
  11. dilation=1,
  12. groups=1,
  13. ):
  14. r"""Compute the gradient of conv1d with respect to the input of the convolution.
  15. This is same as the 1D transposed convolution operator under the hood but requires
  16. the shape of the gradient w.r.t. input to be specified explicitly.
  17. Args:
  18. input_size : Shape of the input gradient tensor
  19. weight: weight tensor (out_channels x in_channels/groups x kW)
  20. grad_output : output gradient tensor (minibatch x out_channels x oW)
  21. stride (int or tuple, optional): Stride of the convolution. Default: 1
  22. padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
  23. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  24. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  25. Examples::
  26. >>> input = torch.randn(1, 1, 3, requires_grad=True)
  27. >>> weight = torch.randn(1, 1, 1, requires_grad=True)
  28. >>> output = F.conv1d(input, weight)
  29. >>> grad_output = torch.randn(output.shape)
  30. >>> grad_input = torch.autograd.grad(output, input, grad_output)
  31. >>> F.grad.conv1d_input(input.shape, weight, grad_output)
  32. """
  33. input = grad_output.new_empty(1).expand(input_size)
  34. return torch.ops.aten.convolution_backward(
  35. grad_output,
  36. input,
  37. weight,
  38. None,
  39. _single(stride),
  40. _single(padding),
  41. _single(dilation),
  42. False,
  43. [0],
  44. groups,
  45. (True, False, False),
  46. )[0]
  47. def conv1d_weight(
  48. input,
  49. weight_size,
  50. grad_output,
  51. stride=1,
  52. padding=0,
  53. dilation=1,
  54. groups=1,
  55. ):
  56. r"""Compute the gradient of conv1d with respect to the weight of the convolution.
  57. Args:
  58. input: input tensor of shape (minibatch x in_channels x iW)
  59. weight_size : Shape of the weight gradient tensor
  60. grad_output : output gradient tensor (minibatch x out_channels x oW)
  61. stride (int or tuple, optional): Stride of the convolution. Default: 1
  62. padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
  63. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  64. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  65. Examples::
  66. >>> input = torch.randn(1, 1, 3, requires_grad=True)
  67. >>> weight = torch.randn(1, 1, 1, requires_grad=True)
  68. >>> output = F.conv1d(input, weight)
  69. >>> grad_output = torch.randn(output.shape)
  70. >>> # xdoctest: +SKIP
  71. >>> grad_weight = torch.autograd.grad(output, filter, grad_output)
  72. >>> F.grad.conv1d_weight(input, weight.shape, grad_output)
  73. """
  74. weight = grad_output.new_empty(1).expand(weight_size)
  75. return torch.ops.aten.convolution_backward(
  76. grad_output,
  77. input,
  78. weight,
  79. None,
  80. _single(stride),
  81. _single(padding),
  82. _single(dilation),
  83. False,
  84. [0],
  85. groups,
  86. (False, True, False),
  87. )[1]
  88. def conv2d_input(
  89. input_size,
  90. weight,
  91. grad_output,
  92. stride=1,
  93. padding=0,
  94. dilation=1,
  95. groups=1,
  96. ):
  97. r"""Compute the gradient of conv2d with respect to the input of the convolution.
  98. This is same as the 2D transposed convolution operator under the hood but requires
  99. the shape of the gradient w.r.t. input to be specified explicitly.
  100. Args:
  101. input_size : Shape of the input gradient tensor
  102. weight: weight tensor (out_channels x in_channels/groups x kH x kW)
  103. grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
  104. stride (int or tuple, optional): Stride of the convolution. Default: 1
  105. padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
  106. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  107. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  108. Examples::
  109. >>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
  110. >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
  111. >>> output = F.conv2d(input, weight)
  112. >>> grad_output = torch.randn(output.shape)
  113. >>> grad_input = torch.autograd.grad(output, input, grad_output)
  114. >>> F.grad.conv2d_input(input.shape, weight, grad_output)
  115. """
  116. input = grad_output.new_empty(1).expand(input_size)
  117. return torch.ops.aten.convolution_backward(
  118. grad_output,
  119. input,
  120. weight,
  121. None,
  122. _pair(stride),
  123. _pair(padding),
  124. _pair(dilation),
  125. False,
  126. [0],
  127. groups,
  128. (True, False, False),
  129. )[0]
  130. def conv2d_weight(
  131. input,
  132. weight_size,
  133. grad_output,
  134. stride=1,
  135. padding=0,
  136. dilation=1,
  137. groups=1,
  138. ):
  139. r"""Compute the gradient of conv2d with respect to the weight of the convolution.
  140. Args:
  141. input: input tensor of shape (minibatch x in_channels x iH x iW)
  142. weight_size : Shape of the weight gradient tensor
  143. grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
  144. stride (int or tuple, optional): Stride of the convolution. Default: 1
  145. padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
  146. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  147. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  148. Examples::
  149. >>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
  150. >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
  151. >>> output = F.conv2d(input, weight)
  152. >>> grad_output = torch.randn(output.shape)
  153. >>> # xdoctest: +SKIP
  154. >>> grad_weight = torch.autograd.grad(output, filter, grad_output)
  155. >>> F.grad.conv2d_weight(input, weight.shape, grad_output)
  156. """
  157. weight = grad_output.new_empty(1).expand(weight_size)
  158. return torch.ops.aten.convolution_backward(
  159. grad_output,
  160. input,
  161. weight,
  162. None,
  163. _pair(stride),
  164. _pair(padding),
  165. _pair(dilation),
  166. False,
  167. [0],
  168. groups,
  169. (False, True, False),
  170. )[1]
  171. def conv3d_input(
  172. input_size,
  173. weight,
  174. grad_output,
  175. stride=1,
  176. padding=0,
  177. dilation=1,
  178. groups=1,
  179. ):
  180. r"""Compute the gradient of conv3d with respect to the input of the convolution.
  181. This is same as the 3D transposed convolution operator under the hood but requires
  182. the shape of the gradient w.r.t. input to be specified explicitly.
  183. Args:
  184. input_size : Shape of the input gradient tensor
  185. weight: weights tensor (out_channels x in_channels/groups x kT x kH x kW)
  186. grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
  187. stride (int or tuple, optional): Stride of the convolution. Default: 1
  188. padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
  189. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  190. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  191. Examples::
  192. >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
  193. >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
  194. >>> output = F.conv3d(input, weight)
  195. >>> grad_output = torch.randn(output.shape)
  196. >>> grad_input = torch.autograd.grad(output, input, grad_output)
  197. >>> F.grad.conv3d_input(input.shape, weight, grad_output)
  198. """
  199. input = grad_output.new_empty(1).expand(input_size)
  200. return torch.ops.aten.convolution_backward(
  201. grad_output,
  202. input,
  203. weight,
  204. None,
  205. _triple(stride),
  206. _triple(padding),
  207. _triple(dilation),
  208. False,
  209. [0],
  210. groups,
  211. (True, False, False),
  212. )[0]
  213. def conv3d_weight(
  214. input,
  215. weight_size,
  216. grad_output,
  217. stride=1,
  218. padding=0,
  219. dilation=1,
  220. groups=1,
  221. ):
  222. r"""Compute the gradient of conv3d with respect to the weight of the convolution.
  223. Args:
  224. input: input tensor of shape (minibatch x in_channels x iT x iH x iW)
  225. weight_size : Shape of the weight gradient tensor
  226. grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
  227. stride (int or tuple, optional): Stride of the convolution. Default: 1
  228. padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
  229. dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
  230. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  231. Examples::
  232. >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
  233. >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
  234. >>> output = F.conv3d(input, weight)
  235. >>> grad_output = torch.randn(output.shape)
  236. >>> grad_weight = torch.autograd.grad(output, weight, grad_output)
  237. >>> F.grad.conv3d_weight(input, weight.shape, grad_output)
  238. """
  239. weight = grad_output.new_empty(1).expand(weight_size)
  240. return torch.ops.aten.convolution_backward(
  241. grad_output,
  242. input,
  243. weight,
  244. None,
  245. _triple(stride),
  246. _triple(padding),
  247. _triple(dilation),
  248. False,
  249. [0],
  250. groups,
  251. (False, True, False),
  252. )[1]