io_context.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import os
  2. from typing import TYPE_CHECKING, Optional
  3. from ray.rllib.utils.annotations import PublicAPI
  4. if TYPE_CHECKING:
  5. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  6. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  7. from ray.rllib.evaluation.sampler import SamplerInput
  8. @PublicAPI
  9. class IOContext:
  10. """Class containing attributes to pass to input/output class constructors.
  11. RLlib auto-sets these attributes when constructing input/output classes,
  12. such as InputReaders and OutputWriters.
  13. """
  14. @PublicAPI
  15. def __init__(
  16. self,
  17. log_dir: Optional[str] = None,
  18. config: Optional["AlgorithmConfig"] = None,
  19. worker_index: int = 0,
  20. worker: Optional["RolloutWorker"] = None,
  21. ):
  22. """Initializes a IOContext object.
  23. Args:
  24. log_dir: The logging directory to read from/write to.
  25. config: The (main) AlgorithmConfig object.
  26. worker_index: When there are multiple workers created, this
  27. uniquely identifies the current worker. 0 for the local
  28. worker, >0 for any of the remote workers.
  29. worker: The RolloutWorker object reference.
  30. """
  31. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  32. self.log_dir = log_dir or os.getcwd()
  33. # In case no config is provided, use the default one, but set
  34. # `actions_in_input_normalized=True` if we don't have a worker.
  35. # Not having a worker and/or a config should only be the case in some test
  36. # cases, though.
  37. self.config = config or AlgorithmConfig().offline_data(
  38. actions_in_input_normalized=worker is None
  39. ).training(train_batch_size=1)
  40. self.worker_index = worker_index
  41. self.worker = worker
  42. @PublicAPI
  43. def default_sampler_input(self) -> Optional["SamplerInput"]:
  44. """Returns the RolloutWorker's SamplerInput object, if any.
  45. Returns None if the RolloutWorker has no SamplerInput. Note that local
  46. workers in case there are also one or more remote workers by default
  47. do not create a SamplerInput object.
  48. Returns:
  49. The RolloutWorkers' SamplerInput object or None if none exists.
  50. """
  51. return self.worker.sampler
  52. @property
  53. @PublicAPI
  54. def input_config(self):
  55. return self.config.get("input_config", {})
  56. @property
  57. @PublicAPI
  58. def output_config(self):
  59. return self.config.get("output_config", {})