rootoftools.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298
  1. """Implementation of RootOf class and related tools. """
  2. from sympy.core.basic import Basic
  3. from sympy.core import (S, Expr, Integer, Float, I, oo, Add, Lambda,
  4. symbols, sympify, Rational, Dummy)
  5. from sympy.core.cache import cacheit
  6. from sympy.core.relational import is_le
  7. from sympy.core.sorting import ordered
  8. from sympy.polys.domains import QQ
  9. from sympy.polys.polyerrors import (
  10. MultivariatePolynomialError,
  11. GeneratorsNeeded,
  12. PolynomialError,
  13. DomainError)
  14. from sympy.polys.polyfuncs import symmetrize, viete
  15. from sympy.polys.polyroots import (
  16. roots_linear, roots_quadratic, roots_binomial,
  17. preprocess_roots, roots)
  18. from sympy.polys.polytools import Poly, PurePoly, factor
  19. from sympy.polys.rationaltools import together
  20. from sympy.polys.rootisolation import (
  21. dup_isolate_complex_roots_sqf,
  22. dup_isolate_real_roots_sqf)
  23. from sympy.utilities import lambdify, public, sift, numbered_symbols
  24. from mpmath import mpf, mpc, findroot, workprec
  25. from mpmath.libmp.libmpf import dps_to_prec, prec_to_dps
  26. from sympy.multipledispatch import dispatch
  27. from itertools import chain
  28. __all__ = ['CRootOf']
  29. class _pure_key_dict:
  30. """A minimal dictionary that makes sure that the key is a
  31. univariate PurePoly instance.
  32. Examples
  33. ========
  34. Only the following actions are guaranteed:
  35. >>> from sympy.polys.rootoftools import _pure_key_dict
  36. >>> from sympy import PurePoly
  37. >>> from sympy.abc import x, y
  38. 1) creation
  39. >>> P = _pure_key_dict()
  40. 2) assignment for a PurePoly or univariate polynomial
  41. >>> P[x] = 1
  42. >>> P[PurePoly(x - y, x)] = 2
  43. 3) retrieval based on PurePoly key comparison (use this
  44. instead of the get method)
  45. >>> P[y]
  46. 1
  47. 4) KeyError when trying to retrieve a nonexisting key
  48. >>> P[y + 1]
  49. Traceback (most recent call last):
  50. ...
  51. KeyError: PurePoly(y + 1, y, domain='ZZ')
  52. 5) ability to query with ``in``
  53. >>> x + 1 in P
  54. False
  55. NOTE: this is a *not* a dictionary. It is a very basic object
  56. for internal use that makes sure to always address its cache
  57. via PurePoly instances. It does not, for example, implement
  58. ``get`` or ``setdefault``.
  59. """
  60. def __init__(self):
  61. self._dict = {}
  62. def __getitem__(self, k):
  63. if not isinstance(k, PurePoly):
  64. if not (isinstance(k, Expr) and len(k.free_symbols) == 1):
  65. raise KeyError
  66. k = PurePoly(k, expand=False)
  67. return self._dict[k]
  68. def __setitem__(self, k, v):
  69. if not isinstance(k, PurePoly):
  70. if not (isinstance(k, Expr) and len(k.free_symbols) == 1):
  71. raise ValueError('expecting univariate expression')
  72. k = PurePoly(k, expand=False)
  73. self._dict[k] = v
  74. def __contains__(self, k):
  75. try:
  76. self[k]
  77. return True
  78. except KeyError:
  79. return False
  80. _reals_cache = _pure_key_dict()
  81. _complexes_cache = _pure_key_dict()
  82. def _pure_factors(poly):
  83. _, factors = poly.factor_list()
  84. return [(PurePoly(f, expand=False), m) for f, m in factors]
  85. def _imag_count_of_factor(f):
  86. """Return the number of imaginary roots for irreducible
  87. univariate polynomial ``f``.
  88. """
  89. terms = [(i, j) for (i,), j in f.terms()]
  90. if any(i % 2 for i, j in terms):
  91. return 0
  92. # update signs
  93. even = [(i, I**i*j) for i, j in terms]
  94. even = Poly.from_dict(dict(even), Dummy('x'))
  95. return int(even.count_roots(-oo, oo))
  96. @public
  97. def rootof(f, x, index=None, radicals=True, expand=True):
  98. """An indexed root of a univariate polynomial.
  99. Returns either a :obj:`ComplexRootOf` object or an explicit
  100. expression involving radicals.
  101. Parameters
  102. ==========
  103. f : Expr
  104. Univariate polynomial.
  105. x : Symbol, optional
  106. Generator for ``f``.
  107. index : int or Integer
  108. radicals : bool
  109. Return a radical expression if possible.
  110. expand : bool
  111. Expand ``f``.
  112. """
  113. return CRootOf(f, x, index=index, radicals=radicals, expand=expand)
  114. @public
  115. class RootOf(Expr):
  116. """Represents a root of a univariate polynomial.
  117. Base class for roots of different kinds of polynomials.
  118. Only complex roots are currently supported.
  119. """
  120. __slots__ = ('poly',)
  121. def __new__(cls, f, x, index=None, radicals=True, expand=True):
  122. """Construct a new ``CRootOf`` object for ``k``-th root of ``f``."""
  123. return rootof(f, x, index=index, radicals=radicals, expand=expand)
  124. @public
  125. class ComplexRootOf(RootOf):
  126. """Represents an indexed complex root of a polynomial.
  127. Roots of a univariate polynomial separated into disjoint
  128. real or complex intervals and indexed in a fixed order:
  129. * real roots come first and are sorted in increasing order;
  130. * complex roots come next and are sorted primarily by increasing
  131. real part, secondarily by increasing imaginary part.
  132. Currently only rational coefficients are allowed.
  133. Can be imported as ``CRootOf``. To avoid confusion, the
  134. generator must be a Symbol.
  135. Examples
  136. ========
  137. >>> from sympy import CRootOf, rootof
  138. >>> from sympy.abc import x
  139. CRootOf is a way to reference a particular root of a
  140. polynomial. If there is a rational root, it will be returned:
  141. >>> CRootOf.clear_cache() # for doctest reproducibility
  142. >>> CRootOf(x**2 - 4, 0)
  143. -2
  144. Whether roots involving radicals are returned or not
  145. depends on whether the ``radicals`` flag is true (which is
  146. set to True with rootof):
  147. >>> CRootOf(x**2 - 3, 0)
  148. CRootOf(x**2 - 3, 0)
  149. >>> CRootOf(x**2 - 3, 0, radicals=True)
  150. -sqrt(3)
  151. >>> rootof(x**2 - 3, 0)
  152. -sqrt(3)
  153. The following cannot be expressed in terms of radicals:
  154. >>> r = rootof(4*x**5 + 16*x**3 + 12*x**2 + 7, 0); r
  155. CRootOf(4*x**5 + 16*x**3 + 12*x**2 + 7, 0)
  156. The root bounds can be seen, however, and they are used by the
  157. evaluation methods to get numerical approximations for the root.
  158. >>> interval = r._get_interval(); interval
  159. (-1, 0)
  160. >>> r.evalf(2)
  161. -0.98
  162. The evalf method refines the width of the root bounds until it
  163. guarantees that any decimal approximation within those bounds
  164. will satisfy the desired precision. It then stores the refined
  165. interval so subsequent requests at or below the requested
  166. precision will not have to recompute the root bounds and will
  167. return very quickly.
  168. Before evaluation above, the interval was
  169. >>> interval
  170. (-1, 0)
  171. After evaluation it is now
  172. >>> r._get_interval() # doctest: +SKIP
  173. (-165/169, -206/211)
  174. To reset all intervals for a given polynomial, the :meth:`_reset` method
  175. can be called from any CRootOf instance of the polynomial:
  176. >>> r._reset()
  177. >>> r._get_interval()
  178. (-1, 0)
  179. The :meth:`eval_approx` method will also find the root to a given
  180. precision but the interval is not modified unless the search
  181. for the root fails to converge within the root bounds. And
  182. the secant method is used to find the root. (The ``evalf``
  183. method uses bisection and will always update the interval.)
  184. >>> r.eval_approx(2)
  185. -0.98
  186. The interval needed to be slightly updated to find that root:
  187. >>> r._get_interval()
  188. (-1, -1/2)
  189. The ``evalf_rational`` will compute a rational approximation
  190. of the root to the desired accuracy or precision.
  191. >>> r.eval_rational(n=2)
  192. -69629/71318
  193. >>> t = CRootOf(x**3 + 10*x + 1, 1)
  194. >>> t.eval_rational(1e-1)
  195. 15/256 - 805*I/256
  196. >>> t.eval_rational(1e-1, 1e-4)
  197. 3275/65536 - 414645*I/131072
  198. >>> t.eval_rational(1e-4, 1e-4)
  199. 6545/131072 - 414645*I/131072
  200. >>> t.eval_rational(n=2)
  201. 104755/2097152 - 6634255*I/2097152
  202. Notes
  203. =====
  204. Although a PurePoly can be constructed from a non-symbol generator
  205. RootOf instances of non-symbols are disallowed to avoid confusion
  206. over what root is being represented.
  207. >>> from sympy import exp, PurePoly
  208. >>> PurePoly(x) == PurePoly(exp(x))
  209. True
  210. >>> CRootOf(x - 1, 0)
  211. 1
  212. >>> CRootOf(exp(x) - 1, 0) # would correspond to x == 0
  213. Traceback (most recent call last):
  214. ...
  215. sympy.polys.polyerrors.PolynomialError: generator must be a Symbol
  216. See Also
  217. ========
  218. eval_approx
  219. eval_rational
  220. """
  221. __slots__ = ('index',)
  222. is_complex = True
  223. is_number = True
  224. is_finite = True
  225. is_algebraic = True
  226. def __new__(cls, f, x, index=None, radicals=False, expand=True):
  227. """ Construct an indexed complex root of a polynomial.
  228. See ``rootof`` for the parameters.
  229. The default value of ``radicals`` is ``False`` to satisfy
  230. ``eval(srepr(expr) == expr``.
  231. """
  232. x = sympify(x)
  233. if index is None and x.is_Integer:
  234. x, index = None, x
  235. else:
  236. index = sympify(index)
  237. if index is not None and index.is_Integer:
  238. index = int(index)
  239. else:
  240. raise ValueError("expected an integer root index, got %s" % index)
  241. poly = PurePoly(f, x, greedy=False, expand=expand)
  242. if not poly.is_univariate:
  243. raise PolynomialError("only univariate polynomials are allowed")
  244. if not poly.gen.is_Symbol:
  245. # PurePoly(sin(x) + 1) == PurePoly(x + 1) but the roots of
  246. # x for each are not the same: issue 8617
  247. raise PolynomialError("generator must be a Symbol")
  248. degree = poly.degree()
  249. if degree <= 0:
  250. raise PolynomialError("Cannot construct CRootOf object for %s" % f)
  251. if index < -degree or index >= degree:
  252. raise IndexError("root index out of [%d, %d] range, got %d" %
  253. (-degree, degree - 1, index))
  254. elif index < 0:
  255. index += degree
  256. dom = poly.get_domain()
  257. if not dom.is_Exact:
  258. poly = poly.to_exact()
  259. roots = cls._roots_trivial(poly, radicals)
  260. if roots is not None:
  261. return roots[index]
  262. coeff, poly = preprocess_roots(poly)
  263. dom = poly.get_domain()
  264. if not dom.is_ZZ:
  265. raise NotImplementedError("CRootOf is not supported over %s" % dom)
  266. root = cls._indexed_root(poly, index, lazy=True)
  267. return coeff * cls._postprocess_root(root, radicals)
  268. @classmethod
  269. def _new(cls, poly, index):
  270. """Construct new ``CRootOf`` object from raw data. """
  271. obj = Expr.__new__(cls)
  272. obj.poly = PurePoly(poly)
  273. obj.index = index
  274. try:
  275. _reals_cache[obj.poly] = _reals_cache[poly]
  276. _complexes_cache[obj.poly] = _complexes_cache[poly]
  277. except KeyError:
  278. pass
  279. return obj
  280. def _hashable_content(self):
  281. return (self.poly, self.index)
  282. @property
  283. def expr(self):
  284. return self.poly.as_expr()
  285. @property
  286. def args(self):
  287. return (self.expr, Integer(self.index))
  288. @property
  289. def free_symbols(self):
  290. # CRootOf currently only works with univariate expressions
  291. # whose poly attribute should be a PurePoly with no free
  292. # symbols
  293. return set()
  294. def _eval_is_real(self):
  295. """Return ``True`` if the root is real. """
  296. self._ensure_reals_init()
  297. return self.index < len(_reals_cache[self.poly])
  298. def _eval_is_imaginary(self):
  299. """Return ``True`` if the root is imaginary. """
  300. self._ensure_reals_init()
  301. if self.index >= len(_reals_cache[self.poly]):
  302. ivl = self._get_interval()
  303. return ivl.ax*ivl.bx <= 0 # all others are on one side or the other
  304. return False # XXX is this necessary?
  305. @classmethod
  306. def real_roots(cls, poly, radicals=True):
  307. """Get real roots of a polynomial. """
  308. return cls._get_roots("_real_roots", poly, radicals)
  309. @classmethod
  310. def all_roots(cls, poly, radicals=True):
  311. """Get real and complex roots of a polynomial. """
  312. return cls._get_roots("_all_roots", poly, radicals)
  313. @classmethod
  314. def _get_reals_sqf(cls, currentfactor, use_cache=True):
  315. """Get real root isolating intervals for a square-free factor."""
  316. if use_cache and currentfactor in _reals_cache:
  317. real_part = _reals_cache[currentfactor]
  318. else:
  319. _reals_cache[currentfactor] = real_part = \
  320. dup_isolate_real_roots_sqf(
  321. currentfactor.rep.to_list(), currentfactor.rep.dom, blackbox=True)
  322. return real_part
  323. @classmethod
  324. def _get_complexes_sqf(cls, currentfactor, use_cache=True):
  325. """Get complex root isolating intervals for a square-free factor."""
  326. if use_cache and currentfactor in _complexes_cache:
  327. complex_part = _complexes_cache[currentfactor]
  328. else:
  329. _complexes_cache[currentfactor] = complex_part = \
  330. dup_isolate_complex_roots_sqf(
  331. currentfactor.rep.to_list(), currentfactor.rep.dom, blackbox=True)
  332. return complex_part
  333. @classmethod
  334. def _get_reals(cls, factors, use_cache=True):
  335. """Compute real root isolating intervals for a list of factors. """
  336. reals = []
  337. for currentfactor, k in factors:
  338. try:
  339. if not use_cache:
  340. raise KeyError
  341. r = _reals_cache[currentfactor]
  342. reals.extend([(i, currentfactor, k) for i in r])
  343. except KeyError:
  344. real_part = cls._get_reals_sqf(currentfactor, use_cache)
  345. new = [(root, currentfactor, k) for root in real_part]
  346. reals.extend(new)
  347. reals = cls._reals_sorted(reals)
  348. return reals
  349. @classmethod
  350. def _get_complexes(cls, factors, use_cache=True):
  351. """Compute complex root isolating intervals for a list of factors. """
  352. complexes = []
  353. for currentfactor, k in ordered(factors):
  354. try:
  355. if not use_cache:
  356. raise KeyError
  357. c = _complexes_cache[currentfactor]
  358. complexes.extend([(i, currentfactor, k) for i in c])
  359. except KeyError:
  360. complex_part = cls._get_complexes_sqf(currentfactor, use_cache)
  361. new = [(root, currentfactor, k) for root in complex_part]
  362. complexes.extend(new)
  363. complexes = cls._complexes_sorted(complexes)
  364. return complexes
  365. @classmethod
  366. def _reals_sorted(cls, reals):
  367. """Make real isolating intervals disjoint and sort roots. """
  368. cache = {}
  369. for i, (u, f, k) in enumerate(reals):
  370. for j, (v, g, m) in enumerate(reals[i + 1:]):
  371. u, v = u.refine_disjoint(v)
  372. reals[i + j + 1] = (v, g, m)
  373. reals[i] = (u, f, k)
  374. reals = sorted(reals, key=lambda r: r[0].a)
  375. for root, currentfactor, _ in reals:
  376. if currentfactor in cache:
  377. cache[currentfactor].append(root)
  378. else:
  379. cache[currentfactor] = [root]
  380. for currentfactor, root in cache.items():
  381. _reals_cache[currentfactor] = root
  382. return reals
  383. @classmethod
  384. def _refine_imaginary(cls, complexes):
  385. sifted = sift(complexes, lambda c: c[1])
  386. complexes = []
  387. for f in ordered(sifted):
  388. nimag = _imag_count_of_factor(f)
  389. if nimag == 0:
  390. # refine until xbounds are neg or pos
  391. for u, f, k in sifted[f]:
  392. while u.ax*u.bx <= 0:
  393. u = u._inner_refine()
  394. complexes.append((u, f, k))
  395. else:
  396. # refine until all but nimag xbounds are neg or pos
  397. potential_imag = list(range(len(sifted[f])))
  398. while True:
  399. assert len(potential_imag) > 1
  400. for i in list(potential_imag):
  401. u, f, k = sifted[f][i]
  402. if u.ax*u.bx > 0:
  403. potential_imag.remove(i)
  404. elif u.ax != u.bx:
  405. u = u._inner_refine()
  406. sifted[f][i] = u, f, k
  407. if len(potential_imag) == nimag:
  408. break
  409. complexes.extend(sifted[f])
  410. return complexes
  411. @classmethod
  412. def _refine_complexes(cls, complexes):
  413. """return complexes such that no bounding rectangles of non-conjugate
  414. roots would intersect. In addition, assure that neither ay nor by is
  415. 0 to guarantee that non-real roots are distinct from real roots in
  416. terms of the y-bounds.
  417. """
  418. # get the intervals pairwise-disjoint.
  419. # If rectangles were drawn around the coordinates of the bounding
  420. # rectangles, no rectangles would intersect after this procedure.
  421. for i, (u, f, k) in enumerate(complexes):
  422. for j, (v, g, m) in enumerate(complexes[i + 1:]):
  423. u, v = u.refine_disjoint(v)
  424. complexes[i + j + 1] = (v, g, m)
  425. complexes[i] = (u, f, k)
  426. # refine until the x-bounds are unambiguously positive or negative
  427. # for non-imaginary roots
  428. complexes = cls._refine_imaginary(complexes)
  429. # make sure that all y bounds are off the real axis
  430. # and on the same side of the axis
  431. for i, (u, f, k) in enumerate(complexes):
  432. while u.ay*u.by <= 0:
  433. u = u.refine()
  434. complexes[i] = u, f, k
  435. return complexes
  436. @classmethod
  437. def _complexes_sorted(cls, complexes):
  438. """Make complex isolating intervals disjoint and sort roots. """
  439. complexes = cls._refine_complexes(complexes)
  440. # XXX don't sort until you are sure that it is compatible
  441. # with the indexing method but assert that the desired state
  442. # is not broken
  443. C, F = 0, 1 # location of ComplexInterval and factor
  444. fs = {i[F] for i in complexes}
  445. for i in range(1, len(complexes)):
  446. if complexes[i][F] != complexes[i - 1][F]:
  447. # if this fails the factors of a root were not
  448. # contiguous because a discontinuity should only
  449. # happen once
  450. fs.remove(complexes[i - 1][F])
  451. for i, cmplx in enumerate(complexes):
  452. # negative im part (conj=True) comes before
  453. # positive im part (conj=False)
  454. assert cmplx[C].conj is (i % 2 == 0)
  455. # update cache
  456. cache = {}
  457. # -- collate
  458. for root, currentfactor, _ in complexes:
  459. cache.setdefault(currentfactor, []).append(root)
  460. # -- store
  461. for currentfactor, root in cache.items():
  462. _complexes_cache[currentfactor] = root
  463. return complexes
  464. @classmethod
  465. def _reals_index(cls, reals, index):
  466. """
  467. Map initial real root index to an index in a factor where
  468. the root belongs.
  469. """
  470. i = 0
  471. for j, (_, currentfactor, k) in enumerate(reals):
  472. if index < i + k:
  473. poly, index = currentfactor, 0
  474. for _, currentfactor, _ in reals[:j]:
  475. if currentfactor == poly:
  476. index += 1
  477. return poly, index
  478. else:
  479. i += k
  480. @classmethod
  481. def _complexes_index(cls, complexes, index):
  482. """
  483. Map initial complex root index to an index in a factor where
  484. the root belongs.
  485. """
  486. i = 0
  487. for j, (_, currentfactor, k) in enumerate(complexes):
  488. if index < i + k:
  489. poly, index = currentfactor, 0
  490. for _, currentfactor, _ in complexes[:j]:
  491. if currentfactor == poly:
  492. index += 1
  493. index += len(_reals_cache[poly])
  494. return poly, index
  495. else:
  496. i += k
  497. @classmethod
  498. def _count_roots(cls, roots):
  499. """Count the number of real or complex roots with multiplicities."""
  500. return sum(k for _, _, k in roots)
  501. @classmethod
  502. def _indexed_root(cls, poly, index, lazy=False):
  503. """Get a root of a composite polynomial by index. """
  504. factors = _pure_factors(poly)
  505. # If the given poly is already irreducible, then the index does not
  506. # need to be adjusted, and we can postpone the heavy lifting of
  507. # computing and refining isolating intervals until that is needed.
  508. # Note, however, that `_pure_factors()` extracts a negative leading
  509. # coeff if present, so `factors[0][0]` may differ from `poly`, and
  510. # is the "normalized" version of `poly` that we must return.
  511. if lazy and len(factors) == 1 and factors[0][1] == 1:
  512. return factors[0][0], index
  513. reals = cls._get_reals(factors)
  514. reals_count = cls._count_roots(reals)
  515. if index < reals_count:
  516. return cls._reals_index(reals, index)
  517. else:
  518. complexes = cls._get_complexes(factors)
  519. return cls._complexes_index(complexes, index - reals_count)
  520. def _ensure_reals_init(self):
  521. """Ensure that our poly has entries in the reals cache. """
  522. if self.poly not in _reals_cache:
  523. self._indexed_root(self.poly, self.index)
  524. def _ensure_complexes_init(self):
  525. """Ensure that our poly has entries in the complexes cache. """
  526. if self.poly not in _complexes_cache:
  527. self._indexed_root(self.poly, self.index)
  528. @classmethod
  529. def _real_roots(cls, poly):
  530. """Get real roots of a composite polynomial. """
  531. factors = _pure_factors(poly)
  532. reals = cls._get_reals(factors)
  533. reals_count = cls._count_roots(reals)
  534. roots = []
  535. for index in range(0, reals_count):
  536. roots.append(cls._reals_index(reals, index))
  537. return roots
  538. def _reset(self):
  539. """
  540. Reset all intervals
  541. """
  542. self._all_roots(self.poly, use_cache=False)
  543. @classmethod
  544. def _all_roots(cls, poly, use_cache=True):
  545. """Get real and complex roots of a composite polynomial. """
  546. factors = _pure_factors(poly)
  547. reals = cls._get_reals(factors, use_cache=use_cache)
  548. reals_count = cls._count_roots(reals)
  549. roots = []
  550. for index in range(0, reals_count):
  551. roots.append(cls._reals_index(reals, index))
  552. complexes = cls._get_complexes(factors, use_cache=use_cache)
  553. complexes_count = cls._count_roots(complexes)
  554. for index in range(0, complexes_count):
  555. roots.append(cls._complexes_index(complexes, index))
  556. return roots
  557. @classmethod
  558. @cacheit
  559. def _roots_trivial(cls, poly, radicals):
  560. """Compute roots in linear, quadratic and binomial cases. """
  561. if poly.degree() == 1:
  562. return roots_linear(poly)
  563. if not radicals:
  564. return None
  565. if poly.degree() == 2:
  566. return roots_quadratic(poly)
  567. elif poly.length() == 2 and poly.TC():
  568. return roots_binomial(poly)
  569. else:
  570. return None
  571. @classmethod
  572. def _preprocess_roots(cls, poly):
  573. """Take heroic measures to make ``poly`` compatible with ``CRootOf``."""
  574. dom = poly.get_domain()
  575. if not dom.is_Exact:
  576. poly = poly.to_exact()
  577. coeff, poly = preprocess_roots(poly)
  578. dom = poly.get_domain()
  579. if not dom.is_ZZ:
  580. raise NotImplementedError(
  581. "sorted roots not supported over %s" % dom)
  582. return coeff, poly
  583. @classmethod
  584. def _postprocess_root(cls, root, radicals):
  585. """Return the root if it is trivial or a ``CRootOf`` object. """
  586. poly, index = root
  587. roots = cls._roots_trivial(poly, radicals)
  588. if roots is not None:
  589. return roots[index]
  590. else:
  591. return cls._new(poly, index)
  592. @classmethod
  593. def _get_roots(cls, method, poly, radicals):
  594. """Return postprocessed roots of specified kind. """
  595. if not poly.is_univariate:
  596. raise PolynomialError("only univariate polynomials are allowed")
  597. dom = poly.get_domain()
  598. # get rid of gen and it's free symbol
  599. d = Dummy()
  600. poly = poly.subs(poly.gen, d)
  601. x = symbols('x')
  602. # see what others are left and select x or a numbered x
  603. # that doesn't clash
  604. free_names = {str(i) for i in poly.free_symbols}
  605. for x in chain((symbols('x'),), numbered_symbols('x')):
  606. if x.name not in free_names:
  607. poly = poly.replace(d, x)
  608. break
  609. if dom.is_QQ or dom.is_ZZ:
  610. return cls._get_roots_qq(method, poly, radicals)
  611. elif dom.is_AlgebraicField or dom.is_ZZ_I or dom.is_QQ_I:
  612. return cls._get_roots_alg(method, poly, radicals)
  613. else:
  614. # XXX: not sure how to handle ZZ[x] which appears in some tests?
  615. # this makes the tests pass alright but has to be a better way?
  616. return cls._get_roots_qq(method, poly, radicals)
  617. @classmethod
  618. def _get_roots_qq(cls, method, poly, radicals):
  619. """Return postprocessed roots of specified kind
  620. for polynomials with rational coefficients. """
  621. coeff, poly = cls._preprocess_roots(poly)
  622. roots = []
  623. for root in getattr(cls, method)(poly):
  624. roots.append(coeff*cls._postprocess_root(root, radicals))
  625. return roots
  626. @classmethod
  627. def _get_roots_alg(cls, method, poly, radicals):
  628. """Return postprocessed roots of specified kind
  629. for polynomials with algebraic coefficients. It assumes
  630. the domain is already an algebraic field. First it
  631. finds the roots using _get_roots_qq, then uses the
  632. square-free factors to filter roots and get the correct
  633. multiplicity.
  634. """
  635. # Existing QQ code can find and sort the roots
  636. roots = cls._get_roots_qq(method, poly.lift(), radicals)
  637. subroots = {}
  638. for f, m in poly.sqf_list()[1]:
  639. if method == "_real_roots":
  640. roots_filt = f.which_real_roots(roots)
  641. elif method == "_all_roots":
  642. roots_filt = f.which_all_roots(roots)
  643. for r in roots_filt:
  644. subroots[r] = m
  645. roots_seen = set()
  646. roots_flat = []
  647. for r in roots:
  648. if r in subroots and r not in roots_seen:
  649. m = subroots[r]
  650. roots_flat.extend([r] * m)
  651. roots_seen.add(r)
  652. return roots_flat
  653. @classmethod
  654. def clear_cache(cls):
  655. """Reset cache for reals and complexes.
  656. The intervals used to approximate a root instance are updated
  657. as needed. When a request is made to see the intervals, the
  658. most current values are shown. `clear_cache` will reset all
  659. CRootOf instances back to their original state.
  660. See Also
  661. ========
  662. _reset
  663. """
  664. global _reals_cache, _complexes_cache
  665. _reals_cache = _pure_key_dict()
  666. _complexes_cache = _pure_key_dict()
  667. def _get_interval(self):
  668. """Internal function for retrieving isolation interval from cache. """
  669. self._ensure_reals_init()
  670. if self.is_real:
  671. return _reals_cache[self.poly][self.index]
  672. else:
  673. reals_count = len(_reals_cache[self.poly])
  674. self._ensure_complexes_init()
  675. return _complexes_cache[self.poly][self.index - reals_count]
  676. def _set_interval(self, interval):
  677. """Internal function for updating isolation interval in cache. """
  678. self._ensure_reals_init()
  679. if self.is_real:
  680. _reals_cache[self.poly][self.index] = interval
  681. else:
  682. reals_count = len(_reals_cache[self.poly])
  683. self._ensure_complexes_init()
  684. _complexes_cache[self.poly][self.index - reals_count] = interval
  685. def _eval_subs(self, old, new):
  686. # don't allow subs to change anything
  687. return self
  688. def _eval_conjugate(self):
  689. if self.is_real:
  690. return self
  691. expr, i = self.args
  692. return self.func(expr, i + (1 if self._get_interval().conj else -1))
  693. def eval_approx(self, n, return_mpmath=False):
  694. """Evaluate this complex root to the given precision.
  695. This uses secant method and root bounds are used to both
  696. generate an initial guess and to check that the root
  697. returned is valid. If ever the method converges outside the
  698. root bounds, the bounds will be made smaller and updated.
  699. """
  700. prec = dps_to_prec(n)
  701. with workprec(prec):
  702. g = self.poly.gen
  703. if not g.is_Symbol:
  704. d = Dummy('x')
  705. if self.is_imaginary:
  706. d *= I
  707. func = lambdify(d, self.expr.subs(g, d))
  708. else:
  709. expr = self.expr
  710. if self.is_imaginary:
  711. expr = self.expr.subs(g, I*g)
  712. func = lambdify(g, expr)
  713. interval = self._get_interval()
  714. while True:
  715. if self.is_real:
  716. a = mpf(str(interval.a))
  717. b = mpf(str(interval.b))
  718. if a == b:
  719. root = a
  720. break
  721. x0 = mpf(str(interval.center))
  722. x1 = x0 + mpf(str(interval.dx))/4
  723. elif self.is_imaginary:
  724. a = mpf(str(interval.ay))
  725. b = mpf(str(interval.by))
  726. if a == b:
  727. root = mpc(mpf('0'), a)
  728. break
  729. x0 = mpf(str(interval.center[1]))
  730. x1 = x0 + mpf(str(interval.dy))/4
  731. else:
  732. ax = mpf(str(interval.ax))
  733. bx = mpf(str(interval.bx))
  734. ay = mpf(str(interval.ay))
  735. by = mpf(str(interval.by))
  736. if ax == bx and ay == by:
  737. root = mpc(ax, ay)
  738. break
  739. x0 = mpc(*map(str, interval.center))
  740. x1 = x0 + mpc(*map(str, (interval.dx, interval.dy)))/4
  741. try:
  742. # without a tolerance, this will return when (to within
  743. # the given precision) x_i == x_{i-1}
  744. root = findroot(func, (x0, x1))
  745. # If the (real or complex) root is not in the 'interval',
  746. # then keep refining the interval. This happens if findroot
  747. # accidentally finds a different root outside of this
  748. # interval because our initial estimate 'x0' was not close
  749. # enough. It is also possible that the secant method will
  750. # get trapped by a max/min in the interval; the root
  751. # verification by findroot will raise a ValueError in this
  752. # case and the interval will then be tightened -- and
  753. # eventually the root will be found.
  754. #
  755. # It is also possible that findroot will not have any
  756. # successful iterations to process (in which case it
  757. # will fail to initialize a variable that is tested
  758. # after the iterations and raise an UnboundLocalError).
  759. if self.is_real or self.is_imaginary:
  760. if not bool(root.imag) == self.is_real and (
  761. a <= root <= b):
  762. if self.is_imaginary:
  763. root = mpc(mpf('0'), root.real)
  764. break
  765. elif (ax <= root.real <= bx and ay <= root.imag <= by):
  766. break
  767. except (UnboundLocalError, ValueError):
  768. pass
  769. interval = interval.refine()
  770. # update the interval so we at least (for this precision or
  771. # less) don't have much work to do to recompute the root
  772. self._set_interval(interval)
  773. if return_mpmath:
  774. return root
  775. return (Float._new(root.real._mpf_, prec) +
  776. I*Float._new(root.imag._mpf_, prec))
  777. def _eval_evalf(self, prec, **kwargs):
  778. """Evaluate this complex root to the given precision."""
  779. # all kwargs are ignored
  780. return self.eval_rational(n=prec_to_dps(prec))._evalf(prec)
  781. def eval_rational(self, dx=None, dy=None, n=15):
  782. """
  783. Return a Rational approximation of ``self`` that has real
  784. and imaginary component approximations that are within ``dx``
  785. and ``dy`` of the true values, respectively. Alternatively,
  786. ``n`` digits of precision can be specified.
  787. The interval is refined with bisection and is sure to
  788. converge. The root bounds are updated when the refinement
  789. is complete so recalculation at the same or lesser precision
  790. will not have to repeat the refinement and should be much
  791. faster.
  792. The following example first obtains Rational approximation to
  793. 1e-8 accuracy for all roots of the 4-th order Legendre
  794. polynomial. Since the roots are all less than 1, this will
  795. ensure the decimal representation of the approximation will be
  796. correct (including rounding) to 6 digits:
  797. >>> from sympy import legendre_poly, Symbol
  798. >>> x = Symbol("x")
  799. >>> p = legendre_poly(4, x, polys=True)
  800. >>> r = p.real_roots()[-1]
  801. >>> r.eval_rational(10**-8).n(6)
  802. 0.861136
  803. It is not necessary to a two-step calculation, however: the
  804. decimal representation can be computed directly:
  805. >>> r.evalf(17)
  806. 0.86113631159405258
  807. """
  808. dy = dy or dx
  809. if dx:
  810. rtol = None
  811. dx = dx if isinstance(dx, Rational) else Rational(str(dx))
  812. dy = dy if isinstance(dy, Rational) else Rational(str(dy))
  813. else:
  814. # 5 binary (or 2 decimal) digits are needed to ensure that
  815. # a given digit is correctly rounded
  816. # prec_to_dps(dps_to_prec(n) + 5) - n <= 2 (tested for
  817. # n in range(1000000)
  818. rtol = S(10)**-(n + 2) # +2 for guard digits
  819. interval = self._get_interval()
  820. while True:
  821. if self.is_real:
  822. if rtol:
  823. dx = abs(interval.center*rtol)
  824. interval = interval.refine_size(dx=dx)
  825. c = interval.center
  826. real = Rational(c)
  827. imag = S.Zero
  828. if not rtol or interval.dx < abs(c*rtol):
  829. break
  830. elif self.is_imaginary:
  831. if rtol:
  832. dy = abs(interval.center[1]*rtol)
  833. dx = 1
  834. interval = interval.refine_size(dx=dx, dy=dy)
  835. c = interval.center[1]
  836. imag = Rational(c)
  837. real = S.Zero
  838. if not rtol or interval.dy < abs(c*rtol):
  839. break
  840. else:
  841. if rtol:
  842. dx = abs(interval.center[0]*rtol)
  843. dy = abs(interval.center[1]*rtol)
  844. interval = interval.refine_size(dx, dy)
  845. c = interval.center
  846. real, imag = map(Rational, c)
  847. if not rtol or (
  848. interval.dx < abs(c[0]*rtol) and
  849. interval.dy < abs(c[1]*rtol)):
  850. break
  851. # update the interval so we at least (for this precision or
  852. # less) don't have much work to do to recompute the root
  853. self._set_interval(interval)
  854. return real + I*imag
  855. CRootOf = ComplexRootOf
  856. @dispatch(ComplexRootOf, ComplexRootOf)
  857. def _eval_is_eq(lhs, rhs): # noqa:F811
  858. # if we use is_eq to check here, we get infinite recursion
  859. return lhs == rhs
  860. @dispatch(ComplexRootOf, Basic) # type:ignore
  861. def _eval_is_eq(lhs, rhs): # noqa:F811
  862. # CRootOf represents a Root, so if rhs is that root, it should set
  863. # the expression to zero *and* it should be in the interval of the
  864. # CRootOf instance. It must also be a number that agrees with the
  865. # is_real value of the CRootOf instance.
  866. if not rhs.is_number:
  867. return None
  868. if not rhs.is_finite:
  869. return False
  870. z = lhs.expr.subs(lhs.expr.free_symbols.pop(), rhs).is_zero
  871. if z is False: # all roots will make z True but we don't know
  872. # whether this is the right root if z is True
  873. return False
  874. o = rhs.is_real, rhs.is_imaginary
  875. s = lhs.is_real, lhs.is_imaginary
  876. assert None not in s # this is part of initial refinement
  877. if o != s and None not in o:
  878. return False
  879. re, im = rhs.as_real_imag()
  880. if lhs.is_real:
  881. if im:
  882. return False
  883. i = lhs._get_interval()
  884. a, b = [Rational(str(_)) for _ in (i.a, i.b)]
  885. return sympify(a <= rhs and rhs <= b)
  886. i = lhs._get_interval()
  887. r1, r2, i1, i2 = [Rational(str(j)) for j in (
  888. i.ax, i.bx, i.ay, i.by)]
  889. return is_le(r1, re) and is_le(re,r2) and is_le(i1,im) and is_le(im,i2)
  890. @public
  891. class RootSum(Expr):
  892. """Represents a sum of all roots of a univariate polynomial. """
  893. __slots__ = ('poly', 'fun', 'auto')
  894. def __new__(cls, expr, func=None, x=None, auto=True, quadratic=False):
  895. """Construct a new ``RootSum`` instance of roots of a polynomial."""
  896. coeff, poly = cls._transform(expr, x)
  897. if not poly.is_univariate:
  898. raise MultivariatePolynomialError(
  899. "only univariate polynomials are allowed")
  900. if func is None:
  901. func = Lambda(poly.gen, poly.gen)
  902. else:
  903. is_func = getattr(func, 'is_Function', False)
  904. if is_func and 1 in func.nargs:
  905. if not isinstance(func, Lambda):
  906. func = Lambda(poly.gen, func(poly.gen))
  907. else:
  908. raise ValueError(
  909. "expected a univariate function, got %s" % func)
  910. var, expr = func.variables[0], func.expr
  911. if coeff is not S.One:
  912. expr = expr.subs(var, coeff*var)
  913. deg = poly.degree()
  914. if not expr.has(var):
  915. return deg*expr
  916. if expr.is_Add:
  917. add_const, expr = expr.as_independent(var)
  918. else:
  919. add_const = S.Zero
  920. if expr.is_Mul:
  921. mul_const, expr = expr.as_independent(var)
  922. else:
  923. mul_const = S.One
  924. func = Lambda(var, expr)
  925. rational = cls._is_func_rational(poly, func)
  926. factors, terms = _pure_factors(poly), []
  927. for poly, k in factors:
  928. if poly.is_linear:
  929. term = func(roots_linear(poly)[0])
  930. elif quadratic and poly.is_quadratic:
  931. term = sum(map(func, roots_quadratic(poly)))
  932. else:
  933. if not rational or not auto:
  934. term = cls._new(poly, func, auto)
  935. else:
  936. term = cls._rational_case(poly, func)
  937. terms.append(k*term)
  938. return mul_const*Add(*terms) + deg*add_const
  939. @classmethod
  940. def _new(cls, poly, func, auto=True):
  941. """Construct new raw ``RootSum`` instance. """
  942. obj = Expr.__new__(cls)
  943. obj.poly = poly
  944. obj.fun = func
  945. obj.auto = auto
  946. return obj
  947. @classmethod
  948. def new(cls, poly, func, auto=True):
  949. """Construct new ``RootSum`` instance. """
  950. if not func.expr.has(*func.variables):
  951. return func.expr
  952. rational = cls._is_func_rational(poly, func)
  953. if not rational or not auto:
  954. return cls._new(poly, func, auto)
  955. else:
  956. return cls._rational_case(poly, func)
  957. @classmethod
  958. def _transform(cls, expr, x):
  959. """Transform an expression to a polynomial. """
  960. poly = PurePoly(expr, x, greedy=False)
  961. return preprocess_roots(poly)
  962. @classmethod
  963. def _is_func_rational(cls, poly, func):
  964. """Check if a lambda is a rational function. """
  965. var, expr = func.variables[0], func.expr
  966. return expr.is_rational_function(var)
  967. @classmethod
  968. def _rational_case(cls, poly, func):
  969. """Handle the rational function case. """
  970. roots = symbols('r:%d' % poly.degree())
  971. var, expr = func.variables[0], func.expr
  972. f = sum(expr.subs(var, r) for r in roots)
  973. p, q = together(f).as_numer_denom()
  974. domain = QQ[roots]
  975. p = p.expand()
  976. q = q.expand()
  977. try:
  978. p = Poly(p, domain=domain, expand=False)
  979. except GeneratorsNeeded:
  980. p, p_coeff = None, (p,)
  981. else:
  982. p_monom, p_coeff = zip(*p.terms())
  983. try:
  984. q = Poly(q, domain=domain, expand=False)
  985. except GeneratorsNeeded:
  986. q, q_coeff = None, (q,)
  987. else:
  988. q_monom, q_coeff = zip(*q.terms())
  989. coeffs, mapping = symmetrize(p_coeff + q_coeff, formal=True)
  990. formulas, values = viete(poly, roots), []
  991. for (sym, _), (_, val) in zip(mapping, formulas):
  992. values.append((sym, val))
  993. for i, (coeff, _) in enumerate(coeffs):
  994. coeffs[i] = coeff.subs(values)
  995. n = len(p_coeff)
  996. p_coeff = coeffs[:n]
  997. q_coeff = coeffs[n:]
  998. if p is not None:
  999. p = Poly(dict(zip(p_monom, p_coeff)), *p.gens).as_expr()
  1000. else:
  1001. (p,) = p_coeff
  1002. if q is not None:
  1003. q = Poly(dict(zip(q_monom, q_coeff)), *q.gens).as_expr()
  1004. else:
  1005. (q,) = q_coeff
  1006. return factor(p/q)
  1007. def _hashable_content(self):
  1008. return (self.poly, self.fun)
  1009. @property
  1010. def expr(self):
  1011. return self.poly.as_expr()
  1012. @property
  1013. def args(self):
  1014. return (self.expr, self.fun, self.poly.gen)
  1015. @property
  1016. def free_symbols(self):
  1017. return self.poly.free_symbols | self.fun.free_symbols
  1018. @property
  1019. def is_commutative(self):
  1020. return True
  1021. def doit(self, **hints):
  1022. if not hints.get('roots', True):
  1023. return self
  1024. _roots = roots(self.poly, multiple=True)
  1025. if len(_roots) < self.poly.degree():
  1026. return self
  1027. else:
  1028. return Add(*[self.fun(r) for r in _roots])
  1029. def _eval_evalf(self, prec):
  1030. try:
  1031. _roots = self.poly.nroots(n=prec_to_dps(prec))
  1032. except (DomainError, PolynomialError):
  1033. return self
  1034. else:
  1035. return Add(*[self.fun(r) for r in _roots])
  1036. def _eval_derivative(self, x):
  1037. var, expr = self.fun.args
  1038. func = Lambda(var, expr.diff(x))
  1039. return self.new(self.poly, func, self.auto)