CompositeViewCopyKernels.cpp 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
  2. // ${generated_comment}
  3. #include <ATen/InferSize.h>
  4. #include <ATen/Tensor.h>
  5. #include <ATen/native/Resize.h>
  6. #ifndef AT_PER_OPERATOR_HEADERS
  7. #include <ATen/Operators.h>
  8. #else
  9. #include <ATen/ops/clone.h>
  10. $ops_headers
  11. #endif
  12. namespace at {
  13. namespace native {
  14. // This file contains a number of kernels for aten functions that are fully code-generated.
  15. // TODO: rename this file to something more generic.
  16. namespace {
  17. at::Tensor clone_arg(const at::Tensor& t) {
  18. return t.clone();
  19. }
  20. std::vector<at::Tensor> clone_arg(const at::TensorList& t_list) {
  21. std::vector<at::Tensor> out(t_list.size());
  22. for (const auto& i : c10::irange(t_list.size())) {
  23. out[i] = t_list[i].clone();
  24. }
  25. return out;
  26. }
  27. // duped with gen_resize_out_helper from structured kernels
  28. void copy_arg(const at::Tensor& dst, const at::Tensor& src) {
  29. TORCH_CHECK(src.dtype() == dst.dtype(),
  30. "Expected out tensor to have dtype ", src.dtype(), ", but got ", dst.dtype(), " instead");
  31. TORCH_CHECK(src.device() == dst.device(),
  32. "Expected out tensor to have device ", src.device(), ", but got ", dst.device(), " instead");
  33. dst.copy_(src);
  34. }
  35. void copy_arg(const at::TensorList& dst, const at::TensorList& src) {
  36. TORCH_INTERNAL_ASSERT(dst.size() == src.size());
  37. for (const auto& i : c10::irange(dst.size())) {
  38. copy_arg(dst[i], src[i]);
  39. }
  40. }
  41. // TODO: this doesn't handle restriding empty tensors correctly; see
  42. // gen_resize_out_helper for the correct algorithm
  43. void resize_out_helper(const at::Tensor& dst, const at::Tensor& src) {
  44. at::native::resize_output(dst, src.sizes());
  45. }
  46. void resize_out_helper(const at::TensorList& dst, const at::TensorList& src) {
  47. TORCH_INTERNAL_ASSERT(dst.size() == src.size());
  48. for (const auto& i : c10::irange(dst.size())) {
  49. at::native::resize_output(dst[i], src[i].sizes());
  50. }
  51. }
  52. }
  53. ${CompositeViewCopyKernel_Definitions}
  54. ${GeneratedCompositeFunctional_Definitions}
  55. ${GeneratedCompositeOut_Definitions}
  56. } // namespace native
  57. } // namespace at