| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- from sympy.core.containers import Tuple
- from sympy.core.basic import Basic
- from sympy.core.expr import Expr
- from sympy.core.function import AppliedUndef
- from sympy.core.relational import Relational
- from sympy.core.symbol import Dummy
- from sympy.core.sympify import sympify
- from sympy.logic.boolalg import BooleanFunction
- from sympy.sets.fancysets import ImageSet
- from sympy.sets.sets import FiniteSet
- from sympy.tensor.indexed import Indexed
- def _get_free_symbols(exprs):
- """Returns the free symbols of a symbolic expression.
- If the expression contains any of these elements, assume that they are
- the "free symbols" of the expression:
- * indexed objects
- * applied undefined function (useful for sympy.physics.mechanics module)
- """
- if not isinstance(exprs, (list, tuple, set)):
- exprs = [exprs]
- if all(callable(e) for e in exprs):
- return set()
- free = set().union(*[e.atoms(Indexed) for e in exprs])
- free = free.union(*[e.atoms(AppliedUndef) for e in exprs])
- return free or set().union(*[e.free_symbols for e in exprs])
- def extract_solution(set_sol, n=10):
- """Extract numerical solutions from a set solution (computed by solveset,
- linsolve, nonlinsolve). Often, it is not trivial do get something useful
- out of them.
- Parameters
- ==========
- n : int, optional
- In order to replace ImageSet with FiniteSet, an iterator is created
- for each ImageSet contained in `set_sol`, starting from 0 up to `n`.
- Default value: 10.
- """
- images = set_sol.find(ImageSet)
- for im in images:
- it = iter(im)
- s = FiniteSet(*[next(it) for n in range(0, n)])
- set_sol = set_sol.subs(im, s)
- return set_sol
- def _plot_sympify(args):
- """This function recursively loop over the arguments passed to the plot
- functions: the sympify function will be applied to all arguments except
- those of type string/dict.
- Generally, users can provide the following arguments to a plot function:
- expr, range1 [tuple, opt], ..., label [str, opt], rendering_kw [dict, opt]
- `expr, range1, ...` can be sympified, whereas `label, rendering_kw` can't.
- In particular, whenever a special character like $, {, }, ... is used in
- the `label`, sympify will raise an error.
- """
- if isinstance(args, Expr):
- return args
- args = list(args)
- for i, a in enumerate(args):
- if isinstance(a, (list, tuple)):
- args[i] = Tuple(*_plot_sympify(a), sympify=False)
- elif not (isinstance(a, (str, dict)) or callable(a)
- # NOTE: check if it is a vector from sympy.physics.vector module
- # without importing the module (because it slows down SymPy's
- # import process and triggers SymPy's optional-dependencies
- # tests to fail).
- or ((a.__class__.__name__ == "Vector") and not isinstance(a, Basic))
- ):
- args[i] = sympify(a)
- return args
- def _create_ranges(exprs, ranges, npar, label="", params=None):
- """This function does two things:
- 1. Check if the number of free symbols is in agreement with the type of
- plot chosen. For example, plot() requires 1 free symbol;
- plot3d() requires 2 free symbols.
- 2. Sometime users create plots without providing ranges for the variables.
- Here we create the necessary ranges.
- Parameters
- ==========
- exprs : iterable
- The expressions from which to extract the free symbols
- ranges : iterable
- The limiting ranges provided by the user
- npar : int
- The number of free symbols required by the plot functions.
- For example,
- npar=1 for plot, npar=2 for plot3d, ...
- params : dict
- A dictionary mapping symbols to parameters for interactive plot.
- """
- get_default_range = lambda symbol: Tuple(symbol, -10, 10)
- free_symbols = _get_free_symbols(exprs)
- if params is not None:
- free_symbols = free_symbols.difference(params.keys())
- if len(free_symbols) > npar:
- raise ValueError(
- "Too many free symbols.\n"
- + "Expected {} free symbols.\n".format(npar)
- + "Received {}: {}".format(len(free_symbols), free_symbols)
- )
- if len(ranges) > npar:
- raise ValueError(
- "Too many ranges. Received %s, expected %s" % (len(ranges), npar))
- # free symbols in the ranges provided by the user
- rfs = set().union([r[0] for r in ranges])
- if len(rfs) != len(ranges):
- raise ValueError("Multiple ranges with the same symbol")
- if len(ranges) < npar:
- symbols = free_symbols.difference(rfs)
- if symbols != set():
- # add a range for each missing free symbols
- for s in symbols:
- ranges.append(get_default_range(s))
- # if there is still room, fill them with dummys
- for i in range(npar - len(ranges)):
- ranges.append(get_default_range(Dummy()))
- if len(free_symbols) == npar:
- # there could be times when this condition is not met, for example
- # plotting the function f(x, y) = x (which is a plane); in this case,
- # free_symbols = {x} whereas rfs = {x, y} (or x and Dummy)
- rfs = set().union([r[0] for r in ranges])
- if len(free_symbols.difference(rfs)) > 0:
- raise ValueError(
- "Incompatible free symbols of the expressions with "
- "the ranges.\n"
- + "Free symbols in the expressions: {}\n".format(free_symbols)
- + "Free symbols in the ranges: {}".format(rfs)
- )
- return ranges
- def _is_range(r):
- """A range is defined as (symbol, start, end). start and end should
- be numbers.
- """
- # TODO: prange check goes here
- return (
- isinstance(r, Tuple)
- and (len(r) == 3)
- and (not isinstance(r.args[1], str)) and r.args[1].is_number
- and (not isinstance(r.args[2], str)) and r.args[2].is_number
- )
- def _unpack_args(*args):
- """Given a list/tuple of arguments previously processed by _plot_sympify()
- and/or _check_arguments(), separates and returns its components:
- expressions, ranges, label and rendering keywords.
- Examples
- ========
- >>> from sympy import cos, sin, symbols
- >>> from sympy.plotting.utils import _plot_sympify, _unpack_args
- >>> x, y = symbols('x, y')
- >>> args = (sin(x), (x, -10, 10), "f1")
- >>> args = _plot_sympify(args)
- >>> _unpack_args(*args)
- ([sin(x)], [(x, -10, 10)], 'f1', None)
- >>> args = (sin(x**2 + y**2), (x, -2, 2), (y, -3, 3), "f2")
- >>> args = _plot_sympify(args)
- >>> _unpack_args(*args)
- ([sin(x**2 + y**2)], [(x, -2, 2), (y, -3, 3)], 'f2', None)
- >>> args = (sin(x + y), cos(x - y), x + y, (x, -2, 2), (y, -3, 3), "f3")
- >>> args = _plot_sympify(args)
- >>> _unpack_args(*args)
- ([sin(x + y), cos(x - y), x + y], [(x, -2, 2), (y, -3, 3)], 'f3', None)
- """
- ranges = [t for t in args if _is_range(t)]
- labels = [t for t in args if isinstance(t, str)]
- label = None if not labels else labels[0]
- rendering_kw = [t for t in args if isinstance(t, dict)]
- rendering_kw = None if not rendering_kw else rendering_kw[0]
- # NOTE: why None? because args might have been preprocessed by
- # _check_arguments, so None might represent the rendering_kw
- results = [not (_is_range(a) or isinstance(a, (str, dict)) or (a is None)) for a in args]
- exprs = [a for a, b in zip(args, results) if b]
- return exprs, ranges, label, rendering_kw
- def _check_arguments(args, nexpr, npar, **kwargs):
- """Checks the arguments and converts into tuples of the
- form (exprs, ranges, label, rendering_kw).
- Parameters
- ==========
- args
- The arguments provided to the plot functions
- nexpr
- The number of sub-expression forming an expression to be plotted.
- For example:
- nexpr=1 for plot.
- nexpr=2 for plot_parametric: a curve is represented by a tuple of two
- elements.
- nexpr=1 for plot3d.
- nexpr=3 for plot3d_parametric_line: a curve is represented by a tuple
- of three elements.
- npar
- The number of free symbols required by the plot functions. For example,
- npar=1 for plot, npar=2 for plot3d, ...
- **kwargs :
- keyword arguments passed to the plotting function. It will be used to
- verify if ``params`` has ben provided.
- Examples
- ========
- .. plot::
- :context: reset
- :format: doctest
- :include-source: True
- >>> from sympy import cos, sin, symbols
- >>> from sympy.plotting.plot import _check_arguments
- >>> x = symbols('x')
- >>> _check_arguments([cos(x), sin(x)], 2, 1)
- [(cos(x), sin(x), (x, -10, 10), None, None)]
- >>> _check_arguments([cos(x), sin(x), "test"], 2, 1)
- [(cos(x), sin(x), (x, -10, 10), 'test', None)]
- >>> _check_arguments([cos(x), sin(x), "test", {"a": 0, "b": 1}], 2, 1)
- [(cos(x), sin(x), (x, -10, 10), 'test', {'a': 0, 'b': 1})]
- >>> _check_arguments([x, x**2], 1, 1)
- [(x, (x, -10, 10), None, None), (x**2, (x, -10, 10), None, None)]
- """
- if not args:
- return []
- output = []
- params = kwargs.get("params", None)
- if all(isinstance(a, (Expr, Relational, BooleanFunction)) for a in args[:nexpr]):
- # In this case, with a single plot command, we are plotting either:
- # 1. one expression
- # 2. multiple expressions over the same range
- exprs, ranges, label, rendering_kw = _unpack_args(*args)
- free_symbols = set().union(*[e.free_symbols for e in exprs])
- ranges = _create_ranges(exprs, ranges, npar, label, params)
- if nexpr > 1:
- # in case of plot_parametric or plot3d_parametric_line, there will
- # be 2 or 3 expressions defining a curve. Group them together.
- if len(exprs) == nexpr:
- exprs = (tuple(exprs),)
- for expr in exprs:
- # need this if-else to deal with both plot/plot3d and
- # plot_parametric/plot3d_parametric_line
- is_expr = isinstance(expr, (Expr, Relational, BooleanFunction))
- e = (expr,) if is_expr else expr
- output.append((*e, *ranges, label, rendering_kw))
- else:
- # In this case, we are plotting multiple expressions, each one with its
- # range. Each "expression" to be plotted has the following form:
- # (expr, range, label) where label is optional
- _, ranges, labels, rendering_kw = _unpack_args(*args)
- labels = [labels] if labels else []
- # number of expressions
- n = (len(ranges) + len(labels) +
- (len(rendering_kw) if rendering_kw is not None else 0))
- new_args = args[:-n] if n > 0 else args
- # at this point, new_args might just be [expr]. But I need it to be
- # [[expr]] in order to be able to loop over
- # [expr, range [opt], label [opt]]
- if not isinstance(new_args[0], (list, tuple, Tuple)):
- new_args = [new_args]
- # Each arg has the form (expr1, expr2, ..., range1 [optional], ...,
- # label [optional], rendering_kw [optional])
- for arg in new_args:
- # look for "local" range and label. If there is not, use "global".
- l = [a for a in arg if isinstance(a, str)]
- if not l:
- l = labels
- r = [a for a in arg if _is_range(a)]
- if not r:
- r = ranges.copy()
- rend_kw = [a for a in arg if isinstance(a, dict)]
- rend_kw = rendering_kw if len(rend_kw) == 0 else rend_kw[0]
- # NOTE: arg = arg[:nexpr] may raise an exception if lambda
- # functions are used. Execute the following instead:
- arg = [arg[i] for i in range(nexpr)]
- free_symbols = set()
- if all(not callable(a) for a in arg):
- free_symbols = free_symbols.union(*[a.free_symbols for a in arg])
- if len(r) != npar:
- r = _create_ranges(arg, r, npar, "", params)
- label = None if not l else l[0]
- output.append((*arg, *r, label, rend_kw))
- return output
|