accumulate.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. // Copyright 2004-present Facebook. All Rights Reserved.
  3. #pragma once
  4. #include <c10/util/Exception.h>
  5. #include <cstdint>
  6. #include <functional>
  7. #include <iterator>
  8. #include <numeric>
  9. #include <type_traits>
  10. #include <utility>
  11. namespace c10 {
  12. /// Sum of a list of integers; accumulates into the int64_t datatype
  13. template <
  14. typename C,
  15. std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
  16. inline int64_t sum_integers(const C& container) {
  17. // std::accumulate infers return type from `init` type, so if the `init` type
  18. // is not large enough to hold the result, computation can overflow. We use
  19. // `int64_t` here to avoid this.
  20. return std::accumulate(
  21. container.begin(), container.end(), static_cast<int64_t>(0));
  22. }
  23. /// Sum of integer elements referred to by iterators; accumulates into the
  24. /// int64_t datatype
  25. template <
  26. typename Iter,
  27. std::enable_if_t<
  28. std::is_integral_v<typename std::iterator_traits<Iter>::value_type>,
  29. int> = 0>
  30. inline int64_t sum_integers(Iter begin, Iter end) {
  31. // std::accumulate infers return type from `init` type, so if the `init` type
  32. // is not large enough to hold the result, computation can overflow. We use
  33. // `int64_t` here to avoid this.
  34. return std::accumulate(begin, end, static_cast<int64_t>(0));
  35. }
  36. /// Product of a list of integers; accumulates into the int64_t datatype
  37. template <
  38. typename C,
  39. std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
  40. inline int64_t multiply_integers(const C& container) {
  41. // std::accumulate infers return type from `init` type, so if the `init` type
  42. // is not large enough to hold the result, computation can overflow. We use
  43. // `int64_t` here to avoid this.
  44. return std::accumulate(
  45. container.begin(),
  46. container.end(),
  47. static_cast<int64_t>(1),
  48. std::multiplies<>());
  49. }
  50. /// Product of integer elements referred to by iterators; accumulates into the
  51. /// int64_t datatype
  52. template <
  53. typename Iter,
  54. std::enable_if_t<
  55. std::is_integral_v<typename std::iterator_traits<Iter>::value_type>,
  56. int> = 0>
  57. inline int64_t multiply_integers(Iter begin, Iter end) {
  58. // std::accumulate infers return type from `init` type, so if the `init` type
  59. // is not large enough to hold the result, computation can overflow. We use
  60. // `int64_t` here to avoid this.
  61. return std::accumulate(
  62. begin, end, static_cast<int64_t>(1), std::multiplies<>());
  63. }
  64. /// Return product of all dimensions starting from k
  65. /// Returns 1 if k>=dims.size()
  66. template <
  67. typename C,
  68. std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
  69. inline int64_t numelements_from_dim(const int k, const C& dims) {
  70. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(k >= 0);
  71. if (k > static_cast<int>(dims.size())) {
  72. return 1;
  73. } else {
  74. auto cbegin = dims.cbegin();
  75. std::advance(cbegin, k);
  76. return multiply_integers(cbegin, dims.cend());
  77. }
  78. }
  79. /// Product of all dims up to k (not including dims[k])
  80. /// Throws an error if k>dims.size()
  81. template <
  82. typename C,
  83. std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
  84. inline int64_t numelements_to_dim(const int k, const C& dims) {
  85. TORCH_INTERNAL_ASSERT(0 <= k);
  86. TORCH_INTERNAL_ASSERT((unsigned)k <= dims.size());
  87. auto cend = dims.cbegin();
  88. std::advance(cend, k);
  89. return multiply_integers(dims.cbegin(), cend);
  90. }
  91. /// Product of all dims between k and l (including dims[k] and excluding
  92. /// dims[l]) k and l may be supplied in either order
  93. template <
  94. typename C,
  95. std::enable_if_t<std::is_integral_v<typename C::value_type>, int> = 0>
  96. inline int64_t numelements_between_dim(int k, int l, const C& dims) {
  97. TORCH_INTERNAL_ASSERT(0 <= k);
  98. TORCH_INTERNAL_ASSERT(0 <= l);
  99. if (k > l) {
  100. std::swap(k, l);
  101. }
  102. TORCH_INTERNAL_ASSERT((unsigned)l < dims.size());
  103. auto cbegin = dims.cbegin();
  104. auto cend = dims.cbegin();
  105. std::advance(cbegin, k);
  106. std::advance(cend, l);
  107. return multiply_integers(cbegin, cend);
  108. }
  109. } // namespace c10
  110. #else
  111. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  112. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)