ExpandBase.h 1.1 KB

1234567891011121314151617181920212223242526272829303132333435
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #include <ATen/core/TensorBase.h>
  3. // Broadcasting utilities for working with TensorBase
  4. namespace at {
  5. namespace internal {
  6. TORCH_API TensorBase expand_slow_path(const TensorBase& self, IntArrayRef size);
  7. } // namespace internal
  8. inline c10::MaybeOwned<TensorBase> expand_size(
  9. const TensorBase& self,
  10. IntArrayRef size) {
  11. if (size.equals(self.sizes())) {
  12. return c10::MaybeOwned<TensorBase>::borrowed(self);
  13. }
  14. return c10::MaybeOwned<TensorBase>::owned(
  15. at::internal::expand_slow_path(self, size));
  16. }
  17. c10::MaybeOwned<TensorBase> expand_size(TensorBase&& self, IntArrayRef size) =
  18. delete;
  19. inline c10::MaybeOwned<TensorBase> expand_inplace(
  20. const TensorBase& tensor,
  21. const TensorBase& to_expand) {
  22. return expand_size(to_expand, tensor.sizes());
  23. }
  24. c10::MaybeOwned<TensorBase> expand_inplace(
  25. const TensorBase& tensor,
  26. TensorBase&& to_expand) = delete;
  27. } // namespace at
  28. #else
  29. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  30. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)