Functions.cpp 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #include <array>
  2. #include <ATen/Functions.h>
  3. #include <ATen/Utils.h>
  4. #include <c10/core/Allocator.h>
  5. namespace at {
  6. Tensor TensorMaker::make_tensor() {
  7. AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
  8. tracer::impl::NoTracerDispatchMode tracer_guard{};
  9. check_size_nonnegative(sizes_);
  10. TORCH_CHECK_VALUE(
  11. !deleter_ || !ctx_,
  12. "The deleter and context arguments are mutually exclusive.");
  13. if (device_ == std::nullopt) {
  14. device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type());
  15. }
  16. if (opts_.device().has_index()) {
  17. // clang-format off
  18. TORCH_CHECK_VALUE(
  19. opts_.device() == *device_,
  20. "Specified device ", opts_.device(), " does not match device of data ", *device_);
  21. // clang-format on
  22. }
  23. std::size_t size_bytes = computeStorageSize();
  24. DataPtr data_ptr{};
  25. if (deleter_) {
  26. data_ptr = makeDataPtrFromDeleter();
  27. } else {
  28. data_ptr = makeDataPtrFromContext();
  29. }
  30. TORCH_CHECK(!resizeable_ || allocator_ != nullptr, "Must specify an allocator with allocator() if you want to use resizeable_storage()");
  31. Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), /*allocator=*/allocator_, /*resizable=*/resizeable_};
  32. Tensor tensor = detail::make_tensor<TensorImpl>(
  33. std::move(storage), opts_.computeDispatchKey(), opts_.dtype());
  34. TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
  35. if (strides_) {
  36. tensor_impl->set_sizes_and_strides(sizes_, *strides_);
  37. } else {
  38. tensor_impl->set_sizes_contiguous(sizes_);
  39. }
  40. if (storage_offset_) {
  41. tensor_impl->set_storage_offset(*storage_offset_);
  42. }
  43. tensor_impl->set_requires_grad(opts_.requires_grad());
  44. return tensor;
  45. }
  46. std::size_t TensorMaker::computeStorageSize() const noexcept {
  47. std::size_t itemsize = opts_.dtype().itemsize();
  48. if (strides_) {
  49. auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize);
  50. if (storage_offset_) {
  51. storage_size += storage_offset_.value() * itemsize;
  52. }
  53. return storage_size;
  54. }
  55. std::size_t size = 1;
  56. for (std::int64_t s : sizes_) {
  57. size *= static_cast<std::size_t>(s);
  58. }
  59. auto storage_size = size * itemsize;
  60. if (storage_offset_) {
  61. storage_size += storage_offset_.value() * itemsize;
  62. }
  63. return storage_size;
  64. }
  65. inline DataPtr TensorMaker::makeDataPtrFromDeleter() noexcept {
  66. return InefficientStdFunctionContext::makeDataPtr(data_, std::move(deleter_), *device_);
  67. }
  68. inline DataPtr TensorMaker::makeDataPtrFromContext() noexcept {
  69. return DataPtr{data_, ctx_.release(), ctx_.get_deleter(), *device_};
  70. }
  71. IntArrayRef TensorMaker::makeTempSizes() const noexcept {
  72. static std::int64_t zeros[5] = {0, 0, 0, 0, 0};
  73. if (opts_.has_memory_format()) {
  74. MemoryFormat format = *opts_.memory_format_opt();
  75. if (format == MemoryFormat::ChannelsLast) {
  76. return IntArrayRef(zeros, 4);
  77. }
  78. if (format == MemoryFormat::ChannelsLast3d) {
  79. return IntArrayRef(zeros, 5);
  80. }
  81. }
  82. return IntArrayRef(zeros, 1);
  83. }
  84. } // namespace at