FractionalMaxPooling.h 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/Tensor.h>
  4. #include <ATen/TensorUtils.h>
  5. #include <c10/util/irange.h>
  6. namespace at::native {
  7. template<typename scalar_t>
  8. inline std::vector<int64_t> generate_intervals(
  9. scalar_t sample,
  10. int64_t inputSize,
  11. int64_t outputSize,
  12. int64_t poolSize) {
  13. std::vector<int64_t> sequence(outputSize);
  14. if (outputSize > 1) {
  15. scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
  16. static_cast<scalar_t>(outputSize - 1);
  17. for (const auto i : c10::irange(outputSize - 1)) {
  18. sequence[i] =
  19. static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
  20. }
  21. }
  22. if (outputSize > 0) {
  23. sequence[outputSize - 1] = inputSize - poolSize;
  24. }
  25. return sequence;
  26. }
  27. template <int64_t ndim>
  28. inline void fractional_max_pool_check_shape(
  29. const Tensor& input,
  30. const Tensor& randomSamples) {
  31. TORCH_CHECK(
  32. input.scalar_type() == randomSamples.scalar_type(),
  33. "Expect _random_samples to have the same dtype as input");
  34. int64_t ndimension = randomSamples.ndimension();
  35. TORCH_CHECK(
  36. ndimension == 3,
  37. "Expect _random_samples to have 3 dimensions, got ", ndimension);
  38. int64_t N = randomSamples.size(0);
  39. int64_t C = randomSamples.size(1);
  40. int64_t D = randomSamples.size(2);
  41. int64_t input_batch = 0, input_channel = 0;
  42. if (ndim == 2) {
  43. // fractional_max_pool2d
  44. if (input.ndimension() == 3) {
  45. input_batch = 1;
  46. input_channel = input.size(0);
  47. } else {
  48. input_batch = input.size(0);
  49. input_channel = input.size(1);
  50. }
  51. } else {
  52. // factional_max_pool3d
  53. if (input.ndimension() == 4) {
  54. input_batch = 1;
  55. input_channel = input.size(0);
  56. } else {
  57. input_batch = input.size(0);
  58. input_channel = input.size(1);
  59. }
  60. }
  61. TORCH_CHECK(
  62. N >= input_batch,
  63. "Expect _random_samples.size(0) no less then input batch size.");
  64. TORCH_CHECK(
  65. C == input_channel,
  66. "Expect _random_samples.size(1) equals to input channel size.");
  67. TORCH_CHECK(
  68. D == ndim,
  69. "Expect _random_samples.size(2) equals to ", ndim, "; got ", D, ".");
  70. }
  71. } // namespace at::native
  72. #else
  73. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  74. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)