IndexKernel.h 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/native/DispatchStub.h>
  4. #include <c10/util/ArrayRef.h>
  5. namespace at {
  6. class Tensor;
  7. class TensorBase;
  8. struct TensorIterator;
  9. struct TensorIteratorBase;
  10. }
  11. namespace c10 {
  12. class Scalar;
  13. }
  14. namespace at::native {
  15. using index_fn = void(*)(TensorIteratorBase &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
  16. using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source);
  17. using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride);
  18. using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
  19. using put_fn = void(*)(TensorIterator & iter, const TensorBase& self, const bool accumulate);
  20. using take_fn = void(*)(TensorIterator & iter, const TensorBase& input);
  21. using flip_fn = void(*)(TensorIterator &, const bool);
  22. using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar);
  23. using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride);
  24. using masked_scatter_fn = void(*)(TensorIterator &, const TensorBase &);
  25. DECLARE_DISPATCH(index_fn, index_stub)
  26. DECLARE_DISPATCH(index_fill_fn, index_fill_stub)
  27. DECLARE_DISPATCH(index_copy_fn, index_copy_stub)
  28. DECLARE_DISPATCH(index_put_fn, index_put_stub)
  29. DECLARE_DISPATCH(put_fn, put_stub)
  30. DECLARE_DISPATCH(take_fn, take_stub)
  31. DECLARE_DISPATCH(flip_fn, flip_stub)
  32. DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub)
  33. DECLARE_DISPATCH(masked_select_fn, masked_select_serial_stub)
  34. DECLARE_DISPATCH(masked_select_fn, masked_select_stub)
  35. DECLARE_DISPATCH(masked_scatter_fn, masked_scatter_stub)
  36. } // namespace at::native
  37. #else
  38. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  39. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)