from_config.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. import importlib
  2. import json
  3. import os
  4. import re
  5. from copy import deepcopy
  6. from functools import partial
  7. from typing import TYPE_CHECKING, Optional
  8. import yaml
  9. from ray.rllib.utils import force_list, merge_dicts
  10. from ray.rllib.utils.annotations import DeveloperAPI
  11. if TYPE_CHECKING:
  12. from ray.rllib.utils.typing import FromConfigSpec
  13. @DeveloperAPI
  14. def from_config(cls, config: Optional["FromConfigSpec"] = None, **kwargs):
  15. """Uses the given config to create an object.
  16. If `config` is a dict, an optional "type" key can be used as a
  17. "constructor hint" to specify a certain class of the object.
  18. If `config` is not a dict, `config`'s value is used directly as this
  19. "constructor hint".
  20. The rest of `config` (if it's a dict) will be used as kwargs for the
  21. constructor. Additional keys in **kwargs will always have precedence
  22. (overwrite keys in `config` (if a dict)).
  23. Also, if the config-dict or **kwargs contains the special key "_args",
  24. it will be popped from the dict and used as *args list to be passed
  25. separately to the constructor.
  26. The following constructor hints are valid:
  27. - None: Use `cls` as constructor.
  28. - An already instantiated object: Will be returned as is; no
  29. constructor call.
  30. - A string or an object that is a key in `cls`'s `__type_registry__`
  31. dict: The value in `__type_registry__` for that key will be used
  32. as the constructor.
  33. - A python callable: Use that very callable as constructor.
  34. - A string: Either a json/yaml filename or the name of a python
  35. module+class (e.g. "ray.rllib. [...] .[some class name]")
  36. Args:
  37. cls: The class to build an instance for (from `config`).
  38. config (Optional[dict, str]): The config dict or type-string or
  39. filename.
  40. Keyword Args:
  41. kwargs: Optional possibility to pass the constructor arguments in
  42. here and use `config` as the type-only info. Then we can call
  43. this like: from_config([type]?, [**kwargs for constructor])
  44. If `config` is already a dict, then `kwargs` will be merged
  45. with `config` (overwriting keys in `config`) after "type" has
  46. been popped out of `config`.
  47. If a constructor of a Configurable needs *args, the special
  48. key `_args` can be passed inside `kwargs` with a list value
  49. (e.g. kwargs={"_args": [arg1, arg2, arg3]}).
  50. Returns:
  51. any: The object generated from the config.
  52. """
  53. # `cls` is the config (config is None).
  54. if config is None and isinstance(cls, (dict, str)):
  55. config = cls
  56. cls = None
  57. # `config` is already a created object of this class ->
  58. # Take it as is.
  59. elif isinstance(cls, type) and isinstance(config, cls):
  60. return config
  61. # `type_`: Indicator for the Configurable's constructor.
  62. # `ctor_args`: *args arguments for the constructor.
  63. # `ctor_kwargs`: **kwargs arguments for the constructor.
  64. # Try to copy, so caller can reuse safely.
  65. try:
  66. config = deepcopy(config)
  67. except Exception:
  68. pass
  69. if isinstance(config, dict):
  70. type_ = config.pop("type", None)
  71. if type_ is None and isinstance(cls, str):
  72. type_ = cls
  73. ctor_kwargs = config
  74. # Give kwargs priority over things defined in config dict.
  75. # This way, one can pass a generic `spec` and then override single
  76. # constructor parameters via the kwargs in the call to `from_config`.
  77. ctor_kwargs.update(kwargs)
  78. else:
  79. type_ = config
  80. if type_ is None and "type" in kwargs:
  81. type_ = kwargs.pop("type")
  82. ctor_kwargs = kwargs
  83. # Special `_args` field in kwargs for *args-utilizing constructors.
  84. ctor_args = force_list(ctor_kwargs.pop("_args", []))
  85. # Figure out the actual constructor (class) from `type_`.
  86. # None: Try __default__object (if no args/kwargs), only then
  87. # constructor of cls (using args/kwargs).
  88. if type_ is None:
  89. # We have a default constructor that was defined directly by cls
  90. # (not by its children).
  91. if (
  92. cls is not None
  93. and hasattr(cls, "__default_constructor__")
  94. and cls.__default_constructor__ is not None
  95. and ctor_args == []
  96. and (
  97. not hasattr(cls.__bases__[0], "__default_constructor__")
  98. or cls.__bases__[0].__default_constructor__ is None
  99. or cls.__bases__[0].__default_constructor__
  100. is not cls.__default_constructor__
  101. )
  102. ):
  103. constructor = cls.__default_constructor__
  104. # Default constructor's keywords into ctor_kwargs.
  105. if isinstance(constructor, partial):
  106. kwargs = merge_dicts(ctor_kwargs, constructor.keywords)
  107. constructor = partial(constructor.func, **kwargs)
  108. ctor_kwargs = {} # erase to avoid duplicate kwarg error
  109. # No default constructor -> Try cls itself as constructor.
  110. else:
  111. constructor = cls
  112. # Try the __type_registry__ of this class.
  113. else:
  114. constructor = _lookup_type(cls, type_)
  115. # Found in cls.__type_registry__.
  116. if constructor is not None:
  117. pass
  118. # type_ is False or None (and this value is not registered) ->
  119. # return value of type_.
  120. elif type_ is False or type_ is None:
  121. return type_
  122. # Python callable.
  123. elif callable(type_):
  124. constructor = type_
  125. # A string: Filename or a python module+class or a json/yaml str.
  126. elif isinstance(type_, str):
  127. if re.search("\\.(yaml|yml|json)$", type_):
  128. return from_file(cls, type_, *ctor_args, **ctor_kwargs)
  129. # Try un-json/un-yaml'ing the string into a dict.
  130. obj = yaml.safe_load(type_)
  131. if isinstance(obj, dict):
  132. return from_config(cls, obj)
  133. try:
  134. obj = from_config(cls, json.loads(type_))
  135. except json.JSONDecodeError:
  136. pass
  137. else:
  138. return obj
  139. # Test for absolute module.class path specifier.
  140. if type_.find(".") != -1:
  141. module_name, function_name = type_.rsplit(".", 1)
  142. try:
  143. module = importlib.import_module(module_name)
  144. constructor = getattr(module, function_name)
  145. # Module not found.
  146. except (ModuleNotFoundError, ImportError, AttributeError):
  147. pass
  148. # If constructor still not found, try attaching cls' module,
  149. # then look for type_ in there.
  150. if constructor is None:
  151. if isinstance(cls, str):
  152. # Module found, but doesn't have the specified
  153. # c'tor/function.
  154. raise ValueError(
  155. f"Full classpath specifier ({type_}) must be a valid "
  156. "full [module].[class] string! E.g.: "
  157. "`my.cool.module.MyCoolClass`."
  158. )
  159. try:
  160. module = importlib.import_module(cls.__module__)
  161. constructor = getattr(module, type_)
  162. except (ModuleNotFoundError, ImportError, AttributeError):
  163. # Try the package as well.
  164. try:
  165. package_name = importlib.import_module(
  166. cls.__module__
  167. ).__package__
  168. module = __import__(package_name, fromlist=[type_])
  169. constructor = getattr(module, type_)
  170. except (ModuleNotFoundError, ImportError, AttributeError):
  171. pass
  172. if constructor is None:
  173. raise ValueError(
  174. f"String specifier ({type_}) must be a valid filename, "
  175. f"a [module].[class], a class within '{cls.__module__}', "
  176. f"or a key into {cls.__name__}.__type_registry__!"
  177. )
  178. if not constructor:
  179. raise TypeError("Invalid type '{}'. Cannot create `from_config`.".format(type_))
  180. # Create object with inferred constructor.
  181. try:
  182. object_ = constructor(*ctor_args, **ctor_kwargs)
  183. # Catch attempts to construct from an abstract class and return None.
  184. except TypeError as e:
  185. if re.match("Can't instantiate abstract class", e.args[0]):
  186. return None
  187. raise e # Re-raise
  188. # No sanity check for fake (lambda)-"constructors".
  189. if type(constructor).__name__ != "function":
  190. assert isinstance(
  191. object_,
  192. constructor.func if isinstance(constructor, partial) else constructor,
  193. )
  194. return object_
  195. @DeveloperAPI
  196. def from_file(cls, filename, *args, **kwargs):
  197. """
  198. Create object from config saved in filename. Expects json or yaml file.
  199. Args:
  200. filename: File containing the config (json or yaml).
  201. Returns:
  202. any: The object generated from the file.
  203. """
  204. path = os.path.join(os.getcwd(), filename)
  205. if not os.path.isfile(path):
  206. raise FileNotFoundError("File '{}' not found!".format(filename))
  207. with open(path, "rt") as fp:
  208. if path.endswith(".yaml") or path.endswith(".yml"):
  209. config = yaml.safe_load(fp)
  210. else:
  211. config = json.load(fp)
  212. # Add possible *args.
  213. config["_args"] = args
  214. return from_config(cls, config=config, **kwargs)
  215. def _lookup_type(cls, type_):
  216. if (
  217. cls is not None
  218. and hasattr(cls, "__type_registry__")
  219. and isinstance(cls.__type_registry__, dict)
  220. and (
  221. type_ in cls.__type_registry__
  222. or (
  223. isinstance(type_, str)
  224. and re.sub("[\\W_]", "", type_.lower()) in cls.__type_registry__
  225. )
  226. )
  227. ):
  228. available_class_for_type = cls.__type_registry__.get(type_)
  229. if available_class_for_type is None:
  230. available_class_for_type = cls.__type_registry__[
  231. re.sub("[\\W_]", "", type_.lower())
  232. ]
  233. return available_class_for_type
  234. return None
  235. class _NotProvided:
  236. """Singleton class to provide a "not provided" value for AlgorithmConfig signatures.
  237. Using the only instance of this class indicates that the user does NOT wish to
  238. change the value of some property.
  239. .. testcode::
  240. :skipif: True
  241. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  242. config = AlgorithmConfig()
  243. # Print out the default learning rate.
  244. print(config.lr)
  245. .. testoutput::
  246. 0.001
  247. .. testcode::
  248. :skipif: True
  249. # Print out the default `preprocessor_pref`.
  250. print(config.preprocessor_pref)
  251. .. testoutput::
  252. "deepmind"
  253. .. testcode::
  254. :skipif: True
  255. # Will only set the `preprocessor_pref` property (to None) and leave
  256. # all other properties at their default values.
  257. config.training(preprocessor_pref=None)
  258. config.preprocessor_pref is None
  259. .. testoutput::
  260. True
  261. .. testcode::
  262. :skipif: True
  263. # Still the same value (didn't touch it in the call to `.training()`.
  264. print(config.lr)
  265. .. testoutput::
  266. 0.001
  267. """
  268. class __NotProvided:
  269. pass
  270. instance = None
  271. def __init__(self):
  272. if _NotProvided.instance is None:
  273. _NotProvided.instance = _NotProvided.__NotProvided()
  274. # Use this object as default values in all method signatures of
  275. # AlgorithmConfig, indicating that the respective property should NOT be touched
  276. # in the call.
  277. NotProvided = _NotProvided()