mixed_input.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from types import FunctionType
  2. from typing import Dict
  3. import numpy as np
  4. from ray.rllib.offline.input_reader import InputReader
  5. from ray.rllib.offline.io_context import IOContext
  6. from ray.rllib.offline.json_reader import JsonReader
  7. from ray.rllib.utils.annotations import DeveloperAPI, override
  8. from ray.rllib.utils.typing import SampleBatchType
  9. from ray.tune.registry import registry_contains_input, registry_get_input
  10. @DeveloperAPI
  11. class MixedInput(InputReader):
  12. """Mixes input from a number of other input sources.
  13. .. testcode::
  14. :skipif: True
  15. from ray.rllib.offline.io_context import IOContext
  16. from ray.rllib.offline.mixed_input import MixedInput
  17. ioctx = IOContext(...)
  18. MixedInput({
  19. "sampler": 0.4,
  20. "/tmp/experiences/*.json": 0.4,
  21. "s3://bucket/expert.json": 0.2,
  22. }, ioctx)
  23. """
  24. @DeveloperAPI
  25. def __init__(self, dist: Dict[JsonReader, float], ioctx: IOContext):
  26. """Initialize a MixedInput.
  27. Args:
  28. dist: dict mapping JSONReader paths or "sampler" to
  29. probabilities. The probabilities must sum to 1.0.
  30. ioctx: current IO context object.
  31. """
  32. if sum(dist.values()) != 1.0:
  33. raise ValueError("Values must sum to 1.0: {}".format(dist))
  34. self.choices = []
  35. self.p = []
  36. for k, v in dist.items():
  37. if k == "sampler":
  38. self.choices.append(ioctx.default_sampler_input())
  39. elif isinstance(k, FunctionType):
  40. self.choices.append(k(ioctx))
  41. elif isinstance(k, str) and registry_contains_input(k):
  42. input_creator = registry_get_input(k)
  43. self.choices.append(input_creator(ioctx))
  44. else:
  45. self.choices.append(JsonReader(k, ioctx))
  46. self.p.append(v)
  47. @override(InputReader)
  48. def next(self) -> SampleBatchType:
  49. source = np.random.choice(self.choices, p=self.p)
  50. return source.next()