test_integrate.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844
  1. # Authors: Nils Wagner, Ed Schofield, Pauli Virtanen, John Travers
  2. """
  3. Tests for numerical integration.
  4. """
  5. import numpy as np
  6. from numpy import (arange, zeros, array, dot, sqrt, cos, sin, eye, pi, exp,
  7. allclose)
  8. from numpy.testing import (
  9. assert_, assert_array_almost_equal,
  10. assert_allclose, assert_array_equal, assert_equal)
  11. import pytest
  12. from pytest import raises as assert_raises
  13. from scipy.integrate import odeint, ode, complex_ode
  14. #------------------------------------------------------------------------------
  15. # Test ODE integrators
  16. #------------------------------------------------------------------------------
  17. class TestOdeint:
  18. # Check integrate.odeint
  19. def _do_problem(self, problem):
  20. t = arange(0.0, problem.stop_t, 0.05)
  21. # Basic case
  22. z, infodict = odeint(problem.f, problem.z0, t, full_output=True)
  23. assert_(problem.verify(z, t))
  24. # Use tfirst=True
  25. z, infodict = odeint(lambda t, y: problem.f(y, t), problem.z0, t,
  26. full_output=True, tfirst=True)
  27. assert_(problem.verify(z, t))
  28. if hasattr(problem, 'jac'):
  29. # Use Dfun
  30. z, infodict = odeint(problem.f, problem.z0, t, Dfun=problem.jac,
  31. full_output=True)
  32. assert_(problem.verify(z, t))
  33. # Use Dfun and tfirst=True
  34. z, infodict = odeint(lambda t, y: problem.f(y, t), problem.z0, t,
  35. Dfun=lambda t, y: problem.jac(y, t),
  36. full_output=True, tfirst=True)
  37. assert_(problem.verify(z, t))
  38. def test_odeint(self):
  39. for problem_cls in PROBLEMS:
  40. problem = problem_cls()
  41. if problem.cmplx:
  42. continue
  43. self._do_problem(problem)
  44. class TestODEClass:
  45. ode_class = None # Set in subclass.
  46. def _do_problem(self, problem, integrator, method='adams'):
  47. # ode has callback arguments in different order than odeint
  48. def f(t, z):
  49. return problem.f(z, t)
  50. jac = None
  51. if hasattr(problem, 'jac'):
  52. def jac(t, z):
  53. return problem.jac(z, t)
  54. integrator_params = {}
  55. if problem.lband is not None or problem.uband is not None:
  56. integrator_params['uband'] = problem.uband
  57. integrator_params['lband'] = problem.lband
  58. ig = self.ode_class(f, jac)
  59. ig.set_integrator(integrator,
  60. atol=problem.atol/10,
  61. rtol=problem.rtol/10,
  62. method=method,
  63. **integrator_params)
  64. ig.set_initial_value(problem.z0, t=0.0)
  65. z = ig.integrate(problem.stop_t)
  66. assert_array_equal(z, ig.y)
  67. assert_(ig.successful(), (problem, method))
  68. assert_(ig.get_return_code() > 0, (problem, method))
  69. assert_(problem.verify(array([z]), problem.stop_t), (problem, method))
  70. class TestOde(TestODEClass):
  71. ode_class = ode
  72. def test_vode(self):
  73. # Check the vode solver
  74. for problem_cls in PROBLEMS:
  75. problem = problem_cls()
  76. if problem.cmplx:
  77. continue
  78. if not problem.stiff:
  79. self._do_problem(problem, 'vode', 'adams')
  80. self._do_problem(problem, 'vode', 'bdf')
  81. def test_zvode(self):
  82. # Check the zvode solver
  83. for problem_cls in PROBLEMS:
  84. problem = problem_cls()
  85. if not problem.stiff:
  86. self._do_problem(problem, 'zvode', 'adams')
  87. self._do_problem(problem, 'zvode', 'bdf')
  88. def test_lsoda(self):
  89. # Check the lsoda solver
  90. for problem_cls in PROBLEMS:
  91. problem = problem_cls()
  92. if problem.cmplx:
  93. continue
  94. self._do_problem(problem, 'lsoda')
  95. def test_dopri5(self):
  96. # Check the dopri5 solver
  97. for problem_cls in PROBLEMS:
  98. problem = problem_cls()
  99. if problem.cmplx:
  100. continue
  101. if problem.stiff:
  102. continue
  103. if hasattr(problem, 'jac'):
  104. continue
  105. self._do_problem(problem, 'dopri5')
  106. def test_dop853(self):
  107. # Check the dop853 solver
  108. for problem_cls in PROBLEMS:
  109. problem = problem_cls()
  110. if problem.cmplx:
  111. continue
  112. if problem.stiff:
  113. continue
  114. if hasattr(problem, 'jac'):
  115. continue
  116. self._do_problem(problem, 'dop853')
  117. def test_concurrent_fail(self):
  118. # Test concurrent usage behavior for different solvers
  119. # All solvers (vode, zvode, lsoda) now support concurrent usage
  120. # with state persistence via explicit state parameters
  121. for sol in ('vode', 'zvode', 'lsoda'):
  122. def f(t, y):
  123. return 1.0
  124. r = ode(f).set_integrator(sol)
  125. r.set_initial_value(0, 0)
  126. r2 = ode(f).set_integrator(sol)
  127. r2.set_initial_value(0, 0)
  128. r.integrate(r.t + 0.1)
  129. r2.integrate(r2.t + 0.1)
  130. # With state persistence, r should still work correctly
  131. r.integrate(r.t + 0.1)
  132. assert r.successful()
  133. def test_concurrent_ok(self, num_parallel_threads):
  134. def f(t, y):
  135. return 1.0
  136. for k in range(3):
  137. for sol in ('vode', 'zvode', 'lsoda', 'dopri5', 'dop853'):
  138. if sol in {'vode', 'zvode', 'lsoda'} and num_parallel_threads > 1:
  139. continue
  140. r = ode(f).set_integrator(sol)
  141. r.set_initial_value(0, 0)
  142. r2 = ode(f).set_integrator(sol)
  143. r2.set_initial_value(0, 0)
  144. r.integrate(r.t + 0.1)
  145. r2.integrate(r2.t + 0.1)
  146. r2.integrate(r2.t + 0.1)
  147. assert_allclose(r.y, 0.1)
  148. assert_allclose(r2.y, 0.2)
  149. for sol in ('dopri5', 'dop853'):
  150. r = ode(f).set_integrator(sol)
  151. r.set_initial_value(0, 0)
  152. r2 = ode(f).set_integrator(sol)
  153. r2.set_initial_value(0, 0)
  154. r.integrate(r.t + 0.1)
  155. r.integrate(r.t + 0.1)
  156. r2.integrate(r2.t + 0.1)
  157. r.integrate(r.t + 0.1)
  158. r2.integrate(r2.t + 0.1)
  159. assert_allclose(r.y, 0.3)
  160. assert_allclose(r2.y, 0.2)
  161. class TestComplexOde(TestODEClass):
  162. ode_class = complex_ode
  163. def test_vode(self):
  164. # Check the vode solver
  165. for problem_cls in PROBLEMS:
  166. problem = problem_cls()
  167. if not problem.stiff:
  168. self._do_problem(problem, 'vode', 'adams')
  169. else:
  170. self._do_problem(problem, 'vode', 'bdf')
  171. def test_lsoda(self):
  172. # Check the lsoda solver
  173. for problem_cls in PROBLEMS:
  174. problem = problem_cls()
  175. self._do_problem(problem, 'lsoda')
  176. def test_dopri5(self):
  177. # Check the dopri5 solver
  178. for problem_cls in PROBLEMS:
  179. problem = problem_cls()
  180. if problem.stiff:
  181. continue
  182. if hasattr(problem, 'jac'):
  183. continue
  184. self._do_problem(problem, 'dopri5')
  185. def test_dop853(self):
  186. # Check the dop853 solver
  187. for problem_cls in PROBLEMS:
  188. problem = problem_cls()
  189. if problem.stiff:
  190. continue
  191. if hasattr(problem, 'jac'):
  192. continue
  193. self._do_problem(problem, 'dop853')
  194. class TestSolout:
  195. # Check integrate.ode correctly handles solout for dopri5 and dop853
  196. def _run_solout_test(self, integrator):
  197. # Check correct usage of solout
  198. ts = []
  199. ys = []
  200. t0 = 0.0
  201. tend = 10.0
  202. y0 = [1.0, 2.0]
  203. def solout(t, y):
  204. ts.append(t)
  205. ys.append(y.copy())
  206. def rhs(t, y):
  207. return [y[0] + y[1], -y[1]**2]
  208. ig = ode(rhs).set_integrator(integrator)
  209. ig.set_solout(solout)
  210. ig.set_initial_value(y0, t0)
  211. ret = ig.integrate(tend)
  212. assert_array_equal(ys[0], y0)
  213. assert_array_equal(ys[-1], ret)
  214. assert_equal(ts[0], t0)
  215. assert_equal(ts[-1], tend)
  216. def test_solout(self):
  217. for integrator in ('dopri5', 'dop853'):
  218. self._run_solout_test(integrator)
  219. def _run_solout_after_initial_test(self, integrator):
  220. # Check if solout works even if it is set after the initial value.
  221. ts = []
  222. ys = []
  223. t0 = 0.0
  224. tend = 10.0
  225. y0 = [1.0, 2.0]
  226. def solout(t, y):
  227. ts.append(t)
  228. ys.append(y.copy())
  229. def rhs(t, y):
  230. return [y[0] + y[1], -y[1]**2]
  231. ig = ode(rhs).set_integrator(integrator)
  232. ig.set_initial_value(y0, t0)
  233. ig.set_solout(solout)
  234. ret = ig.integrate(tend)
  235. assert_array_equal(ys[0], y0)
  236. assert_array_equal(ys[-1], ret)
  237. assert_equal(ts[0], t0)
  238. assert_equal(ts[-1], tend)
  239. def test_solout_after_initial(self):
  240. for integrator in ('dopri5', 'dop853'):
  241. self._run_solout_after_initial_test(integrator)
  242. def _run_solout_break_test(self, integrator):
  243. # Check correct usage of stopping via solout
  244. ts = []
  245. ys = []
  246. t0 = 0.0
  247. tend = 10.0
  248. y0 = [1.0, 2.0]
  249. def solout(t, y):
  250. ts.append(t)
  251. ys.append(y.copy())
  252. if t > tend/2.0:
  253. return -1
  254. def rhs(t, y):
  255. return [y[0] + y[1], -y[1]**2]
  256. ig = ode(rhs).set_integrator(integrator)
  257. ig.set_solout(solout)
  258. ig.set_initial_value(y0, t0)
  259. ret = ig.integrate(tend)
  260. assert_array_equal(ys[0], y0)
  261. assert_array_equal(ys[-1], ret)
  262. assert_equal(ts[0], t0)
  263. assert_(ts[-1] > tend/2.0)
  264. assert_(ts[-1] < tend)
  265. def test_solout_break(self):
  266. for integrator in ('dopri5', 'dop853'):
  267. self._run_solout_break_test(integrator)
  268. class TestComplexSolout:
  269. # Check integrate.ode correctly handles solout for dopri5 and dop853
  270. def _run_solout_test(self, integrator):
  271. # Check correct usage of solout
  272. ts = []
  273. ys = []
  274. t0 = 0.0
  275. tend = 20.0
  276. y0 = [0.0]
  277. def solout(t, y):
  278. ts.append(t)
  279. ys.append(y.copy())
  280. def rhs(t, y):
  281. return [1.0/(t - 10.0 - 1j)]
  282. ig = complex_ode(rhs).set_integrator(integrator)
  283. ig.set_solout(solout)
  284. ig.set_initial_value(y0, t0)
  285. ret = ig.integrate(tend)
  286. assert_array_equal(ys[0], y0)
  287. assert_array_equal(ys[-1], ret)
  288. assert_equal(ts[0], t0)
  289. assert_equal(ts[-1], tend)
  290. def test_solout(self):
  291. for integrator in ('dopri5', 'dop853'):
  292. self._run_solout_test(integrator)
  293. def _run_solout_break_test(self, integrator):
  294. # Check correct usage of stopping via solout
  295. ts = []
  296. ys = []
  297. t0 = 0.0
  298. tend = 20.0
  299. y0 = [0.0]
  300. def solout(t, y):
  301. ts.append(t)
  302. ys.append(y.copy())
  303. if t > tend/2.0:
  304. return -1
  305. def rhs(t, y):
  306. return [1.0/(t - 10.0 - 1j)]
  307. ig = complex_ode(rhs).set_integrator(integrator)
  308. ig.set_solout(solout)
  309. ig.set_initial_value(y0, t0)
  310. ret = ig.integrate(tend)
  311. assert_array_equal(ys[0], y0)
  312. assert_array_equal(ys[-1], ret)
  313. assert_equal(ts[0], t0)
  314. assert_(ts[-1] > tend/2.0)
  315. assert_(ts[-1] < tend)
  316. def test_solout_break(self):
  317. for integrator in ('dopri5', 'dop853'):
  318. self._run_solout_break_test(integrator)
  319. #------------------------------------------------------------------------------
  320. # Test problems
  321. #------------------------------------------------------------------------------
  322. class ODE:
  323. """
  324. ODE problem
  325. """
  326. stiff = False
  327. cmplx = False
  328. stop_t = 1
  329. z0 = []
  330. lband = None
  331. uband = None
  332. atol = 1e-6
  333. rtol = 1e-5
  334. class SimpleOscillator(ODE):
  335. r"""
  336. Free vibration of a simple oscillator::
  337. m \ddot{u} + k u = 0, u(0) = u_0 \dot{u}(0) \dot{u}_0
  338. Solution::
  339. u(t) = u_0*cos(sqrt(k/m)*t)+\dot{u}_0*sin(sqrt(k/m)*t)/sqrt(k/m)
  340. """
  341. stop_t = 1 + 0.09
  342. z0 = array([1.0, 0.1], float)
  343. k = 4.0
  344. m = 1.0
  345. def f(self, z, t):
  346. tmp = zeros((2, 2), float)
  347. tmp[0, 1] = 1.0
  348. tmp[1, 0] = -self.k / self.m
  349. return dot(tmp, z)
  350. def verify(self, zs, t):
  351. omega = sqrt(self.k / self.m)
  352. u = self.z0[0]*cos(omega*t) + self.z0[1]*sin(omega*t)/omega
  353. return allclose(u, zs[:, 0], atol=self.atol, rtol=self.rtol)
  354. class ComplexExp(ODE):
  355. r"""The equation :lm:`\dot u = i u`"""
  356. stop_t = 1.23*pi
  357. z0 = exp([1j, 2j, 3j, 4j, 5j])
  358. cmplx = True
  359. def f(self, z, t):
  360. return 1j*z
  361. def jac(self, z, t):
  362. return 1j*eye(5)
  363. def verify(self, zs, t):
  364. u = self.z0 * exp(1j*t)
  365. return allclose(u, zs, atol=self.atol, rtol=self.rtol)
  366. class Pi(ODE):
  367. r"""Integrate 1/(t + 1j) from t=-10 to t=10"""
  368. stop_t = 20
  369. z0 = [0]
  370. cmplx = True
  371. def f(self, z, t):
  372. return array([1./(t - 10 + 1j)])
  373. def verify(self, zs, t):
  374. u = -2j * np.arctan(10)
  375. return allclose(u, zs[-1, :], atol=self.atol, rtol=self.rtol)
  376. class CoupledDecay(ODE):
  377. r"""
  378. 3 coupled decays suited for banded treatment
  379. (banded mode makes it necessary when N>>3)
  380. """
  381. stiff = True
  382. stop_t = 0.5
  383. z0 = [5.0, 7.0, 13.0]
  384. lband = 1
  385. uband = 0
  386. lmbd = [0.17, 0.23, 0.29] # fictitious decay constants
  387. def f(self, z, t):
  388. lmbd = self.lmbd
  389. return np.array([-lmbd[0]*z[0],
  390. -lmbd[1]*z[1] + lmbd[0]*z[0],
  391. -lmbd[2]*z[2] + lmbd[1]*z[1]])
  392. def jac(self, z, t):
  393. # The full Jacobian is
  394. #
  395. # [-lmbd[0] 0 0 ]
  396. # [ lmbd[0] -lmbd[1] 0 ]
  397. # [ 0 lmbd[1] -lmbd[2]]
  398. #
  399. # The lower and upper bandwidths are lband=1 and uband=0, resp.
  400. # The representation of this array in packed format is
  401. #
  402. # [-lmbd[0] -lmbd[1] -lmbd[2]]
  403. # [ lmbd[0] lmbd[1] 0 ]
  404. lmbd = self.lmbd
  405. j = np.zeros((self.lband + self.uband + 1, 3), order='F')
  406. def set_j(ri, ci, val):
  407. j[self.uband + ri - ci, ci] = val
  408. set_j(0, 0, -lmbd[0])
  409. set_j(1, 0, lmbd[0])
  410. set_j(1, 1, -lmbd[1])
  411. set_j(2, 1, lmbd[1])
  412. set_j(2, 2, -lmbd[2])
  413. return j
  414. def verify(self, zs, t):
  415. # Formulae derived by hand
  416. lmbd = np.array(self.lmbd)
  417. d10 = lmbd[1] - lmbd[0]
  418. d21 = lmbd[2] - lmbd[1]
  419. d20 = lmbd[2] - lmbd[0]
  420. e0 = np.exp(-lmbd[0] * t)
  421. e1 = np.exp(-lmbd[1] * t)
  422. e2 = np.exp(-lmbd[2] * t)
  423. u = np.vstack((
  424. self.z0[0] * e0,
  425. self.z0[1] * e1 + self.z0[0] * lmbd[0] / d10 * (e0 - e1),
  426. self.z0[2] * e2 + self.z0[1] * lmbd[1] / d21 * (e1 - e2) +
  427. lmbd[1] * lmbd[0] * self.z0[0] / d10 *
  428. (1 / d20 * (e0 - e2) - 1 / d21 * (e1 - e2)))).transpose()
  429. return allclose(u, zs, atol=self.atol, rtol=self.rtol)
  430. PROBLEMS = [SimpleOscillator, ComplexExp, Pi, CoupledDecay]
  431. #------------------------------------------------------------------------------
  432. def f(t, x):
  433. dxdt = [x[1], -x[0]]
  434. return dxdt
  435. def jac(t, x):
  436. j = array([[0.0, 1.0],
  437. [-1.0, 0.0]])
  438. return j
  439. def f1(t, x, omega):
  440. dxdt = [omega*x[1], -omega*x[0]]
  441. return dxdt
  442. def jac1(t, x, omega):
  443. j = array([[0.0, omega],
  444. [-omega, 0.0]])
  445. return j
  446. def f2(t, x, omega1, omega2):
  447. dxdt = [omega1*x[1], -omega2*x[0]]
  448. return dxdt
  449. def jac2(t, x, omega1, omega2):
  450. j = array([[0.0, omega1],
  451. [-omega2, 0.0]])
  452. return j
  453. def fv(t, x, omega):
  454. dxdt = [omega[0]*x[1], -omega[1]*x[0]]
  455. return dxdt
  456. def jacv(t, x, omega):
  457. j = array([[0.0, omega[0]],
  458. [-omega[1], 0.0]])
  459. return j
  460. class ODECheckParameterUse:
  461. """Call an ode-class solver with several cases of parameter use."""
  462. # solver_name must be set before tests can be run with this class.
  463. # Set these in subclasses.
  464. solver_name = ''
  465. solver_uses_jac = False
  466. def _get_solver(self, f, jac):
  467. solver = ode(f, jac)
  468. if self.solver_uses_jac:
  469. solver.set_integrator(self.solver_name, atol=1e-9, rtol=1e-7,
  470. with_jacobian=self.solver_uses_jac)
  471. else:
  472. # XXX Shouldn't set_integrator *always* accept the keyword arg
  473. # 'with_jacobian', and perhaps raise an exception if it is set
  474. # to True if the solver can't actually use it?
  475. solver.set_integrator(self.solver_name, atol=1e-9, rtol=1e-7)
  476. return solver
  477. def _check_solver(self, solver):
  478. ic = [1.0, 0.0]
  479. solver.set_initial_value(ic, 0.0)
  480. solver.integrate(pi)
  481. assert_array_almost_equal(solver.y, [-1.0, 0.0])
  482. def test_no_params(self):
  483. solver = self._get_solver(f, jac)
  484. self._check_solver(solver)
  485. def test_one_scalar_param(self):
  486. solver = self._get_solver(f1, jac1)
  487. omega = 1.0
  488. solver.set_f_params(omega)
  489. if self.solver_uses_jac:
  490. solver.set_jac_params(omega)
  491. self._check_solver(solver)
  492. def test_two_scalar_params(self):
  493. solver = self._get_solver(f2, jac2)
  494. omega1 = 1.0
  495. omega2 = 1.0
  496. solver.set_f_params(omega1, omega2)
  497. if self.solver_uses_jac:
  498. solver.set_jac_params(omega1, omega2)
  499. self._check_solver(solver)
  500. def test_vector_param(self):
  501. solver = self._get_solver(fv, jacv)
  502. omega = [1.0, 1.0]
  503. solver.set_f_params(omega)
  504. if self.solver_uses_jac:
  505. solver.set_jac_params(omega)
  506. self._check_solver(solver)
  507. def test_warns_on_failure(self):
  508. # Set nsteps small to ensure failure
  509. solver = self._get_solver(f, jac)
  510. solver.set_integrator(self.solver_name, nsteps=1)
  511. ic = [1.0, 0.0]
  512. solver.set_initial_value(ic, 0.0)
  513. with pytest.warns(UserWarning):
  514. solver.integrate(pi)
  515. class TestDOPRI5CheckParameterUse(ODECheckParameterUse):
  516. solver_name = 'dopri5'
  517. solver_uses_jac = False
  518. class TestDOP853CheckParameterUse(ODECheckParameterUse):
  519. solver_name = 'dop853'
  520. solver_uses_jac = False
  521. class TestVODECheckParameterUse(ODECheckParameterUse):
  522. solver_name = 'vode'
  523. solver_uses_jac = True
  524. class TestZVODECheckParameterUse(ODECheckParameterUse):
  525. solver_name = 'zvode'
  526. solver_uses_jac = True
  527. class TestLSODACheckParameterUse(ODECheckParameterUse):
  528. solver_name = 'lsoda'
  529. solver_uses_jac = True
  530. def test_odeint_trivial_time():
  531. # Test that odeint succeeds when given a single time point
  532. # and full_output=True. This is a regression test for gh-4282.
  533. y0 = 1
  534. t = [0]
  535. y, info = odeint(lambda y, t: -y, y0, t, full_output=True)
  536. assert_array_equal(y, np.array([[y0]]))
  537. def test_odeint_banded_jacobian():
  538. # Test the use of the `Dfun`, `ml` and `mu` options of odeint.
  539. def func(y, t, c):
  540. return c.dot(y)
  541. def jac(y, t, c):
  542. return c
  543. def jac_transpose(y, t, c):
  544. return c.T.copy(order='C')
  545. def bjac_rows(y, t, c):
  546. jac = np.vstack((np.r_[0, np.diag(c, 1)],
  547. np.diag(c),
  548. np.r_[np.diag(c, -1), 0],
  549. np.r_[np.diag(c, -2), 0, 0]))
  550. return jac
  551. def bjac_cols(y, t, c):
  552. return bjac_rows(y, t, c).T.copy(order='C')
  553. c = array([[-205, 0.01, 0.00, 0.0],
  554. [0.1, -2.50, 0.02, 0.0],
  555. [1e-3, 0.01, -2.0, 0.01],
  556. [0.00, 0.00, 0.1, -1.0]])
  557. y0 = np.ones(4)
  558. t = np.array([0, 5, 10, 100])
  559. # Use the full Jacobian.
  560. sol1, info1 = odeint(func, y0, t, args=(c,), full_output=True,
  561. atol=1e-13, rtol=1e-11, mxstep=10000,
  562. Dfun=jac)
  563. # Use the transposed full Jacobian, with col_deriv=True.
  564. sol2, info2 = odeint(func, y0, t, args=(c,), full_output=True,
  565. atol=1e-13, rtol=1e-11, mxstep=10000,
  566. Dfun=jac_transpose, col_deriv=True)
  567. # Use the banded Jacobian.
  568. sol3, info3 = odeint(func, y0, t, args=(c,), full_output=True,
  569. atol=1e-13, rtol=1e-11, mxstep=10000,
  570. Dfun=bjac_rows, ml=2, mu=1)
  571. # Use the transposed banded Jacobian, with col_deriv=True.
  572. sol4, info4 = odeint(func, y0, t, args=(c,), full_output=True,
  573. atol=1e-13, rtol=1e-11, mxstep=10000,
  574. Dfun=bjac_cols, ml=2, mu=1, col_deriv=True)
  575. assert_allclose(sol1, sol2, err_msg="sol1 != sol2")
  576. assert_allclose(sol1, sol3, atol=1e-12, err_msg="sol1 != sol3")
  577. assert_allclose(sol3, sol4, err_msg="sol3 != sol4")
  578. # Verify that the number of jacobian evaluations was the same for the
  579. # calls of odeint with a full jacobian and with a banded jacobian. This is
  580. # a regression test--there was a bug in the handling of banded jacobians
  581. # that resulted in an incorrect jacobian matrix being passed to the LSODA
  582. # code. That would cause errors or excessive jacobian evaluations.
  583. assert_array_equal(info1['nje'], info2['nje'])
  584. assert_array_equal(info3['nje'], info4['nje'])
  585. # Test the use of tfirst
  586. sol1ty, info1ty = odeint(lambda t, y, c: func(y, t, c), y0, t, args=(c,),
  587. full_output=True, atol=1e-13, rtol=1e-11,
  588. mxstep=10000,
  589. Dfun=lambda t, y, c: jac(y, t, c), tfirst=True)
  590. # The code should execute the exact same sequence of floating point
  591. # calculations, so these should be exactly equal. We'll be safe and use
  592. # a small tolerance.
  593. assert_allclose(sol1, sol1ty, rtol=1e-12, err_msg="sol1 != sol1ty")
  594. def test_odeint_errors():
  595. def sys1d(x, t):
  596. return -100*x
  597. def bad1(x, t):
  598. return 1.0/0
  599. def bad2(x, t):
  600. return "foo"
  601. def bad_jac1(x, t):
  602. return 1.0/0
  603. def bad_jac2(x, t):
  604. return [["foo"]]
  605. def sys2d(x, t):
  606. return [-100*x[0], -0.1*x[1]]
  607. def sys2d_bad_jac(x, t):
  608. return [[1.0/0, 0], [0, -0.1]]
  609. assert_raises(ZeroDivisionError, odeint, bad1, 1.0, [0, 1])
  610. assert_raises(ValueError, odeint, bad2, 1.0, [0, 1])
  611. assert_raises(ZeroDivisionError, odeint, sys1d, 1.0, [0, 1], Dfun=bad_jac1)
  612. assert_raises(ValueError, odeint, sys1d, 1.0, [0, 1], Dfun=bad_jac2)
  613. assert_raises(ZeroDivisionError, odeint, sys2d, [1.0, 1.0], [0, 1],
  614. Dfun=sys2d_bad_jac)
  615. def test_odeint_bad_shapes():
  616. # Tests of some errors that can occur with odeint.
  617. def badrhs(x, t):
  618. return [1, -1]
  619. def sys1(x, t):
  620. return -100*x
  621. def badjac(x, t):
  622. return [[0, 0, 0]]
  623. # y0 must be at most 1-d.
  624. bad_y0 = [[0, 0], [0, 0]]
  625. assert_raises(ValueError, odeint, sys1, bad_y0, [0, 1])
  626. # t must be at most 1-d.
  627. bad_t = [[0, 1], [2, 3]]
  628. assert_raises(ValueError, odeint, sys1, [10.0], bad_t)
  629. # y0 is 10, but badrhs(x, t) returns [1, -1].
  630. assert_raises(RuntimeError, odeint, badrhs, 10, [0, 1])
  631. # shape of array returned by badjac(x, t) is not correct.
  632. assert_raises(RuntimeError, odeint, sys1, [10, 10], [0, 1], Dfun=badjac)
  633. def test_repeated_t_values():
  634. """Regression test for gh-8217."""
  635. def func(x, t):
  636. return -0.25*x
  637. t = np.zeros(10)
  638. sol = odeint(func, [1.], t)
  639. assert_array_equal(sol, np.ones((len(t), 1)))
  640. tau = 4*np.log(2)
  641. t = [0]*9 + [tau, 2*tau, 2*tau, 3*tau]
  642. sol = odeint(func, [1, 2], t, rtol=1e-12, atol=1e-12)
  643. expected_sol = np.array([[1.0, 2.0]]*9 +
  644. [[0.5, 1.0],
  645. [0.25, 0.5],
  646. [0.25, 0.5],
  647. [0.125, 0.25]])
  648. assert_allclose(sol, expected_sol)
  649. # Edge case: empty t sequence.
  650. sol = odeint(func, [1.], [])
  651. assert_array_equal(sol, np.array([], dtype=np.float64).reshape((0, 1)))
  652. # t values are not monotonic.
  653. assert_raises(ValueError, odeint, func, [1.], [0, 1, 0.5, 0])
  654. assert_raises(ValueError, odeint, func, [1, 2, 3], [0, -1, -2, 3])