common.h 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. // Set of global constants that could be shareable between CPU and Metal code
  4. #ifdef __METAL__
  5. #include <metal_array>
  6. #define C10_METAL_CONSTEXPR constant constexpr
  7. #else
  8. #include <array>
  9. #define C10_METAL_CONSTEXPR constexpr
  10. #endif
  11. #define C10_METAL_ALL_TYPES_FUNCTOR(_) \
  12. _(Byte, 0) \
  13. _(Char, 1) \
  14. _(Short, 2) \
  15. _(Int, 3) \
  16. _(Long, 4) \
  17. _(Half, 5) \
  18. _(Float, 6) \
  19. _(ComplexHalf, 8) \
  20. _(ComplexFloat, 9) \
  21. _(Bool, 11) \
  22. _(BFloat16, 15)
  23. namespace c10 {
  24. namespace metal {
  25. C10_METAL_CONSTEXPR unsigned max_ndim = 16;
  26. C10_METAL_CONSTEXPR unsigned simdgroup_size = 32;
  27. #ifdef __METAL__
  28. template <typename T, unsigned N>
  29. using array = ::metal::array<T, N>;
  30. #else
  31. template <typename T, unsigned N>
  32. using array = std::array<T, N>;
  33. #endif
  34. enum class ScalarType {
  35. #define _DEFINE_ENUM_VAL_(_v, _n) _v = _n,
  36. C10_METAL_ALL_TYPES_FUNCTOR(_DEFINE_ENUM_VAL_)
  37. #undef _DEFINE_ENUM_VAL_
  38. };
  39. } // namespace metal
  40. } // namespace c10
  41. #else
  42. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  43. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)