PixelShuffle.h 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #include <ATen/core/Tensor.h>
  3. #include <c10/util/Exception.h>
  4. namespace at::native {
  5. inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) {
  6. TORCH_CHECK(self.dim() >= 3,
  7. "pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
  8. self.dim(), " dimension(s)");
  9. TORCH_CHECK(upscale_factor > 0,
  10. "pixel_shuffle expects a positive upscale_factor, but got ",
  11. upscale_factor);
  12. int64_t c = self.size(-3);
  13. TORCH_CHECK_VALUE(upscale_factor <= std::numeric_limits<decltype(upscale_factor)>::max() / upscale_factor,
  14. "upscale factor is too large, (upscale_factor)^2 overflowed: upscale_factor=", upscale_factor);
  15. int64_t upscale_factor_squared = upscale_factor * upscale_factor;
  16. TORCH_CHECK(c % upscale_factor_squared == 0,
  17. "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
  18. "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared);
  19. }
  20. inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) {
  21. TORCH_CHECK(
  22. self.dim() >= 3,
  23. "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
  24. self.dim(),
  25. " dimension(s)");
  26. TORCH_CHECK(
  27. downscale_factor > 0,
  28. "pixel_unshuffle expects a positive downscale_factor, but got ",
  29. downscale_factor);
  30. int64_t h = self.size(-2);
  31. int64_t w = self.size(-1);
  32. TORCH_CHECK(
  33. h % downscale_factor == 0,
  34. "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=",
  35. h,
  36. " is not divisible by ",
  37. downscale_factor);
  38. TORCH_CHECK(
  39. w % downscale_factor == 0,
  40. "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=",
  41. w,
  42. " is not divisible by ",
  43. downscale_factor);
  44. }
  45. } // namespace at::native
  46. #else
  47. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  48. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)