test_linesearch.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. """
  2. Tests for line search routines
  3. """
  4. import warnings
  5. from numpy.testing import (assert_equal, assert_array_almost_equal,
  6. assert_array_almost_equal_nulp)
  7. import scipy.optimize._linesearch as ls
  8. from scipy.optimize._linesearch import LineSearchWarning
  9. import numpy as np
  10. import pytest
  11. import threading
  12. def assert_wolfe(s, phi, derphi, c1=1e-4, c2=0.9, err_msg=""):
  13. """
  14. Check that strong Wolfe conditions apply
  15. """
  16. phi1 = phi(s)
  17. phi0 = phi(0)
  18. derphi0 = derphi(0)
  19. derphi1 = derphi(s)
  20. msg = (f"s = {s}; phi(0) = {phi0}; phi(s) = {phi1}; phi'(0) = {derphi0};"
  21. f" phi'(s) = {derphi1}; {err_msg}")
  22. assert phi1 <= phi0 + c1*s*derphi0, "Wolfe 1 failed: " + msg
  23. assert abs(derphi1) <= abs(c2*derphi0), "Wolfe 2 failed: " + msg
  24. def assert_armijo(s, phi, c1=1e-4, err_msg=""):
  25. """
  26. Check that Armijo condition applies
  27. """
  28. phi1 = phi(s)
  29. phi0 = phi(0)
  30. msg = f"s = {s}; phi(0) = {phi0}; phi(s) = {phi1}; {err_msg}"
  31. assert phi1 <= (1 - c1*s)*phi0, msg
  32. def assert_line_wolfe(x, p, s, f, fprime, **kw):
  33. assert_wolfe(s, phi=lambda sp: f(x + p*sp),
  34. derphi=lambda sp: np.dot(fprime(x + p*sp), p), **kw)
  35. def assert_line_armijo(x, p, s, f, **kw):
  36. assert_armijo(s, phi=lambda sp: f(x + p*sp), **kw)
  37. def assert_fp_equal(x, y, err_msg="", nulp=50):
  38. """Assert two arrays are equal, up to some floating-point rounding error"""
  39. try:
  40. assert_array_almost_equal_nulp(x, y, nulp)
  41. except AssertionError as e:
  42. raise AssertionError(f"{e}\n{err_msg}") from e
  43. class TestLineSearch:
  44. # -- scalar functions; must have dphi(0.) < 0
  45. def _scalar_func_1(self, s): # skip name check
  46. if not hasattr(self.fcount, 'c'):
  47. self.fcount.c = 0
  48. self.fcount.c += 1
  49. p = -s - s**3 + s**4
  50. dp = -1 - 3*s**2 + 4*s**3
  51. return p, dp
  52. def _scalar_func_2(self, s): # skip name check
  53. if not hasattr(self.fcount, 'c'):
  54. self.fcount.c = 0
  55. self.fcount.c += 1
  56. p = np.exp(-4*s) + s**2
  57. dp = -4*np.exp(-4*s) + 2*s
  58. return p, dp
  59. def _scalar_func_3(self, s): # skip name check
  60. if not hasattr(self.fcount, 'c'):
  61. self.fcount.c = 0
  62. self.fcount.c += 1
  63. p = -np.sin(10*s)
  64. dp = -10*np.cos(10*s)
  65. return p, dp
  66. # -- n-d functions
  67. def _line_func_1(self, x): # skip name check
  68. if not hasattr(self.fcount, 'c'):
  69. self.fcount.c = 0
  70. self.fcount.c += 1
  71. f = np.dot(x, x)
  72. df = 2*x
  73. return f, df
  74. def _line_func_2(self, x): # skip name check
  75. if not hasattr(self.fcount, 'c'):
  76. self.fcount.c = 0
  77. self.fcount.c += 1
  78. f = np.dot(x, np.dot(self.A, x)) + 1
  79. df = np.dot(self.A + self.A.T, x)
  80. return f, df
  81. # --
  82. def setup_method(self):
  83. self.scalar_funcs = []
  84. self.line_funcs = []
  85. self.N = 20
  86. self.fcount = threading.local()
  87. def bind_index(func, idx):
  88. # Remember Python's closure semantics!
  89. return lambda *a, **kw: func(*a, **kw)[idx]
  90. for name in sorted(dir(self)):
  91. if name.startswith('_scalar_func_'):
  92. value = getattr(self, name)
  93. self.scalar_funcs.append(
  94. (name, bind_index(value, 0), bind_index(value, 1)))
  95. elif name.startswith('_line_func_'):
  96. value = getattr(self, name)
  97. self.line_funcs.append(
  98. (name, bind_index(value, 0), bind_index(value, 1)))
  99. # the choice of seed affects whether the tests pass
  100. rng = np.random.default_rng(1231892908)
  101. self.A = rng.standard_normal((self.N, self.N))
  102. def scalar_iter(self):
  103. rng = np.random.default_rng(2231892908)
  104. for name, phi, derphi in self.scalar_funcs:
  105. for old_phi0 in rng.standard_normal(3):
  106. yield name, phi, derphi, old_phi0
  107. def line_iter(self):
  108. rng = np.random.default_rng(2231892908)
  109. for name, f, fprime in self.line_funcs:
  110. k = 0
  111. while k < 9:
  112. x = rng.standard_normal(self.N)
  113. p = rng.standard_normal(self.N)
  114. if np.dot(p, fprime(x)) >= 0:
  115. # always pick a descent direction
  116. continue
  117. k += 1
  118. old_fv = float(rng.standard_normal())
  119. yield name, f, fprime, x, p, old_fv
  120. # -- Generic scalar searches
  121. def test_scalar_search_wolfe1(self):
  122. c = 0
  123. for name, phi, derphi, old_phi0 in self.scalar_iter():
  124. c += 1
  125. s, phi1, phi0 = ls.scalar_search_wolfe1(phi, derphi, phi(0),
  126. old_phi0, derphi(0))
  127. assert_fp_equal(phi0, phi(0), name)
  128. assert_fp_equal(phi1, phi(s), name)
  129. assert_wolfe(s, phi, derphi, err_msg=name)
  130. assert c > 3 # check that the iterator really works...
  131. def test_scalar_search_wolfe2(self):
  132. for name, phi, derphi, old_phi0 in self.scalar_iter():
  133. s, phi1, phi0, derphi1 = ls.scalar_search_wolfe2(
  134. phi, derphi, phi(0), old_phi0, derphi(0))
  135. assert_fp_equal(phi0, phi(0), name)
  136. assert_fp_equal(phi1, phi(s), name)
  137. if derphi1 is not None:
  138. assert_fp_equal(derphi1, derphi(s), name)
  139. assert_wolfe(s, phi, derphi, err_msg=f"{name} {old_phi0:g}")
  140. def test_scalar_search_wolfe2_with_low_amax(self):
  141. def phi(alpha):
  142. return (alpha - 5) ** 2
  143. def derphi(alpha):
  144. return 2 * (alpha - 5)
  145. alpha_star, _, _, derphi_star = ls.scalar_search_wolfe2(phi, derphi, amax=0.001)
  146. assert alpha_star is None # Not converged
  147. assert derphi_star is None # Not converged
  148. def test_scalar_search_wolfe2_regression(self):
  149. # Regression test for gh-12157
  150. # This phi has its minimum at alpha=4/3 ~ 1.333.
  151. def phi(alpha):
  152. if alpha < 1:
  153. return - 3*np.pi/2 * (alpha - 1)
  154. else:
  155. return np.cos(3*np.pi/2 * alpha - np.pi)
  156. def derphi(alpha):
  157. if alpha < 1:
  158. return - 3*np.pi/2
  159. else:
  160. return - 3*np.pi/2 * np.sin(3*np.pi/2 * alpha - np.pi)
  161. s, _, _, _ = ls.scalar_search_wolfe2(phi, derphi)
  162. # Without the fix in gh-13073, the scalar_search_wolfe2
  163. # returned s=2.0 instead.
  164. assert s < 1.5
  165. def test_scalar_search_armijo(self):
  166. for name, phi, derphi, old_phi0 in self.scalar_iter():
  167. s, phi1 = ls.scalar_search_armijo(phi, phi(0), derphi(0))
  168. assert_fp_equal(phi1, phi(s), name)
  169. assert_armijo(s, phi, err_msg=f"{name} {old_phi0:g}")
  170. # -- Generic line searches
  171. def test_line_search_wolfe1(self):
  172. c = 0
  173. smax = 100
  174. for name, f, fprime, x, p, old_f in self.line_iter():
  175. f0 = f(x)
  176. g0 = fprime(x)
  177. self.fcount.c = 0
  178. s, fc, gc, fv, ofv, gv = ls.line_search_wolfe1(f, fprime, x, p,
  179. g0, f0, old_f,
  180. amax=smax)
  181. assert_equal(self.fcount.c, fc+gc)
  182. assert_fp_equal(ofv, f(x))
  183. if s is None:
  184. continue
  185. assert_fp_equal(fv, f(x + s*p))
  186. assert_array_almost_equal(gv, fprime(x + s*p), decimal=14)
  187. if s < smax:
  188. c += 1
  189. assert_line_wolfe(x, p, s, f, fprime, err_msg=name)
  190. assert c > 3 # check that the iterator really works...
  191. def test_line_search_wolfe2(self):
  192. c = 0
  193. smax = 512
  194. for name, f, fprime, x, p, old_f in self.line_iter():
  195. f0 = f(x)
  196. g0 = fprime(x)
  197. self.fcount.c = 0
  198. with warnings.catch_warnings():
  199. warnings.filterwarnings(
  200. "ignore",
  201. "The line search algorithm could not find a solution",
  202. LineSearchWarning)
  203. warnings.filterwarnings(
  204. "ignore",
  205. "The line search algorithm did not converge",
  206. LineSearchWarning)
  207. s, fc, gc, fv, ofv, gv = ls.line_search_wolfe2(f, fprime, x, p,
  208. g0, f0, old_f,
  209. amax=smax)
  210. assert_equal(self.fcount.c, fc+gc)
  211. assert_fp_equal(ofv, f(x))
  212. assert_fp_equal(fv, f(x + s*p))
  213. if gv is not None:
  214. assert_array_almost_equal(gv, fprime(x + s*p), decimal=14)
  215. if s < smax:
  216. c += 1
  217. assert_line_wolfe(x, p, s, f, fprime, err_msg=name)
  218. assert c > 3 # check that the iterator really works...
  219. def test_line_search_wolfe2_bounds(self):
  220. # See gh-7475
  221. # For this f and p, starting at a point on axis 0, the strong Wolfe
  222. # condition 2 is met if and only if the step length s satisfies
  223. # |x + s| <= c2 * |x|
  224. def f(x):
  225. return np.dot(x, x)
  226. def fp(x):
  227. return 2 * x
  228. p = np.array([1, 0])
  229. # Smallest s satisfying strong Wolfe conditions for these arguments is 30
  230. x = -60 * p
  231. c2 = 0.5
  232. s, _, _, _, _, _ = ls.line_search_wolfe2(f, fp, x, p, amax=30, c2=c2)
  233. assert_line_wolfe(x, p, s, f, fp)
  234. with pytest.warns(LineSearchWarning):
  235. s, _, _, _, _, _ = ls.line_search_wolfe2(f, fp, x, p,
  236. amax=29, c2=c2)
  237. assert s is None
  238. # s=30 will only be tried on the 6th iteration, so this won't converge
  239. with pytest.warns(LineSearchWarning):
  240. ls.line_search_wolfe2(f, fp, x, p, c2=c2, maxiter=5)
  241. def test_line_search_armijo(self):
  242. c = 0
  243. for name, f, fprime, x, p, old_f in self.line_iter():
  244. f0 = f(x)
  245. g0 = fprime(x)
  246. self.fcount.c = 0
  247. s, fc, fv = ls.line_search_armijo(f, x, p, g0, f0)
  248. c += 1
  249. assert_equal(self.fcount.c, fc)
  250. assert_fp_equal(fv, f(x + s*p))
  251. assert_line_armijo(x, p, s, f, err_msg=name)
  252. assert c >= 9
  253. # -- More specific tests
  254. def test_armijo_terminate_1(self):
  255. # Armijo should evaluate the function only once if the trial step
  256. # is already suitable
  257. count = [0]
  258. def phi(s):
  259. count[0] += 1
  260. return -s + 0.01*s**2
  261. s, phi1 = ls.scalar_search_armijo(phi, phi(0), -1, alpha0=1)
  262. assert_equal(s, 1)
  263. assert_equal(count[0], 2)
  264. assert_armijo(s, phi)
  265. def test_wolfe_terminate(self):
  266. # wolfe1 and wolfe2 should also evaluate the function only a few
  267. # times if the trial step is already suitable
  268. def phi(s):
  269. count[0] += 1
  270. return -s + 0.05*s**2
  271. def derphi(s):
  272. count[0] += 1
  273. return -1 + 0.05*2*s
  274. for func in [ls.scalar_search_wolfe1, ls.scalar_search_wolfe2]:
  275. count = [0]
  276. r = func(phi, derphi, phi(0), None, derphi(0))
  277. assert r[0] is not None, (r, func)
  278. assert count[0] <= 2 + 2, (count, func)
  279. assert_wolfe(r[0], phi, derphi, err_msg=str(func))