UpSample.h 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <cmath>
  4. #include <ATen/OpMathType.h>
  5. #include <ATen/TensorUtils.h>
  6. #include <ATen/core/Tensor.h>
  7. #include <ATen/cpu/vec/functional.h>
  8. #include <ATen/cpu/vec/vec.h>
  9. #include <ATen/native/DispatchStub.h>
  10. #include <ATen/native/cpu/utils.h>
  11. /**
  12. * Note [compute_scales_value]
  13. * Note [area_pixel_compute_scale]
  14. * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  15. * Interpolate with scale_factor can have different behaviors
  16. * depending on the value of recompute_scale_factor:
  17. *
  18. * - With recompute_scale_factor = True (current default behavior):
  19. * the scale_factor, when provided by the user, are used to calculate
  20. * the output size. The input size and the computed output_size
  21. * are then used to infer new values for the scales which are
  22. * used in the interpolation. Because floating-point math is not exact,
  23. * this may be a different value from the user-supplied scales.
  24. *
  25. * - With recompute_scale_factor = False (which will be the default
  26. * behavior starting 1.5.0):
  27. * the behavior follows opencv logic, and the scales provided by
  28. * the user are the ones used in the interpolation calculations.
  29. *
  30. * If the scales are not provided or if they are provided but
  31. * recompute_scale_factor is set to True (default behavior), the scales
  32. * are computed from the input and the output size;
  33. *
  34. *
  35. * When the scales are inferred from the input and output sizes,
  36. * we view each pixel as an area, idx + 0.5 as its center index.
  37. * Here is an example formula in 1D case.
  38. * if align_corners: center of two corner pixel areas are preserved,
  39. * (0.5, 0.5) -> (0.5, 0.5),
  40. * (input_size - 0.5, 0.5) -> (output_size - 0.5)
  41. * scale = (input_size - 0.5 - 0.5) / (output_size - 0.5 - 0.5)
  42. * src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5)
  43. * if not align_corners: the whole range is scaled accordingly
  44. * scale = input_size / output_size
  45. * src_idx + 0.5 = scale * (dst_index + 0.5)
  46. */
  47. namespace at::native {
  48. namespace upsample {
  49. TORCH_API c10::SmallVector<int64_t, 3> compute_output_size(
  50. c10::IntArrayRef input_size, // Full input tensor size.
  51. at::OptionalIntArrayRef output_size,
  52. std::optional<c10::ArrayRef<double>> scale_factors);
  53. inline std::optional<double> get_scale_value(std::optional<c10::ArrayRef<double>> scales, int idx) {
  54. if (!scales) {
  55. return std::nullopt;
  56. }
  57. return scales->at(idx);
  58. }
  59. } // namespace upsample
  60. using scale_t = std::optional<double>;
  61. using upsampling_nearest1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
  62. using _upsampling_nearest_exact1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
  63. using upsampling_nearest2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
  64. using _upsampling_nearest_exact2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
  65. using upsampling_nearest3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
  66. using _upsampling_nearest_exact3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
  67. using upsampling_linear1d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_w);
  68. using upsampling_bilinear2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
  69. using _upsampling_bilinear2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
  70. using upsampling_trilinear3d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_d, scale_t scales_h, scale_t scales_w);
  71. using upsampling_bicubic2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
  72. using _upsampling_bicubic2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
  73. DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_kernel)
  74. DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_kernel)
  75. DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_kernel)
  76. DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_kernel)
  77. DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_kernel)
  78. DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_kernel)
  79. DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_backward_kernel)
  80. DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_backward_kernel)
  81. DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_backward_kernel)
  82. DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_backward_kernel)
  83. DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_backward_kernel)
  84. DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_backward_kernel)
  85. DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_kernel)
  86. DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_kernel)
  87. DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_kernel)
  88. DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_kernel)
  89. DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_backward_kernel)
  90. DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_backward_kernel)
  91. DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_backward_kernel)
  92. DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_backward_kernel)
  93. DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel)
  94. DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_kernel)
  95. DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_backward_kernel)
  96. [[maybe_unused]] inline std::array<int64_t, 3> upsample_1d_common_check(
  97. IntArrayRef input_size,
  98. IntArrayRef output_size) {
  99. TORCH_CHECK(
  100. output_size.size() == 1,
  101. "It is expected output_size equals to 1, but got size ",
  102. output_size.size());
  103. TORCH_CHECK(
  104. input_size.size() == 3,
  105. "It is expected input_size equals to 3, but got size ",
  106. input_size.size());
  107. int64_t output_width = output_size[0];
  108. int64_t nbatch = input_size[0];
  109. int64_t channels = input_size[1];
  110. int64_t input_width = input_size[2];
  111. TORCH_CHECK(
  112. input_width > 0 && output_width > 0,
  113. "Input and output sizes should be greater than 0, but got input (W: ",
  114. input_width,
  115. ") and output (W: ",
  116. output_width,
  117. ")");
  118. return {nbatch, channels, output_width};
  119. }
  120. [[maybe_unused]] inline std::array<int64_t, 4> upsample_2d_common_check(
  121. IntArrayRef input_size,
  122. IntArrayRef output_size) {
  123. TORCH_CHECK(
  124. output_size.size() == 2,
  125. "It is expected output_size equals to 2, but got size ",
  126. output_size.size());
  127. TORCH_CHECK(
  128. input_size.size() == 4,
  129. "It is expected input_size equals to 4, but got size ",
  130. input_size.size());
  131. int64_t output_height = output_size[0];
  132. int64_t output_width = output_size[1];
  133. int64_t nbatch = input_size[0];
  134. int64_t channels = input_size[1];
  135. int64_t input_height = input_size[2];
  136. int64_t input_width = input_size[3];
  137. TORCH_CHECK(
  138. input_height > 0 && input_width > 0 && output_height > 0 &&
  139. output_width > 0,
  140. "Input and output sizes should be greater than 0,"
  141. " but got input (H: ",
  142. input_height,
  143. ", W: ",
  144. input_width,
  145. ") output (H: ",
  146. output_height,
  147. ", W: ",
  148. output_width,
  149. ")");
  150. return {nbatch, channels, output_height, output_width};
  151. }
  152. [[maybe_unused]] inline std::array<int64_t, 5> upsample_3d_common_check(
  153. IntArrayRef input_size,
  154. IntArrayRef output_size) {
  155. TORCH_CHECK(
  156. output_size.size() == 3,
  157. "It is expected output_size equals to 3, but got size ",
  158. output_size.size());
  159. TORCH_CHECK(
  160. input_size.size() == 5,
  161. "It is expected input_size equals to 5, but got size ",
  162. input_size.size());
  163. int64_t output_depth = output_size[0];
  164. int64_t output_height = output_size[1];
  165. int64_t output_width = output_size[2];
  166. int64_t nbatch = input_size[0];
  167. int64_t channels = input_size[1];
  168. int64_t input_depth = input_size[2];
  169. int64_t input_height = input_size[3];
  170. int64_t input_width = input_size[4];
  171. TORCH_CHECK(
  172. input_depth > 0 && input_height > 0 && input_width > 0 &&
  173. output_depth > 0 && output_height > 0 && output_width > 0,
  174. "Input and output sizes should be greater than 0, but got input (D: ",
  175. input_depth,
  176. ", H: ",
  177. input_height,
  178. ", W: ",
  179. input_width,
  180. ") output (D: ",
  181. output_depth,
  182. ", H: ",
  183. output_height,
  184. ", W: ",
  185. output_width,
  186. ")");
  187. return {nbatch, channels, output_depth, output_height, output_width};
  188. }
  189. inline void upsample_2d_shape_check(
  190. const Tensor& input,
  191. const Tensor& grad_output,
  192. int64_t nbatch,
  193. int64_t nchannels,
  194. int64_t input_height,
  195. int64_t input_width,
  196. int64_t output_height,
  197. int64_t output_width) {
  198. TORCH_CHECK(
  199. input_height > 0 && input_width > 0 && output_height > 0 &&
  200. output_width > 0,
  201. "Input and output sizes should be greater than 0,"
  202. " but got input (H: ",
  203. input_height,
  204. ", W: ",
  205. input_width,
  206. ") output (H: ",
  207. output_height,
  208. ", W: ",
  209. output_width,
  210. ")");
  211. if (input.defined()) {
  212. // Allow for empty batch size but not other dimensions
  213. TORCH_CHECK(
  214. (input.numel() != 0 ||
  215. (input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0)
  216. ) &&
  217. input.dim() == 4,
  218. "Non-empty 4D data tensor expected but got a tensor with sizes ",
  219. input.sizes());
  220. } else if (grad_output.defined()) {
  221. check_dim_size(grad_output, 4, 0, nbatch);
  222. check_dim_size(grad_output, 4, 1, nchannels);
  223. check_dim_size(grad_output, 4, 2, output_height);
  224. check_dim_size(grad_output, 4, 3, output_width);
  225. }
  226. }
  227. template <typename scalar_t>
  228. inline scalar_t compute_scales_value(
  229. const std::optional<double> scale,
  230. int64_t input_size,
  231. int64_t output_size) {
  232. // see Note [compute_scales_value]
  233. // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults.
  234. return (scale.has_value() && scale.value() > 0.)
  235. ? static_cast<scalar_t>(1.0 / scale.value())
  236. : (static_cast<scalar_t>(input_size) / output_size);
  237. }
  238. template <typename scalar_t>
  239. inline scalar_t area_pixel_compute_scale(
  240. int64_t input_size,
  241. int64_t output_size,
  242. bool align_corners,
  243. const std::optional<double> scale) {
  244. // see Note [area_pixel_compute_scale]
  245. if(align_corners) {
  246. if(output_size > 1) {
  247. return static_cast<scalar_t>(input_size - 1) / (output_size - 1);
  248. } else {
  249. return static_cast<scalar_t>(0);
  250. }
  251. } else {
  252. return compute_scales_value<scalar_t>(scale, input_size, output_size);
  253. }
  254. }
  255. template <typename scalar_t>
  256. inline scalar_t area_pixel_compute_source_index(
  257. scalar_t scale,
  258. int64_t dst_index,
  259. bool align_corners,
  260. bool cubic) {
  261. if (align_corners) {
  262. return scale * dst_index;
  263. } else {
  264. scalar_t src_idx = scale * (dst_index + static_cast<scalar_t>(0.5)) -
  265. static_cast<scalar_t>(0.5);
  266. // [Note] Follow Opencv resize logic:
  267. // We allow negative src_idx here and later will use
  268. // dx = src_idx - floorf(src_idx)
  269. // to compute the "distance"(which affects weights).
  270. // For linear modes, weight distribution doesn't matter
  271. // for negative indices as they use 2 pixels to interpolate.
  272. // For example, [-1, 0], they both use pixel 0 value so it
  273. // doesn't affect if we bound the src_idx to 0 or not.
  274. // TODO: Our current linear mode impls use unbound indices
  275. // where we should and then remove this cubic flag.
  276. // This matters in cubic mode, as we might need [-1, 0, 1, 2]
  277. // to interpolate and the weights can be affected.
  278. return (!cubic && src_idx < static_cast<scalar_t>(0)) ? scalar_t(0)
  279. : src_idx;
  280. }
  281. }
  282. inline int64_t nearest_neighbor_compute_source_index(
  283. const float scale,
  284. int64_t dst_index,
  285. int64_t input_size) {
  286. // Index computation matching OpenCV INTER_NEAREST
  287. // which is buggy and kept for BC
  288. const int64_t src_index =
  289. std::min(static_cast<int64_t>(floorf(dst_index * scale)), input_size - 1);
  290. return src_index;
  291. }
  292. inline int64_t nearest_neighbor_exact_compute_source_index(
  293. const float scale,
  294. int64_t dst_index,
  295. int64_t input_size) {
  296. // index_f32 = (output_index + 0.5) * scale - 0.5
  297. // input_index = round(index_f32)
  298. // Same as Pillow and Scikit-Image/Scipy ndi.zoom
  299. const int64_t src_index =
  300. std::min(static_cast<int64_t>(floorf((dst_index + 0.5) * scale)), input_size - 1);
  301. return src_index;
  302. }
  303. inline int64_t nearest_idx(
  304. int64_t output_index,
  305. int64_t input_size,
  306. int64_t output_size,
  307. std::optional<double> scales) {
  308. // This method specifically treats cases: output_size == input_size or
  309. // output_size == 2 * input_size, that we would like to get rid of
  310. // We keep this method for BC and consider as deprecated.
  311. // See nearest_exact_idx as replacement
  312. if (output_size == input_size) {
  313. // scale_factor = 1, simply copy
  314. return output_index;
  315. } else if (output_size == 2 * input_size) {
  316. // scale_factor = 2, shift input index
  317. return output_index >> 1;
  318. } else {
  319. float scale = compute_scales_value<float>(scales, input_size, output_size);
  320. return nearest_neighbor_compute_source_index(scale, output_index, input_size);
  321. }
  322. }
  323. inline int64_t nearest_exact_idx(
  324. int64_t output_index,
  325. int64_t input_size,
  326. int64_t output_size,
  327. std::optional<double> scales) {
  328. float scale = compute_scales_value<float>(scales, input_size, output_size);
  329. return nearest_neighbor_exact_compute_source_index(scale, output_index, input_size);
  330. }
  331. // Define a typedef to dispatch to nearest_idx or nearest_exact_idx
  332. typedef int64_t (*nearest_idx_fn_t)(int64_t, int64_t, int64_t, std::optional<double>);
  333. template <typename scalar_t>
  334. scalar_t upsample_get_value_bounded(
  335. scalar_t* data,
  336. int64_t width,
  337. int64_t height,
  338. int64_t x,
  339. int64_t y) {
  340. int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
  341. int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
  342. return data[access_y * width + access_x];
  343. }
  344. template <typename scalar_t>
  345. void upsample_increment_value_bounded(
  346. scalar_t* data,
  347. int64_t width,
  348. int64_t height,
  349. int64_t x,
  350. int64_t y,
  351. scalar_t value) {
  352. int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
  353. int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
  354. data[access_y * width + access_x] += value;
  355. }
  356. // Based on
  357. // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
  358. template <typename scalar_t>
  359. scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
  360. return ((A + 2) * x - (A + 3)) * x * x + 1;
  361. }
  362. template <typename scalar_t>
  363. scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
  364. return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
  365. }
  366. template <typename scalar_t>
  367. static inline void get_cubic_upsample_coefficients(
  368. scalar_t coeffs[4],
  369. scalar_t t) {
  370. scalar_t A = -0.75;
  371. scalar_t x1 = t;
  372. coeffs[0] = cubic_convolution2<scalar_t>(x1 + 1.0, A);
  373. coeffs[1] = cubic_convolution1<scalar_t>(x1, A);
  374. // opposite coefficients
  375. scalar_t x2 = 1.0 - t;
  376. coeffs[2] = cubic_convolution1<scalar_t>(x2, A);
  377. coeffs[3] = cubic_convolution2<scalar_t>(x2 + 1.0, A);
  378. }
  379. template <typename scalar_t>
  380. inline scalar_t cubic_interp1d(
  381. scalar_t x0,
  382. scalar_t x1,
  383. scalar_t x2,
  384. scalar_t x3,
  385. scalar_t t) {
  386. scalar_t coeffs[4];
  387. get_cubic_upsample_coefficients<scalar_t>(coeffs, t);
  388. return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
  389. }
  390. // when `real_input_index` becomes larger than the range the floating point
  391. // type can accurately represent, the type casting to `int64_t` might exceed
  392. // `input_size`, causing overflow. So we guard it with `std::min` below.
  393. template<typename scalar_t, typename opmath_t>
  394. inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) {
  395. input_index = std::min(static_cast<int64_t>(floorf(real_input_index)), input_size - 1);
  396. lambda = std::min(
  397. std::max(real_input_index - input_index, static_cast<opmath_t>(0)),
  398. static_cast<opmath_t>(1)
  399. );
  400. }
  401. template<typename scalar_t, typename opmath_t>
  402. inline void compute_source_index_and_lambda(
  403. int64_t& input_index0,
  404. int64_t& input_index1,
  405. scalar_t& lambda0,
  406. scalar_t& lambda1,
  407. opmath_t ratio,
  408. int64_t output_index,
  409. int64_t input_size,
  410. int64_t output_size,
  411. bool align_corners) {
  412. if (output_size == input_size) {
  413. // scale_factor = 1, simply copy
  414. input_index0 = output_index;
  415. input_index1 = output_index;
  416. lambda0 = static_cast<scalar_t>(1);
  417. lambda1 = static_cast<scalar_t>(0);
  418. } else {
  419. const auto real_input_index =
  420. area_pixel_compute_source_index<opmath_t>(
  421. ratio, output_index, align_corners, /*cubic=*/false);
  422. guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1);
  423. int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
  424. input_index1 = input_index0 + offset;
  425. lambda0 = static_cast<scalar_t>(1.) - lambda1;
  426. }
  427. }
  428. // It will not be used by data types other than BFloat16 and Half.
  429. template <typename scalar_in, typename scalar_out,
  430. typename std::enable_if_t<!is_reduced_floating_point_v<scalar_out> || !std::is_same_v<scalar_in, float>, int> = 0>
  431. void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
  432. TORCH_CHECK((is_reduced_floating_point_v<scalar_out>),
  433. "Upsample backward only support BFloat16 and Half in the lower precision data types on CPU.")
  434. TORCH_CHECK((std::is_same_v<scalar_in, float>),
  435. "Upsample backward should use float as acc buffer for BFloat16 and Half grad input on CPU.")
  436. return;
  437. }
  438. template <typename scalar_in, typename scalar_out,
  439. typename std::enable_if_t<is_reduced_floating_point_v<scalar_out> && std::is_same_v<scalar_in, float>, int> = 0>
  440. void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) {
  441. using bVec = Vectorized<scalar_out>;
  442. using fVec = Vectorized<float>;
  443. int64_t d = 0;
  444. for (; d < size - (size % bVec::size()); d += bVec::size()) {
  445. bVec gin_bvec = bVec::loadu(gin + d);
  446. auto [gin_fvec0, gin_fvec1] = convert_to_float<scalar_out>(gin_bvec);
  447. gin_fvec0 += fVec::loadu(buffer_ptr + d);
  448. gin_fvec1 += fVec::loadu(buffer_ptr + d + fVec::size());
  449. fVec(0).store(buffer_ptr + d);
  450. fVec(0).store(buffer_ptr + d + fVec::size());
  451. convert_from_float<scalar_out>(gin_fvec0, gin_fvec1).store(gin + d);
  452. }
  453. for (; d < size; d++) {
  454. gin[d] += buffer_ptr[d];
  455. buffer_ptr[d] = 0;
  456. }
  457. }
  458. } // namespace at::native
  459. #else
  460. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  461. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)