config_parser.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import argparse
  2. import json
  3. from ray.tune import CheckpointConfig
  4. from ray.tune.error import TuneError
  5. from ray.tune.experiment import Trial
  6. from ray.tune.resources import json_to_resources
  7. # For compatibility under py2 to consider unicode as str
  8. from ray.tune.utils.serialization import TuneFunctionEncoder
  9. from ray.tune.utils.util import SafeFallbackEncoder
  10. def _make_parser(parser_creator=None, **kwargs):
  11. """Returns a base argument parser for the ray.tune tool.
  12. Args:
  13. parser_creator: A constructor for the parser class.
  14. kwargs: Non-positional args to be passed into the
  15. parser class constructor.
  16. """
  17. if parser_creator:
  18. parser = parser_creator(**kwargs)
  19. else:
  20. parser = argparse.ArgumentParser(**kwargs)
  21. # Note: keep this in sync with rllib/train.py
  22. parser.add_argument(
  23. "--run",
  24. default=None,
  25. type=str,
  26. help="The algorithm or model to train. This may refer to the name "
  27. "of a built-on algorithm (e.g. RLlib's DQN or PPO), or a "
  28. "user-defined trainable function or class registered in the "
  29. "tune registry.",
  30. )
  31. parser.add_argument(
  32. "--stop",
  33. default="{}",
  34. type=json.loads,
  35. help="The stopping criteria, specified in JSON. The keys may be any "
  36. "field returned by 'train()' e.g. "
  37. '\'{"time_total_s": 600, "training_iteration": 100000}\' to stop '
  38. "after 600 seconds or 100k iterations, whichever is reached first.",
  39. )
  40. parser.add_argument(
  41. "--config",
  42. default="{}",
  43. type=json.loads,
  44. help="Algorithm-specific configuration (e.g. env, hyperparams), "
  45. "specified in JSON.",
  46. )
  47. parser.add_argument(
  48. "--resources-per-trial",
  49. default=None,
  50. type=json_to_resources,
  51. help="Override the machine resources to allocate per trial, e.g. "
  52. '\'{"cpu": 64, "gpu": 8}\'. Note that GPUs will not be assigned '
  53. "unless you specify them here. For RLlib, you probably want to "
  54. "leave this alone and use RLlib configs to control parallelism.",
  55. )
  56. parser.add_argument(
  57. "--num-samples",
  58. default=1,
  59. type=int,
  60. help="Number of times to repeat each trial.",
  61. )
  62. parser.add_argument(
  63. "--checkpoint-freq",
  64. default=0,
  65. type=int,
  66. help="How many training iterations between checkpoints. "
  67. "A value of 0 (default) disables checkpointing.",
  68. )
  69. parser.add_argument(
  70. "--checkpoint-at-end",
  71. action="store_true",
  72. help="Whether to checkpoint at the end of the experiment. Default is False.",
  73. )
  74. parser.add_argument(
  75. "--keep-checkpoints-num",
  76. default=None,
  77. type=int,
  78. help="Number of best checkpoints to keep. Others get "
  79. "deleted. Default (None) keeps all checkpoints.",
  80. )
  81. parser.add_argument(
  82. "--checkpoint-score-attr",
  83. default="training_iteration",
  84. type=str,
  85. help="Specifies by which attribute to rank the best checkpoint. "
  86. "Default is increasing order. If attribute starts with min- it "
  87. "will rank attribute in decreasing order. Example: "
  88. "min-validation_loss",
  89. )
  90. parser.add_argument(
  91. "--export-formats",
  92. default=None,
  93. help="List of formats that exported at the end of the experiment. "
  94. "Default is None. For RLlib, 'checkpoint' and 'model' are "
  95. "supported for TensorFlow policy graphs.",
  96. )
  97. parser.add_argument(
  98. "--max-failures",
  99. default=3,
  100. type=int,
  101. help="Try to recover a trial from its last checkpoint at least this "
  102. "many times. Only applies if checkpointing is enabled.",
  103. )
  104. parser.add_argument(
  105. "--scheduler",
  106. default="FIFO",
  107. type=str,
  108. help="FIFO (default), MedianStopping, AsyncHyperBand, "
  109. "HyperBand, or HyperOpt.",
  110. )
  111. parser.add_argument(
  112. "--scheduler-config",
  113. default="{}",
  114. type=json.loads,
  115. help="Config options to pass to the scheduler.",
  116. )
  117. # Note: this currently only makes sense when running a single trial
  118. parser.add_argument(
  119. "--restore",
  120. default=None,
  121. type=str,
  122. help="If specified, restore from this checkpoint.",
  123. )
  124. return parser
  125. def _to_argv(config):
  126. """Converts configuration to a command line argument format."""
  127. argv = []
  128. for k, v in config.items():
  129. if "-" in k:
  130. raise ValueError("Use '_' instead of '-' in `{}`".format(k))
  131. if v is None:
  132. continue
  133. if not isinstance(v, bool) or v: # for argparse flags
  134. argv.append("--{}".format(k.replace("_", "-")))
  135. if isinstance(v, str):
  136. argv.append(v)
  137. elif isinstance(v, bool):
  138. pass
  139. elif callable(v):
  140. argv.append(json.dumps(v, cls=TuneFunctionEncoder))
  141. else:
  142. argv.append(json.dumps(v, cls=SafeFallbackEncoder))
  143. return argv
  144. _cached_pgf = {}
  145. def _create_trial_from_spec(
  146. spec: dict, parser: argparse.ArgumentParser, **trial_kwargs
  147. ):
  148. """Creates a Trial object from parsing the spec.
  149. Args:
  150. spec: A resolved experiment specification. Arguments should
  151. The args here should correspond to the command line flags
  152. in ray.tune.experiment.config_parser.
  153. parser: An argument parser object from
  154. make_parser.
  155. trial_kwargs: Extra keyword arguments used in instantiating the Trial.
  156. Returns:
  157. A trial object with corresponding parameters to the specification.
  158. """
  159. global _cached_pgf
  160. spec = spec.copy()
  161. resources = spec.pop("resources_per_trial", None)
  162. try:
  163. args, _ = parser.parse_known_args(_to_argv(spec))
  164. except SystemExit:
  165. raise TuneError("Error parsing args, see above message", spec)
  166. if resources:
  167. trial_kwargs["placement_group_factory"] = resources
  168. checkpoint_config = spec.get("checkpoint_config", CheckpointConfig())
  169. return Trial(
  170. # Submitting trial via server in py2.7 creates Unicode, which does not
  171. # convert to string in a straightforward manner.
  172. trainable_name=spec["run"],
  173. # json.load leads to str -> unicode in py2.7
  174. config=spec.get("config", {}),
  175. # json.load leads to str -> unicode in py2.7
  176. stopping_criterion=spec.get("stop", {}),
  177. checkpoint_config=checkpoint_config,
  178. export_formats=spec.get("export_formats", []),
  179. # str(None) doesn't create None
  180. restore_path=spec.get("restore"),
  181. trial_name_creator=spec.get("trial_name_creator"),
  182. trial_dirname_creator=spec.get("trial_dirname_creator"),
  183. log_to_file=spec.get("log_to_file"),
  184. # str(None) doesn't create None
  185. max_failures=args.max_failures,
  186. storage=spec.get("storage"),
  187. **trial_kwargs,
  188. )