sample.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767
  1. import logging
  2. import warnings
  3. from copy import copy
  4. from inspect import signature
  5. from math import isclose
  6. from typing import Any, Callable, Dict, List, Optional, Sequence, Union
  7. import numpy as np
  8. # Backwards compatibility
  9. from ray.util.annotations import DeveloperAPI, PublicAPI, RayDeprecationWarning
  10. try:
  11. # Added in numpy>=1.17 but we require numpy>=1.16
  12. np_random_generator = np.random.Generator
  13. LEGACY_RNG = False
  14. except AttributeError:
  15. class np_random_generator:
  16. pass
  17. LEGACY_RNG = True
  18. logger = logging.getLogger(__name__)
  19. _MISSING = object() # Sentinel for missing parameters.
  20. def _warn_for_base() -> None:
  21. warnings.warn(
  22. (
  23. "The `base` argument is deprecated. "
  24. "Please remove it as it is not actually needed in this method."
  25. ),
  26. RayDeprecationWarning,
  27. stacklevel=2,
  28. )
  29. class _BackwardsCompatibleNumpyRng:
  30. """Thin wrapper to ensure backwards compatibility between
  31. new and old numpy randomness generators.
  32. """
  33. _rng = None
  34. def __init__(
  35. self,
  36. generator_or_seed: Optional[
  37. Union["np_random_generator", np.random.RandomState, int]
  38. ] = None,
  39. ):
  40. if generator_or_seed is None or isinstance(
  41. generator_or_seed, (np.random.RandomState, np_random_generator)
  42. ):
  43. self._rng = generator_or_seed
  44. elif LEGACY_RNG:
  45. self._rng = np.random.RandomState(generator_or_seed)
  46. else:
  47. self._rng = np.random.default_rng(generator_or_seed)
  48. @property
  49. def legacy_rng(self) -> bool:
  50. return not isinstance(self._rng, np_random_generator)
  51. @property
  52. def rng(self):
  53. # don't set self._rng to np.random to avoid picking issues
  54. return self._rng if self._rng is not None else np.random
  55. def __getattr__(self, name: str) -> Any:
  56. # https://numpy.org/doc/stable/reference/random/new-or-different.html
  57. if self.legacy_rng:
  58. if name == "integers":
  59. name = "randint"
  60. elif name == "random":
  61. name = "rand"
  62. return getattr(self.rng, name)
  63. RandomState = Union[
  64. None, _BackwardsCompatibleNumpyRng, np_random_generator, np.random.RandomState, int
  65. ]
  66. @DeveloperAPI
  67. class Domain:
  68. """Base class to specify a type and valid range to sample parameters from.
  69. This base class is implemented by parameter spaces, like float ranges
  70. (``Float``), integer ranges (``Integer``), or categorical variables
  71. (``Categorical``). The ``Domain`` object contains information about
  72. valid values (e.g. minimum and maximum values), and exposes methods that
  73. allow specification of specific samplers (e.g. ``uniform()`` or
  74. ``loguniform()``).
  75. """
  76. sampler = None
  77. default_sampler_cls = None
  78. def cast(self, value):
  79. """Cast value to domain type"""
  80. return value
  81. def set_sampler(self, sampler, allow_override=False):
  82. if self.sampler and not allow_override:
  83. raise ValueError(
  84. "You can only choose one sampler for parameter "
  85. "domains. Existing sampler for parameter {}: "
  86. "{}. Tried to add {}".format(
  87. self.__class__.__name__, self.sampler, sampler
  88. )
  89. )
  90. self.sampler = sampler
  91. def get_sampler(self):
  92. sampler = self.sampler
  93. if not sampler:
  94. sampler = self.default_sampler_cls()
  95. return sampler
  96. def sample(
  97. self,
  98. config: Optional[Union[List[Dict], Dict]] = None,
  99. size: int = 1,
  100. random_state: "RandomState" = None,
  101. ):
  102. if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
  103. random_state = _BackwardsCompatibleNumpyRng(random_state)
  104. sampler = self.get_sampler()
  105. return sampler.sample(self, config=config, size=size, random_state=random_state)
  106. def is_grid(self):
  107. return isinstance(self.sampler, Grid)
  108. def is_function(self):
  109. return False
  110. def is_valid(self, value: Any):
  111. """Returns True if `value` is a valid value in this domain."""
  112. raise NotImplementedError
  113. @property
  114. def domain_str(self):
  115. return "(unknown)"
  116. @DeveloperAPI
  117. class Sampler:
  118. def sample(
  119. self,
  120. domain: Domain,
  121. config: Optional[Union[List[Dict], Dict]] = None,
  122. size: int = 1,
  123. random_state: "RandomState" = None,
  124. ):
  125. raise NotImplementedError
  126. @DeveloperAPI
  127. class BaseSampler(Sampler):
  128. def __str__(self):
  129. return "Base"
  130. @DeveloperAPI
  131. class Uniform(Sampler):
  132. def __str__(self):
  133. return "Uniform"
  134. @DeveloperAPI
  135. class LogUniform(Sampler):
  136. def __init__(self, base: object = _MISSING):
  137. if base is not _MISSING:
  138. _warn_for_base()
  139. def __str__(self):
  140. return "LogUniform"
  141. @DeveloperAPI
  142. class Normal(Sampler):
  143. def __init__(self, mean: float = 0.0, sd: float = 0.0):
  144. self.mean = mean
  145. self.sd = sd
  146. assert self.sd > 0, "SD has to be strictly greater than 0"
  147. def __str__(self):
  148. return "Normal"
  149. @DeveloperAPI
  150. class Grid(Sampler):
  151. """Dummy sampler used for grid search"""
  152. def sample(
  153. self,
  154. domain: Domain,
  155. config: Optional[Union[List[Dict], Dict]] = None,
  156. size: int = 1,
  157. random_state: "RandomState" = None,
  158. ):
  159. return RuntimeError("Do not call `sample()` on grid.")
  160. @DeveloperAPI
  161. class Float(Domain):
  162. class _Uniform(Uniform):
  163. def sample(
  164. self,
  165. domain: "Float",
  166. config: Optional[Union[List[Dict], Dict]] = None,
  167. size: int = 1,
  168. random_state: "RandomState" = None,
  169. ):
  170. if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
  171. random_state = _BackwardsCompatibleNumpyRng(random_state)
  172. assert domain.lower > float("-inf"), "Uniform needs a lower bound"
  173. assert domain.upper < float("inf"), "Uniform needs a upper bound"
  174. items = random_state.uniform(domain.lower, domain.upper, size=size)
  175. return items if len(items) > 1 else domain.cast(items[0])
  176. class _LogUniform(LogUniform):
  177. def sample(
  178. self,
  179. domain: "Float",
  180. config: Optional[Union[List[Dict], Dict]] = None,
  181. size: int = 1,
  182. random_state: "RandomState" = None,
  183. ):
  184. if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
  185. random_state = _BackwardsCompatibleNumpyRng(random_state)
  186. assert domain.lower > 0, "LogUniform needs a lower bound greater than 0"
  187. assert (
  188. 0 < domain.upper < float("inf")
  189. ), "LogUniform needs a upper bound greater than 0"
  190. logmin = np.log(domain.lower)
  191. logmax = np.log(domain.upper)
  192. items = np.exp(random_state.uniform(logmin, logmax, size=size))
  193. return items if len(items) > 1 else domain.cast(items[0])
  194. class _Normal(Normal):
  195. def sample(
  196. self,
  197. domain: "Float",
  198. config: Optional[Union[List[Dict], Dict]] = None,
  199. size: int = 1,
  200. random_state: "RandomState" = None,
  201. ):
  202. if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
  203. random_state = _BackwardsCompatibleNumpyRng(random_state)
  204. assert not domain.lower or domain.lower == float(
  205. "-inf"
  206. ), "Normal sampling does not allow a lower value bound."
  207. assert not domain.upper or domain.upper == float(
  208. "inf"
  209. ), "Normal sampling does not allow a upper value bound."
  210. items = random_state.normal(self.mean, self.sd, size=size)
  211. return items if len(items) > 1 else domain.cast(items[0])
  212. default_sampler_cls = _Uniform
  213. def __init__(self, lower: Optional[float], upper: Optional[float]):
  214. # Need to explicitly check for None
  215. self.lower = lower if lower is not None else float("-inf")
  216. self.upper = upper if upper is not None else float("inf")
  217. def cast(self, value):
  218. return float(value)
  219. def uniform(self):
  220. if not self.lower > float("-inf"):
  221. raise ValueError(
  222. "Uniform requires a lower bound. Make sure to set the "
  223. "`lower` parameter of `Float()`."
  224. )
  225. if not self.upper < float("inf"):
  226. raise ValueError(
  227. "Uniform requires a upper bound. Make sure to set the "
  228. "`upper` parameter of `Float()`."
  229. )
  230. new = copy(self)
  231. new.set_sampler(self._Uniform())
  232. return new
  233. def loguniform(self, base: object = _MISSING):
  234. if base is not _MISSING:
  235. _warn_for_base()
  236. if not self.lower > 0:
  237. raise ValueError(
  238. "LogUniform requires a lower bound greater than 0."
  239. f"Got: {self.lower}. Did you pass a variable that has "
  240. "been log-transformed? If so, pass the non-transformed value "
  241. "instead."
  242. )
  243. if not 0 < self.upper < float("inf"):
  244. raise ValueError(
  245. "LogUniform requires a upper bound greater than 0. "
  246. f"Got: {self.lower}. Did you pass a variable that has "
  247. "been log-transformed? If so, pass the non-transformed value "
  248. "instead."
  249. )
  250. new = copy(self)
  251. new.set_sampler(self._LogUniform())
  252. return new
  253. def normal(self, mean=0.0, sd=1.0):
  254. new = copy(self)
  255. new.set_sampler(self._Normal(mean, sd))
  256. return new
  257. def quantized(self, q: float):
  258. if self.lower > float("-inf") and not isclose(
  259. self.lower / q, round(self.lower / q)
  260. ):
  261. raise ValueError(
  262. f"Your lower variable bound {self.lower} is not divisible by "
  263. f"quantization factor {q}."
  264. )
  265. if self.upper < float("inf") and not isclose(
  266. self.upper / q, round(self.upper / q)
  267. ):
  268. raise ValueError(
  269. f"Your upper variable bound {self.upper} is not divisible by "
  270. f"quantization factor {q}."
  271. )
  272. new = copy(self)
  273. new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True)
  274. return new
  275. def is_valid(self, value: float):
  276. return self.lower <= value <= self.upper
  277. @property
  278. def domain_str(self):
  279. return f"({self.lower}, {self.upper})"
  280. @DeveloperAPI
  281. class Integer(Domain):
  282. class _Uniform(Uniform):
  283. def sample(
  284. self,
  285. domain: "Integer",
  286. config: Optional[Union[List[Dict], Dict]] = None,
  287. size: int = 1,
  288. random_state: "RandomState" = None,
  289. ):
  290. if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
  291. random_state = _BackwardsCompatibleNumpyRng(random_state)
  292. items = random_state.integers(domain.lower, domain.upper, size=size)
  293. return items if len(items) > 1 else domain.cast(items[0])
  294. class _LogUniform(LogUniform):
  295. def sample(
  296. self,
  297. domain: "Integer",
  298. config: Optional[Union[List[Dict], Dict]] = None,
  299. size: int = 1,
  300. random_state: "RandomState" = None,
  301. ):
  302. if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
  303. random_state = _BackwardsCompatibleNumpyRng(random_state)
  304. assert domain.lower > 0, "LogUniform needs a lower bound greater than 0"
  305. assert (
  306. 0 < domain.upper < float("inf")
  307. ), "LogUniform needs a upper bound greater than 0"
  308. logmin = np.log(domain.lower)
  309. logmax = np.log(domain.upper)
  310. items = np.exp(random_state.uniform(logmin, logmax, size=size))
  311. items = np.floor(items).astype(int)
  312. return items if len(items) > 1 else domain.cast(items[0])
  313. default_sampler_cls = _Uniform
  314. def __init__(self, lower, upper):
  315. self.lower = lower
  316. self.upper = upper
  317. def cast(self, value):
  318. return int(value)
  319. def quantized(self, q: int):
  320. new = copy(self)
  321. new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True)
  322. return new
  323. def uniform(self):
  324. new = copy(self)
  325. new.set_sampler(self._Uniform())
  326. return new
  327. def loguniform(self, base: object = _MISSING):
  328. if base is not _MISSING:
  329. _warn_for_base()
  330. if not self.lower > 0:
  331. raise ValueError(
  332. "LogUniform requires a lower bound greater than 0."
  333. f"Got: {self.lower}. Did you pass a variable that has "
  334. "been log-transformed? If so, pass the non-transformed value "
  335. "instead."
  336. )
  337. if not 0 < self.upper < float("inf"):
  338. raise ValueError(
  339. "LogUniform requires a upper bound greater than 0. "
  340. f"Got: {self.lower}. Did you pass a variable that has "
  341. "been log-transformed? If so, pass the non-transformed value "
  342. "instead."
  343. )
  344. new = copy(self)
  345. new.set_sampler(self._LogUniform())
  346. return new
  347. def is_valid(self, value: int):
  348. return self.lower <= value <= self.upper
  349. @property
  350. def domain_str(self):
  351. return f"({self.lower}, {self.upper})"
  352. @DeveloperAPI
  353. class Categorical(Domain):
  354. class _Uniform(Uniform):
  355. def sample(
  356. self,
  357. domain: "Categorical",
  358. config: Optional[Union[List[Dict], Dict]] = None,
  359. size: int = 1,
  360. random_state: "RandomState" = None,
  361. ):
  362. if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
  363. random_state = _BackwardsCompatibleNumpyRng(random_state)
  364. # do not use .choice() directly on domain.categories
  365. # as that will coerce them to a single dtype
  366. indices = random_state.choice(
  367. np.arange(0, len(domain.categories)), size=size
  368. )
  369. items = [domain.categories[index] for index in indices]
  370. return items if len(items) > 1 else domain.cast(items[0])
  371. default_sampler_cls = _Uniform
  372. def __init__(self, categories: Sequence):
  373. self.categories = list(categories)
  374. def uniform(self):
  375. new = copy(self)
  376. new.set_sampler(self._Uniform())
  377. return new
  378. def grid(self):
  379. new = copy(self)
  380. new.set_sampler(Grid())
  381. return new
  382. def __len__(self):
  383. return len(self.categories)
  384. def __getitem__(self, item):
  385. return self.categories[item]
  386. def is_valid(self, value: Any):
  387. return value in self.categories
  388. @property
  389. def domain_str(self):
  390. return f"{self.categories}"
  391. @DeveloperAPI
  392. class Function(Domain):
  393. class _CallSampler(BaseSampler):
  394. def __try_fn(self, domain: "Function", config: Dict[str, Any]):
  395. try:
  396. return domain.func(config)
  397. except (AttributeError, KeyError):
  398. from ray.tune.search.variant_generator import _UnresolvedAccessGuard
  399. r = domain.func(_UnresolvedAccessGuard({"config": config}))
  400. logger.warning(
  401. "sample_from functions that take a spec dict are "
  402. "deprecated. Please update your function to work with "
  403. "the config dict directly."
  404. )
  405. return r
  406. def sample(
  407. self,
  408. domain: "Function",
  409. config: Optional[Union[List[Dict], Dict]] = None,
  410. size: int = 1,
  411. random_state: "RandomState" = None,
  412. ):
  413. if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
  414. random_state = _BackwardsCompatibleNumpyRng(random_state)
  415. if domain.pass_config:
  416. items = [
  417. (
  418. self.__try_fn(domain, config[i])
  419. if isinstance(config, list)
  420. else self.__try_fn(domain, config)
  421. )
  422. for i in range(size)
  423. ]
  424. else:
  425. items = [domain.func() for i in range(size)]
  426. return items if len(items) > 1 else domain.cast(items[0])
  427. default_sampler_cls = _CallSampler
  428. def __init__(self, func: Callable):
  429. sig = signature(func)
  430. pass_config = True # whether we should pass `config` when calling `func`
  431. try:
  432. sig.bind({})
  433. except TypeError:
  434. pass_config = False
  435. if not pass_config:
  436. try:
  437. sig.bind()
  438. except TypeError as exc:
  439. raise ValueError(
  440. "The function passed to a `Function` parameter must be "
  441. "callable with either 0 or 1 parameters."
  442. ) from exc
  443. self.pass_config = pass_config
  444. self.func = func
  445. def is_function(self):
  446. return True
  447. def is_valid(self, value: Any):
  448. return True # This is user-defined, so lets not assume anything
  449. @property
  450. def domain_str(self):
  451. return f"{self.func}()"
  452. @DeveloperAPI
  453. class Quantized(Sampler):
  454. def __init__(self, sampler: Sampler, q: Union[float, int]):
  455. self.sampler = sampler
  456. self.q = q
  457. assert self.sampler, "Quantized() expects a sampler instance"
  458. def get_sampler(self):
  459. return self.sampler
  460. def sample(
  461. self,
  462. domain: Domain,
  463. config: Optional[Union[List[Dict], Dict]] = None,
  464. size: int = 1,
  465. random_state: "RandomState" = None,
  466. ):
  467. if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
  468. random_state = _BackwardsCompatibleNumpyRng(random_state)
  469. if self.q == 1:
  470. return self.sampler.sample(domain, config, size, random_state=random_state)
  471. quantized_domain = copy(domain)
  472. quantized_domain.lower = np.ceil(domain.lower / self.q) * self.q
  473. quantized_domain.upper = np.floor(domain.upper / self.q) * self.q
  474. values = self.sampler.sample(
  475. quantized_domain, config, size, random_state=random_state
  476. )
  477. quantized = np.round(np.divide(values, self.q)) * self.q
  478. if not isinstance(quantized, np.ndarray):
  479. return domain.cast(quantized)
  480. return list(quantized)
  481. @PublicAPI
  482. def sample_from(func: Callable[[Dict], Any]):
  483. """Specify that tune should sample configuration values from this function.
  484. Arguments:
  485. func: An callable function to draw a sample from.
  486. """
  487. return Function(func)
  488. @PublicAPI
  489. def uniform(lower: float, upper: float):
  490. """Sample a float value uniformly between ``lower`` and ``upper``.
  491. Sampling from ``tune.uniform(1, 10)`` is equivalent to sampling from
  492. ``np.random.uniform(1, 10))``
  493. """
  494. return Float(lower, upper).uniform()
  495. @PublicAPI
  496. def quniform(lower: float, upper: float, q: float):
  497. """Sample a quantized float value uniformly between ``lower`` and ``upper``.
  498. Sampling from ``tune.uniform(1, 10)`` is equivalent to sampling from
  499. ``np.random.uniform(1, 10))``
  500. The value will be quantized, i.e. rounded to an integer increment of ``q``.
  501. Quantization makes the upper bound inclusive.
  502. """
  503. return Float(lower, upper).uniform().quantized(q)
  504. @PublicAPI
  505. def loguniform(lower: float, upper: float, base: object = _MISSING):
  506. """Sugar for sampling in different orders of magnitude.
  507. Args:
  508. lower: Lower boundary of the output interval (e.g. 1e-4)
  509. upper: Upper boundary of the output interval (e.g. 1e-2)
  510. """
  511. if base is not _MISSING:
  512. _warn_for_base()
  513. return Float(lower, upper).loguniform()
  514. @PublicAPI
  515. def qloguniform(lower: float, upper: float, q: float, base: object = _MISSING):
  516. """Sugar for sampling in different orders of magnitude.
  517. The value will be quantized, i.e. rounded to an integer increment of ``q``.
  518. Quantization makes the upper bound inclusive.
  519. Args:
  520. lower: Lower boundary of the output interval (e.g. 1e-4)
  521. upper: Upper boundary of the output interval (e.g. 1e-2)
  522. q: Quantization number. The result will be rounded to an
  523. integer increment of this value.
  524. """
  525. if base is not _MISSING:
  526. _warn_for_base()
  527. return Float(lower, upper).loguniform().quantized(q)
  528. @PublicAPI
  529. def choice(categories: Sequence):
  530. """Sample a categorical value.
  531. Sampling from ``tune.choice([1, 2])`` is equivalent to sampling from
  532. ``np.random.choice([1, 2])``
  533. """
  534. return Categorical(categories).uniform()
  535. @PublicAPI
  536. def randint(lower: int, upper: int):
  537. """Sample an integer value uniformly between ``lower`` and ``upper``.
  538. ``lower`` is inclusive, ``upper`` is exclusive.
  539. Sampling from ``tune.randint(10)`` is equivalent to sampling from
  540. ``np.random.randint(10)``
  541. .. versionchanged:: 1.5.0
  542. When converting Ray Tune configs to searcher-specific search spaces,
  543. the lower and upper limits are adjusted to keep compatibility with
  544. the bounds stated in the docstring above.
  545. """
  546. return Integer(lower, upper).uniform()
  547. @PublicAPI
  548. def lograndint(lower: int, upper: int, base: object = _MISSING):
  549. """Sample an integer value log-uniformly between ``lower`` and ``upper``.
  550. ``lower`` is inclusive, ``upper`` is exclusive.
  551. .. versionchanged:: 1.5.0
  552. When converting Ray Tune configs to searcher-specific search spaces,
  553. the lower and upper limits are adjusted to keep compatibility with
  554. the bounds stated in the docstring above.
  555. """
  556. if base is not _MISSING:
  557. _warn_for_base()
  558. return Integer(lower, upper).loguniform()
  559. @PublicAPI
  560. def qrandint(lower: int, upper: int, q: int = 1):
  561. """Sample an integer value uniformly between ``lower`` and ``upper``.
  562. ``lower`` is inclusive, ``upper`` is also inclusive (!).
  563. The value will be quantized, i.e. rounded to an integer increment of ``q``.
  564. Quantization makes the upper bound inclusive.
  565. .. versionchanged:: 1.5.0
  566. When converting Ray Tune configs to searcher-specific search spaces,
  567. the lower and upper limits are adjusted to keep compatibility with
  568. the bounds stated in the docstring above.
  569. """
  570. return Integer(lower, upper).uniform().quantized(q)
  571. @PublicAPI
  572. def qlograndint(lower: int, upper: int, q: int, base: object = _MISSING):
  573. """Sample an integer value log-uniformly between ``lower`` and ``upper``.
  574. ``lower`` is inclusive, ``upper`` is also inclusive (!).
  575. The value will be quantized, i.e. rounded to an integer increment of ``q``.
  576. Quantization makes the upper bound inclusive.
  577. .. versionchanged:: 1.5.0
  578. When converting Ray Tune configs to searcher-specific search spaces,
  579. the lower and upper limits are adjusted to keep compatibility with
  580. the bounds stated in the docstring above.
  581. """
  582. if base is not _MISSING:
  583. _warn_for_base()
  584. return Integer(lower, upper).loguniform().quantized(q)
  585. @PublicAPI
  586. def randn(mean: float = 0.0, sd: float = 1.0):
  587. """Sample a float value normally with ``mean`` and ``sd``.
  588. Args:
  589. mean: Mean of the normal distribution. Defaults to 0.
  590. sd: SD of the normal distribution. Defaults to 1.
  591. """
  592. return Float(None, None).normal(mean, sd)
  593. @PublicAPI
  594. def qrandn(mean: float, sd: float, q: float):
  595. """Sample a float value normally with ``mean`` and ``sd``.
  596. The value will be quantized, i.e. rounded to an integer increment of ``q``.
  597. Args:
  598. mean: Mean of the normal distribution.
  599. sd: SD of the normal distribution.
  600. q: Quantization number. The result will be rounded to an
  601. integer increment of this value.
  602. """
  603. return Float(None, None).normal(mean, sd).quantized(q)