ComplexHelper.h 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/Tensor.h>
  4. #include <c10/core/SymBool.h>
  5. #include <c10/util/irange.h>
  6. #ifndef AT_PER_OPERATOR_HEADERS
  7. #include <ATen/NativeFunctions.h>
  8. #else
  9. #include <ATen/ops/view_as_real_native.h>
  10. #include <ATen/ops/view_as_complex_native.h>
  11. #include <utility>
  12. #endif
  13. // WARNING: this header contains non-inline functions and should be only
  14. // included from ONE cpp file
  15. namespace at::native {
  16. // View tensor with new dtype, storage offset, sizes and strides
  17. inline Tensor view_tensor(
  18. const Tensor &tensor, ScalarType dtype,
  19. c10::SymInt offset, SymIntArrayRef sizes, SymIntArrayRef strides) {
  20. Storage storage = tensor.storage();
  21. auto key_set = tensor.key_set().remove(DispatchKey::Conjugate);
  22. auto new_tensor = detail::make_tensor<TensorImpl>(
  23. c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype));
  24. auto * impl = new_tensor.unsafeGetTensorImpl();
  25. impl->set_sizes_and_strides(sizes, strides, offset);
  26. return new_tensor;
  27. }
  28. inline SymDimVector computeStrideForViewAsReal(SymIntArrayRef oldstride) {
  29. SymDimVector res(oldstride.size() + 1);
  30. for (const auto i : c10::irange(oldstride.size())) {
  31. res[i] = oldstride[i] * 2;
  32. }
  33. res.back() = 1;
  34. return res;
  35. }
  36. inline Tensor _view_as_real_physical(const Tensor& self) {
  37. TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors");
  38. auto old_sizes = self.sym_sizes();
  39. SymDimVector new_sizes(old_sizes.size() + 1);
  40. std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
  41. // last dimension will always have two elements containing the real and imag vals
  42. new_sizes.back() = 2;
  43. auto new_strides = computeStrideForViewAsReal(self.sym_strides());
  44. auto new_storage_offset = self.sym_storage_offset() * 2;
  45. const auto float_type = c10::toRealValueType(self.scalar_type());
  46. auto real_tensor = view_tensor(self, float_type, std::move(new_storage_offset), new_sizes, new_strides);
  47. return real_tensor;
  48. }
  49. // expects as input a complex tensor and returns back a tensor
  50. // with corresponding real dtype containing the complex values
  51. // in the last two dimensions
  52. Tensor view_as_real(const Tensor& self) {
  53. TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
  54. return _view_as_real_physical(self);
  55. }
  56. inline SymDimVector computeStrideForViewAsComplex(
  57. SymIntArrayRef oldstride,
  58. SymIntArrayRef oldsizes) {
  59. const auto dim = oldstride.size();
  60. TORCH_CHECK(dim > 0, "Tensor must have one or more dimensions");
  61. TORCH_SYM_CHECK(oldstride[dim - 1].sym_eq(1), "Tensor must have a last dimension with stride 1");
  62. SymDimVector res(dim - 1);
  63. for (const auto i : c10::irange(res.size())) {
  64. // Skip divisibility check for singleton dimensions
  65. TORCH_SYM_CHECK(
  66. oldsizes[i].sym_eq(1) | (oldstride[i] % 2).sym_eq(0),
  67. "Tensor must have a stride divisible by 2 for all but last dimension");
  68. res[i] = oldstride[i] / 2;
  69. }
  70. return res;
  71. }
  72. // expects as input a float or double tensor with last dimension of size 2
  73. // and returns back a tensor with corresponding complex dtype
  74. Tensor view_as_complex(const Tensor& self) {
  75. TORCH_CHECK(
  76. self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf,
  77. "view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type());
  78. auto old_sizes = self.sym_sizes();
  79. TORCH_CHECK(!old_sizes.empty(), "Input tensor must have one or more dimensions");
  80. TORCH_SYM_CHECK(old_sizes[old_sizes.size()-1].sym_eq(2), "Tensor must have a last dimension of size 2");
  81. SymDimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1);
  82. const auto new_strides = computeStrideForViewAsComplex(self.sym_strides(), self.sym_sizes());
  83. const auto complex_type = c10::toComplexType(self.scalar_type());
  84. TORCH_SYM_CHECK((self.sym_storage_offset() % 2).sym_eq(0), "Tensor must have a storage_offset divisible by 2");
  85. const auto new_storage_offset = self.sym_storage_offset() / 2;
  86. return view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides);
  87. }
  88. } // namespace at::native
  89. #else
  90. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  91. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)