mpl_renderer.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. from __future__ import annotations
  2. import io
  3. from itertools import pairwise
  4. from typing import TYPE_CHECKING, Any, cast
  5. import matplotlib.collections as mcollections
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. from contourpy import FillType, LineType
  9. from contourpy.convert import convert_filled, convert_lines
  10. from contourpy.enum_util import as_fill_type, as_line_type
  11. from contourpy.util.mpl_util import filled_to_mpl_paths, lines_to_mpl_paths
  12. from contourpy.util.renderer import Renderer
  13. if TYPE_CHECKING:
  14. from collections.abc import Sequence
  15. from matplotlib.axes import Axes
  16. from matplotlib.figure import Figure
  17. from numpy.typing import ArrayLike
  18. import contourpy._contourpy as cpy
  19. class MplRenderer(Renderer):
  20. """Utility renderer using Matplotlib to render a grid of plots over the same (x, y) range.
  21. Args:
  22. nrows (int, optional): Number of rows of plots, default ``1``.
  23. ncols (int, optional): Number of columns of plots, default ``1``.
  24. figsize (tuple(float, float), optional): Figure size in inches, default ``(9, 9)``.
  25. show_frame (bool, optional): Whether to show frame and axes ticks, default ``True``.
  26. backend (str, optional): Matplotlib backend to use or ``None`` for default backend.
  27. Default ``None``.
  28. gridspec_kw (dict, optional): Gridspec keyword arguments to pass to ``plt.subplots``,
  29. default None.
  30. """
  31. _axes: Sequence[Axes]
  32. _fig: Figure
  33. _want_tight: bool
  34. def __init__(
  35. self,
  36. nrows: int = 1,
  37. ncols: int = 1,
  38. figsize: tuple[float, float] = (9, 9),
  39. show_frame: bool = True,
  40. backend: str | None = None,
  41. gridspec_kw: dict[str, Any] | None = None,
  42. ) -> None:
  43. if backend is not None:
  44. import matplotlib as mpl
  45. mpl.use(backend)
  46. kwargs: dict[str, Any] = {"figsize": figsize, "squeeze": False,
  47. "sharex": True, "sharey": True}
  48. if gridspec_kw is not None:
  49. kwargs["gridspec_kw"] = gridspec_kw
  50. else:
  51. kwargs["subplot_kw"] = {"aspect": "equal"}
  52. self._fig, axes = plt.subplots(nrows, ncols, **kwargs)
  53. self._axes = axes.flatten()
  54. if not show_frame:
  55. for ax in self._axes:
  56. ax.axis("off")
  57. self._want_tight = True
  58. def __del__(self) -> None:
  59. if hasattr(self, "_fig"):
  60. plt.close(self._fig)
  61. def _autoscale(self) -> None:
  62. # Using axes._need_autoscale attribute if need to autoscale before rendering after adding
  63. # lines/filled. Only want to autoscale once per axes regardless of how many lines/filled
  64. # added.
  65. for ax in self._axes:
  66. if getattr(ax, "_need_autoscale", False):
  67. ax.autoscale_view(tight=True)
  68. ax._need_autoscale = False # type: ignore[attr-defined]
  69. if self._want_tight and len(self._axes) > 1:
  70. self._fig.tight_layout()
  71. def _get_ax(self, ax: Axes | int) -> Axes:
  72. if isinstance(ax, int):
  73. ax = self._axes[ax]
  74. return ax
  75. def filled(
  76. self,
  77. filled: cpy.FillReturn,
  78. fill_type: FillType | str,
  79. ax: Axes | int = 0,
  80. color: str = "C0",
  81. alpha: float = 0.7,
  82. ) -> None:
  83. """Plot filled contours on a single Axes.
  84. Args:
  85. filled (sequence of arrays): Filled contour data as returned by
  86. :meth:`~.ContourGenerator.filled`.
  87. fill_type (FillType or str): Type of :meth:`~.ContourGenerator.filled` data as returned
  88. by :attr:`~.ContourGenerator.fill_type`, or string equivalent
  89. ax (int or Maplotlib Axes, optional): Which axes to plot on, default ``0``.
  90. color (str, optional): Color to plot with. May be a string color or the letter ``"C"``
  91. followed by an integer in the range ``"C0"`` to ``"C9"`` to use a color from the
  92. ``tab10`` colormap. Default ``"C0"``.
  93. alpha (float, optional): Opacity to plot with, default ``0.7``.
  94. """
  95. fill_type = as_fill_type(fill_type)
  96. ax = self._get_ax(ax)
  97. paths = filled_to_mpl_paths(filled, fill_type)
  98. collection = mcollections.PathCollection(
  99. paths, facecolors=color, edgecolors="none", lw=0, alpha=alpha)
  100. ax.add_collection(collection)
  101. ax._need_autoscale = True # type: ignore[attr-defined]
  102. def grid(
  103. self,
  104. x: ArrayLike,
  105. y: ArrayLike,
  106. ax: Axes | int = 0,
  107. color: str = "black",
  108. alpha: float = 0.1,
  109. point_color: str | None = None,
  110. quad_as_tri_alpha: float = 0,
  111. ) -> None:
  112. """Plot quad grid lines on a single Axes.
  113. Args:
  114. x (array-like of shape (ny, nx) or (nx,)): The x-coordinates of the grid points.
  115. y (array-like of shape (ny, nx) or (ny,)): The y-coordinates of the grid points.
  116. ax (int or Matplotlib Axes, optional): Which Axes to plot on, default ``0``.
  117. color (str, optional): Color to plot grid lines, default ``"black"``.
  118. alpha (float, optional): Opacity to plot lines with, default ``0.1``.
  119. point_color (str, optional): Color to plot grid points or ``None`` if grid points
  120. should not be plotted, default ``None``.
  121. quad_as_tri_alpha (float, optional): Opacity to plot ``quad_as_tri`` grid, default 0.
  122. Colors may be a string color or the letter ``"C"`` followed by an integer in the range
  123. ``"C0"`` to ``"C9"`` to use a color from the ``tab10`` colormap.
  124. Warning:
  125. ``quad_as_tri_alpha > 0`` plots all quads as though they are unmasked.
  126. """
  127. ax = self._get_ax(ax)
  128. x, y = self._grid_as_2d(x, y)
  129. kwargs: dict[str, Any] = {"color": color, "alpha": alpha}
  130. ax.plot(x, y, x.T, y.T, **kwargs)
  131. if quad_as_tri_alpha > 0:
  132. # Assumes no quad mask.
  133. xmid = 0.25*(x[:-1, :-1] + x[1:, :-1] + x[:-1, 1:] + x[1:, 1:])
  134. ymid = 0.25*(y[:-1, :-1] + y[1:, :-1] + y[:-1, 1:] + y[1:, 1:])
  135. kwargs["alpha"] = quad_as_tri_alpha
  136. ax.plot(
  137. np.stack((x[:-1, :-1], xmid, x[1:, 1:])).reshape((3, -1)),
  138. np.stack((y[:-1, :-1], ymid, y[1:, 1:])).reshape((3, -1)),
  139. np.stack((x[1:, :-1], xmid, x[:-1, 1:])).reshape((3, -1)),
  140. np.stack((y[1:, :-1], ymid, y[:-1, 1:])).reshape((3, -1)),
  141. **kwargs)
  142. if point_color is not None:
  143. ax.plot(x, y, color=point_color, alpha=alpha, marker="o", lw=0)
  144. ax._need_autoscale = True # type: ignore[attr-defined]
  145. def lines(
  146. self,
  147. lines: cpy.LineReturn,
  148. line_type: LineType | str,
  149. ax: Axes | int = 0,
  150. color: str = "C0",
  151. alpha: float = 1.0,
  152. linewidth: float = 1,
  153. ) -> None:
  154. """Plot contour lines on a single Axes.
  155. Args:
  156. lines (sequence of arrays): Contour line data as returned by
  157. :meth:`~.ContourGenerator.lines`.
  158. line_type (LineType or str): Type of :meth:`~.ContourGenerator.lines` data as returned
  159. by :attr:`~.ContourGenerator.line_type`, or string equivalent.
  160. ax (int or Matplotlib Axes, optional): Which Axes to plot on, default ``0``.
  161. color (str, optional): Color to plot lines. May be a string color or the letter ``"C"``
  162. followed by an integer in the range ``"C0"`` to ``"C9"`` to use a color from the
  163. ``tab10`` colormap. Default ``"C0"``.
  164. alpha (float, optional): Opacity to plot lines with, default ``1.0``.
  165. linewidth (float, optional): Width of lines, default ``1``.
  166. """
  167. line_type = as_line_type(line_type)
  168. ax = self._get_ax(ax)
  169. paths = lines_to_mpl_paths(lines, line_type)
  170. collection = mcollections.PathCollection(
  171. paths, facecolors="none", edgecolors=color, lw=linewidth, alpha=alpha)
  172. ax.add_collection(collection)
  173. ax._need_autoscale = True # type: ignore[attr-defined]
  174. def mask(
  175. self,
  176. x: ArrayLike,
  177. y: ArrayLike,
  178. z: ArrayLike | np.ma.MaskedArray[Any, Any],
  179. ax: Axes | int = 0,
  180. color: str = "black",
  181. ) -> None:
  182. """Plot masked out grid points as circles on a single Axes.
  183. Args:
  184. x (array-like of shape (ny, nx) or (nx,)): The x-coordinates of the grid points.
  185. y (array-like of shape (ny, nx) or (ny,)): The y-coordinates of the grid points.
  186. z (masked array of shape (ny, nx): z-values.
  187. ax (int or Matplotlib Axes, optional): Which Axes to plot on, default ``0``.
  188. color (str, optional): Circle color, default ``"black"``.
  189. """
  190. mask = np.ma.getmask(z)
  191. if mask is np.ma.nomask:
  192. return
  193. ax = self._get_ax(ax)
  194. x, y = self._grid_as_2d(x, y)
  195. ax.plot(x[mask], y[mask], "o", c=color)
  196. def save(self, filename: str, transparent: bool = False) -> None:
  197. """Save plots to SVG or PNG file.
  198. Args:
  199. filename (str): Filename to save to.
  200. transparent (bool, optional): Whether background should be transparent, default
  201. ``False``.
  202. """
  203. self._autoscale()
  204. self._fig.savefig(filename, transparent=transparent)
  205. def save_to_buffer(self) -> io.BytesIO:
  206. """Save plots to an ``io.BytesIO`` buffer.
  207. Return:
  208. BytesIO: PNG image buffer.
  209. """
  210. self._autoscale()
  211. buf = io.BytesIO()
  212. self._fig.savefig(buf, format="png")
  213. buf.seek(0)
  214. return buf
  215. def show(self) -> None:
  216. """Show plots in an interactive window, in the usual Matplotlib manner.
  217. """
  218. self._autoscale()
  219. plt.show()
  220. def title(self, title: str, ax: Axes | int = 0, color: str | None = None) -> None:
  221. """Set the title of a single Axes.
  222. Args:
  223. title (str): Title text.
  224. ax (int or Matplotlib Axes, optional): Which Axes to set the title of, default ``0``.
  225. color (str, optional): Color to set title. May be a string color or the letter ``"C"``
  226. followed by an integer in the range ``"C0"`` to ``"C9"`` to use a color from the
  227. ``tab10`` colormap. Default is ``None`` which uses Matplotlib's default title color
  228. that depends on the stylesheet in use.
  229. """
  230. if color:
  231. self._get_ax(ax).set_title(title, color=color)
  232. else:
  233. self._get_ax(ax).set_title(title)
  234. def z_values(
  235. self,
  236. x: ArrayLike,
  237. y: ArrayLike,
  238. z: ArrayLike,
  239. ax: Axes | int = 0,
  240. color: str = "green",
  241. fmt: str = ".1f",
  242. quad_as_tri: bool = False,
  243. ) -> None:
  244. """Show ``z`` values on a single Axes.
  245. Args:
  246. x (array-like of shape (ny, nx) or (nx,)): The x-coordinates of the grid points.
  247. y (array-like of shape (ny, nx) or (ny,)): The y-coordinates of the grid points.
  248. z (array-like of shape (ny, nx): z-values.
  249. ax (int or Matplotlib Axes, optional): Which Axes to plot on, default ``0``.
  250. color (str, optional): Color of added text. May be a string color or the letter ``"C"``
  251. followed by an integer in the range ``"C0"`` to ``"C9"`` to use a color from the
  252. ``tab10`` colormap. Default ``"green"``.
  253. fmt (str, optional): Format to display z-values, default ``".1f"``.
  254. quad_as_tri (bool, optional): Whether to show z-values at the ``quad_as_tri`` centers
  255. of quads.
  256. Warning:
  257. ``quad_as_tri=True`` shows z-values for all quads, even if masked.
  258. """
  259. ax = self._get_ax(ax)
  260. x, y = self._grid_as_2d(x, y)
  261. z = np.asarray(z)
  262. ny, nx = z.shape
  263. for j in range(ny):
  264. for i in range(nx):
  265. ax.text(x[j, i], y[j, i], f"{z[j, i]:{fmt}}", ha="center", va="center",
  266. color=color, clip_on=True)
  267. if quad_as_tri:
  268. for j in range(ny-1):
  269. for i in range(nx-1):
  270. xx = np.mean(x[j:j+2, i:i+2], dtype=np.float64)
  271. yy = np.mean(y[j:j+2, i:i+2], dtype=np.float64)
  272. zz = np.mean(z[j:j+2, i:i+2])
  273. ax.text(xx, yy, f"{zz:{fmt}}", ha="center", va="center", color=color,
  274. clip_on=True)
  275. class MplTestRenderer(MplRenderer):
  276. """Test renderer implemented using Matplotlib.
  277. No whitespace around plots and no spines/ticks displayed.
  278. Uses Agg backend, so can only save to file/buffer, cannot call ``show()``.
  279. """
  280. def __init__(
  281. self,
  282. nrows: int = 1,
  283. ncols: int = 1,
  284. figsize: tuple[float, float] = (9, 9),
  285. ) -> None:
  286. gridspec = {
  287. "left": 0.01,
  288. "right": 0.99,
  289. "top": 0.99,
  290. "bottom": 0.01,
  291. "wspace": 0.01,
  292. "hspace": 0.01,
  293. }
  294. super().__init__(
  295. nrows, ncols, figsize, show_frame=True, backend="Agg", gridspec_kw=gridspec,
  296. )
  297. for ax in self._axes:
  298. ax.set_xmargin(0.0)
  299. ax.set_ymargin(0.0)
  300. ax.set_xticks([])
  301. ax.set_yticks([])
  302. self._want_tight = False
  303. class MplDebugRenderer(MplRenderer):
  304. """Debug renderer implemented using Matplotlib.
  305. Extends ``MplRenderer`` to add extra information to help in debugging such as markers, arrows,
  306. text, etc.
  307. """
  308. def __init__(
  309. self,
  310. nrows: int = 1,
  311. ncols: int = 1,
  312. figsize: tuple[float, float] = (9, 9),
  313. show_frame: bool = True,
  314. ) -> None:
  315. super().__init__(nrows, ncols, figsize, show_frame)
  316. def _arrow(
  317. self,
  318. ax: Axes,
  319. line_start: cpy.CoordinateArray,
  320. line_end: cpy.CoordinateArray,
  321. color: str,
  322. alpha: float,
  323. arrow_size: float,
  324. ) -> None:
  325. mid = 0.5*(line_start + line_end)
  326. along = line_end - line_start
  327. along /= np.sqrt(np.dot(along, along)) # Unit vector.
  328. right = np.asarray((along[1], -along[0]))
  329. arrow = np.stack((
  330. mid - (along*0.5 - right)*arrow_size,
  331. mid + along*0.5*arrow_size,
  332. mid - (along*0.5 + right)*arrow_size,
  333. ))
  334. ax.plot(arrow[:, 0], arrow[:, 1], "-", c=color, alpha=alpha)
  335. def filled(
  336. self,
  337. filled: cpy.FillReturn,
  338. fill_type: FillType | str,
  339. ax: Axes | int = 0,
  340. color: str = "C1",
  341. alpha: float = 0.7,
  342. line_color: str = "C0",
  343. line_alpha: float = 0.7,
  344. point_color: str = "C0",
  345. start_point_color: str = "red",
  346. arrow_size: float = 0.1,
  347. ) -> None:
  348. fill_type = as_fill_type(fill_type)
  349. super().filled(filled, fill_type, ax, color, alpha)
  350. if line_color is None and point_color is None:
  351. return
  352. ax = self._get_ax(ax)
  353. filled = convert_filled(filled, fill_type, FillType.ChunkCombinedOffset)
  354. # Lines.
  355. if line_color is not None:
  356. for points, offsets in zip(*filled):
  357. if points is None:
  358. continue
  359. for start, end in pairwise(offsets):
  360. xys = points[start:end]
  361. ax.plot(xys[:, 0], xys[:, 1], c=line_color, alpha=line_alpha)
  362. if arrow_size > 0.0:
  363. n = len(xys)
  364. for i in range(n-1):
  365. self._arrow(ax, xys[i], xys[i+1], line_color, line_alpha, arrow_size)
  366. # Points.
  367. if point_color is not None:
  368. for points, offsets in zip(*filled):
  369. if points is None:
  370. continue
  371. mask = np.ones(offsets[-1], dtype=bool)
  372. mask[offsets[1:]-1] = False # Exclude end points.
  373. if start_point_color is not None:
  374. start_indices = offsets[:-1]
  375. mask[start_indices] = False # Exclude start points.
  376. ax.plot(
  377. points[:, 0][mask], points[:, 1][mask], "o", c=point_color, alpha=line_alpha)
  378. if start_point_color is not None:
  379. ax.plot(points[:, 0][start_indices], points[:, 1][start_indices], "o",
  380. c=start_point_color, alpha=line_alpha)
  381. def lines(
  382. self,
  383. lines: cpy.LineReturn,
  384. line_type: LineType | str,
  385. ax: Axes | int = 0,
  386. color: str = "C0",
  387. alpha: float = 1.0,
  388. linewidth: float = 1,
  389. point_color: str = "C0",
  390. start_point_color: str = "red",
  391. arrow_size: float = 0.1,
  392. ) -> None:
  393. line_type = as_line_type(line_type)
  394. super().lines(lines, line_type, ax, color, alpha, linewidth)
  395. if arrow_size == 0.0 and point_color is None:
  396. return
  397. ax = self._get_ax(ax)
  398. separate_lines = convert_lines(lines, line_type, LineType.Separate)
  399. if TYPE_CHECKING:
  400. separate_lines = cast(cpy.LineReturn_Separate, separate_lines)
  401. if arrow_size > 0.0:
  402. for line in separate_lines:
  403. for i in range(len(line)-1):
  404. self._arrow(ax, line[i], line[i+1], color, alpha, arrow_size)
  405. if point_color is not None:
  406. for line in separate_lines:
  407. start_index = 0
  408. end_index = len(line)
  409. if start_point_color is not None:
  410. ax.plot(line[0, 0], line[0, 1], "o", c=start_point_color, alpha=alpha)
  411. start_index = 1
  412. if line[0][0] == line[-1][0] and line[0][1] == line[-1][1]:
  413. end_index -= 1
  414. ax.plot(line[start_index:end_index, 0], line[start_index:end_index, 1], "o",
  415. c=color, alpha=alpha)
  416. def point_numbers(
  417. self,
  418. x: ArrayLike,
  419. y: ArrayLike,
  420. z: ArrayLike,
  421. ax: Axes | int = 0,
  422. color: str = "red",
  423. ) -> None:
  424. ax = self._get_ax(ax)
  425. x, y = self._grid_as_2d(x, y)
  426. z = np.asarray(z)
  427. ny, nx = z.shape
  428. for j in range(ny):
  429. for i in range(nx):
  430. quad = i + j*nx
  431. ax.text(x[j, i], y[j, i], str(quad), ha="right", va="top", color=color,
  432. clip_on=True)
  433. def quad_numbers(
  434. self,
  435. x: ArrayLike,
  436. y: ArrayLike,
  437. z: ArrayLike,
  438. ax: Axes | int = 0,
  439. color: str = "blue",
  440. ) -> None:
  441. ax = self._get_ax(ax)
  442. x, y = self._grid_as_2d(x, y)
  443. z = np.asarray(z)
  444. ny, nx = z.shape
  445. for j in range(1, ny):
  446. for i in range(1, nx):
  447. quad = i + j*nx
  448. xmid = x[j-1:j+1, i-1:i+1].mean()
  449. ymid = y[j-1:j+1, i-1:i+1].mean()
  450. ax.text(xmid, ymid, str(quad), ha="center", va="center", color=color, clip_on=True)
  451. def z_levels(
  452. self,
  453. x: ArrayLike,
  454. y: ArrayLike,
  455. z: ArrayLike,
  456. lower_level: float,
  457. upper_level: float | None = None,
  458. ax: Axes | int = 0,
  459. color: str = "green",
  460. ) -> None:
  461. ax = self._get_ax(ax)
  462. x, y = self._grid_as_2d(x, y)
  463. z = np.asarray(z)
  464. ny, nx = z.shape
  465. for j in range(ny):
  466. for i in range(nx):
  467. zz = z[j, i]
  468. if upper_level is not None and zz > upper_level:
  469. z_level = 2
  470. elif zz > lower_level:
  471. z_level = 1
  472. else:
  473. z_level = 0
  474. ax.text(x[j, i], y[j, i], str(z_level), ha="left", va="bottom", color=color,
  475. clip_on=True)