LegacyBatchedFallback.h 1.2 KB

123456789101112131415161718192021222324252627282930
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/ATen.h>
  4. #include <ATen/core/op_registration/op_registration.h>
  5. #include <torch/library.h>
  6. namespace at {
  7. // If an operator doesn't have a batching rule implemented then we fallback
  8. // to this implementation. The fallback only works on out-of-place operators
  9. // that return only tensors with new memory. (e.g., no in-place operators, no
  10. // view operations).
  11. //
  12. // The fallback effectively takes all of the BatchedTensors in `stack`, slices
  13. // them, and runs `op` on all of the corresponding slices to produce slices
  14. // of the outputs. The output slices then get `torch.stack`ed to create the
  15. // final returns.
  16. //
  17. // The performance of the fallback is not very good because it introduces an
  18. // extra copy from stacking the sliced outputs. Because of this, we prefer to
  19. // write batching rules for operators whenever possible.
  20. void batchedTensorForLoopFallback(
  21. const c10::OperatorHandle& op,
  22. torch::jit::Stack* stack);
  23. } // namespace at
  24. #else
  25. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  26. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)