| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- from typing import Dict, List, Optional
- from ray.tune.experiment import Trial
- from ray.tune.search import ConcurrencyLimiter, Searcher
- from ray.tune.search.search_generator import SearchGenerator
- class _MockSearcher(Searcher):
- def __init__(self, **kwargs):
- self.live_trials = {}
- self.counter = {"result": 0, "complete": 0}
- self.final_results = []
- self.stall = False
- self.results = []
- super(_MockSearcher, self).__init__(**kwargs)
- def suggest(self, trial_id: str):
- if not self.stall:
- self.live_trials[trial_id] = 1
- return {"test_variable": 2}
- return None
- def on_trial_result(self, trial_id: str, result: Dict):
- self.counter["result"] += 1
- self.results += [result]
- def on_trial_complete(
- self, trial_id: str, result: Optional[Dict] = None, error: bool = False
- ):
- self.counter["complete"] += 1
- if result:
- self._process_result(result)
- if trial_id in self.live_trials:
- del self.live_trials[trial_id]
- def _process_result(self, result: Dict):
- self.final_results += [result]
- class _MockSuggestionAlgorithm(SearchGenerator):
- def __init__(self, max_concurrent: Optional[int] = None, **kwargs):
- self.searcher = _MockSearcher(**kwargs)
- if max_concurrent:
- self.searcher = ConcurrencyLimiter(
- self.searcher, max_concurrent=max_concurrent
- )
- super(_MockSuggestionAlgorithm, self).__init__(self.searcher)
- @property
- def live_trials(self) -> List[Trial]:
- return self.searcher.live_trials
- @property
- def results(self) -> List[Dict]:
- return self.searcher.results
|