axes_grid.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. from numbers import Number
  2. import functools
  3. from types import MethodType
  4. import numpy as np
  5. from matplotlib import _api, cbook
  6. from matplotlib.gridspec import SubplotSpec
  7. from .axes_divider import Size, SubplotDivider, Divider
  8. from .mpl_axes import Axes, SimpleAxisArtist
  9. class CbarAxesBase:
  10. def __init__(self, *args, orientation, **kwargs):
  11. self.orientation = orientation
  12. super().__init__(*args, **kwargs)
  13. def colorbar(self, mappable, **kwargs):
  14. return self.get_figure(root=False).colorbar(
  15. mappable, cax=self, location=self.orientation, **kwargs)
  16. _cbaraxes_class_factory = cbook._make_class_factory(CbarAxesBase, "Cbar{}")
  17. class Grid:
  18. """
  19. A grid of Axes.
  20. In Matplotlib, the Axes location (and size) is specified in normalized
  21. figure coordinates. This may not be ideal for images that needs to be
  22. displayed with a given aspect ratio; for example, it is difficult to
  23. display multiple images of a same size with some fixed padding between
  24. them. AxesGrid can be used in such case.
  25. Attributes
  26. ----------
  27. axes_all : list of Axes
  28. A flat list of Axes. Note that you can also access this directly
  29. from the grid. The following is equivalent ::
  30. grid[i] == grid.axes_all[i]
  31. len(grid) == len(grid.axes_all)
  32. axes_column : list of list of Axes
  33. A 2D list of Axes where the first index is the column. This results
  34. in the usage pattern ``grid.axes_column[col][row]``.
  35. axes_row : list of list of Axes
  36. A 2D list of Axes where the first index is the row. This results
  37. in the usage pattern ``grid.axes_row[row][col]``.
  38. axes_llc : Axes
  39. The Axes in the lower left corner.
  40. ngrids : int
  41. Number of Axes in the grid.
  42. """
  43. _defaultAxesClass = Axes
  44. def __init__(self, fig,
  45. rect,
  46. nrows_ncols,
  47. ngrids=None,
  48. direction="row",
  49. axes_pad=0.02,
  50. *,
  51. share_all=False,
  52. share_x=True,
  53. share_y=True,
  54. label_mode="L",
  55. axes_class=None,
  56. aspect=False,
  57. ):
  58. """
  59. Parameters
  60. ----------
  61. fig : `.Figure`
  62. The parent figure.
  63. rect : (float, float, float, float), (int, int, int), int, or \
  64. `~.SubplotSpec`
  65. The axes position, as a ``(left, bottom, width, height)`` tuple,
  66. as a three-digit subplot position code (e.g., ``(1, 2, 1)`` or
  67. ``121``), or as a `~.SubplotSpec`.
  68. nrows_ncols : (int, int)
  69. Number of rows and columns in the grid.
  70. ngrids : int or None, default: None
  71. If not None, only the first *ngrids* axes in the grid are created.
  72. direction : {"row", "column"}, default: "row"
  73. Whether axes are created in row-major ("row by row") or
  74. column-major order ("column by column"). This also affects the
  75. order in which axes are accessed using indexing (``grid[index]``).
  76. axes_pad : float or (float, float), default: 0.02
  77. Padding or (horizontal padding, vertical padding) between axes, in
  78. inches.
  79. share_all : bool, default: False
  80. Whether all axes share their x- and y-axis. Overrides *share_x*
  81. and *share_y*.
  82. share_x : bool, default: True
  83. Whether all axes of a column share their x-axis.
  84. share_y : bool, default: True
  85. Whether all axes of a row share their y-axis.
  86. label_mode : {"L", "1", "all", "keep"}, default: "L"
  87. Determines which axes will get tick labels:
  88. - "L": All axes on the left column get vertical tick labels;
  89. all axes on the bottom row get horizontal tick labels.
  90. - "1": Only the bottom left axes is labelled.
  91. - "all": All axes are labelled.
  92. - "keep": Do not do anything.
  93. axes_class : subclass of `matplotlib.axes.Axes`, default: `.mpl_axes.Axes`
  94. The type of Axes to create.
  95. aspect : bool, default: False
  96. Whether the axes aspect ratio follows the aspect ratio of the data
  97. limits.
  98. """
  99. self._nrows, self._ncols = nrows_ncols
  100. if ngrids is None:
  101. ngrids = self._nrows * self._ncols
  102. else:
  103. if not 0 < ngrids <= self._nrows * self._ncols:
  104. raise ValueError(
  105. "ngrids must be positive and not larger than nrows*ncols")
  106. self.ngrids = ngrids
  107. self._horiz_pad_size, self._vert_pad_size = map(
  108. Size.Fixed, np.broadcast_to(axes_pad, 2))
  109. _api.check_in_list(["column", "row"], direction=direction)
  110. self._direction = direction
  111. if axes_class is None:
  112. axes_class = self._defaultAxesClass
  113. elif isinstance(axes_class, (list, tuple)):
  114. cls, kwargs = axes_class
  115. axes_class = functools.partial(cls, **kwargs)
  116. kw = dict(horizontal=[], vertical=[], aspect=aspect)
  117. if isinstance(rect, (Number, SubplotSpec)):
  118. self._divider = SubplotDivider(fig, rect, **kw)
  119. elif len(rect) == 3:
  120. self._divider = SubplotDivider(fig, *rect, **kw)
  121. elif len(rect) == 4:
  122. self._divider = Divider(fig, rect, **kw)
  123. else:
  124. raise TypeError("Incorrect rect format")
  125. rect = self._divider.get_position()
  126. axes_array = np.full((self._nrows, self._ncols), None, dtype=object)
  127. for i in range(self.ngrids):
  128. col, row = self._get_col_row(i)
  129. if share_all:
  130. sharex = sharey = axes_array[0, 0]
  131. else:
  132. sharex = axes_array[0, col] if share_x else None
  133. sharey = axes_array[row, 0] if share_y else None
  134. axes_array[row, col] = axes_class(
  135. fig, rect, sharex=sharex, sharey=sharey)
  136. self.axes_all = axes_array.ravel(
  137. order="C" if self._direction == "row" else "F").tolist()
  138. self.axes_column = axes_array.T.tolist()
  139. self.axes_row = axes_array.tolist()
  140. self.axes_llc = self.axes_column[0][-1]
  141. self._init_locators()
  142. for ax in self.axes_all:
  143. fig.add_axes(ax)
  144. self.set_label_mode(label_mode)
  145. def _init_locators(self):
  146. self._divider.set_horizontal(
  147. [Size.Scaled(1), self._horiz_pad_size] * (self._ncols-1) + [Size.Scaled(1)])
  148. self._divider.set_vertical(
  149. [Size.Scaled(1), self._vert_pad_size] * (self._nrows-1) + [Size.Scaled(1)])
  150. for i in range(self.ngrids):
  151. col, row = self._get_col_row(i)
  152. self.axes_all[i].set_axes_locator(
  153. self._divider.new_locator(nx=2 * col, ny=2 * (self._nrows - 1 - row)))
  154. def _get_col_row(self, n):
  155. if self._direction == "column":
  156. col, row = divmod(n, self._nrows)
  157. else:
  158. row, col = divmod(n, self._ncols)
  159. return col, row
  160. # Good to propagate __len__ if we have __getitem__
  161. def __len__(self):
  162. return len(self.axes_all)
  163. def __getitem__(self, i):
  164. return self.axes_all[i]
  165. def get_geometry(self):
  166. """
  167. Return the number of rows and columns of the grid as (nrows, ncols).
  168. """
  169. return self._nrows, self._ncols
  170. def set_axes_pad(self, axes_pad):
  171. """
  172. Set the padding between the axes.
  173. Parameters
  174. ----------
  175. axes_pad : (float, float)
  176. The padding (horizontal pad, vertical pad) in inches.
  177. """
  178. self._horiz_pad_size.fixed_size = axes_pad[0]
  179. self._vert_pad_size.fixed_size = axes_pad[1]
  180. def get_axes_pad(self):
  181. """
  182. Return the axes padding.
  183. Returns
  184. -------
  185. hpad, vpad
  186. Padding (horizontal pad, vertical pad) in inches.
  187. """
  188. return (self._horiz_pad_size.fixed_size,
  189. self._vert_pad_size.fixed_size)
  190. def set_aspect(self, aspect):
  191. """Set the aspect of the SubplotDivider."""
  192. self._divider.set_aspect(aspect)
  193. def get_aspect(self):
  194. """Return the aspect of the SubplotDivider."""
  195. return self._divider.get_aspect()
  196. def set_label_mode(self, mode):
  197. """
  198. Define which axes have tick labels.
  199. Parameters
  200. ----------
  201. mode : {"L", "1", "all", "keep"}
  202. The label mode:
  203. - "L": All axes on the left column get vertical tick labels;
  204. all axes on the bottom row get horizontal tick labels.
  205. - "1": Only the bottom left axes is labelled.
  206. - "all": All axes are labelled.
  207. - "keep": Do not do anything.
  208. """
  209. _api.check_in_list(["all", "L", "1", "keep"], mode=mode)
  210. is_last_row, is_first_col = (
  211. np.mgrid[:self._nrows, :self._ncols] == [[[self._nrows - 1]], [[0]]])
  212. if mode == "all":
  213. bottom = left = np.full((self._nrows, self._ncols), True)
  214. elif mode == "L":
  215. bottom = is_last_row
  216. left = is_first_col
  217. elif mode == "1":
  218. bottom = left = is_last_row & is_first_col
  219. else:
  220. return
  221. for i in range(self._nrows):
  222. for j in range(self._ncols):
  223. ax = self.axes_row[i][j]
  224. if isinstance(ax.axis, MethodType):
  225. bottom_axis = SimpleAxisArtist(ax.xaxis, 1, ax.spines["bottom"])
  226. left_axis = SimpleAxisArtist(ax.yaxis, 1, ax.spines["left"])
  227. else:
  228. bottom_axis = ax.axis["bottom"]
  229. left_axis = ax.axis["left"]
  230. bottom_axis.toggle(ticklabels=bottom[i, j], label=bottom[i, j])
  231. left_axis.toggle(ticklabels=left[i, j], label=left[i, j])
  232. def get_divider(self):
  233. return self._divider
  234. def set_axes_locator(self, locator):
  235. self._divider.set_locator(locator)
  236. def get_axes_locator(self):
  237. return self._divider.get_locator()
  238. class ImageGrid(Grid):
  239. """
  240. A grid of Axes for Image display.
  241. This class is a specialization of `~.axes_grid1.axes_grid.Grid` for displaying a
  242. grid of images. In particular, it forces all axes in a column to share their x-axis
  243. and all axes in a row to share their y-axis. It further provides helpers to add
  244. colorbars to some or all axes.
  245. """
  246. def __init__(self, fig,
  247. rect,
  248. nrows_ncols,
  249. ngrids=None,
  250. direction="row",
  251. axes_pad=0.02,
  252. *,
  253. share_all=False,
  254. aspect=True,
  255. label_mode="L",
  256. cbar_mode=None,
  257. cbar_location="right",
  258. cbar_pad=None,
  259. cbar_size="5%",
  260. cbar_set_cax=True,
  261. axes_class=None,
  262. ):
  263. """
  264. Parameters
  265. ----------
  266. fig : `.Figure`
  267. The parent figure.
  268. rect : (float, float, float, float) or int
  269. The axes position, as a ``(left, bottom, width, height)`` tuple or
  270. as a three-digit subplot position code (e.g., "121").
  271. nrows_ncols : (int, int)
  272. Number of rows and columns in the grid.
  273. ngrids : int or None, default: None
  274. If not None, only the first *ngrids* axes in the grid are created.
  275. direction : {"row", "column"}, default: "row"
  276. Whether axes are created in row-major ("row by row") or
  277. column-major order ("column by column"). This also affects the
  278. order in which axes are accessed using indexing (``grid[index]``).
  279. axes_pad : float or (float, float), default: 0.02in
  280. Padding or (horizontal padding, vertical padding) between axes, in
  281. inches.
  282. share_all : bool, default: False
  283. Whether all axes share their x- and y-axis. Note that in any case,
  284. all axes in a column share their x-axis and all axes in a row share
  285. their y-axis.
  286. aspect : bool, default: True
  287. Whether the axes aspect ratio follows the aspect ratio of the data
  288. limits.
  289. label_mode : {"L", "1", "all"}, default: "L"
  290. Determines which axes will get tick labels:
  291. - "L": All axes on the left column get vertical tick labels;
  292. all axes on the bottom row get horizontal tick labels.
  293. - "1": Only the bottom left axes is labelled.
  294. - "all": all axes are labelled.
  295. cbar_mode : {"each", "single", "edge", None}, default: None
  296. Whether to create a colorbar for "each" axes, a "single" colorbar
  297. for the entire grid, colorbars only for axes on the "edge"
  298. determined by *cbar_location*, or no colorbars. The colorbars are
  299. stored in the :attr:`cbar_axes` attribute.
  300. cbar_location : {"left", "right", "bottom", "top"}, default: "right"
  301. cbar_pad : float, default: None
  302. Padding between the image axes and the colorbar axes.
  303. .. versionchanged:: 3.10
  304. ``cbar_mode="single"`` no longer adds *axes_pad* between the axes
  305. and the colorbar if the *cbar_location* is "left" or "bottom".
  306. cbar_size : size specification (see `.Size.from_any`), default: "5%"
  307. Colorbar size.
  308. cbar_set_cax : bool, default: True
  309. If True, each axes in the grid has a *cax* attribute that is bound
  310. to associated *cbar_axes*.
  311. axes_class : subclass of `matplotlib.axes.Axes`, default: None
  312. """
  313. _api.check_in_list(["each", "single", "edge", None],
  314. cbar_mode=cbar_mode)
  315. _api.check_in_list(["left", "right", "bottom", "top"],
  316. cbar_location=cbar_location)
  317. self._colorbar_mode = cbar_mode
  318. self._colorbar_location = cbar_location
  319. self._colorbar_pad = cbar_pad
  320. self._colorbar_size = cbar_size
  321. # The colorbar axes are created in _init_locators().
  322. super().__init__(
  323. fig, rect, nrows_ncols, ngrids,
  324. direction=direction, axes_pad=axes_pad,
  325. share_all=share_all, share_x=True, share_y=True, aspect=aspect,
  326. label_mode=label_mode, axes_class=axes_class)
  327. for ax in self.cbar_axes:
  328. fig.add_axes(ax)
  329. if cbar_set_cax:
  330. if self._colorbar_mode == "single":
  331. for ax in self.axes_all:
  332. ax.cax = self.cbar_axes[0]
  333. elif self._colorbar_mode == "edge":
  334. for index, ax in enumerate(self.axes_all):
  335. col, row = self._get_col_row(index)
  336. if self._colorbar_location in ("left", "right"):
  337. ax.cax = self.cbar_axes[row]
  338. else:
  339. ax.cax = self.cbar_axes[col]
  340. else:
  341. for ax, cax in zip(self.axes_all, self.cbar_axes):
  342. ax.cax = cax
  343. def _init_locators(self):
  344. # Slightly abusing this method to inject colorbar creation into init.
  345. if self._colorbar_pad is None:
  346. # horizontal or vertical arrangement?
  347. if self._colorbar_location in ("left", "right"):
  348. self._colorbar_pad = self._horiz_pad_size.fixed_size
  349. else:
  350. self._colorbar_pad = self._vert_pad_size.fixed_size
  351. self.cbar_axes = [
  352. _cbaraxes_class_factory(self._defaultAxesClass)(
  353. self.axes_all[0].get_figure(root=False), self._divider.get_position(),
  354. orientation=self._colorbar_location)
  355. for _ in range(self.ngrids)]
  356. cb_mode = self._colorbar_mode
  357. cb_location = self._colorbar_location
  358. h = []
  359. v = []
  360. h_ax_pos = []
  361. h_cb_pos = []
  362. if cb_mode == "single" and cb_location in ("left", "bottom"):
  363. if cb_location == "left":
  364. sz = self._nrows * Size.AxesX(self.axes_llc)
  365. h.append(Size.from_any(self._colorbar_size, sz))
  366. h.append(Size.from_any(self._colorbar_pad, sz))
  367. locator = self._divider.new_locator(nx=0, ny=0, ny1=-1)
  368. elif cb_location == "bottom":
  369. sz = self._ncols * Size.AxesY(self.axes_llc)
  370. v.append(Size.from_any(self._colorbar_size, sz))
  371. v.append(Size.from_any(self._colorbar_pad, sz))
  372. locator = self._divider.new_locator(nx=0, nx1=-1, ny=0)
  373. for i in range(self.ngrids):
  374. self.cbar_axes[i].set_visible(False)
  375. self.cbar_axes[0].set_axes_locator(locator)
  376. self.cbar_axes[0].set_visible(True)
  377. for col, ax in enumerate(self.axes_row[0]):
  378. if col != 0:
  379. h.append(self._horiz_pad_size)
  380. if ax:
  381. sz = Size.AxesX(ax, aspect="axes", ref_ax=self.axes_all[0])
  382. else:
  383. sz = Size.AxesX(self.axes_all[0],
  384. aspect="axes", ref_ax=self.axes_all[0])
  385. if (cb_location == "left"
  386. and (cb_mode == "each"
  387. or (cb_mode == "edge" and col == 0))):
  388. h_cb_pos.append(len(h))
  389. h.append(Size.from_any(self._colorbar_size, sz))
  390. h.append(Size.from_any(self._colorbar_pad, sz))
  391. h_ax_pos.append(len(h))
  392. h.append(sz)
  393. if (cb_location == "right"
  394. and (cb_mode == "each"
  395. or (cb_mode == "edge" and col == self._ncols - 1))):
  396. h.append(Size.from_any(self._colorbar_pad, sz))
  397. h_cb_pos.append(len(h))
  398. h.append(Size.from_any(self._colorbar_size, sz))
  399. v_ax_pos = []
  400. v_cb_pos = []
  401. for row, ax in enumerate(self.axes_column[0][::-1]):
  402. if row != 0:
  403. v.append(self._vert_pad_size)
  404. if ax:
  405. sz = Size.AxesY(ax, aspect="axes", ref_ax=self.axes_all[0])
  406. else:
  407. sz = Size.AxesY(self.axes_all[0],
  408. aspect="axes", ref_ax=self.axes_all[0])
  409. if (cb_location == "bottom"
  410. and (cb_mode == "each"
  411. or (cb_mode == "edge" and row == 0))):
  412. v_cb_pos.append(len(v))
  413. v.append(Size.from_any(self._colorbar_size, sz))
  414. v.append(Size.from_any(self._colorbar_pad, sz))
  415. v_ax_pos.append(len(v))
  416. v.append(sz)
  417. if (cb_location == "top"
  418. and (cb_mode == "each"
  419. or (cb_mode == "edge" and row == self._nrows - 1))):
  420. v.append(Size.from_any(self._colorbar_pad, sz))
  421. v_cb_pos.append(len(v))
  422. v.append(Size.from_any(self._colorbar_size, sz))
  423. for i in range(self.ngrids):
  424. col, row = self._get_col_row(i)
  425. locator = self._divider.new_locator(nx=h_ax_pos[col],
  426. ny=v_ax_pos[self._nrows-1-row])
  427. self.axes_all[i].set_axes_locator(locator)
  428. if cb_mode == "each":
  429. if cb_location in ("right", "left"):
  430. locator = self._divider.new_locator(
  431. nx=h_cb_pos[col], ny=v_ax_pos[self._nrows - 1 - row])
  432. elif cb_location in ("top", "bottom"):
  433. locator = self._divider.new_locator(
  434. nx=h_ax_pos[col], ny=v_cb_pos[self._nrows - 1 - row])
  435. self.cbar_axes[i].set_axes_locator(locator)
  436. elif cb_mode == "edge":
  437. if (cb_location == "left" and col == 0
  438. or cb_location == "right" and col == self._ncols - 1):
  439. locator = self._divider.new_locator(
  440. nx=h_cb_pos[0], ny=v_ax_pos[self._nrows - 1 - row])
  441. self.cbar_axes[row].set_axes_locator(locator)
  442. elif (cb_location == "bottom" and row == self._nrows - 1
  443. or cb_location == "top" and row == 0):
  444. locator = self._divider.new_locator(nx=h_ax_pos[col],
  445. ny=v_cb_pos[0])
  446. self.cbar_axes[col].set_axes_locator(locator)
  447. if cb_mode == "single":
  448. if cb_location == "right":
  449. sz = self._nrows * Size.AxesX(self.axes_llc)
  450. h.append(Size.from_any(self._colorbar_pad, sz))
  451. h.append(Size.from_any(self._colorbar_size, sz))
  452. locator = self._divider.new_locator(nx=-2, ny=0, ny1=-1)
  453. elif cb_location == "top":
  454. sz = self._ncols * Size.AxesY(self.axes_llc)
  455. v.append(Size.from_any(self._colorbar_pad, sz))
  456. v.append(Size.from_any(self._colorbar_size, sz))
  457. locator = self._divider.new_locator(nx=0, nx1=-1, ny=-2)
  458. if cb_location in ("right", "top"):
  459. for i in range(self.ngrids):
  460. self.cbar_axes[i].set_visible(False)
  461. self.cbar_axes[0].set_axes_locator(locator)
  462. self.cbar_axes[0].set_visible(True)
  463. elif cb_mode == "each":
  464. for i in range(self.ngrids):
  465. self.cbar_axes[i].set_visible(True)
  466. elif cb_mode == "edge":
  467. if cb_location in ("right", "left"):
  468. count = self._nrows
  469. else:
  470. count = self._ncols
  471. for i in range(count):
  472. self.cbar_axes[i].set_visible(True)
  473. for j in range(i + 1, self.ngrids):
  474. self.cbar_axes[j].set_visible(False)
  475. else:
  476. for i in range(self.ngrids):
  477. self.cbar_axes[i].set_visible(False)
  478. self.cbar_axes[i].set_position([1., 1., 0.001, 0.001],
  479. which="active")
  480. self._divider.set_horizontal(h)
  481. self._divider.set_vertical(v)
  482. AxesGrid = ImageGrid