_direct_py.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. from typing import ( # noqa: UP035
  2. Any, Callable, Iterable
  3. )
  4. import numpy as np
  5. from scipy.optimize import OptimizeResult
  6. from ._constraints import old_bound_to_new, Bounds
  7. from ._direct import direct as _direct # type: ignore
  8. __all__ = ['direct']
  9. ERROR_MESSAGES = (
  10. "Number of function evaluations done is larger than maxfun={}",
  11. "Number of iterations is larger than maxiter={}",
  12. "u[i] < l[i] for some i",
  13. "maxfun is too large",
  14. "Initialization failed",
  15. "There was an error in the creation of the sample points",
  16. "An error occurred while the function was sampled",
  17. "Maximum number of levels has been reached.",
  18. "Forced stop",
  19. "Invalid arguments",
  20. "Out of memory",
  21. )
  22. SUCCESS_MESSAGES = (
  23. ("The best function value found is within a relative error={} "
  24. "of the (known) global optimum f_min"),
  25. ("The volume of the hyperrectangle containing the lowest function value "
  26. "found is below vol_tol={}"),
  27. ("The side length measure of the hyperrectangle containing the lowest "
  28. "function value found is below len_tol={}"),
  29. )
  30. def direct(
  31. func: Callable[
  32. [np.ndarray[tuple[int], np.dtype[np.float64]]],
  33. float | np.floating[Any] | np.integer[Any] | np.bool_,
  34. ],
  35. bounds: Iterable | Bounds,
  36. *,
  37. args: tuple = (),
  38. eps: float = 1e-4,
  39. maxfun: int | None = None,
  40. maxiter: int = 1000,
  41. locally_biased: bool = True,
  42. f_min: float = -np.inf,
  43. f_min_rtol: float = 1e-4,
  44. vol_tol: float = 1e-16,
  45. len_tol: float = 1e-6,
  46. callback: Callable[
  47. [np.ndarray[tuple[int], np.dtype[np.float64]]],
  48. object,
  49. ] | None = None,
  50. ) -> OptimizeResult:
  51. """
  52. Finds the global minimum of a function using the
  53. DIRECT algorithm.
  54. Parameters
  55. ----------
  56. func : callable
  57. The objective function to be minimized.
  58. ``func(x, *args) -> float``
  59. where ``x`` is a 1-D array with shape (n,) and ``args`` is a tuple of
  60. the fixed parameters needed to completely specify the function.
  61. bounds : sequence or `Bounds`
  62. Bounds for variables. There are two ways to specify the bounds:
  63. 1. Instance of `Bounds` class.
  64. 2. ``(min, max)`` pairs for each element in ``x``.
  65. args : tuple, optional
  66. Any additional fixed parameters needed to
  67. completely specify the objective function.
  68. eps : float, optional
  69. Minimal required difference of the objective function values
  70. between the current best hyperrectangle and the next potentially
  71. optimal hyperrectangle to be divided. In consequence, `eps` serves as a
  72. tradeoff between local and global search: the smaller, the more local
  73. the search becomes. Default is 1e-4.
  74. maxfun : int or None, optional
  75. Approximate upper bound on objective function evaluations.
  76. If `None`, will be automatically set to ``1000 * N`` where ``N``
  77. represents the number of dimensions. Will be capped if necessary to
  78. limit DIRECT's RAM usage to app. 1GiB. This will only occur for very
  79. high dimensional problems and excessive `max_fun`. Default is `None`.
  80. maxiter : int, optional
  81. Maximum number of iterations. Default is 1000.
  82. locally_biased : bool, optional
  83. If `True` (default), use the locally biased variant of the
  84. algorithm known as DIRECT_L. If `False`, use the original unbiased
  85. DIRECT algorithm. For hard problems with many local minima,
  86. `False` is recommended.
  87. f_min : float, optional
  88. Function value of the global optimum. Set this value only if the
  89. global optimum is known. Default is ``-np.inf``, so that this
  90. termination criterion is deactivated.
  91. f_min_rtol : float, optional
  92. Terminate the optimization once the relative error between the
  93. current best minimum `f` and the supplied global minimum `f_min`
  94. is smaller than `f_min_rtol`. This parameter is only used if
  95. `f_min` is also set. Must lie between 0 and 1. Default is 1e-4.
  96. vol_tol : float, optional
  97. Terminate the optimization once the volume of the hyperrectangle
  98. containing the lowest function value is smaller than `vol_tol`
  99. of the complete search space. Must lie between 0 and 1.
  100. Default is 1e-16.
  101. len_tol : float, optional
  102. If ``locally_biased=True``, terminate the optimization once half of
  103. the normalized maximal side length of the hyperrectangle containing
  104. the lowest function value is smaller than `len_tol`.
  105. If ``locally_biased=False``, terminate the optimization once half of
  106. the normalized diagonal of the hyperrectangle containing the lowest
  107. function value is smaller than `len_tol`. Must lie between 0 and 1.
  108. Default is 1e-6.
  109. callback : callable, optional
  110. A callback function with signature ``callback(xk)`` where ``xk``
  111. represents the best function value found so far.
  112. Returns
  113. -------
  114. res : OptimizeResult
  115. The optimization result represented as a ``OptimizeResult`` object.
  116. Important attributes are: ``x`` the solution array, ``success`` a
  117. Boolean flag indicating if the optimizer exited successfully and
  118. ``message`` which describes the cause of the termination. See
  119. `OptimizeResult` for a description of other attributes.
  120. Notes
  121. -----
  122. DIviding RECTangles (DIRECT) is a deterministic global
  123. optimization algorithm capable of minimizing a black box function with
  124. its variables subject to lower and upper bound constraints by sampling
  125. potential solutions in the search space [1]_. The algorithm starts by
  126. normalising the search space to an n-dimensional unit hypercube.
  127. It samples the function at the center of this hypercube and at 2n
  128. (n is the number of variables) more points, 2 in each coordinate
  129. direction. Using these function values, DIRECT then divides the
  130. domain into hyperrectangles, each having exactly one of the sampling
  131. points as its center. In each iteration, DIRECT chooses, using the `eps`
  132. parameter which defaults to 1e-4, some of the existing hyperrectangles
  133. to be further divided. This division process continues until either the
  134. maximum number of iterations or maximum function evaluations allowed
  135. are exceeded, or the hyperrectangle containing the minimal value found
  136. so far becomes small enough. If `f_min` is specified, the optimization
  137. will stop once this function value is reached within a relative tolerance.
  138. The locally biased variant of DIRECT (originally called DIRECT_L) [2]_ is
  139. used by default. It makes the search more locally biased and more
  140. efficient for cases with only a few local minima.
  141. A note about termination criteria: `vol_tol` refers to the volume of the
  142. hyperrectangle containing the lowest function value found so far. This
  143. volume decreases exponentially with increasing dimensionality of the
  144. problem. Therefore `vol_tol` should be decreased to avoid premature
  145. termination of the algorithm for higher dimensions. This does not hold
  146. for `len_tol`: it refers either to half of the maximal side length
  147. (for ``locally_biased=True``) or half of the diagonal of the
  148. hyperrectangle (for ``locally_biased=False``).
  149. This code is based on the DIRECT 2.0.4 Fortran code by Gablonsky et al. at
  150. https://ctk.math.ncsu.edu/SOFTWARE/DIRECTv204.tar.gz .
  151. This original version was initially converted via f2c and then cleaned up
  152. and reorganized by Steven G. Johnson, August 2007, for the NLopt project.
  153. The `direct` function wraps the C implementation.
  154. .. versionadded:: 1.9.0
  155. References
  156. ----------
  157. .. [1] Jones, D.R., Perttunen, C.D. & Stuckman, B.E. Lipschitzian
  158. optimization without the Lipschitz constant. J Optim Theory Appl
  159. 79, 157-181 (1993).
  160. .. [2] Gablonsky, J., Kelley, C. A Locally-Biased form of the DIRECT
  161. Algorithm. Journal of Global Optimization 21, 27-37 (2001).
  162. Examples
  163. --------
  164. The following example is a 2-D problem with four local minima: minimizing
  165. the Styblinski-Tang function
  166. (https://en.wikipedia.org/wiki/Test_functions_for_optimization).
  167. >>> from scipy.optimize import direct, Bounds
  168. >>> def styblinski_tang(pos):
  169. ... x, y = pos
  170. ... return 0.5 * (x**4 - 16*x**2 + 5*x + y**4 - 16*y**2 + 5*y)
  171. >>> bounds = Bounds([-4., -4.], [4., 4.])
  172. >>> result = direct(styblinski_tang, bounds)
  173. >>> result.x, result.fun, result.nfev
  174. array([-2.90321597, -2.90321597]), -78.3323279095383, 2011
  175. The correct global minimum was found but with a huge number of function
  176. evaluations (2011). Loosening the termination tolerances `vol_tol` and
  177. `len_tol` can be used to stop DIRECT earlier.
  178. >>> result = direct(styblinski_tang, bounds, len_tol=1e-3)
  179. >>> result.x, result.fun, result.nfev
  180. array([-2.9044353, -2.9044353]), -78.33230330754142, 207
  181. """
  182. # convert bounds to new Bounds class if necessary
  183. if not isinstance(bounds, Bounds):
  184. if isinstance(bounds, list) or isinstance(bounds, tuple):
  185. lb, ub = old_bound_to_new(bounds)
  186. bounds = Bounds(lb, ub)
  187. else:
  188. message = ("bounds must be a sequence or "
  189. "instance of Bounds class")
  190. raise ValueError(message)
  191. lb = np.ascontiguousarray(bounds.lb, dtype=np.float64)
  192. ub = np.ascontiguousarray(bounds.ub, dtype=np.float64)
  193. # validate bounds
  194. # check that lower bounds are smaller than upper bounds
  195. if not np.all(lb < ub):
  196. raise ValueError('Bounds are not consistent min < max')
  197. # check for infs
  198. if (np.any(np.isinf(lb)) or np.any(np.isinf(ub))):
  199. raise ValueError("Bounds must not be inf.")
  200. # validate tolerances
  201. if (vol_tol < 0 or vol_tol > 1):
  202. raise ValueError("vol_tol must be between 0 and 1.")
  203. if (len_tol < 0 or len_tol > 1):
  204. raise ValueError("len_tol must be between 0 and 1.")
  205. if (f_min_rtol < 0 or f_min_rtol > 1):
  206. raise ValueError("f_min_rtol must be between 0 and 1.")
  207. # validate maxfun and maxiter
  208. if maxfun is None:
  209. maxfun = 1000 * lb.shape[0]
  210. if not isinstance(maxfun, int):
  211. raise ValueError("maxfun must be of type int.")
  212. if maxfun < 0:
  213. raise ValueError("maxfun must be > 0.")
  214. if not isinstance(maxiter, int):
  215. raise ValueError("maxiter must be of type int.")
  216. if maxiter < 0:
  217. raise ValueError("maxiter must be > 0.")
  218. # validate boolean parameters
  219. if not isinstance(locally_biased, bool):
  220. raise ValueError("locally_biased must be True or False.")
  221. def _func_wrap(x, args=None):
  222. x = np.asarray(x)
  223. if args is None:
  224. f = func(x)
  225. else:
  226. f = func(x, *args)
  227. # always return a float
  228. return np.asarray(f).item()
  229. # TODO: fix disp argument
  230. x, fun, ret_code, nfev, nit = _direct(
  231. _func_wrap,
  232. np.asarray(lb), np.asarray(ub),
  233. args,
  234. False, eps, maxfun, maxiter,
  235. locally_biased,
  236. f_min, f_min_rtol,
  237. vol_tol, len_tol, callback
  238. )
  239. format_val = (maxfun, maxiter, f_min_rtol, vol_tol, len_tol)
  240. if ret_code > 2:
  241. message = SUCCESS_MESSAGES[ret_code - 3].format(
  242. format_val[ret_code - 1])
  243. elif 0 < ret_code <= 2:
  244. message = ERROR_MESSAGES[ret_code - 1].format(format_val[ret_code - 1])
  245. elif 0 > ret_code > -100:
  246. message = ERROR_MESSAGES[abs(ret_code) + 1]
  247. else:
  248. message = ERROR_MESSAGES[ret_code + 99]
  249. return OptimizeResult(x=np.asarray(x), fun=fun, status=ret_code,
  250. success=ret_code > 2, message=message,
  251. nfev=nfev, nit=nit)