import argparse import json from ray.tune import CheckpointConfig from ray.tune.error import TuneError from ray.tune.experiment import Trial from ray.tune.resources import json_to_resources # For compatibility under py2 to consider unicode as str from ray.tune.utils.serialization import TuneFunctionEncoder from ray.tune.utils.util import SafeFallbackEncoder def _make_parser(parser_creator=None, **kwargs): """Returns a base argument parser for the ray.tune tool. Args: parser_creator: A constructor for the parser class. kwargs: Non-positional args to be passed into the parser class constructor. """ if parser_creator: parser = parser_creator(**kwargs) else: parser = argparse.ArgumentParser(**kwargs) # Note: keep this in sync with rllib/train.py parser.add_argument( "--run", default=None, type=str, help="The algorithm or model to train. This may refer to the name " "of a built-on algorithm (e.g. RLlib's DQN or PPO), or a " "user-defined trainable function or class registered in the " "tune registry.", ) parser.add_argument( "--stop", default="{}", type=json.loads, help="The stopping criteria, specified in JSON. The keys may be any " "field returned by 'train()' e.g. " '\'{"time_total_s": 600, "training_iteration": 100000}\' to stop ' "after 600 seconds or 100k iterations, whichever is reached first.", ) parser.add_argument( "--config", default="{}", type=json.loads, help="Algorithm-specific configuration (e.g. env, hyperparams), " "specified in JSON.", ) parser.add_argument( "--resources-per-trial", default=None, type=json_to_resources, help="Override the machine resources to allocate per trial, e.g. " '\'{"cpu": 64, "gpu": 8}\'. Note that GPUs will not be assigned ' "unless you specify them here. For RLlib, you probably want to " "leave this alone and use RLlib configs to control parallelism.", ) parser.add_argument( "--num-samples", default=1, type=int, help="Number of times to repeat each trial.", ) parser.add_argument( "--checkpoint-freq", default=0, type=int, help="How many training iterations between checkpoints. " "A value of 0 (default) disables checkpointing.", ) parser.add_argument( "--checkpoint-at-end", action="store_true", help="Whether to checkpoint at the end of the experiment. Default is False.", ) parser.add_argument( "--keep-checkpoints-num", default=None, type=int, help="Number of best checkpoints to keep. Others get " "deleted. Default (None) keeps all checkpoints.", ) parser.add_argument( "--checkpoint-score-attr", default="training_iteration", type=str, help="Specifies by which attribute to rank the best checkpoint. " "Default is increasing order. If attribute starts with min- it " "will rank attribute in decreasing order. Example: " "min-validation_loss", ) parser.add_argument( "--export-formats", default=None, help="List of formats that exported at the end of the experiment. " "Default is None. For RLlib, 'checkpoint' and 'model' are " "supported for TensorFlow policy graphs.", ) parser.add_argument( "--max-failures", default=3, type=int, help="Try to recover a trial from its last checkpoint at least this " "many times. Only applies if checkpointing is enabled.", ) parser.add_argument( "--scheduler", default="FIFO", type=str, help="FIFO (default), MedianStopping, AsyncHyperBand, " "HyperBand, or HyperOpt.", ) parser.add_argument( "--scheduler-config", default="{}", type=json.loads, help="Config options to pass to the scheduler.", ) # Note: this currently only makes sense when running a single trial parser.add_argument( "--restore", default=None, type=str, help="If specified, restore from this checkpoint.", ) return parser def _to_argv(config): """Converts configuration to a command line argument format.""" argv = [] for k, v in config.items(): if "-" in k: raise ValueError("Use '_' instead of '-' in `{}`".format(k)) if v is None: continue if not isinstance(v, bool) or v: # for argparse flags argv.append("--{}".format(k.replace("_", "-"))) if isinstance(v, str): argv.append(v) elif isinstance(v, bool): pass elif callable(v): argv.append(json.dumps(v, cls=TuneFunctionEncoder)) else: argv.append(json.dumps(v, cls=SafeFallbackEncoder)) return argv _cached_pgf = {} def _create_trial_from_spec( spec: dict, parser: argparse.ArgumentParser, **trial_kwargs ): """Creates a Trial object from parsing the spec. Args: spec: A resolved experiment specification. Arguments should The args here should correspond to the command line flags in ray.tune.experiment.config_parser. parser: An argument parser object from make_parser. trial_kwargs: Extra keyword arguments used in instantiating the Trial. Returns: A trial object with corresponding parameters to the specification. """ global _cached_pgf spec = spec.copy() resources = spec.pop("resources_per_trial", None) try: args, _ = parser.parse_known_args(_to_argv(spec)) except SystemExit: raise TuneError("Error parsing args, see above message", spec) if resources: trial_kwargs["placement_group_factory"] = resources checkpoint_config = spec.get("checkpoint_config", CheckpointConfig()) return Trial( # Submitting trial via server in py2.7 creates Unicode, which does not # convert to string in a straightforward manner. trainable_name=spec["run"], # json.load leads to str -> unicode in py2.7 config=spec.get("config", {}), # json.load leads to str -> unicode in py2.7 stopping_criterion=spec.get("stop", {}), checkpoint_config=checkpoint_config, export_formats=spec.get("export_formats", []), # str(None) doesn't create None restore_path=spec.get("restore"), trial_name_creator=spec.get("trial_name_creator"), trial_dirname_creator=spec.get("trial_dirname_creator"), log_to_file=spec.get("log_to_file"), # str(None) doesn't create None max_failures=args.max_failures, storage=spec.get("storage"), **trial_kwargs, )