matplotlib.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. from collections.abc import Callable
  2. from sympy.core.basic import Basic
  3. from sympy.external import import_module
  4. import sympy.plotting.backends.base_backend as base_backend
  5. from sympy.printing.latex import latex
  6. # N.B.
  7. # When changing the minimum module version for matplotlib, please change
  8. # the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py`
  9. def _str_or_latex(label):
  10. if isinstance(label, Basic):
  11. return latex(label, mode='inline')
  12. return str(label)
  13. def _matplotlib_list(interval_list):
  14. """
  15. Returns lists for matplotlib ``fill`` command from a list of bounding
  16. rectangular intervals
  17. """
  18. xlist = []
  19. ylist = []
  20. if len(interval_list):
  21. for intervals in interval_list:
  22. intervalx = intervals[0]
  23. intervaly = intervals[1]
  24. xlist.extend([intervalx.start, intervalx.start,
  25. intervalx.end, intervalx.end, None])
  26. ylist.extend([intervaly.start, intervaly.end,
  27. intervaly.end, intervaly.start, None])
  28. else:
  29. #XXX Ugly hack. Matplotlib does not accept empty lists for ``fill``
  30. xlist.extend((None, None, None, None))
  31. ylist.extend((None, None, None, None))
  32. return xlist, ylist
  33. # Don't have to check for the success of importing matplotlib in each case;
  34. # we will only be using this backend if we can successfully import matploblib
  35. class MatplotlibBackend(base_backend.Plot):
  36. """ This class implements the functionalities to use Matplotlib with SymPy
  37. plotting functions.
  38. """
  39. def __init__(self, *series, **kwargs):
  40. super().__init__(*series, **kwargs)
  41. self.matplotlib = import_module('matplotlib',
  42. import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']},
  43. min_module_version='1.1.0', catch=(RuntimeError,))
  44. self.plt = self.matplotlib.pyplot
  45. self.cm = self.matplotlib.cm
  46. self.LineCollection = self.matplotlib.collections.LineCollection
  47. self.aspect = kwargs.get('aspect_ratio', 'auto')
  48. if self.aspect != 'auto':
  49. self.aspect = float(self.aspect[1]) / self.aspect[0]
  50. # PlotGrid can provide its figure and axes to be populated with
  51. # the data from the series.
  52. self._plotgrid_fig = kwargs.pop("fig", None)
  53. self._plotgrid_ax = kwargs.pop("ax", None)
  54. def _create_figure(self):
  55. def set_spines(ax):
  56. ax.spines['left'].set_position('zero')
  57. ax.spines['right'].set_color('none')
  58. ax.spines['bottom'].set_position('zero')
  59. ax.spines['top'].set_color('none')
  60. ax.xaxis.set_ticks_position('bottom')
  61. ax.yaxis.set_ticks_position('left')
  62. if self._plotgrid_fig is not None:
  63. self.fig = self._plotgrid_fig
  64. self.ax = self._plotgrid_ax
  65. if not any(s.is_3D for s in self._series):
  66. set_spines(self.ax)
  67. else:
  68. self.fig = self.plt.figure(figsize=self.size)
  69. if any(s.is_3D for s in self._series):
  70. self.ax = self.fig.add_subplot(1, 1, 1, projection="3d")
  71. else:
  72. self.ax = self.fig.add_subplot(1, 1, 1)
  73. set_spines(self.ax)
  74. @staticmethod
  75. def get_segments(x, y, z=None):
  76. """ Convert two list of coordinates to a list of segments to be used
  77. with Matplotlib's :external:class:`~matplotlib.collections.LineCollection`.
  78. Parameters
  79. ==========
  80. x : list
  81. List of x-coordinates
  82. y : list
  83. List of y-coordinates
  84. z : list
  85. List of z-coordinates for a 3D line.
  86. """
  87. np = import_module('numpy')
  88. if z is not None:
  89. dim = 3
  90. points = (x, y, z)
  91. else:
  92. dim = 2
  93. points = (x, y)
  94. points = np.ma.array(points).T.reshape(-1, 1, dim)
  95. return np.ma.concatenate([points[:-1], points[1:]], axis=1)
  96. def _process_series(self, series, ax):
  97. np = import_module('numpy')
  98. mpl_toolkits = import_module(
  99. 'mpl_toolkits', import_kwargs={'fromlist': ['mplot3d']})
  100. # XXX Workaround for matplotlib issue
  101. # https://github.com/matplotlib/matplotlib/issues/17130
  102. xlims, ylims, zlims = [], [], []
  103. for s in series:
  104. # Create the collections
  105. if s.is_2Dline:
  106. if s.is_parametric:
  107. x, y, param = s.get_data()
  108. else:
  109. x, y = s.get_data()
  110. if (isinstance(s.line_color, (int, float)) or
  111. callable(s.line_color)):
  112. segments = self.get_segments(x, y)
  113. collection = self.LineCollection(segments)
  114. collection.set_array(s.get_color_array())
  115. ax.add_collection(collection)
  116. else:
  117. lbl = _str_or_latex(s.label)
  118. line, = ax.plot(x, y, label=lbl, color=s.line_color)
  119. elif s.is_contour:
  120. ax.contour(*s.get_data())
  121. elif s.is_3Dline:
  122. x, y, z, param = s.get_data()
  123. if (isinstance(s.line_color, (int, float)) or
  124. callable(s.line_color)):
  125. art3d = mpl_toolkits.mplot3d.art3d
  126. segments = self.get_segments(x, y, z)
  127. collection = art3d.Line3DCollection(segments)
  128. collection.set_array(s.get_color_array())
  129. ax.add_collection(collection)
  130. else:
  131. lbl = _str_or_latex(s.label)
  132. ax.plot(x, y, z, label=lbl, color=s.line_color)
  133. xlims.append(s._xlim)
  134. ylims.append(s._ylim)
  135. zlims.append(s._zlim)
  136. elif s.is_3Dsurface:
  137. if s.is_parametric:
  138. x, y, z, u, v = s.get_data()
  139. else:
  140. x, y, z = s.get_data()
  141. collection = ax.plot_surface(x, y, z,
  142. cmap=getattr(self.cm, 'viridis', self.cm.jet),
  143. rstride=1, cstride=1, linewidth=0.1)
  144. if isinstance(s.surface_color, (float, int, Callable)):
  145. color_array = s.get_color_array()
  146. color_array = color_array.reshape(color_array.size)
  147. collection.set_array(color_array)
  148. else:
  149. collection.set_color(s.surface_color)
  150. xlims.append(s._xlim)
  151. ylims.append(s._ylim)
  152. zlims.append(s._zlim)
  153. elif s.is_implicit:
  154. points = s.get_data()
  155. if len(points) == 2:
  156. # interval math plotting
  157. x, y = _matplotlib_list(points[0])
  158. ax.fill(x, y, facecolor=s.line_color, edgecolor='None')
  159. else:
  160. # use contourf or contour depending on whether it is
  161. # an inequality or equality.
  162. # XXX: ``contour`` plots multiple lines. Should be fixed.
  163. ListedColormap = self.matplotlib.colors.ListedColormap
  164. colormap = ListedColormap(["white", s.line_color])
  165. xarray, yarray, zarray, plot_type = points
  166. if plot_type == 'contour':
  167. ax.contour(xarray, yarray, zarray, cmap=colormap)
  168. else:
  169. ax.contourf(xarray, yarray, zarray, cmap=colormap)
  170. elif s.is_generic:
  171. if s.type == "markers":
  172. # s.rendering_kw["color"] = s.line_color
  173. ax.plot(*s.args, **s.rendering_kw)
  174. elif s.type == "annotations":
  175. ax.annotate(*s.args, **s.rendering_kw)
  176. elif s.type == "fill":
  177. # s.rendering_kw["color"] = s.line_color
  178. ax.fill_between(*s.args, **s.rendering_kw)
  179. elif s.type == "rectangles":
  180. # s.rendering_kw["color"] = s.line_color
  181. ax.add_patch(
  182. self.matplotlib.patches.Rectangle(
  183. *s.args, **s.rendering_kw))
  184. else:
  185. raise NotImplementedError(
  186. '{} is not supported in the SymPy plotting module '
  187. 'with matplotlib backend. Please report this issue.'
  188. .format(ax))
  189. Axes3D = mpl_toolkits.mplot3d.Axes3D
  190. if not isinstance(ax, Axes3D):
  191. ax.autoscale_view(
  192. scalex=ax.get_autoscalex_on(),
  193. scaley=ax.get_autoscaley_on())
  194. else:
  195. # XXX Workaround for matplotlib issue
  196. # https://github.com/matplotlib/matplotlib/issues/17130
  197. if xlims:
  198. xlims = np.array(xlims)
  199. xlim = (np.amin(xlims[:, 0]), np.amax(xlims[:, 1]))
  200. ax.set_xlim(xlim)
  201. else:
  202. ax.set_xlim([0, 1])
  203. if ylims:
  204. ylims = np.array(ylims)
  205. ylim = (np.amin(ylims[:, 0]), np.amax(ylims[:, 1]))
  206. ax.set_ylim(ylim)
  207. else:
  208. ax.set_ylim([0, 1])
  209. if zlims:
  210. zlims = np.array(zlims)
  211. zlim = (np.amin(zlims[:, 0]), np.amax(zlims[:, 1]))
  212. ax.set_zlim(zlim)
  213. else:
  214. ax.set_zlim([0, 1])
  215. # Set global options.
  216. # TODO The 3D stuff
  217. # XXX The order of those is important.
  218. if self.xscale and not isinstance(ax, Axes3D):
  219. ax.set_xscale(self.xscale)
  220. if self.yscale and not isinstance(ax, Axes3D):
  221. ax.set_yscale(self.yscale)
  222. if not isinstance(ax, Axes3D) or self.matplotlib.__version__ >= '1.2.0': # XXX in the distant future remove this check
  223. ax.set_autoscale_on(self.autoscale)
  224. if self.axis_center:
  225. val = self.axis_center
  226. if isinstance(ax, Axes3D):
  227. pass
  228. elif val == 'center':
  229. ax.spines['left'].set_position('center')
  230. ax.spines['bottom'].set_position('center')
  231. elif val == 'auto':
  232. xl, xh = ax.get_xlim()
  233. yl, yh = ax.get_ylim()
  234. pos_left = ('data', 0) if xl*xh <= 0 else 'center'
  235. pos_bottom = ('data', 0) if yl*yh <= 0 else 'center'
  236. ax.spines['left'].set_position(pos_left)
  237. ax.spines['bottom'].set_position(pos_bottom)
  238. else:
  239. ax.spines['left'].set_position(('data', val[0]))
  240. ax.spines['bottom'].set_position(('data', val[1]))
  241. if not self.axis:
  242. ax.set_axis_off()
  243. if self.legend:
  244. if ax.legend():
  245. ax.legend_.set_visible(self.legend)
  246. if self.margin:
  247. ax.set_xmargin(self.margin)
  248. ax.set_ymargin(self.margin)
  249. if self.title:
  250. ax.set_title(self.title)
  251. if self.xlabel:
  252. xlbl = _str_or_latex(self.xlabel)
  253. ax.set_xlabel(xlbl, position=(1, 0))
  254. if self.ylabel:
  255. ylbl = _str_or_latex(self.ylabel)
  256. ax.set_ylabel(ylbl, position=(0, 1))
  257. if isinstance(ax, Axes3D) and self.zlabel:
  258. zlbl = _str_or_latex(self.zlabel)
  259. ax.set_zlabel(zlbl, position=(0, 1))
  260. # xlim and ylim should always be set at last so that plot limits
  261. # doesn't get altered during the process.
  262. if self.xlim:
  263. ax.set_xlim(self.xlim)
  264. if self.ylim:
  265. ax.set_ylim(self.ylim)
  266. self.ax.set_aspect(self.aspect)
  267. def process_series(self):
  268. """
  269. Iterates over every ``Plot`` object and further calls
  270. _process_series()
  271. """
  272. self._create_figure()
  273. self._process_series(self._series, self.ax)
  274. def show(self):
  275. self.process_series()
  276. #TODO after fixing https://github.com/ipython/ipython/issues/1255
  277. # you can uncomment the next line and remove the pyplot.show() call
  278. #self.fig.show()
  279. if base_backend._show:
  280. self.fig.tight_layout()
  281. self.plt.show()
  282. else:
  283. self.close()
  284. def save(self, path):
  285. self.process_series()
  286. self.fig.savefig(path)
  287. def close(self):
  288. self.plt.close(self.fig)