GridSampler.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <algorithm>
  4. #include <cmath>
  5. #include <cstdint>
  6. #include <utility>
  7. #include <ATen/native/GridSamplerUtils.h>
  8. namespace at::native {
  9. using detail::GridSamplerInterpolation;
  10. using detail::GridSamplerPadding;
  11. // Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
  12. // where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
  13. // if align_corners: -1 and +1 get sent to the centers of the corner pixels
  14. // -1 --> 0
  15. // +1 --> (size - 1)
  16. // scale_factor = (size - 1) / 2
  17. // if not align_corners: -1 and +1 get sent to the image edges
  18. // -1 --> -0.5
  19. // +1 --> (size - 1) + 0.5 == size - 0.5
  20. // scale_factor = size / 2
  21. template <typename scalar_t>
  22. static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size,
  23. bool align_corners) {
  24. if (align_corners) {
  25. // unnormalize coord from [-1, 1] to [0, size - 1]
  26. return ((coord + 1) / 2) * (size - 1);
  27. } else {
  28. // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
  29. return ((coord + 1) * size - 1) / 2;
  30. }
  31. }
  32. // grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
  33. // except that it also returns the `d output / d input` via pointer argument
  34. // `grad_in`.
  35. // This is useful in the backward pass of grid_sampler.
  36. template <typename scalar_t>
  37. static inline scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int64_t size,
  38. bool align_corners, scalar_t *grad_in) {
  39. if (align_corners) {
  40. // unnormalize coord from [-1, 1] to [0, size - 1]
  41. *grad_in = static_cast<scalar_t>(size - 1) / 2;
  42. return ((coord + 1) / 2) * (size - 1);
  43. } else {
  44. // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
  45. *grad_in = static_cast<scalar_t>(size) / 2;
  46. return ((coord + 1) * size - 1) / 2;
  47. }
  48. }
  49. // Clips coordinates to between 0 and clip_limit - 1
  50. template<typename scalar_t>
  51. static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) {
  52. return std::min(static_cast<scalar_t>(clip_limit - 1), std::max(in, static_cast<scalar_t>(0)));
  53. }
  54. // clip_coordinates_set_grad works similarly to clip_coordinates except that
  55. // it also returns the `d output / d input` via pointer argument `grad_in`.
  56. // This is useful in the backward pass of grid_sampler.
  57. template<typename scalar_t>
  58. static inline scalar_t clip_coordinates_set_grad(scalar_t in, int64_t clip_limit,
  59. scalar_t *grad_in) {
  60. // Note that it is important for the gradient calculation that borders
  61. // are considered out of bounds.
  62. if (in <= static_cast<scalar_t>(0)) {
  63. *grad_in = static_cast<scalar_t>(0);
  64. return static_cast<scalar_t>(0);
  65. } else {
  66. scalar_t max = static_cast<scalar_t>(clip_limit - 1);
  67. if (in >= max) {
  68. *grad_in = static_cast<scalar_t>(0);
  69. return max;
  70. } else {
  71. *grad_in = static_cast<scalar_t>(1);
  72. return in;
  73. }
  74. }
  75. }
  76. // Reflects coordinates until they fall between low and high (inclusive).
  77. // The bounds are passed as twice their value so that half-integer values
  78. // can be represented as ints.
  79. template<typename scalar_t>
  80. static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low,
  81. int64_t twice_high) {
  82. if (twice_low == twice_high) {
  83. return static_cast<scalar_t>(0);
  84. }
  85. scalar_t min = static_cast<scalar_t>(twice_low) / 2;
  86. scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
  87. in = std::fabs(in - min);
  88. // `fmod` returns same sign as `in`, which is positive after the `fabs` above.
  89. scalar_t extra = std::fmod(in, span);
  90. int flips = static_cast<int>(std::floor(in / span));
  91. if (flips % 2 == 0) {
  92. return extra + min;
  93. } else {
  94. return span - extra + min;
  95. }
  96. }
  97. // reflect_coordinates_set_grad works similarly to reflect_coordinates except
  98. // that it also returns the `d output / d input` via pointer argument
  99. // `grad_in`.
  100. // This is useful in the backward pass of grid_sampler.
  101. template<typename scalar_t>
  102. static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_low,
  103. int64_t twice_high, scalar_t *grad_in) {
  104. if (twice_low == twice_high) {
  105. *grad_in = static_cast<scalar_t>(0);
  106. return static_cast<scalar_t>(0);
  107. }
  108. int grad_in_mult_;
  109. scalar_t min = static_cast<scalar_t>(twice_low) / 2;
  110. scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
  111. in = in - min;
  112. if (in < static_cast<scalar_t>(0)) {
  113. grad_in_mult_ = -1;
  114. in = -in;
  115. } else {
  116. grad_in_mult_ = 1;
  117. }
  118. // `fmod` returns same sign as `in`, which is positive after the `if` above.
  119. scalar_t extra = std::fmod(in, span);
  120. int flips = static_cast<int>(std::floor(in / span));
  121. if (flips % 2 == 0) {
  122. *grad_in = static_cast<scalar_t>(grad_in_mult_);
  123. return extra + min;
  124. } else {
  125. *grad_in = static_cast<scalar_t>(-grad_in_mult_);
  126. return span - extra + min;
  127. }
  128. }
  129. // Mapping the out-of-boundary points back into boundary
  130. // This would only affect padding_mode=border or reflection
  131. template<typename scalar_t>
  132. static inline scalar_t compute_coordinates(scalar_t coord, int64_t size,
  133. GridSamplerPadding padding_mode,
  134. bool align_corners) {
  135. if (padding_mode == GridSamplerPadding::Border) {
  136. // clip coordinates to image borders
  137. coord = clip_coordinates(coord, size);
  138. } else if (padding_mode == GridSamplerPadding::Reflection) {
  139. // reflect coordinates by image borders
  140. if (align_corners) {
  141. coord = reflect_coordinates(coord, 0, 2*(size - 1));
  142. } else {
  143. coord = reflect_coordinates(coord, -1, 2*size - 1);
  144. }
  145. // clip coordinates to image borders
  146. coord = clip_coordinates(coord, size);
  147. }
  148. return coord;
  149. }
  150. // Computes the pixel source index value for a grid coordinate
  151. template <typename scalar_t>
  152. static inline scalar_t grid_sampler_compute_source_index(
  153. scalar_t coord,
  154. int64_t size,
  155. GridSamplerPadding padding_mode,
  156. bool align_corners) {
  157. coord = grid_sampler_unnormalize(coord, size, align_corners);
  158. coord = compute_coordinates(coord, size, padding_mode, align_corners);
  159. return coord;
  160. }
  161. // grid_sampler_compute_source_index_set_grad works similarly to
  162. // grid_sampler_compute_source_index except that it also returns the
  163. // `d output / d input` via pointer argument `grad_in`.
  164. // This is useful in the backward pass of grid_sampler.
  165. template <typename scalar_t>
  166. static inline scalar_t grid_sampler_compute_source_index_set_grad(
  167. scalar_t coord,
  168. int64_t size,
  169. GridSamplerPadding padding_mode,
  170. bool align_corners,
  171. scalar_t *grad_in) {
  172. scalar_t grad_clip, grad_refl;
  173. coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
  174. if (padding_mode == GridSamplerPadding::Border) {
  175. // clip coordinates to image borders
  176. coord = clip_coordinates_set_grad(coord, size, &grad_clip);
  177. *grad_in = (*grad_in) * grad_clip;
  178. } else if (padding_mode == GridSamplerPadding::Reflection) {
  179. // reflect coordinates by image borders
  180. if (align_corners) {
  181. coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
  182. } else {
  183. coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
  184. }
  185. // clip coordinates to image borders
  186. coord = clip_coordinates_set_grad(coord, size, &grad_clip);
  187. *grad_in = (*grad_in) * grad_refl * grad_clip;
  188. }
  189. return coord;
  190. }
  191. static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) {
  192. return h >= 0 && h < H && w >= 0 && w < W;
  193. }
  194. static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W) {
  195. return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
  196. }
  197. template<typename scalar_t>
  198. static inline scalar_t get_value_bounded(
  199. const scalar_t* data,
  200. scalar_t x,
  201. scalar_t y,
  202. int64_t W,
  203. int64_t H,
  204. int64_t sW,
  205. int64_t sH,
  206. GridSamplerPadding padding_mode,
  207. bool align_corners) {
  208. x = compute_coordinates(x, W, padding_mode, align_corners);
  209. y = compute_coordinates(y, H, padding_mode, align_corners);
  210. int64_t ix = static_cast<int64_t>(x);
  211. int64_t iy = static_cast<int64_t>(y);
  212. if (within_bounds_2d(iy, ix, H, W)) {
  213. return data[iy * sH + ix * sW];
  214. }
  215. return static_cast<scalar_t>(0);
  216. }
  217. template<typename scalar_t>
  218. static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w,
  219. int64_t sH, int64_t sW, int64_t H, int64_t W,
  220. scalar_t delta) {
  221. if (within_bounds_2d(h, w, H, W)) {
  222. data[h * sH + w * sW] += delta;
  223. }
  224. }
  225. template<typename scalar_t>
  226. static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w,
  227. int64_t sD, int64_t sH, int64_t sW,
  228. int64_t D, int64_t H, int64_t W,
  229. scalar_t delta) {
  230. if (within_bounds_3d(d, h, w, D, H, W)) {
  231. data[d * sD + h * sH + w * sW] += delta;
  232. }
  233. }
  234. template<typename scalar_t>
  235. static inline void add_value_bounded(
  236. scalar_t* data,
  237. scalar_t x,
  238. scalar_t y,
  239. int64_t W,
  240. int64_t H,
  241. int64_t sW,
  242. int64_t sH,
  243. scalar_t delta,
  244. GridSamplerPadding padding_mode,
  245. bool align_corners) {
  246. x = compute_coordinates(x, W, padding_mode, align_corners);
  247. y = compute_coordinates(y, H, padding_mode, align_corners);
  248. int64_t ix = static_cast<int64_t>(x);
  249. int64_t iy = static_cast<int64_t>(y);
  250. safe_add_2d(data, iy, ix, sH, sW, H, W, delta);
  251. }
  252. // Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
  253. template<typename scalar_t>
  254. static inline void get_cubic_coefficients_grad(
  255. scalar_t coeffs[4],
  256. scalar_t t) {
  257. // Must be the same as forward calculation in
  258. // aten/src/ATen/native/UpSample.h:get_cubic_upsample_coefficients
  259. scalar_t A = -0.75;
  260. scalar_t x;
  261. x = -1 - t; // 1 < x = |-1 - tx| < 2
  262. coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A;
  263. x = -t; // x = |0 - tx| <= 1
  264. coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
  265. x = 1 - t; // x = |1 - tx| <= 1
  266. coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
  267. x = 2 - t; // 1 < x = |2 - tx| < 2
  268. coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
  269. }
  270. } // namespace at::native
  271. #else
  272. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  273. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)