XPUScaledBlas.h 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #include <c10/core/Scalar.h>
  3. #include <c10/core/ScalarType.h>
  4. #include <c10/util/Exception.h>
  5. #include <c10/util/SmallVector.h>
  6. #include <c10/util/typeid.h>
  7. #include <cstdint>
  8. #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
  9. #include <ATen/Dispatch.h>
  10. #include <ATen/ExpandUtils.h>
  11. #include <ATen/OpMathType.h>
  12. #include <ATen/TensorUtils.h>
  13. #include <ATen/core/NamedTensor.h>
  14. #include <ATen/core/Tensor.h>
  15. #include <ATen/native/Resize.h>
  16. #include <c10/util/MaybeOwned.h>
  17. #include <ATen/BlasBackend.h>
  18. #include <ATen/ceil_div.h>
  19. #ifdef USE_MSLK
  20. #include <mslk/gemm/gemm_torch.h>
  21. #endif
  22. #ifndef AT_PER_OPERATOR_HEADERS
  23. #include <ATen/Functions.h>
  24. #include <ATen/NativeFunctions.h>
  25. #else
  26. #include <ATen/ops/_addmm_activation_native.h>
  27. #include <ATen/ops/_efficientzerotensor.h>
  28. #include <ATen/ops/_scaled_mm_native.h>
  29. #include <ATen/ops/_unsafe_view_native.h>
  30. #include <ATen/ops/abs.h>
  31. #include <ATen/ops/addmm_native.h>
  32. #include <ATen/ops/addmv_native.h>
  33. #include <ATen/ops/baddbmm_native.h>
  34. #include <ATen/ops/bmm_native.h>
  35. #include <ATen/ops/copy_native.h>
  36. #include <ATen/ops/dot_native.h>
  37. #include <ATen/ops/empty.h>
  38. #include <ATen/ops/empty_strided.h>
  39. #include <ATen/ops/gelu.h>
  40. #include <ATen/ops/max.h>
  41. #include <ATen/ops/mm_native.h>
  42. #include <ATen/ops/mul.h>
  43. #include <ATen/ops/ones.h>
  44. #include <ATen/ops/relu.h>
  45. #include <ATen/ops/scalar_tensor_native.h>
  46. #include <ATen/ops/vdot_native.h>
  47. #endif
  48. using at::blas::ScalingType;
  49. namespace at::native::onednn::scaled {
  50. /**
  51. * Track concrete implementations available
  52. */
  53. enum class ScaledGemmImplementation {
  54. NONE = 0,
  55. TENSORWISE_TENSORWISE = 1,
  56. ROWWISE_ROWWISE = 2,
  57. };
  58. /**
  59. * Convert passed int (enum) from python back into a
  60. * strictly-typed enum
  61. */
  62. template <class EnumType, class ArrayType>
  63. std::vector<EnumType> convert_int_to_enum(ArrayType& v) {
  64. std::vector<EnumType> converted;
  65. converted.reserve(v.size());
  66. for (auto vi : v) {
  67. converted.push_back(static_cast<EnumType>(vi));
  68. }
  69. return converted;
  70. }
  71. bool check_tensorwise_recipe(
  72. c10::ScalarType,
  73. std::vector<ScalingType>&,
  74. ArrayRef<Tensor>&,
  75. c10::ScalarType,
  76. std::vector<ScalingType>&,
  77. ArrayRef<Tensor>&);
  78. bool check_rowwise_recipe(
  79. c10::ScalarType,
  80. std::vector<ScalingType>&,
  81. ArrayRef<Tensor>&,
  82. c10::ScalarType,
  83. std::vector<ScalingType>&,
  84. ArrayRef<Tensor>&);
  85. } // namespace at::native::onednn::scaled
  86. #else
  87. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  88. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)