plotgrid.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. from sympy.external import import_module
  2. import sympy.plotting.backends.base_backend as base_backend
  3. # N.B.
  4. # When changing the minimum module version for matplotlib, please change
  5. # the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py`
  6. __doctest_requires__ = {
  7. ("PlotGrid",): ["matplotlib"],
  8. }
  9. class PlotGrid:
  10. """This class helps to plot subplots from already created SymPy plots
  11. in a single figure.
  12. Examples
  13. ========
  14. .. plot::
  15. :context: close-figs
  16. :format: doctest
  17. :include-source: True
  18. >>> from sympy import symbols
  19. >>> from sympy.plotting import plot, plot3d, PlotGrid
  20. >>> x, y = symbols('x, y')
  21. >>> p1 = plot(x, x**2, x**3, (x, -5, 5))
  22. >>> p2 = plot((x**2, (x, -6, 6)), (x, (x, -5, 5)))
  23. >>> p3 = plot(x**3, (x, -5, 5))
  24. >>> p4 = plot3d(x*y, (x, -5, 5), (y, -5, 5))
  25. Plotting vertically in a single line:
  26. .. plot::
  27. :context: close-figs
  28. :format: doctest
  29. :include-source: True
  30. >>> PlotGrid(2, 1, p1, p2)
  31. PlotGrid object containing:
  32. Plot[0]:Plot object containing:
  33. [0]: cartesian line: x for x over (-5.0, 5.0)
  34. [1]: cartesian line: x**2 for x over (-5.0, 5.0)
  35. [2]: cartesian line: x**3 for x over (-5.0, 5.0)
  36. Plot[1]:Plot object containing:
  37. [0]: cartesian line: x**2 for x over (-6.0, 6.0)
  38. [1]: cartesian line: x for x over (-5.0, 5.0)
  39. Plotting horizontally in a single line:
  40. .. plot::
  41. :context: close-figs
  42. :format: doctest
  43. :include-source: True
  44. >>> PlotGrid(1, 3, p2, p3, p4)
  45. PlotGrid object containing:
  46. Plot[0]:Plot object containing:
  47. [0]: cartesian line: x**2 for x over (-6.0, 6.0)
  48. [1]: cartesian line: x for x over (-5.0, 5.0)
  49. Plot[1]:Plot object containing:
  50. [0]: cartesian line: x**3 for x over (-5.0, 5.0)
  51. Plot[2]:Plot object containing:
  52. [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0)
  53. Plotting in a grid form:
  54. .. plot::
  55. :context: close-figs
  56. :format: doctest
  57. :include-source: True
  58. >>> PlotGrid(2, 2, p1, p2, p3, p4)
  59. PlotGrid object containing:
  60. Plot[0]:Plot object containing:
  61. [0]: cartesian line: x for x over (-5.0, 5.0)
  62. [1]: cartesian line: x**2 for x over (-5.0, 5.0)
  63. [2]: cartesian line: x**3 for x over (-5.0, 5.0)
  64. Plot[1]:Plot object containing:
  65. [0]: cartesian line: x**2 for x over (-6.0, 6.0)
  66. [1]: cartesian line: x for x over (-5.0, 5.0)
  67. Plot[2]:Plot object containing:
  68. [0]: cartesian line: x**3 for x over (-5.0, 5.0)
  69. Plot[3]:Plot object containing:
  70. [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0)
  71. """
  72. def __init__(self, nrows, ncolumns, *args, show=True, size=None, **kwargs):
  73. """
  74. Parameters
  75. ==========
  76. nrows :
  77. The number of rows that should be in the grid of the
  78. required subplot.
  79. ncolumns :
  80. The number of columns that should be in the grid
  81. of the required subplot.
  82. nrows and ncolumns together define the required grid.
  83. Arguments
  84. =========
  85. A list of predefined plot objects entered in a row-wise sequence
  86. i.e. plot objects which are to be in the top row of the required
  87. grid are written first, then the second row objects and so on
  88. Keyword arguments
  89. =================
  90. show : Boolean
  91. The default value is set to ``True``. Set show to ``False`` and
  92. the function will not display the subplot. The returned instance
  93. of the ``PlotGrid`` class can then be used to save or display the
  94. plot by calling the ``save()`` and ``show()`` methods
  95. respectively.
  96. size : (float, float), optional
  97. A tuple in the form (width, height) in inches to specify the size of
  98. the overall figure. The default value is set to ``None``, meaning
  99. the size will be set by the default backend.
  100. """
  101. self.matplotlib = import_module('matplotlib',
  102. import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']},
  103. min_module_version='1.1.0', catch=(RuntimeError,))
  104. self.nrows = nrows
  105. self.ncolumns = ncolumns
  106. self._series = []
  107. self._fig = None
  108. self.args = args
  109. for arg in args:
  110. self._series.append(arg._series)
  111. self.size = size
  112. if show and self.matplotlib:
  113. self.show()
  114. def _create_figure(self):
  115. gs = self.matplotlib.gridspec.GridSpec(self.nrows, self.ncolumns)
  116. mapping = {}
  117. c = 0
  118. for i in range(self.nrows):
  119. for j in range(self.ncolumns):
  120. if c < len(self.args):
  121. mapping[gs[i, j]] = self.args[c]
  122. c += 1
  123. kw = {} if not self.size else {"figsize": self.size}
  124. self._fig = self.matplotlib.pyplot.figure(**kw)
  125. for spec, p in mapping.items():
  126. kw = ({"projection": "3d"} if (len(p._series) > 0 and
  127. p._series[0].is_3D) else {})
  128. cur_ax = self._fig.add_subplot(spec, **kw)
  129. p._plotgrid_fig = self._fig
  130. p._plotgrid_ax = cur_ax
  131. p.process_series()
  132. @property
  133. def fig(self):
  134. if not self._fig:
  135. self._create_figure()
  136. return self._fig
  137. @property
  138. def _backend(self):
  139. return self
  140. def close(self):
  141. self.matplotlib.pyplot.close(self.fig)
  142. def show(self):
  143. if base_backend._show:
  144. self.fig.tight_layout()
  145. self.matplotlib.pyplot.show()
  146. else:
  147. self.close()
  148. def save(self, path):
  149. self.fig.savefig(path)
  150. def __str__(self):
  151. plot_strs = [('Plot[%d]:' % i) + str(plot)
  152. for i, plot in enumerate(self.args)]
  153. return 'PlotGrid object containing:\n' + '\n'.join(plot_strs)