_mock.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from typing import Dict, List, Optional
  2. from ray.tune.experiment import Trial
  3. from ray.tune.search import ConcurrencyLimiter, Searcher
  4. from ray.tune.search.search_generator import SearchGenerator
  5. class _MockSearcher(Searcher):
  6. def __init__(self, **kwargs):
  7. self.live_trials = {}
  8. self.counter = {"result": 0, "complete": 0}
  9. self.final_results = []
  10. self.stall = False
  11. self.results = []
  12. super(_MockSearcher, self).__init__(**kwargs)
  13. def suggest(self, trial_id: str):
  14. if not self.stall:
  15. self.live_trials[trial_id] = 1
  16. return {"test_variable": 2}
  17. return None
  18. def on_trial_result(self, trial_id: str, result: Dict):
  19. self.counter["result"] += 1
  20. self.results += [result]
  21. def on_trial_complete(
  22. self, trial_id: str, result: Optional[Dict] = None, error: bool = False
  23. ):
  24. self.counter["complete"] += 1
  25. if result:
  26. self._process_result(result)
  27. if trial_id in self.live_trials:
  28. del self.live_trials[trial_id]
  29. def _process_result(self, result: Dict):
  30. self.final_results += [result]
  31. class _MockSuggestionAlgorithm(SearchGenerator):
  32. def __init__(self, max_concurrent: Optional[int] = None, **kwargs):
  33. self.searcher = _MockSearcher(**kwargs)
  34. if max_concurrent:
  35. self.searcher = ConcurrencyLimiter(
  36. self.searcher, max_concurrent=max_concurrent
  37. )
  38. super(_MockSuggestionAlgorithm, self).__init__(self.searcher)
  39. @property
  40. def live_trials(self) -> List[Trial]:
  41. return self.searcher.live_trials
  42. @property
  43. def results(self) -> List[Dict]:
  44. return self.searcher.results