SobolEngineOpsUtils.h 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. /// This file contains some tensor-agnostic operations to be used in the
  3. /// core functions of the `SobolEngine`
  4. #include <ATen/core/Tensor.h>
  5. #ifndef AT_PER_OPERATOR_HEADERS
  6. #include <ATen/Functions.h>
  7. #else
  8. #include <ATen/ops/arange.h>
  9. #include <ATen/ops/mul.h>
  10. #include <ATen/ops/pow.h>
  11. #endif
  12. namespace at::native::sobol_utils {
  13. /// Function to return the minimum of number of bits to represent the integer `n`
  14. inline int64_t bit_length(const int64_t n) {
  15. int64_t nbits, nloc;
  16. for (nloc = n, nbits = 0; nloc > 0; nloc /= 2, nbits++);
  17. return nbits;
  18. }
  19. /// Function to get the position of the rightmost zero in the bit representation of an integer
  20. /// This value is the zero-indexed position
  21. inline int64_t rightmost_zero(const int64_t n) {
  22. int64_t z, i;
  23. for (z = n, i = 0; z % 2 == 1; z /= 2, i++);
  24. return i;
  25. }
  26. /// Function to get a subsequence of bits in the representation of an integer starting from
  27. /// `pos` and of length `length`
  28. inline int64_t bitsubseq(const int64_t n, const int64_t pos, const int64_t length) {
  29. return (n >> pos) & ((1 << length) - 1);
  30. }
  31. /// Function to perform the inner product between a batched square matrix and a power of 2 vector
  32. inline at::Tensor cdot_pow2(const at::Tensor& bmat) {
  33. at::Tensor inter = at::arange(bmat.size(-1) - 1, -1, -1, bmat.options());
  34. inter = at::pow(2, inter).expand_as(bmat);
  35. return at::mul(inter, bmat).sum(-1);
  36. }
  37. /// All definitions below this point are data. These are constant, and should not be modified
  38. /// without notice
  39. constexpr int64_t MAXDIM = 21201;
  40. constexpr int64_t MAXDEG = 18;
  41. constexpr int64_t MAXBIT = 30;
  42. constexpr int64_t LARGEST_NUMBER = 1 << MAXBIT;
  43. constexpr float RECIPD = 1.0 / LARGEST_NUMBER;
  44. extern const int64_t poly[MAXDIM];
  45. extern const int64_t initsobolstate[MAXDIM][MAXDEG];
  46. } // namespace at::native::sobol_utils
  47. #else
  48. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  49. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)