window_stat.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import numpy as np
  2. from ray.rllib.utils.annotations import OldAPIStack
  3. @OldAPIStack
  4. class WindowStat:
  5. """Handles/stores incoming dataset and provides window-based statistics.
  6. .. testcode::
  7. :skipif: True
  8. win_stats = WindowStat("level", 3)
  9. win_stats.push(5.0)
  10. win_stats.push(7.0)
  11. win_stats.push(7.0)
  12. win_stats.push(10.0)
  13. # Expect 8.0 as the mean of the last 3 values: (7+7+10)/3=8.0
  14. print(win_stats.mean())
  15. .. testoutput::
  16. 8.0
  17. """
  18. def __init__(self, name: str, n: int):
  19. """Initializes a WindowStat instance.
  20. Args:
  21. name: The name of the stats to collect and return stats for.
  22. n: The window size. Statistics will be computed for the last n
  23. items received from the stream.
  24. """
  25. # The window-size.
  26. self.window_size = n
  27. # The name of the data (used for `self.stats()`).
  28. self.name = name
  29. # List of items to do calculations over (len=self.n).
  30. self.items = [None] * self.window_size
  31. # The current index to insert the next item into `self.items`.
  32. self.idx = 0
  33. # How many items have been added over the lifetime of this object.
  34. self.count = 0
  35. def push(self, obj) -> None:
  36. """Pushes a new value/object into the data buffer."""
  37. # Insert object at current index.
  38. self.items[self.idx] = obj
  39. # Increase insertion index by 1.
  40. self.idx += 1
  41. # Increase lifetime count by 1.
  42. self.count += 1
  43. # Fix index in case of rollover.
  44. self.idx %= len(self.items)
  45. def mean(self) -> float:
  46. """Returns the (NaN-)mean of the last `self.window_size` items."""
  47. return float(np.nanmean(self.items[: self.count]))
  48. def std(self) -> float:
  49. """Returns the (NaN)-stddev of the last `self.window_size` items."""
  50. return float(np.nanstd(self.items[: self.count]))
  51. def quantiles(self) -> np.ndarray:
  52. """Returns ndarray with 0, 10, 50, 90, and 100 percentiles."""
  53. if not self.count:
  54. return np.ndarray([], dtype=np.float32)
  55. else:
  56. return np.nanpercentile(
  57. self.items[: self.count], [0, 10, 50, 90, 100]
  58. ).tolist()
  59. def stats(self):
  60. return {
  61. self.name + "_count": int(self.count),
  62. self.name + "_mean": self.mean(),
  63. self.name + "_std": self.std(),
  64. self.name + "_quantiles": self.quantiles(),
  65. }