SegmentReduce.h 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/native/DispatchStub.h>
  4. #include <ATen/native/ReductionType.h>
  5. #include <c10/core/Scalar.h>
  6. #include <optional>
  7. namespace at {
  8. class Tensor;
  9. namespace native {
  10. using segment_reduce_lengths_fn = Tensor (*)(
  11. ReductionType,
  12. const Tensor&,
  13. const Tensor&,
  14. int64_t,
  15. const std::optional<Scalar>&);
  16. DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub)
  17. using segment_reduce_offsets_fn = Tensor (*)(
  18. ReductionType,
  19. const Tensor&,
  20. const Tensor&,
  21. int64_t,
  22. const std::optional<Scalar>&);
  23. DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub)
  24. using segment_reduce_lengths_backward_fn = Tensor (*)(
  25. const Tensor&,
  26. const Tensor&,
  27. const Tensor&,
  28. ReductionType,
  29. const Tensor&,
  30. int64_t,
  31. const std::optional<Scalar>&);
  32. DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub)
  33. using segment_reduce_offsets_backward_fn = Tensor (*)(
  34. const Tensor&,
  35. const Tensor&,
  36. const Tensor&,
  37. ReductionType,
  38. const Tensor&,
  39. int64_t,
  40. const std::optional<Scalar>&);
  41. DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub)
  42. } // namespace native
  43. } // namespace at
  44. #else
  45. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  46. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)