base.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. from types import GenericAlias
  2. import numpy as np
  3. def check_arguments(fun, y0, support_complex):
  4. """Helper function for checking arguments common to all solvers."""
  5. y0 = np.asarray(y0)
  6. if np.issubdtype(y0.dtype, np.complexfloating):
  7. if not support_complex:
  8. raise ValueError("`y0` is complex, but the chosen solver does "
  9. "not support integration in a complex domain.")
  10. dtype = complex
  11. else:
  12. dtype = float
  13. y0 = y0.astype(dtype, copy=False)
  14. if y0.ndim != 1:
  15. raise ValueError("`y0` must be 1-dimensional.")
  16. if not np.isfinite(y0).all():
  17. raise ValueError("All components of the initial state `y0` must be finite.")
  18. def fun_wrapped(t, y):
  19. return np.asarray(fun(t, y), dtype=dtype)
  20. return fun_wrapped, y0
  21. class OdeSolver:
  22. """Base class for ODE solvers.
  23. In order to implement a new solver you need to follow the guidelines:
  24. 1. A constructor must accept parameters presented in the base class
  25. (listed below) along with any other parameters specific to a solver.
  26. 2. A constructor must accept arbitrary extraneous arguments
  27. ``**extraneous``, but warn that these arguments are irrelevant
  28. using `common.warn_extraneous` function. Do not pass these
  29. arguments to the base class.
  30. 3. A solver must implement a private method `_step_impl(self)` which
  31. propagates a solver one step further. It must return tuple
  32. ``(success, message)``, where ``success`` is a boolean indicating
  33. whether a step was successful, and ``message`` is a string
  34. containing description of a failure if a step failed or None
  35. otherwise.
  36. 4. A solver must implement a private method `_dense_output_impl(self)`,
  37. which returns a `DenseOutput` object covering the last successful
  38. step.
  39. 5. A solver must have attributes listed below in Attributes section.
  40. Note that ``t_old`` and ``step_size`` are updated automatically.
  41. 6. Use `fun(self, t, y)` method for the system rhs evaluation, this
  42. way the number of function evaluations (`nfev`) will be tracked
  43. automatically.
  44. 7. For convenience, a base class provides `fun_single(self, t, y)` and
  45. `fun_vectorized(self, t, y)` for evaluating the rhs in
  46. non-vectorized and vectorized fashions respectively (regardless of
  47. how `fun` from the constructor is implemented). These calls don't
  48. increment `nfev`.
  49. 8. If a solver uses a Jacobian matrix and LU decompositions, it should
  50. track the number of Jacobian evaluations (`njev`) and the number of
  51. LU decompositions (`nlu`).
  52. 9. By convention, the function evaluations used to compute a finite
  53. difference approximation of the Jacobian should not be counted in
  54. `nfev`, thus use `fun_single(self, t, y)` or
  55. `fun_vectorized(self, t, y)` when computing a finite difference
  56. approximation of the Jacobian.
  57. Parameters
  58. ----------
  59. fun : callable
  60. Right-hand side of the system: the time derivative of the state ``y``
  61. at time ``t``. The calling signature is ``fun(t, y)``, where ``t`` is a
  62. scalar and ``y`` is an ndarray with ``len(y) = len(y0)``. ``fun`` must
  63. return an array of the same shape as ``y``. See `vectorized` for more
  64. information.
  65. t0 : float
  66. Initial time.
  67. y0 : array_like, shape (n,)
  68. Initial state.
  69. t_bound : float
  70. Boundary time --- the integration won't continue beyond it. It also
  71. determines the direction of the integration.
  72. vectorized : bool
  73. Whether `fun` can be called in a vectorized fashion. Default is False.
  74. If ``vectorized`` is False, `fun` will always be called with ``y`` of
  75. shape ``(n,)``, where ``n = len(y0)``.
  76. If ``vectorized`` is True, `fun` may be called with ``y`` of shape
  77. ``(n, k)``, where ``k`` is an integer. In this case, `fun` must behave
  78. such that ``fun(t, y)[:, i] == fun(t, y[:, i])`` (i.e. each column of
  79. the returned array is the time derivative of the state corresponding
  80. with a column of ``y``).
  81. Setting ``vectorized=True`` allows for faster finite difference
  82. approximation of the Jacobian by methods 'Radau' and 'BDF', but
  83. will result in slower execution for other methods. It can also
  84. result in slower overall execution for 'Radau' and 'BDF' in some
  85. circumstances (e.g. small ``len(y0)``).
  86. support_complex : bool, optional
  87. Whether integration in a complex domain should be supported.
  88. Generally determined by a derived solver class capabilities.
  89. Default is False.
  90. Attributes
  91. ----------
  92. n : int
  93. Number of equations.
  94. status : string
  95. Current status of the solver: 'running', 'finished' or 'failed'.
  96. t_bound : float
  97. Boundary time.
  98. direction : float
  99. Integration direction: +1 or -1.
  100. t : float
  101. Current time.
  102. y : ndarray
  103. Current state.
  104. t_old : float
  105. Previous time. None if no steps were made yet.
  106. step_size : float
  107. Size of the last successful step. None if no steps were made yet.
  108. nfev : int
  109. Number of the system's rhs evaluations.
  110. njev : int
  111. Number of the Jacobian evaluations.
  112. nlu : int
  113. Number of LU decompositions.
  114. """
  115. TOO_SMALL_STEP = "Required step size is less than spacing between numbers."
  116. # generic type compatibility with scipy-stubs
  117. __class_getitem__ = classmethod(GenericAlias)
  118. def __init__(self, fun, t0, y0, t_bound, vectorized,
  119. support_complex=False):
  120. self.t_old = None
  121. self.t = t0
  122. self._fun, self.y = check_arguments(fun, y0, support_complex)
  123. self.t_bound = t_bound
  124. self.vectorized = vectorized
  125. if vectorized:
  126. def fun_single(t, y):
  127. return self._fun(t, y[:, None]).ravel()
  128. fun_vectorized = self._fun
  129. else:
  130. fun_single = self._fun
  131. def fun_vectorized(t, y):
  132. f = np.empty_like(y)
  133. for i, yi in enumerate(y.T):
  134. f[:, i] = self._fun(t, yi)
  135. return f
  136. def fun(t, y):
  137. self.nfev += 1
  138. return self.fun_single(t, y)
  139. self.fun = fun
  140. self.fun_single = fun_single
  141. self.fun_vectorized = fun_vectorized
  142. self.direction = np.sign(t_bound - t0) if t_bound != t0 else 1
  143. self.n = self.y.size
  144. self.status = 'running'
  145. self.nfev = 0
  146. self.njev = 0
  147. self.nlu = 0
  148. @property
  149. def step_size(self):
  150. if self.t_old is None:
  151. return None
  152. else:
  153. return np.abs(self.t - self.t_old)
  154. def step(self):
  155. """Perform one integration step.
  156. Returns
  157. -------
  158. message : string or None
  159. Report from the solver. Typically a reason for a failure if
  160. `self.status` is 'failed' after the step was taken or None
  161. otherwise.
  162. """
  163. if self.status != 'running':
  164. raise RuntimeError("Attempt to step on a failed or finished "
  165. "solver.")
  166. if self.n == 0 or self.t == self.t_bound:
  167. # Handle corner cases of empty solver or no integration.
  168. self.t_old = self.t
  169. self.t = self.t_bound
  170. message = None
  171. self.status = 'finished'
  172. else:
  173. t = self.t
  174. success, message = self._step_impl()
  175. if not success:
  176. self.status = 'failed'
  177. else:
  178. self.t_old = t
  179. if self.direction * (self.t - self.t_bound) >= 0:
  180. self.status = 'finished'
  181. return message
  182. def dense_output(self):
  183. """Compute a local interpolant over the last successful step.
  184. Returns
  185. -------
  186. sol : `DenseOutput`
  187. Local interpolant over the last successful step.
  188. """
  189. if self.t_old is None:
  190. raise RuntimeError("Dense output is available after a successful "
  191. "step was made.")
  192. if self.n == 0 or self.t == self.t_old:
  193. # Handle corner cases of empty solver and no integration.
  194. return ConstantDenseOutput(self.t_old, self.t, self.y)
  195. else:
  196. return self._dense_output_impl()
  197. def _step_impl(self):
  198. raise NotImplementedError
  199. def _dense_output_impl(self):
  200. raise NotImplementedError
  201. class DenseOutput:
  202. """Base class for local interpolant over step made by an ODE solver.
  203. It interpolates between `t_min` and `t_max` (see Attributes below).
  204. Evaluation outside this interval is not forbidden, but the accuracy is not
  205. guaranteed.
  206. Attributes
  207. ----------
  208. t_min, t_max : float
  209. Time range of the interpolation.
  210. """
  211. # generic type compatibility with scipy-stubs
  212. __class_getitem__ = classmethod(GenericAlias)
  213. def __init__(self, t_old, t):
  214. self.t_old = t_old
  215. self.t = t
  216. self.t_min = min(t, t_old)
  217. self.t_max = max(t, t_old)
  218. def __call__(self, t):
  219. """Evaluate the interpolant.
  220. Parameters
  221. ----------
  222. t : float or array_like with shape (n_points,)
  223. Points to evaluate the solution at.
  224. Returns
  225. -------
  226. y : ndarray, shape (n,) or (n, n_points)
  227. Computed values. Shape depends on whether `t` was a scalar or a
  228. 1-D array.
  229. """
  230. t = np.asarray(t)
  231. if t.ndim > 1:
  232. raise ValueError("`t` must be a float or a 1-D array.")
  233. return self._call_impl(t)
  234. def _call_impl(self, t):
  235. raise NotImplementedError
  236. class ConstantDenseOutput(DenseOutput):
  237. """Constant value interpolator.
  238. This class used for degenerate integration cases: equal integration limits
  239. or a system with 0 equations.
  240. """
  241. def __init__(self, t_old, t, value):
  242. super().__init__(t_old, t)
  243. self.value = value
  244. def _call_impl(self, t):
  245. if t.ndim == 0:
  246. return self.value
  247. else:
  248. ret = np.empty((self.value.shape[0], t.shape[0]))
  249. ret[:] = self.value[:, None]
  250. return ret