registry.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. """Registry of algorithm names for tune.Tuner(trainable=[..])."""
  2. import importlib
  3. import re
  4. def _import_appo():
  5. import ray.rllib.algorithms.appo as appo
  6. return appo.APPO, appo.APPO.get_default_config()
  7. def _import_bc():
  8. import ray.rllib.algorithms.bc as bc
  9. return bc.BC, bc.BC.get_default_config()
  10. def _import_cql():
  11. import ray.rllib.algorithms.cql as cql
  12. return cql.CQL, cql.CQL.get_default_config()
  13. def _import_dqn():
  14. import ray.rllib.algorithms.dqn as dqn
  15. return dqn.DQN, dqn.DQN.get_default_config()
  16. def _import_dreamerv3():
  17. import ray.rllib.algorithms.dreamerv3 as dreamerv3
  18. return dreamerv3.DreamerV3, dreamerv3.DreamerV3.get_default_config()
  19. def _import_impala():
  20. import ray.rllib.algorithms.impala as impala
  21. return impala.IMPALA, impala.IMPALA.get_default_config()
  22. def _import_iql():
  23. import ray.rllib.algorithms.iql as iql
  24. return iql.IQL, iql.IQL.get_default_config()
  25. def _import_marwil():
  26. import ray.rllib.algorithms.marwil as marwil
  27. return marwil.MARWIL, marwil.MARWIL.get_default_config()
  28. def _import_ppo():
  29. import ray.rllib.algorithms.ppo as ppo
  30. return ppo.PPO, ppo.PPO.get_default_config()
  31. def _import_sac():
  32. import ray.rllib.algorithms.sac as sac
  33. return sac.SAC, sac.SAC.get_default_config()
  34. ALGORITHMS = {
  35. "APPO": _import_appo,
  36. "BC": _import_bc,
  37. "CQL": _import_cql,
  38. "DQN": _import_dqn,
  39. "DreamerV3": _import_dreamerv3,
  40. "IMPALA": _import_impala,
  41. "IQL": _import_iql,
  42. "MARWIL": _import_marwil,
  43. "PPO": _import_ppo,
  44. "SAC": _import_sac,
  45. }
  46. ALGORITHMS_CLASS_TO_NAME = {
  47. "APPO": "APPO",
  48. "BC": "BC",
  49. "CQL": "CQL",
  50. "DQN": "DQN",
  51. "DreamerV3": "DreamerV3",
  52. "Impala": "IMPALA",
  53. "IQL": "IQL",
  54. "IMPALA": "IMPALA",
  55. "MARWIL": "MARWIL",
  56. "PPO": "PPO",
  57. "SAC": "SAC",
  58. }
  59. def _get_algorithm_class(alg: str) -> type:
  60. # This helps us get around a circular import (tune calls rllib._register_all when
  61. # checking if a rllib Trainable is registered)
  62. if alg in ALGORITHMS:
  63. return ALGORITHMS[alg]()[0]
  64. elif alg == "script":
  65. from ray.tune import script_runner
  66. return script_runner.ScriptRunner
  67. elif alg == "__fake":
  68. from ray.rllib.algorithms.mock import _MockTrainer
  69. return _MockTrainer
  70. elif alg == "__sigmoid_fake_data":
  71. from ray.rllib.algorithms.mock import _SigmoidFakeData
  72. return _SigmoidFakeData
  73. elif alg == "__parameter_tuning":
  74. from ray.rllib.algorithms.mock import _ParameterTuningTrainer
  75. return _ParameterTuningTrainer
  76. else:
  77. raise Exception("Unknown algorithm {}.".format(alg))
  78. # Dict mapping policy names to where the class is located, relative to rllib.algorithms.
  79. # TODO(jungong) : Finish migrating all the policies to PolicyV2, so we can list
  80. # all the TF eager policies here.
  81. POLICIES = {
  82. "APPOTF1Policy": "appo.appo_tf_policy",
  83. "APPOTF2Policy": "appo.appo_tf_policy",
  84. "APPOTorchPolicy": "appo.appo_torch_policy",
  85. "CQLTFPolicy": "cql.cql_tf_policy",
  86. "CQLTorchPolicy": "cql.cql_torch_policy",
  87. "DQNTFPolicy": "dqn.dqn_tf_policy",
  88. "DQNTorchPolicy": "dqn.dqn_torch_policy",
  89. "ImpalaTF1Policy": "impala.impala_tf_policy",
  90. "ImpalaTF2Policy": "impala.impala_tf_policy",
  91. "ImpalaTorchPolicy": "impala.impala_torch_policy",
  92. "MARWILTF1Policy": "marwil.marwil_tf_policy",
  93. "MARWILTF2Policy": "marwil.marwil_tf_policy",
  94. "MARWILTorchPolicy": "marwil.marwil_torch_policy",
  95. "SACTFPolicy": "sac.sac_tf_policy",
  96. "SACTorchPolicy": "sac.sac_torch_policy",
  97. "PPOTF1Policy": "ppo.ppo_tf_policy",
  98. "PPOTF2Policy": "ppo.ppo_tf_policy",
  99. "PPOTorchPolicy": "ppo.ppo_torch_policy",
  100. }
  101. def get_policy_class_name(policy_class: type):
  102. """Returns a string name for the provided policy class.
  103. Args:
  104. policy_class: RLlib policy class, e.g. A3CTorchPolicy, DQNTFPolicy, etc.
  105. Returns:
  106. A string name uniquely mapped to the given policy class.
  107. """
  108. # TF2 policy classes may get automatically converted into new class types
  109. # that have eager tracing capability.
  110. # These policy classes have the "_traced" postfix in their names.
  111. # When checkpointing these policy classes, we should save the name of the
  112. # original policy class instead. So that users have the choice of turning
  113. # on eager tracing during inference time.
  114. name = re.sub("_traced$", "", policy_class.__name__)
  115. if name in POLICIES:
  116. return name
  117. return None
  118. def get_policy_class(name: str):
  119. """Return an actual policy class given the string name.
  120. Args:
  121. name: string name of the policy class.
  122. Returns:
  123. Actual policy class for the given name.
  124. """
  125. if name not in POLICIES:
  126. return None
  127. path = POLICIES[name]
  128. module = importlib.import_module("ray.rllib.algorithms." + path)
  129. if not hasattr(module, name):
  130. return None
  131. return getattr(module, name)