sample.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. """sample."""
  2. import math
  3. class UniformSampleAccumulator:
  4. def __init__(self, min_samples=None):
  5. self._samples = min_samples or 64
  6. # force power of 2 samples
  7. self._samples = 2 ** int(math.ceil(math.log(self._samples, 2)))
  8. # target oversample by factor of 2
  9. self._samples2 = self._samples * 2
  10. # max size of each buffer
  11. self._max = self._samples2 // 2
  12. self._shift = 0
  13. self._mask = (1 << self._shift) - 1
  14. self._buckets = int(math.log(self._samples2, 2))
  15. self._buckets_bits = int(math.log(self._buckets, 2))
  16. self._buckets_mask = (1 << self._buckets_bits + 1) - 1
  17. self._buckets_index = 0
  18. self._bucket = []
  19. self._index = [0] * self._buckets
  20. self._count = 0
  21. self._log2 = [0]
  22. # pre-allocate buckets
  23. for _ in range(self._buckets):
  24. self._bucket.append([0] * self._max)
  25. # compute integer log2
  26. self._log2 += [int(math.log(i, 2)) for i in range(1, 2**self._buckets + 1)]
  27. def _show(self):
  28. print("=" * 20) # noqa: T201
  29. for b in range(self._buckets):
  30. b = (b + self._buckets_index) % self._buckets
  31. vals = [self._bucket[b][i] for i in range(self._index[b])]
  32. print(f"{b}: {vals}") # noqa: T201
  33. def add(self, val):
  34. self._count += 1
  35. cnt = self._count
  36. if cnt & self._mask:
  37. return
  38. b = cnt >> self._shift
  39. b = self._log2[b] # b = int(math.log(b, 2))
  40. if b >= self._buckets:
  41. self._index[self._buckets_index] = 0
  42. self._buckets_index = (self._buckets_index + 1) % self._buckets
  43. self._shift += 1
  44. self._mask = (self._mask << 1) | 1
  45. b += self._buckets - 1
  46. b = (b + self._buckets_index) % self._buckets
  47. self._bucket[b][self._index[b]] = val
  48. self._index[b] += 1
  49. def get(self):
  50. full = []
  51. sampled = []
  52. # self._show()
  53. for b in range(self._buckets):
  54. max_num = 2**b
  55. b = (b + self._buckets_index) % self._buckets
  56. modb = self._index[b] // max_num
  57. for i in range(self._index[b]):
  58. if not modb or i % modb == 0:
  59. sampled.append(self._bucket[b][i])
  60. full.append(self._bucket[b][i])
  61. if len(sampled) < self._samples:
  62. return tuple(full)
  63. return tuple(sampled)