TopKImpl.h 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/TensorAccessor.h>
  4. #include <ATen/NumericUtils.h>
  5. namespace at::native {
  6. #ifdef CPU_CAPABILITY
  7. inline namespace CPU_CAPABILITY {
  8. #else
  9. inline namespace DEFAULT {
  10. #endif
  11. // Core topk loop, shared between CPU and QuantizedCPU
  12. template <typename scalar_t, typename accscalar_t>
  13. void topk_impl_loop(
  14. const int64_t mode_values_stride,
  15. const int64_t mode_indices_stride,
  16. const int64_t tmp_values_stride,
  17. const int64_t k,
  18. const int64_t dim_size,
  19. const bool largest,
  20. const bool sorted,
  21. char** data, const int64_t* strides, const int64_t n) {
  22. // If k is zero, then output values and indices are empty tensors
  23. // So iterating over other dims is pointless
  24. if (k == 0) {
  25. return;
  26. }
  27. using elem_t = std::pair<accscalar_t, int64_t>;
  28. std::vector<elem_t> queue(dim_size);
  29. for (const auto i : c10::irange(n)) {
  30. TensorAccessor<scalar_t, 1> mode_values(
  31. reinterpret_cast<scalar_t*>(data[0] + i * strides[0]),
  32. &k, &mode_values_stride);
  33. TensorAccessor<int64_t, 1> mode_indices(
  34. reinterpret_cast<int64_t*>(data[1] + i * strides[1]),
  35. &k, &mode_indices_stride);
  36. TensorAccessor<const scalar_t, 1> tmp_values(
  37. reinterpret_cast<scalar_t*>(data[2] + i * strides[2]),
  38. &dim_size, &tmp_values_stride);
  39. auto n_2 = dim_size;
  40. auto use_partial_sort = k * 64 <= n_2;
  41. for (const auto j : c10::irange(n_2)) {
  42. queue[j].first = tmp_values[j];
  43. queue[j].second = j;
  44. }
  45. // we want nan to be sorted as top for numpy compatibility
  46. if (use_partial_sort) {
  47. if (largest) {
  48. std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
  49. [](const elem_t& x, const elem_t& y) -> bool {
  50. return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
  51. });
  52. } else {
  53. std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
  54. [](const elem_t& x, const elem_t& y) -> bool {
  55. return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
  56. });
  57. }
  58. } else {
  59. if (largest) {
  60. std::nth_element(queue.begin(), queue.begin() + k - 1, queue.end(),
  61. [](const elem_t& x, const elem_t& y) -> bool {
  62. return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
  63. });
  64. if (sorted) {
  65. std::sort(queue.begin(), queue.begin() + k - 1,
  66. [](const elem_t& x, const elem_t& y) -> bool {
  67. return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
  68. });
  69. }
  70. } else {
  71. std::nth_element(queue.begin(), queue.begin() + k -1, queue.end(),
  72. [](const elem_t& x, const elem_t& y) -> bool {
  73. return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
  74. });
  75. if (sorted) {
  76. std::sort(queue.begin(), queue.begin() + k -1,
  77. [](const elem_t& x, const elem_t& y) -> bool {
  78. return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
  79. });
  80. }
  81. }
  82. }
  83. for (const auto j : c10::irange(k)) {
  84. mode_values[j] = queue[j].first;
  85. mode_indices[j] = queue[j].second;
  86. }
  87. }
  88. }
  89. } // namespace CPU_CAPABILITY
  90. } // namespace at::native
  91. #else
  92. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  93. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)