series.py 94 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591
  1. ### The base class for all series
  2. from collections.abc import Callable
  3. from sympy.calculus.util import continuous_domain
  4. from sympy.concrete import Sum, Product
  5. from sympy.core.containers import Tuple
  6. from sympy.core.expr import Expr
  7. from sympy.core.function import arity
  8. from sympy.core.sorting import default_sort_key
  9. from sympy.core.symbol import Symbol
  10. from sympy.functions import atan2, zeta, frac, ceiling, floor, im
  11. from sympy.core.relational import (Equality, GreaterThan,
  12. LessThan, Relational, Ne)
  13. from sympy.core.sympify import sympify
  14. from sympy.external import import_module
  15. from sympy.logic.boolalg import BooleanFunction
  16. from sympy.plotting.utils import _get_free_symbols, extract_solution
  17. from sympy.printing.latex import latex
  18. from sympy.printing.pycode import PythonCodePrinter
  19. from sympy.printing.precedence import precedence
  20. from sympy.sets.sets import Set, Interval, Union
  21. from sympy.simplify.simplify import nsimplify
  22. from sympy.utilities.exceptions import sympy_deprecation_warning
  23. from sympy.utilities.lambdify import lambdify
  24. from .intervalmath import interval
  25. import warnings
  26. class IntervalMathPrinter(PythonCodePrinter):
  27. """A printer to be used inside `plot_implicit` when `adaptive=True`,
  28. in which case the interval arithmetic module is going to be used, which
  29. requires the following edits.
  30. """
  31. def _print_And(self, expr):
  32. PREC = precedence(expr)
  33. return " & ".join(self.parenthesize(a, PREC)
  34. for a in sorted(expr.args, key=default_sort_key))
  35. def _print_Or(self, expr):
  36. PREC = precedence(expr)
  37. return " | ".join(self.parenthesize(a, PREC)
  38. for a in sorted(expr.args, key=default_sort_key))
  39. def _uniform_eval(f1, f2, *args, modules=None,
  40. force_real_eval=False, has_sum=False):
  41. """
  42. Note: this is an experimental function, as such it is prone to changes.
  43. Please, do not use it in your code.
  44. """
  45. np = import_module('numpy')
  46. def wrapper_func(func, *args):
  47. try:
  48. return complex(func(*args))
  49. except (ZeroDivisionError, OverflowError):
  50. return complex(np.nan, np.nan)
  51. # NOTE: np.vectorize is much slower than numpy vectorized operations.
  52. # However, this modules must be able to evaluate functions also with
  53. # mpmath or sympy.
  54. wrapper_func = np.vectorize(wrapper_func, otypes=[complex])
  55. def _eval_with_sympy(err=None):
  56. if f2 is None:
  57. msg = "Impossible to evaluate the provided numerical function"
  58. if err is None:
  59. msg += "."
  60. else:
  61. msg += "because the following exception was raised:\n"
  62. "{}: {}".format(type(err).__name__, err)
  63. raise RuntimeError(msg)
  64. if err:
  65. warnings.warn(
  66. "The evaluation with %s failed.\n" % (
  67. "NumPy/SciPy" if not modules else modules) +
  68. "{}: {}\n".format(type(err).__name__, err) +
  69. "Trying to evaluate the expression with Sympy, but it might "
  70. "be a slow operation."
  71. )
  72. return wrapper_func(f2, *args)
  73. if modules == "sympy":
  74. return _eval_with_sympy()
  75. try:
  76. return wrapper_func(f1, *args)
  77. except Exception as err:
  78. return _eval_with_sympy(err)
  79. def _adaptive_eval(f, x):
  80. """Evaluate f(x) with an adaptive algorithm. Post-process the result.
  81. If a symbolic expression is evaluated with SymPy, it might returns
  82. another symbolic expression, containing additions, ...
  83. Force evaluation to a float.
  84. Parameters
  85. ==========
  86. f : callable
  87. x : float
  88. """
  89. np = import_module('numpy')
  90. y = f(x)
  91. if isinstance(y, Expr) and (not y.is_Number):
  92. y = y.evalf()
  93. y = complex(y)
  94. if y.imag > 1e-08:
  95. return np.nan
  96. return y.real
  97. def _get_wrapper_for_expr(ret):
  98. wrapper = "%s"
  99. if ret == "real":
  100. wrapper = "re(%s)"
  101. elif ret == "imag":
  102. wrapper = "im(%s)"
  103. elif ret == "abs":
  104. wrapper = "abs(%s)"
  105. elif ret == "arg":
  106. wrapper = "arg(%s)"
  107. return wrapper
  108. class BaseSeries:
  109. """Base class for the data objects containing stuff to be plotted.
  110. Notes
  111. =====
  112. The backend should check if it supports the data series that is given.
  113. (e.g. TextBackend supports only LineOver1DRangeSeries).
  114. It is the backend responsibility to know how to use the class of
  115. data series that is given.
  116. Some data series classes are grouped (using a class attribute like is_2Dline)
  117. according to the api they present (based only on convention). The backend is
  118. not obliged to use that api (e.g. LineOver1DRangeSeries belongs to the
  119. is_2Dline group and presents the get_points method, but the
  120. TextBackend does not use the get_points method).
  121. BaseSeries
  122. """
  123. # Some flags follow. The rationale for using flags instead of checking base
  124. # classes is that setting multiple flags is simpler than multiple
  125. # inheritance.
  126. is_2Dline = False
  127. # Some of the backends expect:
  128. # - get_points returning 1D np.arrays list_x, list_y
  129. # - get_color_array returning 1D np.array (done in Line2DBaseSeries)
  130. # with the colors calculated at the points from get_points
  131. is_3Dline = False
  132. # Some of the backends expect:
  133. # - get_points returning 1D np.arrays list_x, list_y, list_y
  134. # - get_color_array returning 1D np.array (done in Line2DBaseSeries)
  135. # with the colors calculated at the points from get_points
  136. is_3Dsurface = False
  137. # Some of the backends expect:
  138. # - get_meshes returning mesh_x, mesh_y, mesh_z (2D np.arrays)
  139. # - get_points an alias for get_meshes
  140. is_contour = False
  141. # Some of the backends expect:
  142. # - get_meshes returning mesh_x, mesh_y, mesh_z (2D np.arrays)
  143. # - get_points an alias for get_meshes
  144. is_implicit = False
  145. # Some of the backends expect:
  146. # - get_meshes returning mesh_x (1D array), mesh_y(1D array,
  147. # mesh_z (2D np.arrays)
  148. # - get_points an alias for get_meshes
  149. # Different from is_contour as the colormap in backend will be
  150. # different
  151. is_interactive = False
  152. # An interactive series can update its data.
  153. is_parametric = False
  154. # The calculation of aesthetics expects:
  155. # - get_parameter_points returning one or two np.arrays (1D or 2D)
  156. # used for calculation aesthetics
  157. is_generic = False
  158. # Represent generic user-provided numerical data
  159. is_vector = False
  160. is_2Dvector = False
  161. is_3Dvector = False
  162. # Represents a 2D or 3D vector data series
  163. _N = 100
  164. # default number of discretization points for uniform sampling. Each
  165. # subclass can set its number.
  166. def __init__(self, *args, **kwargs):
  167. kwargs = _set_discretization_points(kwargs.copy(), type(self))
  168. # discretize the domain using only integer numbers
  169. self.only_integers = kwargs.get("only_integers", False)
  170. # represents the evaluation modules to be used by lambdify
  171. self.modules = kwargs.get("modules", None)
  172. # plot functions might create data series that might not be useful to
  173. # be shown on the legend, for example wireframe lines on 3D plots.
  174. self.show_in_legend = kwargs.get("show_in_legend", True)
  175. # line and surface series can show data with a colormap, hence a
  176. # colorbar is essential to understand the data. However, sometime it
  177. # is useful to hide it on series-by-series base. The following keyword
  178. # controls whether the series should show a colorbar or not.
  179. self.colorbar = kwargs.get("colorbar", True)
  180. # Some series might use a colormap as default coloring. Setting this
  181. # attribute to False will inform the backends to use solid color.
  182. self.use_cm = kwargs.get("use_cm", False)
  183. # If True, the backend will attempt to render it on a polar-projection
  184. # axis, or using a polar discretization if a 3D plot is requested
  185. self.is_polar = kwargs.get("is_polar", kwargs.get("polar", False))
  186. # If True, the rendering will use points, not lines.
  187. self.is_point = kwargs.get("is_point", kwargs.get("point", False))
  188. # some backend is able to render latex, other needs standard text
  189. self._label = self._latex_label = ""
  190. self._ranges = []
  191. self._n = [
  192. int(kwargs.get("n1", self._N)),
  193. int(kwargs.get("n2", self._N)),
  194. int(kwargs.get("n3", self._N))
  195. ]
  196. self._scales = [
  197. kwargs.get("xscale", "linear"),
  198. kwargs.get("yscale", "linear"),
  199. kwargs.get("zscale", "linear")
  200. ]
  201. # enable interactive widget plots
  202. self._params = kwargs.get("params", {})
  203. if not isinstance(self._params, dict):
  204. raise TypeError("`params` must be a dictionary mapping symbols "
  205. "to numeric values.")
  206. if len(self._params) > 0:
  207. self.is_interactive = True
  208. # contains keyword arguments that will be passed to the rendering
  209. # function of the chosen plotting library
  210. self.rendering_kw = kwargs.get("rendering_kw", {})
  211. # numerical transformation functions to be applied to the output data:
  212. # x, y, z (coordinates), p (parameter on parametric plots)
  213. self._tx = kwargs.get("tx", None)
  214. self._ty = kwargs.get("ty", None)
  215. self._tz = kwargs.get("tz", None)
  216. self._tp = kwargs.get("tp", None)
  217. if not all(callable(t) or (t is None) for t in
  218. [self._tx, self._ty, self._tz, self._tp]):
  219. raise TypeError("`tx`, `ty`, `tz`, `tp` must be functions.")
  220. # list of numerical functions representing the expressions to evaluate
  221. self._functions = []
  222. # signature for the numerical functions
  223. self._signature = []
  224. # some expressions don't like to be evaluated over complex data.
  225. # if that's the case, set this to True
  226. self._force_real_eval = kwargs.get("force_real_eval", None)
  227. # this attribute will eventually contain a dictionary with the
  228. # discretized ranges
  229. self._discretized_domain = None
  230. # whether the series contains any interactive range, which is a range
  231. # where the minimum and maximum values can be changed with an
  232. # interactive widget
  233. self._interactive_ranges = False
  234. # NOTE: consider a generic summation, for example:
  235. # s = Sum(cos(pi * x), (x, 1, y))
  236. # This gets lambdified to something:
  237. # sum(cos(pi*x) for x in range(1, y+1))
  238. # Hence, y needs to be an integer, otherwise it raises:
  239. # TypeError: 'complex' object cannot be interpreted as an integer
  240. # This list will contains symbols that are upper bound to summations
  241. # or products
  242. self._needs_to_be_int = []
  243. # a color function will be responsible to set the line/surface color
  244. # according to some logic. Each data series will et an appropriate
  245. # default value.
  246. self.color_func = None
  247. # NOTE: color_func usually receives numerical functions that are going
  248. # to be evaluated over the coordinates of the computed points (or the
  249. # discretized meshes).
  250. # However, if an expression is given to color_func, then it will be
  251. # lambdified with symbols in self._signature, and it will be evaluated
  252. # with the same data used to evaluate the plotted expression.
  253. self._eval_color_func_with_signature = False
  254. def _block_lambda_functions(self, *exprs):
  255. """Some data series can be used to plot numerical functions, others
  256. cannot. Execute this method inside the `__init__` to prevent the
  257. processing of numerical functions.
  258. """
  259. if any(callable(e) for e in exprs):
  260. raise TypeError(type(self).__name__ + " requires a symbolic "
  261. "expression.")
  262. def _check_fs(self):
  263. """ Checks if there are enough parameters and free symbols.
  264. """
  265. exprs, ranges = self.expr, self.ranges
  266. params, label = self.params, self.label
  267. exprs = exprs if hasattr(exprs, "__iter__") else [exprs]
  268. if any(callable(e) for e in exprs):
  269. return
  270. # from the expression's free symbols, remove the ones used in
  271. # the parameters and the ranges
  272. fs = _get_free_symbols(exprs)
  273. fs = fs.difference(params.keys())
  274. if ranges is not None:
  275. fs = fs.difference([r[0] for r in ranges])
  276. if len(fs) > 0:
  277. raise ValueError(
  278. "Incompatible expression and parameters.\n"
  279. + "Expression: {}\n".format(
  280. (exprs, ranges, label) if ranges is not None else (exprs, label))
  281. + "params: {}\n".format(params)
  282. + "Specify what these symbols represent: {}\n".format(fs)
  283. + "Are they ranges or parameters?"
  284. )
  285. # verify that all symbols are known (they either represent plotting
  286. # ranges or parameters)
  287. range_symbols = [r[0] for r in ranges]
  288. for r in ranges:
  289. fs = set().union(*[e.free_symbols for e in r[1:]])
  290. if any(t in fs for t in range_symbols):
  291. # ranges can't depend on each other, for example this are
  292. # not allowed:
  293. # (x, 0, y), (y, 0, 3)
  294. # (x, 0, y), (y, x + 2, 3)
  295. raise ValueError("Range symbols can't be included into "
  296. "minimum and maximum of a range. "
  297. "Received range: %s" % str(r))
  298. if len(fs) > 0:
  299. self._interactive_ranges = True
  300. remaining_fs = fs.difference(params.keys())
  301. if len(remaining_fs) > 0:
  302. raise ValueError(
  303. "Unknown symbols found in plotting range: %s. " % (r,) +
  304. "Are the following parameters? %s" % remaining_fs)
  305. def _create_lambda_func(self):
  306. """Create the lambda functions to be used by the uniform meshing
  307. strategy.
  308. Notes
  309. =====
  310. The old sympy.plotting used experimental_lambdify. It created one
  311. lambda function each time an evaluation was requested. If that failed,
  312. it went on to create a different lambda function and evaluated it,
  313. and so on.
  314. This new module changes strategy: it creates right away the default
  315. lambda function as well as the backup one. The reason is that the
  316. series could be interactive, hence the numerical function will be
  317. evaluated multiple times. So, let's create the functions just once.
  318. This approach works fine for the majority of cases, in which the
  319. symbolic expression is relatively short, hence the lambdification
  320. is fast. If the expression is very long, this approach takes twice
  321. the time to create the lambda functions. Be aware of that!
  322. """
  323. exprs = self.expr if hasattr(self.expr, "__iter__") else [self.expr]
  324. if not any(callable(e) for e in exprs):
  325. fs = _get_free_symbols(exprs)
  326. self._signature = sorted(fs, key=lambda t: t.name)
  327. # Generate a list of lambda functions, two for each expression:
  328. # 1. the default one.
  329. # 2. the backup one, in case of failures with the default one.
  330. self._functions = []
  331. for e in exprs:
  332. # TODO: set cse=True once this issue is solved:
  333. # https://github.com/sympy/sympy/issues/24246
  334. self._functions.append([
  335. lambdify(self._signature, e, modules=self.modules),
  336. lambdify(self._signature, e, modules="sympy", dummify=True),
  337. ])
  338. else:
  339. self._signature = sorted([r[0] for r in self.ranges], key=lambda t: t.name)
  340. self._functions = [(e, None) for e in exprs]
  341. # deal with symbolic color_func
  342. if isinstance(self.color_func, Expr):
  343. self.color_func = lambdify(self._signature, self.color_func)
  344. self._eval_color_func_with_signature = True
  345. def _update_range_value(self, t):
  346. """If the value of a plotting range is a symbolic expression,
  347. substitute the parameters in order to get a numerical value.
  348. """
  349. if not self._interactive_ranges:
  350. return complex(t)
  351. return complex(t.subs(self.params))
  352. def _create_discretized_domain(self):
  353. """Discretize the ranges for uniform meshing strategy.
  354. """
  355. # NOTE: the goal is to create a dictionary stored in
  356. # self._discretized_domain, mapping symbols to a numpy array
  357. # representing the discretization
  358. discr_symbols = []
  359. discretizations = []
  360. # create a 1D discretization
  361. for i, r in enumerate(self.ranges):
  362. discr_symbols.append(r[0])
  363. c_start = self._update_range_value(r[1])
  364. c_end = self._update_range_value(r[2])
  365. start = c_start.real if c_start.imag == c_end.imag == 0 else c_start
  366. end = c_end.real if c_start.imag == c_end.imag == 0 else c_end
  367. needs_integer_discr = self.only_integers or (r[0] in self._needs_to_be_int)
  368. d = BaseSeries._discretize(start, end, self.n[i],
  369. scale=self.scales[i],
  370. only_integers=needs_integer_discr)
  371. if ((not self._force_real_eval) and (not needs_integer_discr) and
  372. (d.dtype != "complex")):
  373. d = d + 1j * c_start.imag
  374. if needs_integer_discr:
  375. d = d.astype(int)
  376. discretizations.append(d)
  377. # create 2D or 3D
  378. self._create_discretized_domain_helper(discr_symbols, discretizations)
  379. def _create_discretized_domain_helper(self, discr_symbols, discretizations):
  380. """Create 2D or 3D discretized grids.
  381. Subclasses should override this method in order to implement a
  382. different behaviour.
  383. """
  384. np = import_module('numpy')
  385. # discretization suitable for 2D line plots, 3D surface plots,
  386. # contours plots, vector plots
  387. # NOTE: why indexing='ij'? Because it produces consistent results with
  388. # np.mgrid. This is important as Mayavi requires this indexing
  389. # to correctly compute 3D streamlines. While VTK is able to compute
  390. # streamlines regardless of the indexing, with indexing='xy' it
  391. # produces "strange" results with "voids" into the
  392. # discretization volume. indexing='ij' solves the problem.
  393. # Also note that matplotlib 2D streamlines requires indexing='xy'.
  394. indexing = "xy"
  395. if self.is_3Dvector or (self.is_3Dsurface and self.is_implicit):
  396. indexing = "ij"
  397. meshes = np.meshgrid(*discretizations, indexing=indexing)
  398. self._discretized_domain = dict(zip(discr_symbols, meshes))
  399. def _evaluate(self, cast_to_real=True):
  400. """Evaluation of the symbolic expression (or expressions) with the
  401. uniform meshing strategy, based on current values of the parameters.
  402. """
  403. np = import_module('numpy')
  404. # create lambda functions
  405. if not self._functions:
  406. self._create_lambda_func()
  407. # create (or update) the discretized domain
  408. if (not self._discretized_domain) or self._interactive_ranges:
  409. self._create_discretized_domain()
  410. # ensure that discretized domains are returned with the proper order
  411. discr = [self._discretized_domain[s[0]] for s in self.ranges]
  412. args = self._aggregate_args()
  413. results = []
  414. for f in self._functions:
  415. r = _uniform_eval(*f, *args)
  416. # the evaluation might produce an int/float. Need this correction.
  417. r = self._correct_shape(np.array(r), discr[0])
  418. # sometime the evaluation is performed over arrays of type object.
  419. # hence, `result` might be of type object, which don't work well
  420. # with numpy real and imag functions.
  421. r = r.astype(complex)
  422. results.append(r)
  423. if cast_to_real:
  424. discr = [np.real(d.astype(complex)) for d in discr]
  425. return [*discr, *results]
  426. def _aggregate_args(self):
  427. """Create a list of arguments to be passed to the lambda function,
  428. sorted according to self._signature.
  429. """
  430. args = []
  431. for s in self._signature:
  432. if s in self._params.keys():
  433. args.append(
  434. int(self._params[s]) if s in self._needs_to_be_int else
  435. self._params[s] if self._force_real_eval
  436. else complex(self._params[s]))
  437. else:
  438. args.append(self._discretized_domain[s])
  439. return args
  440. @property
  441. def expr(self):
  442. """Return the expression (or expressions) of the series."""
  443. return self._expr
  444. @expr.setter
  445. def expr(self, e):
  446. """Set the expression (or expressions) of the series."""
  447. is_iter = hasattr(e, "__iter__")
  448. is_callable = callable(e) if not is_iter else any(callable(t) for t in e)
  449. if is_callable:
  450. self._expr = e
  451. else:
  452. self._expr = sympify(e) if not is_iter else Tuple(*e)
  453. # look for the upper bound of summations and products
  454. s = set()
  455. for e in self._expr.atoms(Sum, Product):
  456. for a in e.args[1:]:
  457. if isinstance(a[-1], Symbol):
  458. s.add(a[-1])
  459. self._needs_to_be_int = list(s)
  460. # list of sympy functions that when lambdified, the corresponding
  461. # numpy functions don't like complex-type arguments
  462. pf = [ceiling, floor, atan2, frac, zeta]
  463. if self._force_real_eval is not True:
  464. check_res = [self._expr.has(f) for f in pf]
  465. self._force_real_eval = any(check_res)
  466. if self._force_real_eval and ((self.modules is None) or
  467. (isinstance(self.modules, str) and "numpy" in self.modules)):
  468. funcs = [f for f, c in zip(pf, check_res) if c]
  469. warnings.warn("NumPy is unable to evaluate with complex "
  470. "numbers some of the functions included in this "
  471. "symbolic expression: %s. " % funcs +
  472. "Hence, the evaluation will use real numbers. "
  473. "If you believe the resulting plot is incorrect, "
  474. "change the evaluation module by setting the "
  475. "`modules` keyword argument.")
  476. if self._functions:
  477. # update lambda functions
  478. self._create_lambda_func()
  479. @property
  480. def is_3D(self):
  481. flags3D = [self.is_3Dline, self.is_3Dsurface, self.is_3Dvector]
  482. return any(flags3D)
  483. @property
  484. def is_line(self):
  485. flagslines = [self.is_2Dline, self.is_3Dline]
  486. return any(flagslines)
  487. def _line_surface_color(self, prop, val):
  488. """This method enables back-compatibility with old sympy.plotting"""
  489. # NOTE: color_func is set inside the init method of the series.
  490. # If line_color/surface_color is not a callable, then color_func will
  491. # be set to None.
  492. setattr(self, prop, val)
  493. if callable(val) or isinstance(val, Expr):
  494. self.color_func = val
  495. setattr(self, prop, None)
  496. elif val is not None:
  497. self.color_func = None
  498. @property
  499. def line_color(self):
  500. return self._line_color
  501. @line_color.setter
  502. def line_color(self, val):
  503. self._line_surface_color("_line_color", val)
  504. @property
  505. def n(self):
  506. """Returns a list [n1, n2, n3] of numbers of discratization points.
  507. """
  508. return self._n
  509. @n.setter
  510. def n(self, v):
  511. """Set the numbers of discretization points. ``v`` must be an int or
  512. a list.
  513. Let ``s`` be a series. Then:
  514. * to set the number of discretization points along the x direction (or
  515. first parameter): ``s.n = 10``
  516. * to set the number of discretization points along the x and y
  517. directions (or first and second parameters): ``s.n = [10, 15]``
  518. * to set the number of discretization points along the x, y and z
  519. directions: ``s.n = [10, 15, 20]``
  520. The following is highly unreccomended, because it prevents
  521. the execution of necessary code in order to keep updated data:
  522. ``s.n[1] = 15``
  523. """
  524. if not hasattr(v, "__iter__"):
  525. self._n[0] = v
  526. else:
  527. self._n[:len(v)] = v
  528. if self._discretized_domain:
  529. # update the discretized domain
  530. self._create_discretized_domain()
  531. @property
  532. def params(self):
  533. """Get or set the current parameters dictionary.
  534. Parameters
  535. ==========
  536. p : dict
  537. * key: symbol associated to the parameter
  538. * val: the numeric value
  539. """
  540. return self._params
  541. @params.setter
  542. def params(self, p):
  543. self._params = p
  544. def _post_init(self):
  545. exprs = self.expr if hasattr(self.expr, "__iter__") else [self.expr]
  546. if any(callable(e) for e in exprs) and self.params:
  547. raise TypeError("`params` was provided, hence an interactive plot "
  548. "is expected. However, interactive plots do not support "
  549. "user-provided numerical functions.")
  550. # if the expressions is a lambda function and no label has been
  551. # provided, then its better to do the following in order to avoid
  552. # surprises on the backend
  553. if any(callable(e) for e in exprs):
  554. if self._label == str(self.expr):
  555. self.label = ""
  556. self._check_fs()
  557. if hasattr(self, "adaptive") and self.adaptive and self.params:
  558. warnings.warn("`params` was provided, hence an interactive plot "
  559. "is expected. However, interactive plots do not support "
  560. "adaptive evaluation. Automatically switched to "
  561. "adaptive=False.")
  562. self.adaptive = False
  563. @property
  564. def scales(self):
  565. return self._scales
  566. @scales.setter
  567. def scales(self, v):
  568. if isinstance(v, str):
  569. self._scales[0] = v
  570. else:
  571. self._scales[:len(v)] = v
  572. @property
  573. def surface_color(self):
  574. return self._surface_color
  575. @surface_color.setter
  576. def surface_color(self, val):
  577. self._line_surface_color("_surface_color", val)
  578. @property
  579. def rendering_kw(self):
  580. return self._rendering_kw
  581. @rendering_kw.setter
  582. def rendering_kw(self, kwargs):
  583. if isinstance(kwargs, dict):
  584. self._rendering_kw = kwargs
  585. else:
  586. self._rendering_kw = {}
  587. if kwargs is not None:
  588. warnings.warn(
  589. "`rendering_kw` must be a dictionary, instead an "
  590. "object of type %s was received. " % type(kwargs) +
  591. "Automatically setting `rendering_kw` to an empty "
  592. "dictionary")
  593. @staticmethod
  594. def _discretize(start, end, N, scale="linear", only_integers=False):
  595. """Discretize a 1D domain.
  596. Returns
  597. =======
  598. domain : np.ndarray with dtype=float or complex
  599. The domain's dtype will be float or complex (depending on the
  600. type of start/end) even if only_integers=True. It is left for
  601. the downstream code to perform further casting, if necessary.
  602. """
  603. np = import_module('numpy')
  604. if only_integers is True:
  605. start, end = int(start), int(end)
  606. N = end - start + 1
  607. if scale == "linear":
  608. return np.linspace(start, end, N)
  609. return np.geomspace(start, end, N)
  610. @staticmethod
  611. def _correct_shape(a, b):
  612. """Convert ``a`` to a np.ndarray of the same shape of ``b``.
  613. Parameters
  614. ==========
  615. a : int, float, complex, np.ndarray
  616. Usually, this is the result of a numerical evaluation of a
  617. symbolic expression. Even if a discretized domain was used to
  618. evaluate the function, the result can be a scalar (int, float,
  619. complex). Think for example to ``expr = Float(2)`` and
  620. ``f = lambdify(x, expr)``. No matter the shape of the numerical
  621. array representing x, the result of the evaluation will be
  622. a single value.
  623. b : np.ndarray
  624. It represents the correct shape that ``a`` should have.
  625. Returns
  626. =======
  627. new_a : np.ndarray
  628. An array with the correct shape.
  629. """
  630. np = import_module('numpy')
  631. if not isinstance(a, np.ndarray):
  632. a = np.array(a)
  633. if a.shape != b.shape:
  634. if a.shape == ():
  635. a = a * np.ones_like(b)
  636. else:
  637. a = a.reshape(b.shape)
  638. return a
  639. def eval_color_func(self, *args):
  640. """Evaluate the color function.
  641. Parameters
  642. ==========
  643. args : tuple
  644. Arguments to be passed to the coloring function. Can be coordinates
  645. or parameters or both.
  646. Notes
  647. =====
  648. The backend will request the data series to generate the numerical
  649. data. Depending on the data series, either the data series itself or
  650. the backend will eventually execute this function to generate the
  651. appropriate coloring value.
  652. """
  653. np = import_module('numpy')
  654. if self.color_func is None:
  655. # NOTE: with the line_color and surface_color attributes
  656. # (back-compatibility with the old sympy.plotting module) it is
  657. # possible to create a plot with a callable line_color (or
  658. # surface_color). For example:
  659. # p = plot(sin(x), line_color=lambda x, y: -y)
  660. # This creates a ColoredLineOver1DRangeSeries with line_color=None
  661. # and color_func=lambda x, y: -y, which effectively is a
  662. # parametric series. Later we could change it to a string value:
  663. # p[0].line_color = "red"
  664. # However, this sets ine_color="red" and color_func=None, but the
  665. # series is still ColoredLineOver1DRangeSeries (a parametric
  666. # series), which will render using a color_func...
  667. warnings.warn("This is likely not the result you were "
  668. "looking for. Please, re-execute the plot command, this time "
  669. "with the appropriate an appropriate value to line_color "
  670. "or surface_color.")
  671. return np.ones_like(args[0])
  672. if self._eval_color_func_with_signature:
  673. args = self._aggregate_args()
  674. color = self.color_func(*args)
  675. _re, _im = np.real(color), np.imag(color)
  676. _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan
  677. return _re
  678. nargs = arity(self.color_func)
  679. if nargs == 1:
  680. if self.is_2Dline and self.is_parametric:
  681. if len(args) == 2:
  682. # ColoredLineOver1DRangeSeries
  683. return self._correct_shape(self.color_func(args[0]), args[0])
  684. # Parametric2DLineSeries
  685. return self._correct_shape(self.color_func(args[2]), args[2])
  686. elif self.is_3Dline and self.is_parametric:
  687. return self._correct_shape(self.color_func(args[3]), args[3])
  688. elif self.is_3Dsurface and self.is_parametric:
  689. return self._correct_shape(self.color_func(args[3]), args[3])
  690. return self._correct_shape(self.color_func(args[0]), args[0])
  691. elif nargs == 2:
  692. if self.is_3Dsurface and self.is_parametric:
  693. return self._correct_shape(self.color_func(*args[3:]), args[3])
  694. return self._correct_shape(self.color_func(*args[:2]), args[0])
  695. return self._correct_shape(self.color_func(*args[:nargs]), args[0])
  696. def get_data(self):
  697. """Compute and returns the numerical data.
  698. The number of parameters returned by this method depends on the
  699. specific instance. If ``s`` is the series, make sure to read
  700. ``help(s.get_data)`` to understand what it returns.
  701. """
  702. raise NotImplementedError
  703. def _get_wrapped_label(self, label, wrapper):
  704. """Given a latex representation of an expression, wrap it inside
  705. some characters. Matplotlib needs "$%s%$", K3D-Jupyter needs "%s".
  706. """
  707. return wrapper % label
  708. def get_label(self, use_latex=False, wrapper="$%s$"):
  709. """Return the label to be used to display the expression.
  710. Parameters
  711. ==========
  712. use_latex : bool
  713. If False, the string representation of the expression is returned.
  714. If True, the latex representation is returned.
  715. wrapper : str
  716. The backend might need the latex representation to be wrapped by
  717. some characters. Default to ``"$%s$"``.
  718. Returns
  719. =======
  720. label : str
  721. """
  722. if use_latex is False:
  723. return self._label
  724. if self._label == str(self.expr):
  725. # when the backend requests a latex label and user didn't provide
  726. # any label
  727. return self._get_wrapped_label(self._latex_label, wrapper)
  728. return self._latex_label
  729. @property
  730. def label(self):
  731. return self.get_label()
  732. @label.setter
  733. def label(self, val):
  734. """Set the labels associated to this series."""
  735. # NOTE: the init method of any series requires a label. If the user do
  736. # not provide it, the preprocessing function will set label=None, which
  737. # informs the series to initialize two attributes:
  738. # _label contains the string representation of the expression.
  739. # _latex_label contains the latex representation of the expression.
  740. self._label = self._latex_label = val
  741. @property
  742. def ranges(self):
  743. return self._ranges
  744. @ranges.setter
  745. def ranges(self, val):
  746. new_vals = []
  747. for v in val:
  748. if v is not None:
  749. new_vals.append(tuple([sympify(t) for t in v]))
  750. self._ranges = new_vals
  751. def _apply_transform(self, *args):
  752. """Apply transformations to the results of numerical evaluation.
  753. Parameters
  754. ==========
  755. args : tuple
  756. Results of numerical evaluation.
  757. Returns
  758. =======
  759. transformed_args : tuple
  760. Tuple containing the transformed results.
  761. """
  762. t = lambda x, transform: x if transform is None else transform(x)
  763. x, y, z = None, None, None
  764. if len(args) == 2:
  765. x, y = args
  766. return t(x, self._tx), t(y, self._ty)
  767. elif (len(args) == 3) and isinstance(self, Parametric2DLineSeries):
  768. x, y, u = args
  769. return (t(x, self._tx), t(y, self._ty), t(u, self._tp))
  770. elif len(args) == 3:
  771. x, y, z = args
  772. return t(x, self._tx), t(y, self._ty), t(z, self._tz)
  773. elif (len(args) == 4) and isinstance(self, Parametric3DLineSeries):
  774. x, y, z, u = args
  775. return (t(x, self._tx), t(y, self._ty), t(z, self._tz), t(u, self._tp))
  776. elif len(args) == 4: # 2D vector plot
  777. x, y, u, v = args
  778. return (
  779. t(x, self._tx), t(y, self._ty),
  780. t(u, self._tx), t(v, self._ty)
  781. )
  782. elif (len(args) == 5) and isinstance(self, ParametricSurfaceSeries):
  783. x, y, z, u, v = args
  784. return (t(x, self._tx), t(y, self._ty), t(z, self._tz), u, v)
  785. elif (len(args) == 6) and self.is_3Dvector: # 3D vector plot
  786. x, y, z, u, v, w = args
  787. return (
  788. t(x, self._tx), t(y, self._ty), t(z, self._tz),
  789. t(u, self._tx), t(v, self._ty), t(w, self._tz)
  790. )
  791. elif len(args) == 6: # complex plot
  792. x, y, _abs, _arg, img, colors = args
  793. return (
  794. x, y, t(_abs, self._tz), _arg, img, colors)
  795. return args
  796. def _str_helper(self, s):
  797. pre, post = "", ""
  798. if self.is_interactive:
  799. pre = "interactive "
  800. post = " and parameters " + str(tuple(self.params.keys()))
  801. return pre + s + post
  802. def _detect_poles_numerical_helper(x, y, eps=0.01, expr=None, symb=None, symbolic=False):
  803. """Compute the steepness of each segment. If it's greater than a
  804. threshold, set the right-point y-value non NaN and record the
  805. corresponding x-location for further processing.
  806. Returns
  807. =======
  808. x : np.ndarray
  809. Unchanged x-data.
  810. yy : np.ndarray
  811. Modified y-data with NaN values.
  812. """
  813. np = import_module('numpy')
  814. yy = y.copy()
  815. threshold = np.pi / 2 - eps
  816. for i in range(len(x) - 1):
  817. dx = x[i + 1] - x[i]
  818. dy = abs(y[i + 1] - y[i])
  819. angle = np.arctan(dy / dx)
  820. if abs(angle) >= threshold:
  821. yy[i + 1] = np.nan
  822. return x, yy
  823. def _detect_poles_symbolic_helper(expr, symb, start, end):
  824. """Attempts to compute symbolic discontinuities.
  825. Returns
  826. =======
  827. pole : list
  828. List of symbolic poles, possibly empty.
  829. """
  830. poles = []
  831. interval = Interval(nsimplify(start), nsimplify(end))
  832. res = continuous_domain(expr, symb, interval)
  833. res = res.simplify()
  834. if res == interval:
  835. pass
  836. elif (isinstance(res, Union) and
  837. all(isinstance(t, Interval) for t in res.args)):
  838. poles = []
  839. for s in res.args:
  840. if s.left_open:
  841. poles.append(s.left)
  842. if s.right_open:
  843. poles.append(s.right)
  844. poles = list(set(poles))
  845. else:
  846. raise ValueError(
  847. f"Could not parse the following object: {res} .\n"
  848. "Please, submit this as a bug. Consider also to set "
  849. "`detect_poles=True`."
  850. )
  851. return poles
  852. ### 2D lines
  853. class Line2DBaseSeries(BaseSeries):
  854. """A base class for 2D lines.
  855. - adding the label, steps and only_integers options
  856. - making is_2Dline true
  857. - defining get_segments and get_color_array
  858. """
  859. is_2Dline = True
  860. _dim = 2
  861. _N = 1000
  862. def __init__(self, **kwargs):
  863. super().__init__(**kwargs)
  864. self.steps = kwargs.get("steps", False)
  865. self.is_point = kwargs.get("is_point", kwargs.get("point", False))
  866. self.is_filled = kwargs.get("is_filled", kwargs.get("fill", True))
  867. self.adaptive = kwargs.get("adaptive", False)
  868. self.depth = kwargs.get('depth', 12)
  869. self.use_cm = kwargs.get("use_cm", False)
  870. self.color_func = kwargs.get("color_func", None)
  871. self.line_color = kwargs.get("line_color", None)
  872. self.detect_poles = kwargs.get("detect_poles", False)
  873. self.eps = kwargs.get("eps", 0.01)
  874. self.is_polar = kwargs.get("is_polar", kwargs.get("polar", False))
  875. self.unwrap = kwargs.get("unwrap", False)
  876. # when detect_poles="symbolic", stores the location of poles so that
  877. # they can be appropriately rendered
  878. self.poles_locations = []
  879. exclude = kwargs.get("exclude", [])
  880. if isinstance(exclude, Set):
  881. exclude = list(extract_solution(exclude, n=100))
  882. if not hasattr(exclude, "__iter__"):
  883. exclude = [exclude]
  884. exclude = [float(e) for e in exclude]
  885. self.exclude = sorted(exclude)
  886. def get_data(self):
  887. """Return coordinates for plotting the line.
  888. Returns
  889. =======
  890. x: np.ndarray
  891. x-coordinates
  892. y: np.ndarray
  893. y-coordinates
  894. z: np.ndarray (optional)
  895. z-coordinates in case of Parametric3DLineSeries,
  896. Parametric3DLineInteractiveSeries
  897. param : np.ndarray (optional)
  898. The parameter in case of Parametric2DLineSeries,
  899. Parametric3DLineSeries or AbsArgLineSeries (and their
  900. corresponding interactive series).
  901. """
  902. np = import_module('numpy')
  903. points = self._get_data_helper()
  904. if (isinstance(self, LineOver1DRangeSeries) and
  905. (self.detect_poles == "symbolic")):
  906. poles = _detect_poles_symbolic_helper(
  907. self.expr.subs(self.params), *self.ranges[0])
  908. poles = np.array([float(t) for t in poles])
  909. t = lambda x, transform: x if transform is None else transform(x)
  910. self.poles_locations = t(np.array(poles), self._tx)
  911. # postprocessing
  912. points = self._apply_transform(*points)
  913. if self.is_2Dline and self.detect_poles:
  914. if len(points) == 2:
  915. x, y = points
  916. x, y = _detect_poles_numerical_helper(
  917. x, y, self.eps)
  918. points = (x, y)
  919. else:
  920. x, y, p = points
  921. x, y = _detect_poles_numerical_helper(x, y, self.eps)
  922. points = (x, y, p)
  923. if self.unwrap:
  924. kw = {}
  925. if self.unwrap is not True:
  926. kw = self.unwrap
  927. if self.is_2Dline:
  928. if len(points) == 2:
  929. x, y = points
  930. y = np.unwrap(y, **kw)
  931. points = (x, y)
  932. else:
  933. x, y, p = points
  934. y = np.unwrap(y, **kw)
  935. points = (x, y, p)
  936. if self.steps is True:
  937. if self.is_2Dline:
  938. x, y = points[0], points[1]
  939. x = np.array((x, x)).T.flatten()[1:]
  940. y = np.array((y, y)).T.flatten()[:-1]
  941. if self.is_parametric:
  942. points = (x, y, points[2])
  943. else:
  944. points = (x, y)
  945. elif self.is_3Dline:
  946. x = np.repeat(points[0], 3)[2:]
  947. y = np.repeat(points[1], 3)[:-2]
  948. z = np.repeat(points[2], 3)[1:-1]
  949. if len(points) > 3:
  950. points = (x, y, z, points[3])
  951. else:
  952. points = (x, y, z)
  953. if len(self.exclude) > 0:
  954. points = self._insert_exclusions(points)
  955. return points
  956. def get_segments(self):
  957. sympy_deprecation_warning(
  958. """
  959. The Line2DBaseSeries.get_segments() method is deprecated.
  960. Instead, use the MatplotlibBackend.get_segments() method, or use
  961. The get_points() or get_data() methods.
  962. """,
  963. deprecated_since_version="1.9",
  964. active_deprecations_target="deprecated-get-segments")
  965. np = import_module('numpy')
  966. points = type(self).get_data(self)
  967. points = np.ma.array(points).T.reshape(-1, 1, self._dim)
  968. return np.ma.concatenate([points[:-1], points[1:]], axis=1)
  969. def _insert_exclusions(self, points):
  970. """Add NaN to each of the exclusion point. Practically, this adds a
  971. NaN to the exclusion point, plus two other nearby points evaluated with
  972. the numerical functions associated to this data series.
  973. These nearby points are important when the number of discretization
  974. points is low, or the scale is logarithm.
  975. NOTE: it would be easier to just add exclusion points to the
  976. discretized domain before evaluation, then after evaluation add NaN
  977. to the exclusion points. But that's only work with adaptive=False.
  978. The following approach work even with adaptive=True.
  979. """
  980. np = import_module("numpy")
  981. points = list(points)
  982. n = len(points)
  983. # index of the x-coordinate (for 2d plots) or parameter (for 2d/3d
  984. # parametric plots)
  985. k = n - 1
  986. if n == 2:
  987. k = 0
  988. # indices of the other coordinates
  989. j_indeces = sorted(set(range(n)).difference([k]))
  990. # TODO: for now, I assume that numpy functions are going to succeed
  991. funcs = [f[0] for f in self._functions]
  992. for e in self.exclude:
  993. res = points[k] - e >= 0
  994. # if res contains both True and False, ie, if e is found
  995. if any(res) and any(~res):
  996. idx = np.nanargmax(res)
  997. # select the previous point with respect to e
  998. idx -= 1
  999. # TODO: what if points[k][idx]==e or points[k][idx+1]==e?
  1000. if idx > 0 and idx < len(points[k]) - 1:
  1001. delta_prev = abs(e - points[k][idx])
  1002. delta_post = abs(e - points[k][idx + 1])
  1003. delta = min(delta_prev, delta_post) / 100
  1004. prev = e - delta
  1005. post = e + delta
  1006. # add points to the x-coord or the parameter
  1007. points[k] = np.concatenate(
  1008. (points[k][:idx], [prev, e, post], points[k][idx+1:]))
  1009. # add points to the other coordinates
  1010. c = 0
  1011. for j in j_indeces:
  1012. values = funcs[c](np.array([prev, post]))
  1013. c += 1
  1014. points[j] = np.concatenate(
  1015. (points[j][:idx], [values[0], np.nan, values[1]], points[j][idx+1:]))
  1016. return points
  1017. @property
  1018. def var(self):
  1019. return None if not self.ranges else self.ranges[0][0]
  1020. @property
  1021. def start(self):
  1022. if not self.ranges:
  1023. return None
  1024. try:
  1025. return self._cast(self.ranges[0][1])
  1026. except TypeError:
  1027. return self.ranges[0][1]
  1028. @property
  1029. def end(self):
  1030. if not self.ranges:
  1031. return None
  1032. try:
  1033. return self._cast(self.ranges[0][2])
  1034. except TypeError:
  1035. return self.ranges[0][2]
  1036. @property
  1037. def xscale(self):
  1038. return self._scales[0]
  1039. @xscale.setter
  1040. def xscale(self, v):
  1041. self.scales = v
  1042. def get_color_array(self):
  1043. np = import_module('numpy')
  1044. c = self.line_color
  1045. if hasattr(c, '__call__'):
  1046. f = np.vectorize(c)
  1047. nargs = arity(c)
  1048. if nargs == 1 and self.is_parametric:
  1049. x = self.get_parameter_points()
  1050. return f(centers_of_segments(x))
  1051. else:
  1052. variables = list(map(centers_of_segments, self.get_points()))
  1053. if nargs == 1:
  1054. return f(variables[0])
  1055. elif nargs == 2:
  1056. return f(*variables[:2])
  1057. else: # only if the line is 3D (otherwise raises an error)
  1058. return f(*variables)
  1059. else:
  1060. return c*np.ones(self.nb_of_points)
  1061. class List2DSeries(Line2DBaseSeries):
  1062. """Representation for a line consisting of list of points."""
  1063. def __init__(self, list_x, list_y, label="", **kwargs):
  1064. super().__init__(**kwargs)
  1065. np = import_module('numpy')
  1066. if len(list_x) != len(list_y):
  1067. raise ValueError(
  1068. "The two lists of coordinates must have the same "
  1069. "number of elements.\n"
  1070. "Received: len(list_x) = {} ".format(len(list_x)) +
  1071. "and len(list_y) = {}".format(len(list_y))
  1072. )
  1073. self._block_lambda_functions(list_x, list_y)
  1074. check = lambda l: [isinstance(t, Expr) and (not t.is_number) for t in l]
  1075. if any(check(list_x) + check(list_y)) or self.params:
  1076. if not self.params:
  1077. raise ValueError("Some or all elements of the provided lists "
  1078. "are symbolic expressions, but the ``params`` dictionary "
  1079. "was not provided: those elements can't be evaluated.")
  1080. self.list_x = Tuple(*list_x)
  1081. self.list_y = Tuple(*list_y)
  1082. else:
  1083. self.list_x = np.array(list_x, dtype=np.float64)
  1084. self.list_y = np.array(list_y, dtype=np.float64)
  1085. self._expr = (self.list_x, self.list_y)
  1086. if not any(isinstance(t, np.ndarray) for t in [self.list_x, self.list_y]):
  1087. self._check_fs()
  1088. self.is_polar = kwargs.get("is_polar", kwargs.get("polar", False))
  1089. self.label = label
  1090. self.rendering_kw = kwargs.get("rendering_kw", {})
  1091. if self.use_cm and self.color_func:
  1092. self.is_parametric = True
  1093. if isinstance(self.color_func, Expr):
  1094. raise TypeError(
  1095. "%s don't support symbolic " % self.__class__.__name__ +
  1096. "expression for `color_func`.")
  1097. def __str__(self):
  1098. return "2D list plot"
  1099. def _get_data_helper(self):
  1100. """Returns coordinates that needs to be postprocessed."""
  1101. lx, ly = self.list_x, self.list_y
  1102. if not self.is_interactive:
  1103. return self._eval_color_func_and_return(lx, ly)
  1104. np = import_module('numpy')
  1105. lx = np.array([t.evalf(subs=self.params) for t in lx], dtype=float)
  1106. ly = np.array([t.evalf(subs=self.params) for t in ly], dtype=float)
  1107. return self._eval_color_func_and_return(lx, ly)
  1108. def _eval_color_func_and_return(self, *data):
  1109. if self.use_cm and callable(self.color_func):
  1110. return [*data, self.eval_color_func(*data)]
  1111. return data
  1112. class LineOver1DRangeSeries(Line2DBaseSeries):
  1113. """Representation for a line consisting of a SymPy expression over a range."""
  1114. def __init__(self, expr, var_start_end, label="", **kwargs):
  1115. super().__init__(**kwargs)
  1116. self.expr = expr if callable(expr) else sympify(expr)
  1117. self._label = str(self.expr) if label is None else label
  1118. self._latex_label = latex(self.expr) if label is None else label
  1119. self.ranges = [var_start_end]
  1120. self._cast = complex
  1121. # for complex-related data series, this determines what data to return
  1122. # on the y-axis
  1123. self._return = kwargs.get("return", None)
  1124. self._post_init()
  1125. if not self._interactive_ranges:
  1126. # NOTE: the following check is only possible when the minimum and
  1127. # maximum values of a plotting range are numeric
  1128. start, end = [complex(t) for t in self.ranges[0][1:]]
  1129. if im(start) != im(end):
  1130. raise ValueError(
  1131. "%s requires the imaginary " % self.__class__.__name__ +
  1132. "part of the start and end values of the range "
  1133. "to be the same.")
  1134. if self.adaptive and self._return:
  1135. warnings.warn("The adaptive algorithm is unable to deal with "
  1136. "complex numbers. Automatically switching to uniform meshing.")
  1137. self.adaptive = False
  1138. @property
  1139. def nb_of_points(self):
  1140. return self.n[0]
  1141. @nb_of_points.setter
  1142. def nb_of_points(self, v):
  1143. self.n = v
  1144. def __str__(self):
  1145. def f(t):
  1146. if isinstance(t, complex):
  1147. if t.imag != 0:
  1148. return t
  1149. return t.real
  1150. return t
  1151. pre = "interactive " if self.is_interactive else ""
  1152. post = ""
  1153. if self.is_interactive:
  1154. post = " and parameters " + str(tuple(self.params.keys()))
  1155. wrapper = _get_wrapper_for_expr(self._return)
  1156. return pre + "cartesian line: %s for %s over %s" % (
  1157. wrapper % self.expr,
  1158. str(self.var),
  1159. str((f(self.start), f(self.end))),
  1160. ) + post
  1161. def get_points(self):
  1162. """Return lists of coordinates for plotting. Depending on the
  1163. ``adaptive`` option, this function will either use an adaptive algorithm
  1164. or it will uniformly sample the expression over the provided range.
  1165. This function is available for back-compatibility purposes. Consider
  1166. using ``get_data()`` instead.
  1167. Returns
  1168. =======
  1169. x : list
  1170. List of x-coordinates
  1171. y : list
  1172. List of y-coordinates
  1173. """
  1174. return self._get_data_helper()
  1175. def _adaptive_sampling(self):
  1176. try:
  1177. if callable(self.expr):
  1178. f = self.expr
  1179. else:
  1180. f = lambdify([self.var], self.expr, self.modules)
  1181. x, y = self._adaptive_sampling_helper(f)
  1182. except Exception as err:
  1183. warnings.warn(
  1184. "The evaluation with %s failed.\n" % (
  1185. "NumPy/SciPy" if not self.modules else self.modules) +
  1186. "{}: {}\n".format(type(err).__name__, err) +
  1187. "Trying to evaluate the expression with Sympy, but it might "
  1188. "be a slow operation."
  1189. )
  1190. f = lambdify([self.var], self.expr, "sympy")
  1191. x, y = self._adaptive_sampling_helper(f)
  1192. return x, y
  1193. def _adaptive_sampling_helper(self, f):
  1194. """The adaptive sampling is done by recursively checking if three
  1195. points are almost collinear. If they are not collinear, then more
  1196. points are added between those points.
  1197. References
  1198. ==========
  1199. .. [1] Adaptive polygonal approximation of parametric curves,
  1200. Luiz Henrique de Figueiredo.
  1201. """
  1202. np = import_module('numpy')
  1203. x_coords = []
  1204. y_coords = []
  1205. def sample(p, q, depth):
  1206. """ Samples recursively if three points are almost collinear.
  1207. For depth < 6, points are added irrespective of whether they
  1208. satisfy the collinearity condition or not. The maximum depth
  1209. allowed is 12.
  1210. """
  1211. # Randomly sample to avoid aliasing.
  1212. random = 0.45 + np.random.rand() * 0.1
  1213. if self.xscale == 'log':
  1214. xnew = 10**(np.log10(p[0]) + random * (np.log10(q[0]) -
  1215. np.log10(p[0])))
  1216. else:
  1217. xnew = p[0] + random * (q[0] - p[0])
  1218. ynew = _adaptive_eval(f, xnew)
  1219. new_point = np.array([xnew, ynew])
  1220. # Maximum depth
  1221. if depth > self.depth:
  1222. x_coords.append(q[0])
  1223. y_coords.append(q[1])
  1224. # Sample to depth of 6 (whether the line is flat or not)
  1225. # without using linspace (to avoid aliasing).
  1226. elif depth < 6:
  1227. sample(p, new_point, depth + 1)
  1228. sample(new_point, q, depth + 1)
  1229. # Sample ten points if complex values are encountered
  1230. # at both ends. If there is a real value in between, then
  1231. # sample those points further.
  1232. elif p[1] is None and q[1] is None:
  1233. if self.xscale == 'log':
  1234. xarray = np.logspace(p[0], q[0], 10)
  1235. else:
  1236. xarray = np.linspace(p[0], q[0], 10)
  1237. yarray = list(map(f, xarray))
  1238. if not all(y is None for y in yarray):
  1239. for i in range(len(yarray) - 1):
  1240. if not (yarray[i] is None and yarray[i + 1] is None):
  1241. sample([xarray[i], yarray[i]],
  1242. [xarray[i + 1], yarray[i + 1]], depth + 1)
  1243. # Sample further if one of the end points in None (i.e. a
  1244. # complex value) or the three points are not almost collinear.
  1245. elif (p[1] is None or q[1] is None or new_point[1] is None
  1246. or not flat(p, new_point, q)):
  1247. sample(p, new_point, depth + 1)
  1248. sample(new_point, q, depth + 1)
  1249. else:
  1250. x_coords.append(q[0])
  1251. y_coords.append(q[1])
  1252. f_start = _adaptive_eval(f, self.start.real)
  1253. f_end = _adaptive_eval(f, self.end.real)
  1254. x_coords.append(self.start.real)
  1255. y_coords.append(f_start)
  1256. sample(np.array([self.start.real, f_start]),
  1257. np.array([self.end.real, f_end]), 0)
  1258. return (x_coords, y_coords)
  1259. def _uniform_sampling(self):
  1260. np = import_module('numpy')
  1261. x, result = self._evaluate()
  1262. _re, _im = np.real(result), np.imag(result)
  1263. _re = self._correct_shape(_re, x)
  1264. _im = self._correct_shape(_im, x)
  1265. return x, _re, _im
  1266. def _get_data_helper(self):
  1267. """Returns coordinates that needs to be postprocessed.
  1268. """
  1269. np = import_module('numpy')
  1270. if self.adaptive and (not self.only_integers):
  1271. x, y = self._adaptive_sampling()
  1272. return [np.array(t) for t in [x, y]]
  1273. x, _re, _im = self._uniform_sampling()
  1274. if self._return is None:
  1275. # The evaluation could produce complex numbers. Set real elements
  1276. # to NaN where there are non-zero imaginary elements
  1277. _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan
  1278. elif self._return == "real":
  1279. pass
  1280. elif self._return == "imag":
  1281. _re = _im
  1282. elif self._return == "abs":
  1283. _re = np.sqrt(_re**2 + _im**2)
  1284. elif self._return == "arg":
  1285. _re = np.arctan2(_im, _re)
  1286. else:
  1287. raise ValueError("`_return` not recognized. "
  1288. "Received: %s" % self._return)
  1289. return x, _re
  1290. class ParametricLineBaseSeries(Line2DBaseSeries):
  1291. is_parametric = True
  1292. def _set_parametric_line_label(self, label):
  1293. """Logic to set the correct label to be shown on the plot.
  1294. If `use_cm=True` there will be a colorbar, so we show the parameter.
  1295. If `use_cm=False`, there might be a legend, so we show the expressions.
  1296. Parameters
  1297. ==========
  1298. label : str
  1299. label passed in by the pre-processor or the user
  1300. """
  1301. self._label = str(self.var) if label is None else label
  1302. self._latex_label = latex(self.var) if label is None else label
  1303. if (self.use_cm is False) and (self._label == str(self.var)):
  1304. self._label = str(self.expr)
  1305. self._latex_label = latex(self.expr)
  1306. # if the expressions is a lambda function and use_cm=False and no label
  1307. # has been provided, then its better to do the following in order to
  1308. # avoid surprises on the backend
  1309. if any(callable(e) for e in self.expr) and (not self.use_cm):
  1310. if self._label == str(self.expr):
  1311. self._label = ""
  1312. def get_label(self, use_latex=False, wrapper="$%s$"):
  1313. # parametric lines returns the representation of the parameter to be
  1314. # shown on the colorbar if `use_cm=True`, otherwise it returns the
  1315. # representation of the expression to be placed on the legend.
  1316. if self.use_cm:
  1317. if str(self.var) == self._label:
  1318. if use_latex:
  1319. return self._get_wrapped_label(latex(self.var), wrapper)
  1320. return str(self.var)
  1321. # here the user has provided a custom label
  1322. return self._label
  1323. if use_latex:
  1324. if self._label != str(self.expr):
  1325. return self._latex_label
  1326. return self._get_wrapped_label(self._latex_label, wrapper)
  1327. return self._label
  1328. def _get_data_helper(self):
  1329. """Returns coordinates that needs to be postprocessed.
  1330. Depending on the `adaptive` option, this function will either use an
  1331. adaptive algorithm or it will uniformly sample the expression over the
  1332. provided range.
  1333. """
  1334. if self.adaptive:
  1335. np = import_module("numpy")
  1336. coords = self._adaptive_sampling()
  1337. coords = [np.array(t) for t in coords]
  1338. else:
  1339. coords = self._uniform_sampling()
  1340. if self.is_2Dline and self.is_polar:
  1341. # when plot_polar is executed with polar_axis=True
  1342. np = import_module('numpy')
  1343. x, y, _ = coords
  1344. r = np.sqrt(x**2 + y**2)
  1345. t = np.arctan2(y, x)
  1346. coords = [t, r, coords[-1]]
  1347. if callable(self.color_func):
  1348. coords = list(coords)
  1349. coords[-1] = self.eval_color_func(*coords)
  1350. return coords
  1351. def _uniform_sampling(self):
  1352. """Returns coordinates that needs to be postprocessed."""
  1353. np = import_module('numpy')
  1354. results = self._evaluate()
  1355. for i, r in enumerate(results):
  1356. _re, _im = np.real(r), np.imag(r)
  1357. _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan
  1358. results[i] = _re
  1359. return [*results[1:], results[0]]
  1360. def get_parameter_points(self):
  1361. return self.get_data()[-1]
  1362. def get_points(self):
  1363. """ Return lists of coordinates for plotting. Depending on the
  1364. ``adaptive`` option, this function will either use an adaptive algorithm
  1365. or it will uniformly sample the expression over the provided range.
  1366. This function is available for back-compatibility purposes. Consider
  1367. using ``get_data()`` instead.
  1368. Returns
  1369. =======
  1370. x : list
  1371. List of x-coordinates
  1372. y : list
  1373. List of y-coordinates
  1374. z : list
  1375. List of z-coordinates, only for 3D parametric line plot.
  1376. """
  1377. return self._get_data_helper()[:-1]
  1378. @property
  1379. def nb_of_points(self):
  1380. return self.n[0]
  1381. @nb_of_points.setter
  1382. def nb_of_points(self, v):
  1383. self.n = v
  1384. class Parametric2DLineSeries(ParametricLineBaseSeries):
  1385. """Representation for a line consisting of two parametric SymPy expressions
  1386. over a range."""
  1387. is_2Dline = True
  1388. def __init__(self, expr_x, expr_y, var_start_end, label="", **kwargs):
  1389. super().__init__(**kwargs)
  1390. self.expr_x = expr_x if callable(expr_x) else sympify(expr_x)
  1391. self.expr_y = expr_y if callable(expr_y) else sympify(expr_y)
  1392. self.expr = (self.expr_x, self.expr_y)
  1393. self.ranges = [var_start_end]
  1394. self._cast = float
  1395. self.use_cm = kwargs.get("use_cm", True)
  1396. self._set_parametric_line_label(label)
  1397. self._post_init()
  1398. def __str__(self):
  1399. return self._str_helper(
  1400. "parametric cartesian line: (%s, %s) for %s over %s" % (
  1401. str(self.expr_x),
  1402. str(self.expr_y),
  1403. str(self.var),
  1404. str((self.start, self.end))
  1405. ))
  1406. def _adaptive_sampling(self):
  1407. try:
  1408. if callable(self.expr_x) and callable(self.expr_y):
  1409. f_x = self.expr_x
  1410. f_y = self.expr_y
  1411. else:
  1412. f_x = lambdify([self.var], self.expr_x)
  1413. f_y = lambdify([self.var], self.expr_y)
  1414. x, y, p = self._adaptive_sampling_helper(f_x, f_y)
  1415. except Exception as err:
  1416. warnings.warn(
  1417. "The evaluation with %s failed.\n" % (
  1418. "NumPy/SciPy" if not self.modules else self.modules) +
  1419. "{}: {}\n".format(type(err).__name__, err) +
  1420. "Trying to evaluate the expression with Sympy, but it might "
  1421. "be a slow operation."
  1422. )
  1423. f_x = lambdify([self.var], self.expr_x, "sympy")
  1424. f_y = lambdify([self.var], self.expr_y, "sympy")
  1425. x, y, p = self._adaptive_sampling_helper(f_x, f_y)
  1426. return x, y, p
  1427. def _adaptive_sampling_helper(self, f_x, f_y):
  1428. """The adaptive sampling is done by recursively checking if three
  1429. points are almost collinear. If they are not collinear, then more
  1430. points are added between those points.
  1431. References
  1432. ==========
  1433. .. [1] Adaptive polygonal approximation of parametric curves,
  1434. Luiz Henrique de Figueiredo.
  1435. """
  1436. x_coords = []
  1437. y_coords = []
  1438. param = []
  1439. def sample(param_p, param_q, p, q, depth):
  1440. """ Samples recursively if three points are almost collinear.
  1441. For depth < 6, points are added irrespective of whether they
  1442. satisfy the collinearity condition or not. The maximum depth
  1443. allowed is 12.
  1444. """
  1445. # Randomly sample to avoid aliasing.
  1446. np = import_module('numpy')
  1447. random = 0.45 + np.random.rand() * 0.1
  1448. param_new = param_p + random * (param_q - param_p)
  1449. xnew = _adaptive_eval(f_x, param_new)
  1450. ynew = _adaptive_eval(f_y, param_new)
  1451. new_point = np.array([xnew, ynew])
  1452. # Maximum depth
  1453. if depth > self.depth:
  1454. x_coords.append(q[0])
  1455. y_coords.append(q[1])
  1456. param.append(param_p)
  1457. # Sample irrespective of whether the line is flat till the
  1458. # depth of 6. We are not using linspace to avoid aliasing.
  1459. elif depth < 6:
  1460. sample(param_p, param_new, p, new_point, depth + 1)
  1461. sample(param_new, param_q, new_point, q, depth + 1)
  1462. # Sample ten points if complex values are encountered
  1463. # at both ends. If there is a real value in between, then
  1464. # sample those points further.
  1465. elif ((p[0] is None and q[1] is None) or
  1466. (p[1] is None and q[1] is None)):
  1467. param_array = np.linspace(param_p, param_q, 10)
  1468. x_array = [_adaptive_eval(f_x, t) for t in param_array]
  1469. y_array = [_adaptive_eval(f_y, t) for t in param_array]
  1470. if not all(x is None and y is None
  1471. for x, y in zip(x_array, y_array)):
  1472. for i in range(len(y_array) - 1):
  1473. if ((x_array[i] is not None and y_array[i] is not None) or
  1474. (x_array[i + 1] is not None and y_array[i + 1] is not None)):
  1475. point_a = [x_array[i], y_array[i]]
  1476. point_b = [x_array[i + 1], y_array[i + 1]]
  1477. sample(param_array[i], param_array[i], point_a,
  1478. point_b, depth + 1)
  1479. # Sample further if one of the end points in None (i.e. a complex
  1480. # value) or the three points are not almost collinear.
  1481. elif (p[0] is None or p[1] is None
  1482. or q[1] is None or q[0] is None
  1483. or not flat(p, new_point, q)):
  1484. sample(param_p, param_new, p, new_point, depth + 1)
  1485. sample(param_new, param_q, new_point, q, depth + 1)
  1486. else:
  1487. x_coords.append(q[0])
  1488. y_coords.append(q[1])
  1489. param.append(param_p)
  1490. f_start_x = _adaptive_eval(f_x, self.start)
  1491. f_start_y = _adaptive_eval(f_y, self.start)
  1492. start = [f_start_x, f_start_y]
  1493. f_end_x = _adaptive_eval(f_x, self.end)
  1494. f_end_y = _adaptive_eval(f_y, self.end)
  1495. end = [f_end_x, f_end_y]
  1496. x_coords.append(f_start_x)
  1497. y_coords.append(f_start_y)
  1498. param.append(self.start)
  1499. sample(self.start, self.end, start, end, 0)
  1500. return x_coords, y_coords, param
  1501. ### 3D lines
  1502. class Line3DBaseSeries(Line2DBaseSeries):
  1503. """A base class for 3D lines.
  1504. Most of the stuff is derived from Line2DBaseSeries."""
  1505. is_2Dline = False
  1506. is_3Dline = True
  1507. _dim = 3
  1508. def __init__(self):
  1509. super().__init__()
  1510. class Parametric3DLineSeries(ParametricLineBaseSeries):
  1511. """Representation for a 3D line consisting of three parametric SymPy
  1512. expressions and a range."""
  1513. is_2Dline = False
  1514. is_3Dline = True
  1515. def __init__(self, expr_x, expr_y, expr_z, var_start_end, label="", **kwargs):
  1516. super().__init__(**kwargs)
  1517. self.expr_x = expr_x if callable(expr_x) else sympify(expr_x)
  1518. self.expr_y = expr_y if callable(expr_y) else sympify(expr_y)
  1519. self.expr_z = expr_z if callable(expr_z) else sympify(expr_z)
  1520. self.expr = (self.expr_x, self.expr_y, self.expr_z)
  1521. self.ranges = [var_start_end]
  1522. self._cast = float
  1523. self.adaptive = False
  1524. self.use_cm = kwargs.get("use_cm", True)
  1525. self._set_parametric_line_label(label)
  1526. self._post_init()
  1527. # TODO: remove this
  1528. self._xlim = None
  1529. self._ylim = None
  1530. self._zlim = None
  1531. def __str__(self):
  1532. return self._str_helper(
  1533. "3D parametric cartesian line: (%s, %s, %s) for %s over %s" % (
  1534. str(self.expr_x),
  1535. str(self.expr_y),
  1536. str(self.expr_z),
  1537. str(self.var),
  1538. str((self.start, self.end))
  1539. ))
  1540. def get_data(self):
  1541. # TODO: remove this
  1542. np = import_module("numpy")
  1543. x, y, z, p = super().get_data()
  1544. self._xlim = (np.amin(x), np.amax(x))
  1545. self._ylim = (np.amin(y), np.amax(y))
  1546. self._zlim = (np.amin(z), np.amax(z))
  1547. return x, y, z, p
  1548. ### Surfaces
  1549. class SurfaceBaseSeries(BaseSeries):
  1550. """A base class for 3D surfaces."""
  1551. is_3Dsurface = True
  1552. def __init__(self, *args, **kwargs):
  1553. super().__init__(**kwargs)
  1554. self.use_cm = kwargs.get("use_cm", False)
  1555. # NOTE: why should SurfaceOver2DRangeSeries support is polar?
  1556. # After all, the same result can be achieve with
  1557. # ParametricSurfaceSeries. For example:
  1558. # sin(r) for (r, 0, 2 * pi) and (theta, 0, pi/2) can be parameterized
  1559. # as (r * cos(theta), r * sin(theta), sin(t)) for (r, 0, 2 * pi) and
  1560. # (theta, 0, pi/2).
  1561. # Because it is faster to evaluate (important for interactive plots).
  1562. self.is_polar = kwargs.get("is_polar", kwargs.get("polar", False))
  1563. self.surface_color = kwargs.get("surface_color", None)
  1564. self.color_func = kwargs.get("color_func", lambda x, y, z: z)
  1565. if callable(self.surface_color):
  1566. self.color_func = self.surface_color
  1567. self.surface_color = None
  1568. def _set_surface_label(self, label):
  1569. exprs = self.expr
  1570. self._label = str(exprs) if label is None else label
  1571. self._latex_label = latex(exprs) if label is None else label
  1572. # if the expressions is a lambda function and no label
  1573. # has been provided, then its better to do the following to avoid
  1574. # surprises on the backend
  1575. is_lambda = (callable(exprs) if not hasattr(exprs, "__iter__")
  1576. else any(callable(e) for e in exprs))
  1577. if is_lambda and (self._label == str(exprs)):
  1578. self._label = ""
  1579. self._latex_label = ""
  1580. def get_color_array(self):
  1581. np = import_module('numpy')
  1582. c = self.surface_color
  1583. if isinstance(c, Callable):
  1584. f = np.vectorize(c)
  1585. nargs = arity(c)
  1586. if self.is_parametric:
  1587. variables = list(map(centers_of_faces, self.get_parameter_meshes()))
  1588. if nargs == 1:
  1589. return f(variables[0])
  1590. elif nargs == 2:
  1591. return f(*variables)
  1592. variables = list(map(centers_of_faces, self.get_meshes()))
  1593. if nargs == 1:
  1594. return f(variables[0])
  1595. elif nargs == 2:
  1596. return f(*variables[:2])
  1597. else:
  1598. return f(*variables)
  1599. else:
  1600. if isinstance(self, SurfaceOver2DRangeSeries):
  1601. return c*np.ones(min(self.nb_of_points_x, self.nb_of_points_y))
  1602. else:
  1603. return c*np.ones(min(self.nb_of_points_u, self.nb_of_points_v))
  1604. class SurfaceOver2DRangeSeries(SurfaceBaseSeries):
  1605. """Representation for a 3D surface consisting of a SymPy expression and 2D
  1606. range."""
  1607. def __init__(self, expr, var_start_end_x, var_start_end_y, label="", **kwargs):
  1608. super().__init__(**kwargs)
  1609. self.expr = expr if callable(expr) else sympify(expr)
  1610. self.ranges = [var_start_end_x, var_start_end_y]
  1611. self._set_surface_label(label)
  1612. self._post_init()
  1613. # TODO: remove this
  1614. self._xlim = (self.start_x, self.end_x)
  1615. self._ylim = (self.start_y, self.end_y)
  1616. @property
  1617. def var_x(self):
  1618. return self.ranges[0][0]
  1619. @property
  1620. def var_y(self):
  1621. return self.ranges[1][0]
  1622. @property
  1623. def start_x(self):
  1624. try:
  1625. return float(self.ranges[0][1])
  1626. except TypeError:
  1627. return self.ranges[0][1]
  1628. @property
  1629. def end_x(self):
  1630. try:
  1631. return float(self.ranges[0][2])
  1632. except TypeError:
  1633. return self.ranges[0][2]
  1634. @property
  1635. def start_y(self):
  1636. try:
  1637. return float(self.ranges[1][1])
  1638. except TypeError:
  1639. return self.ranges[1][1]
  1640. @property
  1641. def end_y(self):
  1642. try:
  1643. return float(self.ranges[1][2])
  1644. except TypeError:
  1645. return self.ranges[1][2]
  1646. @property
  1647. def nb_of_points_x(self):
  1648. return self.n[0]
  1649. @nb_of_points_x.setter
  1650. def nb_of_points_x(self, v):
  1651. n = self.n
  1652. self.n = [v, n[1:]]
  1653. @property
  1654. def nb_of_points_y(self):
  1655. return self.n[1]
  1656. @nb_of_points_y.setter
  1657. def nb_of_points_y(self, v):
  1658. n = self.n
  1659. self.n = [n[0], v, n[2]]
  1660. def __str__(self):
  1661. series_type = "cartesian surface" if self.is_3Dsurface else "contour"
  1662. return self._str_helper(
  1663. series_type + ": %s for" " %s over %s and %s over %s" % (
  1664. str(self.expr),
  1665. str(self.var_x), str((self.start_x, self.end_x)),
  1666. str(self.var_y), str((self.start_y, self.end_y)),
  1667. ))
  1668. def get_meshes(self):
  1669. """Return the x,y,z coordinates for plotting the surface.
  1670. This function is available for back-compatibility purposes. Consider
  1671. using ``get_data()`` instead.
  1672. """
  1673. return self.get_data()
  1674. def get_data(self):
  1675. """Return arrays of coordinates for plotting.
  1676. Returns
  1677. =======
  1678. mesh_x : np.ndarray
  1679. Discretized x-domain.
  1680. mesh_y : np.ndarray
  1681. Discretized y-domain.
  1682. mesh_z : np.ndarray
  1683. Results of the evaluation.
  1684. """
  1685. np = import_module('numpy')
  1686. results = self._evaluate()
  1687. # mask out complex values
  1688. for i, r in enumerate(results):
  1689. _re, _im = np.real(r), np.imag(r)
  1690. _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan
  1691. results[i] = _re
  1692. x, y, z = results
  1693. if self.is_polar and self.is_3Dsurface:
  1694. r = x.copy()
  1695. x = r * np.cos(y)
  1696. y = r * np.sin(y)
  1697. # TODO: remove this
  1698. self._zlim = (np.amin(z), np.amax(z))
  1699. return self._apply_transform(x, y, z)
  1700. class ParametricSurfaceSeries(SurfaceBaseSeries):
  1701. """Representation for a 3D surface consisting of three parametric SymPy
  1702. expressions and a range."""
  1703. is_parametric = True
  1704. def __init__(self, expr_x, expr_y, expr_z,
  1705. var_start_end_u, var_start_end_v, label="", **kwargs):
  1706. super().__init__(**kwargs)
  1707. self.expr_x = expr_x if callable(expr_x) else sympify(expr_x)
  1708. self.expr_y = expr_y if callable(expr_y) else sympify(expr_y)
  1709. self.expr_z = expr_z if callable(expr_z) else sympify(expr_z)
  1710. self.expr = (self.expr_x, self.expr_y, self.expr_z)
  1711. self.ranges = [var_start_end_u, var_start_end_v]
  1712. self.color_func = kwargs.get("color_func", lambda x, y, z, u, v: z)
  1713. self._set_surface_label(label)
  1714. self._post_init()
  1715. @property
  1716. def var_u(self):
  1717. return self.ranges[0][0]
  1718. @property
  1719. def var_v(self):
  1720. return self.ranges[1][0]
  1721. @property
  1722. def start_u(self):
  1723. try:
  1724. return float(self.ranges[0][1])
  1725. except TypeError:
  1726. return self.ranges[0][1]
  1727. @property
  1728. def end_u(self):
  1729. try:
  1730. return float(self.ranges[0][2])
  1731. except TypeError:
  1732. return self.ranges[0][2]
  1733. @property
  1734. def start_v(self):
  1735. try:
  1736. return float(self.ranges[1][1])
  1737. except TypeError:
  1738. return self.ranges[1][1]
  1739. @property
  1740. def end_v(self):
  1741. try:
  1742. return float(self.ranges[1][2])
  1743. except TypeError:
  1744. return self.ranges[1][2]
  1745. @property
  1746. def nb_of_points_u(self):
  1747. return self.n[0]
  1748. @nb_of_points_u.setter
  1749. def nb_of_points_u(self, v):
  1750. n = self.n
  1751. self.n = [v, n[1:]]
  1752. @property
  1753. def nb_of_points_v(self):
  1754. return self.n[1]
  1755. @nb_of_points_v.setter
  1756. def nb_of_points_v(self, v):
  1757. n = self.n
  1758. self.n = [n[0], v, n[2]]
  1759. def __str__(self):
  1760. return self._str_helper(
  1761. "parametric cartesian surface: (%s, %s, %s) for"
  1762. " %s over %s and %s over %s" % (
  1763. str(self.expr_x), str(self.expr_y), str(self.expr_z),
  1764. str(self.var_u), str((self.start_u, self.end_u)),
  1765. str(self.var_v), str((self.start_v, self.end_v)),
  1766. ))
  1767. def get_parameter_meshes(self):
  1768. return self.get_data()[3:]
  1769. def get_meshes(self):
  1770. """Return the x,y,z coordinates for plotting the surface.
  1771. This function is available for back-compatibility purposes. Consider
  1772. using ``get_data()`` instead.
  1773. """
  1774. return self.get_data()[:3]
  1775. def get_data(self):
  1776. """Return arrays of coordinates for plotting.
  1777. Returns
  1778. =======
  1779. x : np.ndarray [n2 x n1]
  1780. x-coordinates.
  1781. y : np.ndarray [n2 x n1]
  1782. y-coordinates.
  1783. z : np.ndarray [n2 x n1]
  1784. z-coordinates.
  1785. mesh_u : np.ndarray [n2 x n1]
  1786. Discretized u range.
  1787. mesh_v : np.ndarray [n2 x n1]
  1788. Discretized v range.
  1789. """
  1790. np = import_module('numpy')
  1791. results = self._evaluate()
  1792. # mask out complex values
  1793. for i, r in enumerate(results):
  1794. _re, _im = np.real(r), np.imag(r)
  1795. _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan
  1796. results[i] = _re
  1797. # TODO: remove this
  1798. x, y, z = results[2:]
  1799. self._xlim = (np.amin(x), np.amax(x))
  1800. self._ylim = (np.amin(y), np.amax(y))
  1801. self._zlim = (np.amin(z), np.amax(z))
  1802. return self._apply_transform(*results[2:], *results[:2])
  1803. ### Contours
  1804. class ContourSeries(SurfaceOver2DRangeSeries):
  1805. """Representation for a contour plot."""
  1806. is_3Dsurface = False
  1807. is_contour = True
  1808. def __init__(self, *args, **kwargs):
  1809. super().__init__(*args, **kwargs)
  1810. self.is_filled = kwargs.get("is_filled", kwargs.get("fill", True))
  1811. self.show_clabels = kwargs.get("clabels", True)
  1812. # NOTE: contour plots are used by plot_contour, plot_vector and
  1813. # plot_complex_vector. By implementing contour_kw we are able to
  1814. # quickly target the contour plot.
  1815. self.rendering_kw = kwargs.get("contour_kw",
  1816. kwargs.get("rendering_kw", {}))
  1817. class GenericDataSeries(BaseSeries):
  1818. """Represents generic numerical data.
  1819. Notes
  1820. =====
  1821. This class serves the purpose of back-compatibility with the "markers,
  1822. annotations, fill, rectangles" keyword arguments that represent
  1823. user-provided numerical data. In particular, it solves the problem of
  1824. combining together two or more plot-objects with the ``extend`` or
  1825. ``append`` methods: user-provided numerical data is also taken into
  1826. consideration because it is stored in this series class.
  1827. Also note that the current implementation is far from optimal, as each
  1828. keyword argument is stored into an attribute in the ``Plot`` class, which
  1829. requires a hard-coded if-statement in the ``MatplotlibBackend`` class.
  1830. The implementation suggests that it is ok to add attributes and
  1831. if-statements to provide more and more functionalities for user-provided
  1832. numerical data (e.g. adding horizontal lines, or vertical lines, or bar
  1833. plots, etc). However, in doing so one would reinvent the wheel: plotting
  1834. libraries (like Matplotlib) already implements the necessary API.
  1835. Instead of adding more keyword arguments and attributes, users interested
  1836. in adding custom numerical data to a plot should retrieve the figure
  1837. created by this plotting module. For example, this code:
  1838. .. plot::
  1839. :context: close-figs
  1840. :include-source: True
  1841. from sympy import Symbol, plot, cos
  1842. x = Symbol("x")
  1843. p = plot(cos(x), markers=[{"args": [[0, 1, 2], [0, 1, -1], "*"]}])
  1844. Becomes:
  1845. .. plot::
  1846. :context: close-figs
  1847. :include-source: True
  1848. p = plot(cos(x), backend="matplotlib")
  1849. fig, ax = p._backend.fig, p._backend.ax
  1850. ax.plot([0, 1, 2], [0, 1, -1], "*")
  1851. fig
  1852. Which is far better in terms of readability. Also, it gives access to the
  1853. full plotting library capabilities, without the need to reinvent the wheel.
  1854. """
  1855. is_generic = True
  1856. def __init__(self, tp, *args, **kwargs):
  1857. self.type = tp
  1858. self.args = args
  1859. self.rendering_kw = kwargs
  1860. def get_data(self):
  1861. return self.args
  1862. class ImplicitSeries(BaseSeries):
  1863. """Representation for 2D Implicit plot."""
  1864. is_implicit = True
  1865. use_cm = False
  1866. _N = 100
  1867. def __init__(self, expr, var_start_end_x, var_start_end_y, label="", **kwargs):
  1868. super().__init__(**kwargs)
  1869. self.adaptive = kwargs.get("adaptive", False)
  1870. self.expr = expr
  1871. self._label = str(expr) if label is None else label
  1872. self._latex_label = latex(expr) if label is None else label
  1873. self.ranges = [var_start_end_x, var_start_end_y]
  1874. self.var_x, self.start_x, self.end_x = self.ranges[0]
  1875. self.var_y, self.start_y, self.end_y = self.ranges[1]
  1876. self._color = kwargs.get("color", kwargs.get("line_color", None))
  1877. if self.is_interactive and self.adaptive:
  1878. raise NotImplementedError("Interactive plot with `adaptive=True` "
  1879. "is not supported.")
  1880. # Check whether the depth is greater than 4 or less than 0.
  1881. depth = kwargs.get("depth", 0)
  1882. if depth > 4:
  1883. depth = 4
  1884. elif depth < 0:
  1885. depth = 0
  1886. self.depth = 4 + depth
  1887. self._post_init()
  1888. @property
  1889. def expr(self):
  1890. if self.adaptive:
  1891. return self._adaptive_expr
  1892. return self._non_adaptive_expr
  1893. @expr.setter
  1894. def expr(self, expr):
  1895. self._block_lambda_functions(expr)
  1896. # these are needed for adaptive evaluation
  1897. expr, has_equality = self._has_equality(sympify(expr))
  1898. self._adaptive_expr = expr
  1899. self.has_equality = has_equality
  1900. self._label = str(expr)
  1901. self._latex_label = latex(expr)
  1902. if isinstance(expr, (BooleanFunction, Ne)) and (not self.adaptive):
  1903. self.adaptive = True
  1904. msg = "contains Boolean functions. "
  1905. if isinstance(expr, Ne):
  1906. msg = "is an unequality. "
  1907. warnings.warn(
  1908. "The provided expression " + msg
  1909. + "In order to plot the expression, the algorithm "
  1910. + "automatically switched to an adaptive sampling."
  1911. )
  1912. if isinstance(expr, BooleanFunction):
  1913. self._non_adaptive_expr = None
  1914. self._is_equality = False
  1915. else:
  1916. # these are needed for uniform meshing evaluation
  1917. expr, is_equality = self._preprocess_meshgrid_expression(expr, self.adaptive)
  1918. self._non_adaptive_expr = expr
  1919. self._is_equality = is_equality
  1920. @property
  1921. def line_color(self):
  1922. return self._color
  1923. @line_color.setter
  1924. def line_color(self, v):
  1925. self._color = v
  1926. color = line_color
  1927. def _has_equality(self, expr):
  1928. # Represents whether the expression contains an Equality, GreaterThan
  1929. # or LessThan
  1930. has_equality = False
  1931. def arg_expand(bool_expr):
  1932. """Recursively expands the arguments of an Boolean Function"""
  1933. for arg in bool_expr.args:
  1934. if isinstance(arg, BooleanFunction):
  1935. arg_expand(arg)
  1936. elif isinstance(arg, Relational):
  1937. arg_list.append(arg)
  1938. arg_list = []
  1939. if isinstance(expr, BooleanFunction):
  1940. arg_expand(expr)
  1941. # Check whether there is an equality in the expression provided.
  1942. if any(isinstance(e, (Equality, GreaterThan, LessThan)) for e in arg_list):
  1943. has_equality = True
  1944. elif not isinstance(expr, Relational):
  1945. expr = Equality(expr, 0)
  1946. has_equality = True
  1947. elif isinstance(expr, (Equality, GreaterThan, LessThan)):
  1948. has_equality = True
  1949. return expr, has_equality
  1950. def __str__(self):
  1951. f = lambda t: float(t) if len(t.free_symbols) == 0 else t
  1952. return self._str_helper(
  1953. "Implicit expression: %s for %s over %s and %s over %s") % (
  1954. str(self._adaptive_expr),
  1955. str(self.var_x),
  1956. str((f(self.start_x), f(self.end_x))),
  1957. str(self.var_y),
  1958. str((f(self.start_y), f(self.end_y))),
  1959. )
  1960. def get_data(self):
  1961. """Returns numerical data.
  1962. Returns
  1963. =======
  1964. If the series is evaluated with the `adaptive=True` it returns:
  1965. interval_list : list
  1966. List of bounding rectangular intervals to be postprocessed and
  1967. eventually used with Matplotlib's ``fill`` command.
  1968. dummy : str
  1969. A string containing ``"fill"``.
  1970. Otherwise, it returns 2D numpy arrays to be used with Matplotlib's
  1971. ``contour`` or ``contourf`` commands:
  1972. x_array : np.ndarray
  1973. y_array : np.ndarray
  1974. z_array : np.ndarray
  1975. plot_type : str
  1976. A string specifying which plot command to use, ``"contour"``
  1977. or ``"contourf"``.
  1978. """
  1979. if self.adaptive:
  1980. data = self._adaptive_eval()
  1981. if data is not None:
  1982. return data
  1983. return self._get_meshes_grid()
  1984. def _adaptive_eval(self):
  1985. """
  1986. References
  1987. ==========
  1988. .. [1] Jeffrey Allen Tupper. Reliable Two-Dimensional Graphing Methods for
  1989. Mathematical Formulae with Two Free Variables.
  1990. .. [2] Jeffrey Allen Tupper. Graphing Equations with Generalized Interval
  1991. Arithmetic. Master's thesis. University of Toronto, 1996
  1992. """
  1993. import sympy.plotting.intervalmath.lib_interval as li
  1994. user_functions = {}
  1995. printer = IntervalMathPrinter({
  1996. 'fully_qualified_modules': False, 'inline': True,
  1997. 'allow_unknown_functions': True,
  1998. 'user_functions': user_functions})
  1999. keys = [t for t in dir(li) if ("__" not in t) and (t not in ["import_module", "interval"])]
  2000. vals = [getattr(li, k) for k in keys]
  2001. d = dict(zip(keys, vals))
  2002. func = lambdify((self.var_x, self.var_y), self.expr, modules=[d], printer=printer)
  2003. data = None
  2004. try:
  2005. data = self._get_raster_interval(func)
  2006. except NameError as err:
  2007. warnings.warn(
  2008. "Adaptive meshing could not be applied to the"
  2009. " expression, as some functions are not yet implemented"
  2010. " in the interval math module:\n\n"
  2011. "NameError: %s\n\n" % err +
  2012. "Proceeding with uniform meshing."
  2013. )
  2014. self.adaptive = False
  2015. except TypeError:
  2016. warnings.warn(
  2017. "Adaptive meshing could not be applied to the"
  2018. " expression. Using uniform meshing.")
  2019. self.adaptive = False
  2020. return data
  2021. def _get_raster_interval(self, func):
  2022. """Uses interval math to adaptively mesh and obtain the plot"""
  2023. np = import_module('numpy')
  2024. k = self.depth
  2025. interval_list = []
  2026. sx, sy = [float(t) for t in [self.start_x, self.start_y]]
  2027. ex, ey = [float(t) for t in [self.end_x, self.end_y]]
  2028. # Create initial 32 divisions
  2029. xsample = np.linspace(sx, ex, 33)
  2030. ysample = np.linspace(sy, ey, 33)
  2031. # Add a small jitter so that there are no false positives for equality.
  2032. # Ex: y==x becomes True for x interval(1, 2) and y interval(1, 2)
  2033. # which will draw a rectangle.
  2034. jitterx = (
  2035. (np.random.rand(len(xsample)) * 2 - 1)
  2036. * (ex - sx)
  2037. / 2 ** 20
  2038. )
  2039. jittery = (
  2040. (np.random.rand(len(ysample)) * 2 - 1)
  2041. * (ey - sy)
  2042. / 2 ** 20
  2043. )
  2044. xsample += jitterx
  2045. ysample += jittery
  2046. xinter = [interval(x1, x2) for x1, x2 in zip(xsample[:-1], xsample[1:])]
  2047. yinter = [interval(y1, y2) for y1, y2 in zip(ysample[:-1], ysample[1:])]
  2048. interval_list = [[x, y] for x in xinter for y in yinter]
  2049. plot_list = []
  2050. # recursive call refinepixels which subdivides the intervals which are
  2051. # neither True nor False according to the expression.
  2052. def refine_pixels(interval_list):
  2053. """Evaluates the intervals and subdivides the interval if the
  2054. expression is partially satisfied."""
  2055. temp_interval_list = []
  2056. plot_list = []
  2057. for intervals in interval_list:
  2058. # Convert the array indices to x and y values
  2059. intervalx = intervals[0]
  2060. intervaly = intervals[1]
  2061. func_eval = func(intervalx, intervaly)
  2062. # The expression is valid in the interval. Change the contour
  2063. # array values to 1.
  2064. if func_eval[1] is False or func_eval[0] is False:
  2065. pass
  2066. elif func_eval == (True, True):
  2067. plot_list.append([intervalx, intervaly])
  2068. elif func_eval[1] is None or func_eval[0] is None:
  2069. # Subdivide
  2070. avgx = intervalx.mid
  2071. avgy = intervaly.mid
  2072. a = interval(intervalx.start, avgx)
  2073. b = interval(avgx, intervalx.end)
  2074. c = interval(intervaly.start, avgy)
  2075. d = interval(avgy, intervaly.end)
  2076. temp_interval_list.append([a, c])
  2077. temp_interval_list.append([a, d])
  2078. temp_interval_list.append([b, c])
  2079. temp_interval_list.append([b, d])
  2080. return temp_interval_list, plot_list
  2081. while k >= 0 and len(interval_list):
  2082. interval_list, plot_list_temp = refine_pixels(interval_list)
  2083. plot_list.extend(plot_list_temp)
  2084. k = k - 1
  2085. # Check whether the expression represents an equality
  2086. # If it represents an equality, then none of the intervals
  2087. # would have satisfied the expression due to floating point
  2088. # differences. Add all the undecided values to the plot.
  2089. if self.has_equality:
  2090. for intervals in interval_list:
  2091. intervalx = intervals[0]
  2092. intervaly = intervals[1]
  2093. func_eval = func(intervalx, intervaly)
  2094. if func_eval[1] and func_eval[0] is not False:
  2095. plot_list.append([intervalx, intervaly])
  2096. return plot_list, "fill"
  2097. def _get_meshes_grid(self):
  2098. """Generates the mesh for generating a contour.
  2099. In the case of equality, ``contour`` function of matplotlib can
  2100. be used. In other cases, matplotlib's ``contourf`` is used.
  2101. """
  2102. np = import_module('numpy')
  2103. xarray, yarray, z_grid = self._evaluate()
  2104. _re, _im = np.real(z_grid), np.imag(z_grid)
  2105. _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan
  2106. if self._is_equality:
  2107. return xarray, yarray, _re, 'contour'
  2108. return xarray, yarray, _re, 'contourf'
  2109. @staticmethod
  2110. def _preprocess_meshgrid_expression(expr, adaptive):
  2111. """If the expression is a Relational, rewrite it as a single
  2112. expression.
  2113. Returns
  2114. =======
  2115. expr : Expr
  2116. The rewritten expression
  2117. equality : Boolean
  2118. Whether the original expression was an Equality or not.
  2119. """
  2120. equality = False
  2121. if isinstance(expr, Equality):
  2122. expr = expr.lhs - expr.rhs
  2123. equality = True
  2124. elif isinstance(expr, Relational):
  2125. expr = expr.gts - expr.lts
  2126. elif not adaptive:
  2127. raise NotImplementedError(
  2128. "The expression is not supported for "
  2129. "plotting in uniform meshed plot."
  2130. )
  2131. return expr, equality
  2132. def get_label(self, use_latex=False, wrapper="$%s$"):
  2133. """Return the label to be used to display the expression.
  2134. Parameters
  2135. ==========
  2136. use_latex : bool
  2137. If False, the string representation of the expression is returned.
  2138. If True, the latex representation is returned.
  2139. wrapper : str
  2140. The backend might need the latex representation to be wrapped by
  2141. some characters. Default to ``"$%s$"``.
  2142. Returns
  2143. =======
  2144. label : str
  2145. """
  2146. if use_latex is False:
  2147. return self._label
  2148. if self._label == str(self._adaptive_expr):
  2149. return self._get_wrapped_label(self._latex_label, wrapper)
  2150. return self._latex_label
  2151. ##############################################################################
  2152. # Finding the centers of line segments or mesh faces
  2153. ##############################################################################
  2154. def centers_of_segments(array):
  2155. np = import_module('numpy')
  2156. return np.mean(np.vstack((array[:-1], array[1:])), 0)
  2157. def centers_of_faces(array):
  2158. np = import_module('numpy')
  2159. return np.mean(np.dstack((array[:-1, :-1],
  2160. array[1:, :-1],
  2161. array[:-1, 1:],
  2162. array[:-1, :-1],
  2163. )), 2)
  2164. def flat(x, y, z, eps=1e-3):
  2165. """Checks whether three points are almost collinear"""
  2166. np = import_module('numpy')
  2167. # Workaround plotting piecewise (#8577)
  2168. vector_a = (x - y).astype(float)
  2169. vector_b = (z - y).astype(float)
  2170. dot_product = np.dot(vector_a, vector_b)
  2171. vector_a_norm = np.linalg.norm(vector_a)
  2172. vector_b_norm = np.linalg.norm(vector_b)
  2173. cos_theta = dot_product / (vector_a_norm * vector_b_norm)
  2174. return abs(cos_theta + 1) < eps
  2175. def _set_discretization_points(kwargs, pt):
  2176. """Allow the use of the keyword arguments ``n, n1, n2`` to
  2177. specify the number of discretization points in one and two
  2178. directions, while keeping back-compatibility with older keyword arguments
  2179. like, ``nb_of_points, nb_of_points_*, points``.
  2180. Parameters
  2181. ==========
  2182. kwargs : dict
  2183. Dictionary of keyword arguments passed into a plotting function.
  2184. pt : type
  2185. The type of the series, which indicates the kind of plot we are
  2186. trying to create.
  2187. """
  2188. replace_old_keywords = {
  2189. "nb_of_points": "n",
  2190. "nb_of_points_x": "n1",
  2191. "nb_of_points_y": "n2",
  2192. "nb_of_points_u": "n1",
  2193. "nb_of_points_v": "n2",
  2194. "points": "n"
  2195. }
  2196. for k, v in replace_old_keywords.items():
  2197. if k in kwargs.keys():
  2198. kwargs[v] = kwargs.pop(k)
  2199. if pt in [LineOver1DRangeSeries, Parametric2DLineSeries,
  2200. Parametric3DLineSeries]:
  2201. if "n" in kwargs.keys():
  2202. kwargs["n1"] = kwargs["n"]
  2203. if hasattr(kwargs["n"], "__iter__") and (len(kwargs["n"]) > 0):
  2204. kwargs["n1"] = kwargs["n"][0]
  2205. elif pt in [SurfaceOver2DRangeSeries, ContourSeries,
  2206. ParametricSurfaceSeries, ImplicitSeries]:
  2207. if "n" in kwargs.keys():
  2208. if hasattr(kwargs["n"], "__iter__") and (len(kwargs["n"]) > 1):
  2209. kwargs["n1"] = kwargs["n"][0]
  2210. kwargs["n2"] = kwargs["n"][1]
  2211. else:
  2212. kwargs["n1"] = kwargs["n2"] = kwargs["n"]
  2213. return kwargs