WrapDimUtilsMulti.h 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/WrapDimUtils.h>
  4. #include <c10/core/TensorImpl.h>
  5. #include <c10/util/irange.h>
  6. #include <bitset>
  7. #include <sstream>
  8. namespace at {
  9. // This is in an extra file to work around strange interaction of
  10. // bitset on Windows with operator overloading
  11. constexpr size_t dim_bitset_size = 64;
  12. inline std::bitset<dim_bitset_size> dim_list_to_bitset(
  13. OptionalIntArrayRef opt_dims,
  14. size_t ndims) {
  15. TORCH_CHECK(
  16. ndims <= dim_bitset_size,
  17. "only tensors with up to ",
  18. dim_bitset_size,
  19. " dims are supported");
  20. std::bitset<dim_bitset_size> seen;
  21. if (opt_dims.has_value()) {
  22. auto dims = opt_dims.value();
  23. for (const auto i : c10::irange(dims.size())) {
  24. size_t dim = maybe_wrap_dim(dims[i], static_cast<int64_t>(ndims));
  25. TORCH_CHECK(
  26. !seen[dim],
  27. "dim ",
  28. dim,
  29. " appears multiple times in the list of dims");
  30. seen[dim] = true;
  31. }
  32. } else {
  33. for (size_t dim = 0; dim < ndims; dim++) {
  34. seen[dim] = true;
  35. }
  36. }
  37. return seen;
  38. }
  39. } // namespace at
  40. #else
  41. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  42. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)