bc.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  2. from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
  3. from ray.rllib.core.rl_module.rl_module import RLModuleSpec
  4. from ray.rllib.utils.annotations import override
  5. from ray.rllib.utils.typing import RLModuleSpecType
  6. class BCConfig(MARWILConfig):
  7. """Defines a configuration class from which a new BC Algorithm can be built
  8. .. testcode::
  9. :skipif: True
  10. from ray.rllib.algorithms.bc import BCConfig
  11. # Run this from the ray directory root.
  12. config = BCConfig().training(lr=0.00001, gamma=0.99)
  13. config = config.offline_data(
  14. input_="./rllib/offline/tests/data/cartpole/large.json")
  15. # Build an Algorithm object from the config and run 1 training iteration.
  16. algo = config.build()
  17. algo.train()
  18. .. testcode::
  19. :skipif: True
  20. from ray.rllib.algorithms.bc import BCConfig
  21. from ray import tune
  22. config = BCConfig()
  23. # Print out some default values.
  24. print(config.beta)
  25. # Update the config object.
  26. config.training(
  27. lr=tune.grid_search([0.001, 0.0001]), beta=0.75
  28. )
  29. # Set the config object's data path.
  30. # Run this from the ray directory root.
  31. config.offline_data(
  32. input_="./rllib/offline/tests/data/cartpole/large.json"
  33. )
  34. # Set the config object's env, used for evaluation.
  35. config.environment(env="CartPole-v1")
  36. # Use to_dict() to get the old-style python config dict
  37. # when running with tune.
  38. tune.Tuner(
  39. "BC",
  40. param_space=config.to_dict(),
  41. ).fit()
  42. """
  43. def __init__(self, algo_class=None):
  44. super().__init__(algo_class=algo_class or BC)
  45. # fmt: off
  46. # __sphinx_doc_begin__
  47. # No need to calculate advantages (or do anything else with the rewards).
  48. self.beta = 0.0
  49. # Advantages (calculated during postprocessing)
  50. # not important for behavioral cloning.
  51. self.postprocess_inputs = False
  52. # Materialize only the mapped data. This is optimal as long
  53. # as no connector in the connector pipeline holds a state.
  54. self.materialize_data = False
  55. self.materialize_mapped_data = True
  56. # __sphinx_doc_end__
  57. # fmt: on
  58. @override(AlgorithmConfig)
  59. def get_default_rl_module_spec(self) -> RLModuleSpecType:
  60. if self.framework_str == "torch":
  61. from ray.rllib.algorithms.bc.torch.default_bc_torch_rl_module import (
  62. DefaultBCTorchRLModule,
  63. )
  64. return RLModuleSpec(module_class=DefaultBCTorchRLModule)
  65. else:
  66. raise ValueError(
  67. f"The framework {self.framework_str} is not supported. "
  68. "Use `torch` instead."
  69. )
  70. @override(AlgorithmConfig)
  71. def build_learner_connector(
  72. self,
  73. input_observation_space,
  74. input_action_space,
  75. device=None,
  76. ):
  77. pipeline = super().build_learner_connector(
  78. input_observation_space=input_observation_space,
  79. input_action_space=input_action_space,
  80. device=device,
  81. )
  82. # Remove unneeded connectors from the MARWIL connector pipeline.
  83. pipeline.remove("AddOneTsToEpisodesAndTruncate")
  84. pipeline.remove("GeneralAdvantageEstimation")
  85. return pipeline
  86. @override(MARWILConfig)
  87. def validate(self) -> None:
  88. # Call super's validation method.
  89. super().validate()
  90. if self.beta != 0.0:
  91. self._value_error("For behavioral cloning, `beta` parameter must be 0.0!")
  92. class BC(MARWIL):
  93. """Behavioral Cloning (derived from MARWIL).
  94. Uses MARWIL with beta force-set to 0.0.
  95. """
  96. @classmethod
  97. @override(MARWIL)
  98. def get_default_config(cls) -> BCConfig:
  99. return BCConfig()