| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #pragma once
- #include <ATen/core/Tensor.h>
- #include <ATen/TensorUtils.h>
- #include <c10/util/irange.h>
- namespace at::native {
- template<typename scalar_t>
- inline std::vector<int64_t> generate_intervals(
- scalar_t sample,
- int64_t inputSize,
- int64_t outputSize,
- int64_t poolSize) {
- std::vector<int64_t> sequence(outputSize);
- if (outputSize > 1) {
- scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
- static_cast<scalar_t>(outputSize - 1);
- for (const auto i : c10::irange(outputSize - 1)) {
- sequence[i] =
- static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
- }
- }
- if (outputSize > 0) {
- sequence[outputSize - 1] = inputSize - poolSize;
- }
- return sequence;
- }
- template <int64_t ndim>
- inline void fractional_max_pool_check_shape(
- const Tensor& input,
- const Tensor& randomSamples) {
- TORCH_CHECK(
- input.scalar_type() == randomSamples.scalar_type(),
- "Expect _random_samples to have the same dtype as input");
- int64_t ndimension = randomSamples.ndimension();
- TORCH_CHECK(
- ndimension == 3,
- "Expect _random_samples to have 3 dimensions, got ", ndimension);
- int64_t N = randomSamples.size(0);
- int64_t C = randomSamples.size(1);
- int64_t D = randomSamples.size(2);
- int64_t input_batch = 0, input_channel = 0;
- if (ndim == 2) {
- // fractional_max_pool2d
- if (input.ndimension() == 3) {
- input_batch = 1;
- input_channel = input.size(0);
- } else {
- input_batch = input.size(0);
- input_channel = input.size(1);
- }
- } else {
- // factional_max_pool3d
- if (input.ndimension() == 4) {
- input_batch = 1;
- input_channel = input.size(0);
- } else {
- input_batch = input.size(0);
- input_channel = input.size(1);
- }
- }
- TORCH_CHECK(
- N >= input_batch,
- "Expect _random_samples.size(0) no less then input batch size.");
- TORCH_CHECK(
- C == input_channel,
- "Expect _random_samples.size(1) equals to input channel size.");
- TORCH_CHECK(
- D == ndim,
- "Expect _random_samples.size(2) equals to ", ndim, "; got ", D, ".");
- }
- } // namespace at::native
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|