Padding.h 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/Tensor.h>
  4. #include <ATen/native/DispatchStub.h>
  5. namespace at::native {
  6. using padding_fn = void (*)(const Tensor&, const Tensor&, IntArrayRef);
  7. // reflection padding
  8. DECLARE_DISPATCH(padding_fn, reflection_pad1d_kernel)
  9. DECLARE_DISPATCH(padding_fn, reflection_pad1d_backward_kernel)
  10. DECLARE_DISPATCH(padding_fn, reflection_pad2d_kernel)
  11. DECLARE_DISPATCH(padding_fn, reflection_pad2d_backward_kernel)
  12. DECLARE_DISPATCH(padding_fn, reflection_pad3d_kernel)
  13. DECLARE_DISPATCH(padding_fn, reflection_pad3d_backward_kernel)
  14. // replication padding
  15. DECLARE_DISPATCH(padding_fn, replication_pad1d_kernel)
  16. DECLARE_DISPATCH(padding_fn, replication_pad1d_backward_kernel)
  17. DECLARE_DISPATCH(padding_fn, replication_pad2d_kernel)
  18. DECLARE_DISPATCH(padding_fn, replication_pad2d_backward_kernel)
  19. DECLARE_DISPATCH(padding_fn, replication_pad3d_kernel)
  20. DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel)
  21. namespace padding {
  22. template <int dim>
  23. inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
  24. TORCH_CHECK(padding.size() == 2 * dim,
  25. "padding size is expected to be ", 2 * dim,
  26. ", but got: ", padding.size());
  27. int input_dim = input.dim();
  28. bool is_batch_mode = input_dim == (dim + 2);
  29. bool is_non_batch_mode = input_dim == (dim + 1);
  30. bool valid_batch_mode = is_batch_mode;
  31. bool valid_non_batch_mode = is_non_batch_mode;
  32. if (is_batch_mode) {
  33. // allow batch size of 0-dim.
  34. for (const auto d : c10::irange(1, input_dim)) {
  35. valid_batch_mode = valid_batch_mode && input.size(d) != 0;
  36. }
  37. } else {
  38. for (const auto d : c10::irange(0, input_dim)) {
  39. valid_non_batch_mode = valid_non_batch_mode && input.size(d) != 0;
  40. }
  41. }
  42. // allow empty batch size but not other dimensions.
  43. TORCH_CHECK(valid_batch_mode || valid_non_batch_mode,
  44. "Expected ", dim + 1, "D or ", dim + 2,
  45. "D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
  46. input.sizes());
  47. }
  48. } // namespace padding
  49. } // at::native
  50. #else
  51. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  52. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)