| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #pragma once
- #include <ATen/core/TensorAccessor.h>
- #include <ATen/NumericUtils.h>
- namespace at::native {
- #ifdef CPU_CAPABILITY
- inline namespace CPU_CAPABILITY {
- #else
- inline namespace DEFAULT {
- #endif
- // Core topk loop, shared between CPU and QuantizedCPU
- template <typename scalar_t, typename accscalar_t>
- void topk_impl_loop(
- const int64_t mode_values_stride,
- const int64_t mode_indices_stride,
- const int64_t tmp_values_stride,
- const int64_t k,
- const int64_t dim_size,
- const bool largest,
- const bool sorted,
- char** data, const int64_t* strides, const int64_t n) {
- // If k is zero, then output values and indices are empty tensors
- // So iterating over other dims is pointless
- if (k == 0) {
- return;
- }
- using elem_t = std::pair<accscalar_t, int64_t>;
- std::vector<elem_t> queue(dim_size);
- for (const auto i : c10::irange(n)) {
- TensorAccessor<scalar_t, 1> mode_values(
- reinterpret_cast<scalar_t*>(data[0] + i * strides[0]),
- &k, &mode_values_stride);
- TensorAccessor<int64_t, 1> mode_indices(
- reinterpret_cast<int64_t*>(data[1] + i * strides[1]),
- &k, &mode_indices_stride);
- TensorAccessor<const scalar_t, 1> tmp_values(
- reinterpret_cast<scalar_t*>(data[2] + i * strides[2]),
- &dim_size, &tmp_values_stride);
- auto n_2 = dim_size;
- auto use_partial_sort = k * 64 <= n_2;
- for (const auto j : c10::irange(n_2)) {
- queue[j].first = tmp_values[j];
- queue[j].second = j;
- }
- // we want nan to be sorted as top for numpy compatibility
- if (use_partial_sort) {
- if (largest) {
- std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
- [](const elem_t& x, const elem_t& y) -> bool {
- return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
- });
- } else {
- std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
- [](const elem_t& x, const elem_t& y) -> bool {
- return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
- });
- }
- } else {
- if (largest) {
- std::nth_element(queue.begin(), queue.begin() + k - 1, queue.end(),
- [](const elem_t& x, const elem_t& y) -> bool {
- return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
- });
- if (sorted) {
- std::sort(queue.begin(), queue.begin() + k - 1,
- [](const elem_t& x, const elem_t& y) -> bool {
- return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
- });
- }
- } else {
- std::nth_element(queue.begin(), queue.begin() + k -1, queue.end(),
- [](const elem_t& x, const elem_t& y) -> bool {
- return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
- });
- if (sorted) {
- std::sort(queue.begin(), queue.begin() + k -1,
- [](const elem_t& x, const elem_t& y) -> bool {
- return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
- });
- }
- }
- }
- for (const auto j : c10::irange(k)) {
- mode_values[j] = queue[j].first;
- mode_indices[j] = queue[j].second;
- }
- }
- }
- } // namespace CPU_CAPABILITY
- } // namespace at::native
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|