MaxPooling.h 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/Tensor.h>
  4. #include <ATen/Parallel.h>
  5. #include <ATen/native/DispatchStub.h>
  6. #include <ATen/native/Pool.h>
  7. namespace at::native {
  8. inline void check_max_pool1d(
  9. const Tensor& self,
  10. IntArrayRef kernel_size,
  11. IntArrayRef stride,
  12. IntArrayRef padding,
  13. IntArrayRef dilation,
  14. bool ceil_mode) {
  15. TORCH_CHECK(
  16. self.dim() == 2 || self.dim() == 3,
  17. "max_pool1d() Expected 2D or 3D input tensor, but got ", self.sym_sizes());
  18. TORCH_CHECK(
  19. kernel_size.size() == 1,
  20. "max_pool1d() kernel_size must be an int, list of ints or tuple of ints of size 1 but got size ",
  21. kernel_size.size());
  22. TORCH_CHECK(
  23. stride.empty() || stride.size() == 1,
  24. "max_pool1d() stride must be None, an int, list of ints, or tuple of ints of size 1 but got size ",
  25. stride.size());
  26. TORCH_CHECK(
  27. padding.size() == 1,
  28. "max_pool1d() padding must be an int, list of ints, or tuple of ints of size 1 but got size ",
  29. padding.size());
  30. TORCH_CHECK(
  31. dilation.size() == 1,
  32. "max_pool1d() dilation must be an int, list of ints or tuple of ints of size 1 but got size ",
  33. dilation.size());
  34. // If stride=None then set it to kernel_size
  35. if (stride.empty()) {
  36. stride = kernel_size;
  37. }
  38. TORCH_CHECK(
  39. kernel_size[0] > 0,
  40. "max_pool1d() kernel_size must be greater than zero, but got ",
  41. kernel_size[0]);
  42. TORCH_CHECK(
  43. stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]);
  44. TORCH_CHECK(
  45. padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]);
  46. TORCH_CHECK(
  47. padding[0] <= kernel_size[0] / 2,
  48. "max_pool1d() padding should be at most half of kernel size, but got padding=",
  49. padding[0],
  50. " and kernel_size=",
  51. kernel_size[0]);
  52. TORCH_CHECK(
  53. dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);
  54. const int64_t OW = pooling_output_shape(self.sym_size(-1).guard_int(__FILE__, __LINE__), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
  55. TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW);
  56. }
  57. // TODO(Heitor) Template by dimension
  58. struct PoolingParams1D {
  59. int64_t NB; // Number of batches
  60. int64_t NC; // Number of channels
  61. int64_t IW; // Input width
  62. int64_t OW; // Output width
  63. int64_t KW; // Kernel width
  64. int64_t SJ; // Column stride
  65. int64_t PJ; // Column padding
  66. int64_t DJ; // Column dilation
  67. // Return index of input element for the given kernel and output index
  68. inline int64_t index(int64_t kj, int64_t oj) const {
  69. return oj * SJ + kj * DJ - PJ;
  70. }
  71. // Return index of first output within bounds for this kernel index
  72. inline int64_t valid_output_start(int64_t kj) const {
  73. int64_t ij = index(kj, 0);;
  74. return ij < 0 ? at::divup(-ij, SJ) : 0;
  75. }
  76. // Return index one past last output within bounds for this kernel index
  77. inline int64_t valid_output_end(int64_t kj) const {
  78. int64_t ij = index(kj, OW - 1);
  79. return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW;
  80. }
  81. };
  82. using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&);
  83. DECLARE_DISPATCH(pooling_fn, max_pool1d_stub)
  84. } // namespace at::native
  85. #else
  86. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  87. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)