Repeat.h 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/Tensor.h>
  4. #include <ATen/TensorOperators.h>
  5. #ifndef AT_PER_OPERATOR_HEADERS
  6. #include <ATen/Functions.h>
  7. #else
  8. #include <ATen/ops/empty.h>
  9. #include <ATen/ops/empty_like.h>
  10. #endif
  11. namespace at::native {
  12. template <
  13. typename index_t,
  14. void compute(const index_t*, const int64_t*, index_t*, int64_t, int64_t)>
  15. static inline Tensor repeat_interleave_common(
  16. const Tensor& repeats,
  17. std::optional<int64_t> output_size) {
  18. TORCH_CHECK(
  19. repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
  20. TORCH_CHECK(
  21. repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt,
  22. "repeats has to be Long or Int tensor");
  23. if (repeats.size(0) == 0) {
  24. return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  25. }
  26. Tensor repeats_ = repeats.contiguous();
  27. Tensor cumsum = repeats.cumsum(0);
  28. int64_t total = 0;
  29. if (output_size.has_value()) {
  30. total = output_size.value();
  31. } else {
  32. total = cumsum[-1].item<int64_t>();
  33. TORCH_CHECK(
  34. (repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
  35. }
  36. Tensor result = at::empty({total}, repeats.options());
  37. const index_t* repeat_ptr = repeats_.const_data_ptr<index_t>();
  38. const int64_t* cumsum_ptr = cumsum.const_data_ptr<int64_t>();
  39. index_t* result_ptr = result.data_ptr<index_t>();
  40. compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total);
  41. return result;
  42. }
  43. } // namespace at::native
  44. #else
  45. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  46. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)