EmbeddingBag.h 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/Tensor.h>
  4. #include <ATen/Config.h>
  5. #include <cstdint>
  6. #ifdef USE_FBGEMM
  7. #include <fbgemm/FbgemmEmbedding.h>
  8. #endif
  9. namespace at::native {
  10. enum class EmbeddingBagMode {
  11. SUM = 0,
  12. MEAN = 1,
  13. MAX = 2,
  14. };
  15. [[maybe_unused]] static bool operator==(int64_t op1, EmbeddingBagMode op2) {
  16. return op1 == static_cast<int64_t>(op2);
  17. }
  18. [[maybe_unused]] static bool operator!=(int64_t op1, EmbeddingBagMode op2) {
  19. return !(op1 == op2);
  20. }
  21. void check_arguments(
  22. const Tensor& weight,
  23. const Tensor& indices,
  24. const Tensor& offsets,
  25. const int64_t mode,
  26. const std::optional<Tensor>& per_sample_weights,
  27. bool include_last_offset);
  28. void make_bag_size_out(
  29. Tensor& bag_size_out,
  30. const Tensor& offsets,
  31. const Tensor& indices,
  32. const int64_t mode,
  33. const bool include_last_offset,
  34. const bool requires_grad);
  35. void make_max_indices_out(
  36. Tensor& max_indices_out,
  37. const Tensor& weight,
  38. const Tensor& indices,
  39. const Tensor& offsets,
  40. const Tensor& bag_size,
  41. const int64_t mode,
  42. bool include_last_offset);
  43. void make_offset2bag_out(
  44. Tensor& offset2bag,
  45. Tensor& output,
  46. const Tensor& weight,
  47. const Tensor& indices,
  48. const Tensor& offsets,
  49. const int64_t mode,
  50. const std::optional<Tensor>& per_sample_weights,
  51. const int64_t padding_idx = -1);
  52. #ifdef USE_FBGEMM
  53. template<bool has_weight, typename TIndex, typename TData>
  54. struct _CallbackAndBlockSize {
  55. using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type;
  56. int64_t blockSize = -1;
  57. TCallback callback = nullptr;
  58. static TCallback generateCallback(int64_t block_size) {
  59. return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>(
  60. block_size,
  61. has_weight,
  62. /* normalize_by_lengths */false,
  63. /* prefetch */16,
  64. /* is_weight_positional */false,
  65. /* use_offsets */true);
  66. }
  67. _CallbackAndBlockSize() = default;
  68. explicit _CallbackAndBlockSize(std::optional<int64_t> maybe_block_size)
  69. : blockSize(maybe_block_size.value_or(-1))
  70. , callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr)
  71. {}
  72. };
  73. template<typename... StorageMixins>
  74. struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
  75. _EmbeddingBagKernelCacheImpl() = default;
  76. // use each of the mixins to store corresponding kernel and block size
  77. explicit _EmbeddingBagKernelCacheImpl(std::optional<int64_t> maybe_block_size)
  78. : StorageMixins(maybe_block_size)...
  79. {}
  80. // this method is thread safe (call sites may call from different threads)
  81. template<bool has_weight, typename TIndex, typename TData>
  82. typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback
  83. getCallback(int64_t block_size) const {
  84. // if the cache doesn't store the kernel for the incoming block size
  85. // (so it is different from the one stored in corresponding mixin)
  86. // regenerate the kernel (not writing it into the cache so we avoid locks)
  87. if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) {
  88. return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size);
  89. }
  90. // else retrieve the cached kernel from the corresponding mixin
  91. return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback;
  92. }
  93. };
  94. // instantiate the cache with the list of storage mixins
  95. // for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
  96. using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
  97. _CallbackAndBlockSize<true, int32_t, float>,
  98. _CallbackAndBlockSize<false, int32_t, float>,
  99. _CallbackAndBlockSize<true, int64_t, float>,
  100. _CallbackAndBlockSize<false, int64_t, float>,
  101. _CallbackAndBlockSize<true, int32_t, unsigned short>,
  102. _CallbackAndBlockSize<false, int32_t, unsigned short>,
  103. _CallbackAndBlockSize<true, int64_t, unsigned short>,
  104. _CallbackAndBlockSize<false, int64_t, unsigned short>>;
  105. #else
  106. struct _EmbeddingBagKernelCache {
  107. explicit _EmbeddingBagKernelCache(std::optional<int64_t> /* maybe_block_size */) {}
  108. };
  109. #endif
  110. void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
  111. Tensor& bag_size, Tensor* max_indices,
  112. const Tensor &weight, const Tensor &indices,
  113. const Tensor &offsets, const int64_t mode = 0,
  114. const std::optional<Tensor>& per_sample_weights = std::nullopt,
  115. bool include_last_offset = false,
  116. int64_t padding_idx = -1,
  117. _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
  118. void _embedding_bag_cpu_out(
  119. at::Tensor& output,
  120. at::Tensor& offset2bag,
  121. at::Tensor& bag_size,
  122. at::Tensor* p_max_indices,
  123. const at::Tensor& weight,
  124. const at::Tensor& indices,
  125. const at::Tensor& offsets,
  126. const bool scale_grad_by_freq,
  127. const int64_t mode,
  128. const bool sparse,
  129. const std::optional<at::Tensor>& per_sample_weights,
  130. const bool include_last_offset,
  131. const std::optional<int64_t>& padding_idx,
  132. _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
  133. } // namespace at::native
  134. #else
  135. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  136. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)