ScalarOps.h 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/Tensor.h>
  4. #include <c10/core/Scalar.h>
  5. #ifndef AT_PER_OPERATOR_HEADERS
  6. #include <ATen/Functions.h>
  7. #else
  8. #include <ATen/ops/scalar_tensor.h>
  9. #endif
  10. namespace at::detail {
  11. // When filling a number to 1-element CPU tensor, we want to skip
  12. // everything but manipulate data ptr directly.
  13. // Ideally this fast pass should be implemented in TensorIterator,
  14. // but we also want to skip compute_types which in not avoidable
  15. // in TensorIterator for now.
  16. Tensor& scalar_fill(Tensor& self, const Scalar& value);
  17. TORCH_API Tensor scalar_tensor_static(
  18. const Scalar& s,
  19. std::optional<ScalarType> dtype_opt,
  20. std::optional<Device> device_opt);
  21. } // namespace at::detail
  22. // This is in the c10 namespace because we use ADL to find the functions in it.
  23. namespace c10 {
  24. // FIXME: this should be (and was) Scalar::toTensor, but there is currently no
  25. // way to implement this without going through Derived Types (which are not part
  26. // of core).
  27. inline at::Tensor scalar_to_tensor(
  28. const Scalar& s,
  29. const Device device = at::kCPU) {
  30. // This is the fast track we have for CPU scalar tensors.
  31. if (device == at::kCPU) {
  32. return at::detail::scalar_tensor_static(s, s.type(), at::kCPU);
  33. }
  34. return at::scalar_tensor(s, at::device(device).dtype(s.type()));
  35. }
  36. } // namespace c10
  37. namespace at::native {
  38. inline Tensor wrapped_scalar_tensor(
  39. const Scalar& scalar,
  40. const Device device = at::kCPU) {
  41. auto tensor = scalar_to_tensor(scalar, device);
  42. tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
  43. return tensor;
  44. }
  45. } // namespace at::native
  46. #else
  47. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  48. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)