DilatedConvolutionUtils.h 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <algorithm>
  4. #include <vector>
  5. #include <ATen/div_rtn.h>
  6. #include <ATen/core/Tensor.h>
  7. #include <c10/util/irange.h>
  8. #define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
  9. TORCH_CHECK( \
  10. T.dim() == DIM && T.size(DIM_SIZE) == SIZE, \
  11. "Need " #T " of dimension ", \
  12. DIM, \
  13. " and " #T ".size[", \
  14. DIM_SIZE, \
  15. "] == ", \
  16. SIZE, \
  17. " but got input to be of shape ", \
  18. T.sizes())
  19. namespace at::native::internal {
  20. namespace {
  21. inline bool all_positive(IntArrayRef& arr) {
  22. return std::all_of(
  23. arr.begin(), arr.end(), [](int64_t item) { return item > 0; });
  24. }
  25. inline bool all_nonnegative(std::vector<int64_t>& arr) {
  26. return std::all_of(
  27. arr.begin(), arr.end(), [](int64_t item) { return item >= 0; });
  28. }
  29. } // namespace
  30. // calculate the rear part of output tensor sizes
  31. template <int64_t dim>
  32. std::vector<int64_t> get_output_size(
  33. const Tensor& input,
  34. IntArrayRef kernel_size,
  35. IntArrayRef stride_size,
  36. IntArrayRef pad_size,
  37. IntArrayRef dilation_size) {
  38. std::vector<int64_t> sizes;
  39. for (const auto index : c10::irange(dim)) {
  40. sizes.push_back(
  41. div_rtn<int64_t>(
  42. input.size(index + input.dim() - dim) + 2 * pad_size[index] -
  43. (dilation_size[index] * (kernel_size[index] - 1) + 1),
  44. stride_size[index]) +
  45. 1);
  46. }
  47. return sizes;
  48. }
  49. // calculate the sizes of output tensor
  50. template <int64_t dim>
  51. std::vector<int64_t> get_output_size(
  52. const Tensor& input,
  53. const Tensor& weight,
  54. IntArrayRef kernel_size,
  55. IntArrayRef stride_size,
  56. IntArrayRef pad_size,
  57. IntArrayRef dilation_size) {
  58. auto output_size = get_output_size<dim>(
  59. input, kernel_size, stride_size, pad_size, dilation_size);
  60. output_size.insert(output_size.begin(), weight.size(0));
  61. if (input.dim() == dim + 2) {
  62. output_size.insert(output_size.begin(), input.size(0));
  63. }
  64. return output_size;
  65. }
  66. /*
  67. slow_conv_dilated_shape_check - check user-input to dilated convolution
  68. forward and backward functions.
  69. */
  70. template <int64_t dim>
  71. void slow_conv_dilated_shape_check(
  72. const Tensor& input,
  73. const Tensor& weight,
  74. const Tensor& bias,
  75. const Tensor& grad_output,
  76. IntArrayRef kernel_size,
  77. IntArrayRef stride_size,
  78. IntArrayRef pad_size,
  79. IntArrayRef dilation_size) {
  80. /*
  81. When the following tensors are defined:
  82. bias, grad_weight, grad_output
  83. then these are assumed to be contiguous without checking
  84. because of these tensors are made contiguous by calling
  85. .contiguous() method or by resizing of zero-sized tensors in
  86. forward/backward functions.
  87. When grad_weight is defined then it is assumed without
  88. checking to have the same shape as weight, see backward
  89. functions.
  90. */
  91. // Check size arguments
  92. TORCH_CHECK(
  93. kernel_size.size() == dim,
  94. "kernel sizes length should be ",
  95. dim,
  96. ", but got ",
  97. kernel_size.size());
  98. TORCH_CHECK(
  99. stride_size.size() == dim,
  100. "strides length should be ",
  101. dim,
  102. ", but got ",
  103. stride_size.size());
  104. TORCH_CHECK(
  105. dilation_size.size() == dim,
  106. "dilations length should be ",
  107. dim,
  108. ", but got ",
  109. dilation_size.size());
  110. TORCH_CHECK(
  111. pad_size.size() == dim,
  112. "pads length should be ",
  113. dim,
  114. ", but got ",
  115. pad_size.size());
  116. TORCH_CHECK(
  117. all_positive(kernel_size),
  118. "kernel size should be greater than zero, but got ",
  119. kernel_size);
  120. TORCH_CHECK(
  121. all_positive(stride_size),
  122. "stride should be greater than zero, but got ",
  123. stride_size);
  124. TORCH_CHECK(
  125. all_positive(dilation_size),
  126. "dilation should be greater than zero, but got ",
  127. dilation_size);
  128. // check input
  129. TORCH_CHECK(input.defined(), "input must be defined");
  130. bool is_batch = input.dim() == dim + 2;
  131. int64_t n = (is_batch ? 2 : 1);
  132. int64_t ndim = n + dim;
  133. if (!is_batch) {
  134. // input dim has to be dim + 1 if not batched
  135. TORCH_CHECK(
  136. input.dim() == dim + 1,
  137. "input must be 4D or 5D tensor but got ",
  138. input.dim(),
  139. "D tensor");
  140. }
  141. // check output sizes
  142. auto output_size = get_output_size<dim>(
  143. input, kernel_size, stride_size, pad_size, dilation_size);
  144. TORCH_CHECK(
  145. all_nonnegative(output_size),
  146. "calculated output size ",
  147. output_size,
  148. " is too small (all sizes must be non-negative)");
  149. // check weight
  150. TORCH_CHECK(weight.defined(), "weight must be defined");
  151. TORCH_CHECK(
  152. weight.dim() == dim + 2,
  153. "weight must be ",
  154. dim + 2,
  155. "D tensor but got ",
  156. weight.dim(),
  157. "D tensor dim=",
  158. dim);
  159. TORCH_CHECK(
  160. weight.sizes().slice(2) == kernel_size,
  161. "weight[2:] shape ",
  162. weight.sizes().slice(2),
  163. " must be equal to kernel_size ",
  164. kernel_size);
  165. TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1));
  166. // check bias when present
  167. if (bias.defined()) {
  168. TORCH_CHECK(
  169. bias.dim() == 1,
  170. "bias must be 1D tensor but got ",
  171. bias.dim(),
  172. "D tensor");
  173. TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0));
  174. }
  175. // check grad_output when present
  176. if (grad_output.defined()) {
  177. TORCH_CHECK(
  178. grad_output.dim() == ndim,
  179. "grad_output must be ",
  180. ndim,
  181. "D tensor but got ",
  182. grad_output.dim(),
  183. "D tensor");
  184. if (is_batch) {
  185. TORCH_CHECK(
  186. grad_output.size(0) == input.size(0),
  187. "grad_output.size(0)=",
  188. grad_output.size(0),
  189. " must be input.size(0)=",
  190. input.size(0));
  191. }
  192. TORCH_CHECK(
  193. grad_output.size(n - 1) == weight.size(0),
  194. "grad_output.size(",
  195. n - 1,
  196. ")=",
  197. grad_output.size(n - 1),
  198. " must be weight.size(0)=",
  199. weight.size(0));
  200. TORCH_CHECK(
  201. grad_output.sizes().slice(n) == output_size,
  202. "grad_output[",
  203. n,
  204. ":] shape",
  205. grad_output.sizes().slice(n),
  206. " must be equal to output size ",
  207. output_size);
  208. }
  209. }
  210. } // namespace at::native::internal
  211. #else
  212. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  213. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)