DeviceGuard.h 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/IListRef.h>
  4. #include <ATen/core/Tensor.h>
  5. #include <c10/core/DeviceGuard.h>
  6. #include <c10/core/ScalarType.h> // TensorList whyyyyy
  7. namespace at {
  8. // Are you here because you're wondering why DeviceGuard(tensor) no
  9. // longer works? For code organization reasons, we have temporarily(?)
  10. // removed this constructor from DeviceGuard. The new way to
  11. // spell it is:
  12. //
  13. // OptionalDeviceGuard guard(device_of(tensor));
  14. /// Return the Device of a Tensor, if the Tensor is defined.
  15. inline std::optional<Device> device_of(const Tensor& t) {
  16. if (t.defined()) {
  17. return t.device();
  18. } else {
  19. return std::nullopt;
  20. }
  21. }
  22. inline std::optional<Device> device_of(const std::optional<Tensor>& t) {
  23. return t.has_value() ? device_of(t.value()) : std::nullopt;
  24. }
  25. /// Return the Device of a TensorList, if the list is non-empty and
  26. /// the first Tensor is defined. (This function implicitly assumes
  27. /// that all tensors in the list have the same device.)
  28. inline std::optional<Device> device_of(ITensorListRef t) {
  29. if (!t.empty()) {
  30. return device_of(t.front());
  31. } else {
  32. return std::nullopt;
  33. }
  34. }
  35. } // namespace at
  36. #else
  37. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  38. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)