utils.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from collections import defaultdict, deque
  2. import numpy as np
  3. class _SleepTimeController:
  4. def __init__(self):
  5. self.L = 0.0
  6. self.H = 0.4
  7. self._recompute_candidates()
  8. # Defaultdict mapping.
  9. self.results = defaultdict(lambda: deque(maxlen=3))
  10. self.iteration = 0
  11. def _recompute_candidates(self):
  12. self.center = (self.L + self.H) / 2
  13. self.low = (self.L + self.center) / 2
  14. self.high = (self.H + self.center) / 2
  15. # Expand a little if range becomes too narrow to avoid
  16. # overoptimization.
  17. if self.H - self.L < 0.00001:
  18. self.L = max(self.center - 0.1, 0.0)
  19. self.H = min(self.center + 0.1, 1.0)
  20. self._recompute_candidates()
  21. # Reduce results, just in case it has grown too much.
  22. c, l, h = (
  23. self.results[self.center],
  24. self.results[self.low],
  25. self.results[self.high],
  26. )
  27. self.results = defaultdict(lambda: deque(maxlen=3))
  28. self.results[self.center] = c
  29. self.results[self.low] = l
  30. self.results[self.high] = h
  31. @property
  32. def current(self):
  33. if len(self.results[self.center]) < 3:
  34. return self.center
  35. elif len(self.results[self.low]) < 3:
  36. return self.low
  37. else:
  38. return self.high
  39. def log_result(self, performance):
  40. self.iteration += 1
  41. # Skip first 2 iterations for ignoring warm-up effect.
  42. if self.iteration < 2:
  43. return
  44. self.results[self.current].append(performance)
  45. # If all candidates have at least 3 results logged, re-evaluate
  46. # and compute new L and H.
  47. center, low, high = self.center, self.low, self.high
  48. if (
  49. len(self.results[center]) == 3
  50. and len(self.results[low]) == 3
  51. and len(self.results[high]) == 3
  52. ):
  53. perf_center = np.mean(self.results[center])
  54. perf_low = np.mean(self.results[low])
  55. perf_high = np.mean(self.results[high])
  56. # Case: `center` is best.
  57. if perf_center > perf_low and perf_center > perf_high:
  58. self.L = low
  59. self.H = high
  60. # Erase low/high results: We'll not use these again.
  61. self.results.pop(low, None)
  62. self.results.pop(high, None)
  63. # Case: `low` is best.
  64. elif perf_low > perf_center and perf_low > perf_high:
  65. self.H = center
  66. # Erase center/high results: We'll not use these again.
  67. self.results.pop(center, None)
  68. self.results.pop(high, None)
  69. # Case: `high` is best.
  70. else:
  71. self.L = center
  72. # Erase center/low results: We'll not use these again.
  73. self.results.pop(center, None)
  74. self.results.pop(low, None)
  75. self._recompute_candidates()
  76. if __name__ == "__main__":
  77. controller = _SleepTimeController()
  78. for _ in range(1000):
  79. performance = np.random.random()
  80. controller.log_result(performance)