| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318 |
- from collections.abc import Callable
- from sympy.core.basic import Basic
- from sympy.external import import_module
- import sympy.plotting.backends.base_backend as base_backend
- from sympy.printing.latex import latex
- # N.B.
- # When changing the minimum module version for matplotlib, please change
- # the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py`
- def _str_or_latex(label):
- if isinstance(label, Basic):
- return latex(label, mode='inline')
- return str(label)
- def _matplotlib_list(interval_list):
- """
- Returns lists for matplotlib ``fill`` command from a list of bounding
- rectangular intervals
- """
- xlist = []
- ylist = []
- if len(interval_list):
- for intervals in interval_list:
- intervalx = intervals[0]
- intervaly = intervals[1]
- xlist.extend([intervalx.start, intervalx.start,
- intervalx.end, intervalx.end, None])
- ylist.extend([intervaly.start, intervaly.end,
- intervaly.end, intervaly.start, None])
- else:
- #XXX Ugly hack. Matplotlib does not accept empty lists for ``fill``
- xlist.extend((None, None, None, None))
- ylist.extend((None, None, None, None))
- return xlist, ylist
- # Don't have to check for the success of importing matplotlib in each case;
- # we will only be using this backend if we can successfully import matploblib
- class MatplotlibBackend(base_backend.Plot):
- """ This class implements the functionalities to use Matplotlib with SymPy
- plotting functions.
- """
- def __init__(self, *series, **kwargs):
- super().__init__(*series, **kwargs)
- self.matplotlib = import_module('matplotlib',
- import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']},
- min_module_version='1.1.0', catch=(RuntimeError,))
- self.plt = self.matplotlib.pyplot
- self.cm = self.matplotlib.cm
- self.LineCollection = self.matplotlib.collections.LineCollection
- self.aspect = kwargs.get('aspect_ratio', 'auto')
- if self.aspect != 'auto':
- self.aspect = float(self.aspect[1]) / self.aspect[0]
- # PlotGrid can provide its figure and axes to be populated with
- # the data from the series.
- self._plotgrid_fig = kwargs.pop("fig", None)
- self._plotgrid_ax = kwargs.pop("ax", None)
- def _create_figure(self):
- def set_spines(ax):
- ax.spines['left'].set_position('zero')
- ax.spines['right'].set_color('none')
- ax.spines['bottom'].set_position('zero')
- ax.spines['top'].set_color('none')
- ax.xaxis.set_ticks_position('bottom')
- ax.yaxis.set_ticks_position('left')
- if self._plotgrid_fig is not None:
- self.fig = self._plotgrid_fig
- self.ax = self._plotgrid_ax
- if not any(s.is_3D for s in self._series):
- set_spines(self.ax)
- else:
- self.fig = self.plt.figure(figsize=self.size)
- if any(s.is_3D for s in self._series):
- self.ax = self.fig.add_subplot(1, 1, 1, projection="3d")
- else:
- self.ax = self.fig.add_subplot(1, 1, 1)
- set_spines(self.ax)
- @staticmethod
- def get_segments(x, y, z=None):
- """ Convert two list of coordinates to a list of segments to be used
- with Matplotlib's :external:class:`~matplotlib.collections.LineCollection`.
- Parameters
- ==========
- x : list
- List of x-coordinates
- y : list
- List of y-coordinates
- z : list
- List of z-coordinates for a 3D line.
- """
- np = import_module('numpy')
- if z is not None:
- dim = 3
- points = (x, y, z)
- else:
- dim = 2
- points = (x, y)
- points = np.ma.array(points).T.reshape(-1, 1, dim)
- return np.ma.concatenate([points[:-1], points[1:]], axis=1)
- def _process_series(self, series, ax):
- np = import_module('numpy')
- mpl_toolkits = import_module(
- 'mpl_toolkits', import_kwargs={'fromlist': ['mplot3d']})
- # XXX Workaround for matplotlib issue
- # https://github.com/matplotlib/matplotlib/issues/17130
- xlims, ylims, zlims = [], [], []
- for s in series:
- # Create the collections
- if s.is_2Dline:
- if s.is_parametric:
- x, y, param = s.get_data()
- else:
- x, y = s.get_data()
- if (isinstance(s.line_color, (int, float)) or
- callable(s.line_color)):
- segments = self.get_segments(x, y)
- collection = self.LineCollection(segments)
- collection.set_array(s.get_color_array())
- ax.add_collection(collection)
- else:
- lbl = _str_or_latex(s.label)
- line, = ax.plot(x, y, label=lbl, color=s.line_color)
- elif s.is_contour:
- ax.contour(*s.get_data())
- elif s.is_3Dline:
- x, y, z, param = s.get_data()
- if (isinstance(s.line_color, (int, float)) or
- callable(s.line_color)):
- art3d = mpl_toolkits.mplot3d.art3d
- segments = self.get_segments(x, y, z)
- collection = art3d.Line3DCollection(segments)
- collection.set_array(s.get_color_array())
- ax.add_collection(collection)
- else:
- lbl = _str_or_latex(s.label)
- ax.plot(x, y, z, label=lbl, color=s.line_color)
- xlims.append(s._xlim)
- ylims.append(s._ylim)
- zlims.append(s._zlim)
- elif s.is_3Dsurface:
- if s.is_parametric:
- x, y, z, u, v = s.get_data()
- else:
- x, y, z = s.get_data()
- collection = ax.plot_surface(x, y, z,
- cmap=getattr(self.cm, 'viridis', self.cm.jet),
- rstride=1, cstride=1, linewidth=0.1)
- if isinstance(s.surface_color, (float, int, Callable)):
- color_array = s.get_color_array()
- color_array = color_array.reshape(color_array.size)
- collection.set_array(color_array)
- else:
- collection.set_color(s.surface_color)
- xlims.append(s._xlim)
- ylims.append(s._ylim)
- zlims.append(s._zlim)
- elif s.is_implicit:
- points = s.get_data()
- if len(points) == 2:
- # interval math plotting
- x, y = _matplotlib_list(points[0])
- ax.fill(x, y, facecolor=s.line_color, edgecolor='None')
- else:
- # use contourf or contour depending on whether it is
- # an inequality or equality.
- # XXX: ``contour`` plots multiple lines. Should be fixed.
- ListedColormap = self.matplotlib.colors.ListedColormap
- colormap = ListedColormap(["white", s.line_color])
- xarray, yarray, zarray, plot_type = points
- if plot_type == 'contour':
- ax.contour(xarray, yarray, zarray, cmap=colormap)
- else:
- ax.contourf(xarray, yarray, zarray, cmap=colormap)
- elif s.is_generic:
- if s.type == "markers":
- # s.rendering_kw["color"] = s.line_color
- ax.plot(*s.args, **s.rendering_kw)
- elif s.type == "annotations":
- ax.annotate(*s.args, **s.rendering_kw)
- elif s.type == "fill":
- # s.rendering_kw["color"] = s.line_color
- ax.fill_between(*s.args, **s.rendering_kw)
- elif s.type == "rectangles":
- # s.rendering_kw["color"] = s.line_color
- ax.add_patch(
- self.matplotlib.patches.Rectangle(
- *s.args, **s.rendering_kw))
- else:
- raise NotImplementedError(
- '{} is not supported in the SymPy plotting module '
- 'with matplotlib backend. Please report this issue.'
- .format(ax))
- Axes3D = mpl_toolkits.mplot3d.Axes3D
- if not isinstance(ax, Axes3D):
- ax.autoscale_view(
- scalex=ax.get_autoscalex_on(),
- scaley=ax.get_autoscaley_on())
- else:
- # XXX Workaround for matplotlib issue
- # https://github.com/matplotlib/matplotlib/issues/17130
- if xlims:
- xlims = np.array(xlims)
- xlim = (np.amin(xlims[:, 0]), np.amax(xlims[:, 1]))
- ax.set_xlim(xlim)
- else:
- ax.set_xlim([0, 1])
- if ylims:
- ylims = np.array(ylims)
- ylim = (np.amin(ylims[:, 0]), np.amax(ylims[:, 1]))
- ax.set_ylim(ylim)
- else:
- ax.set_ylim([0, 1])
- if zlims:
- zlims = np.array(zlims)
- zlim = (np.amin(zlims[:, 0]), np.amax(zlims[:, 1]))
- ax.set_zlim(zlim)
- else:
- ax.set_zlim([0, 1])
- # Set global options.
- # TODO The 3D stuff
- # XXX The order of those is important.
- if self.xscale and not isinstance(ax, Axes3D):
- ax.set_xscale(self.xscale)
- if self.yscale and not isinstance(ax, Axes3D):
- ax.set_yscale(self.yscale)
- if not isinstance(ax, Axes3D) or self.matplotlib.__version__ >= '1.2.0': # XXX in the distant future remove this check
- ax.set_autoscale_on(self.autoscale)
- if self.axis_center:
- val = self.axis_center
- if isinstance(ax, Axes3D):
- pass
- elif val == 'center':
- ax.spines['left'].set_position('center')
- ax.spines['bottom'].set_position('center')
- elif val == 'auto':
- xl, xh = ax.get_xlim()
- yl, yh = ax.get_ylim()
- pos_left = ('data', 0) if xl*xh <= 0 else 'center'
- pos_bottom = ('data', 0) if yl*yh <= 0 else 'center'
- ax.spines['left'].set_position(pos_left)
- ax.spines['bottom'].set_position(pos_bottom)
- else:
- ax.spines['left'].set_position(('data', val[0]))
- ax.spines['bottom'].set_position(('data', val[1]))
- if not self.axis:
- ax.set_axis_off()
- if self.legend:
- if ax.legend():
- ax.legend_.set_visible(self.legend)
- if self.margin:
- ax.set_xmargin(self.margin)
- ax.set_ymargin(self.margin)
- if self.title:
- ax.set_title(self.title)
- if self.xlabel:
- xlbl = _str_or_latex(self.xlabel)
- ax.set_xlabel(xlbl, position=(1, 0))
- if self.ylabel:
- ylbl = _str_or_latex(self.ylabel)
- ax.set_ylabel(ylbl, position=(0, 1))
- if isinstance(ax, Axes3D) and self.zlabel:
- zlbl = _str_or_latex(self.zlabel)
- ax.set_zlabel(zlbl, position=(0, 1))
- # xlim and ylim should always be set at last so that plot limits
- # doesn't get altered during the process.
- if self.xlim:
- ax.set_xlim(self.xlim)
- if self.ylim:
- ax.set_ylim(self.ylim)
- self.ax.set_aspect(self.aspect)
- def process_series(self):
- """
- Iterates over every ``Plot`` object and further calls
- _process_series()
- """
- self._create_figure()
- self._process_series(self._series, self.ax)
- def show(self):
- self.process_series()
- #TODO after fixing https://github.com/ipython/ipython/issues/1255
- # you can uncomment the next line and remove the pyplot.show() call
- #self.fig.show()
- if base_backend._show:
- self.fig.tight_layout()
- self.plt.show()
- else:
- self.close()
- def save(self, path):
- self.process_series()
- self.fig.savefig(path)
- def close(self):
- self.plt.close(self.fig)
|