axislines.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. """
  2. Axislines includes modified implementation of the Axes class. The
  3. biggest difference is that the artists responsible for drawing the axis spine,
  4. ticks, ticklabels and axis labels are separated out from Matplotlib's Axis
  5. class. Originally, this change was motivated to support curvilinear
  6. grid. Here are a few reasons that I came up with a new axes class:
  7. * "top" and "bottom" x-axis (or "left" and "right" y-axis) can have
  8. different ticks (tick locations and labels). This is not possible
  9. with the current Matplotlib, although some twin axes trick can help.
  10. * Curvilinear grid.
  11. * angled ticks.
  12. In the new axes class, xaxis and yaxis is set to not visible by
  13. default, and new set of artist (AxisArtist) are defined to draw axis
  14. line, ticks, ticklabels and axis label. Axes.axis attribute serves as
  15. a dictionary of these artists, i.e., ax.axis["left"] is a AxisArtist
  16. instance responsible to draw left y-axis. The default Axes.axis contains
  17. "bottom", "left", "top" and "right".
  18. AxisArtist can be considered as a container artist and has the following
  19. children artists which will draw ticks, labels, etc.
  20. * line
  21. * major_ticks, major_ticklabels
  22. * minor_ticks, minor_ticklabels
  23. * offsetText
  24. * label
  25. Note that these are separate artists from `matplotlib.axis.Axis`, thus most
  26. tick-related functions in Matplotlib won't work. For example, color and
  27. markerwidth of the ``ax.axis["bottom"].major_ticks`` will follow those of
  28. Axes.xaxis unless explicitly specified.
  29. In addition to AxisArtist, the Axes will have *gridlines* attribute,
  30. which obviously draws grid lines. The gridlines needs to be separated
  31. from the axis as some gridlines can never pass any axis.
  32. """
  33. import numpy as np
  34. import matplotlib as mpl
  35. from matplotlib import _api
  36. import matplotlib.axes as maxes
  37. from matplotlib.path import Path
  38. from mpl_toolkits.axes_grid1 import mpl_axes
  39. from .axisline_style import AxislineStyle # noqa
  40. from .axis_artist import AxisArtist, GridlinesCollection
  41. class _AxisArtistHelperBase:
  42. """
  43. Base class for axis helper.
  44. Subclasses should define the methods listed below. The *axes*
  45. argument will be the ``.axes`` attribute of the caller artist. ::
  46. # Construct the spine.
  47. def get_line_transform(self, axes):
  48. return transform
  49. def get_line(self, axes):
  50. return path
  51. # Construct the label.
  52. def get_axislabel_transform(self, axes):
  53. return transform
  54. def get_axislabel_pos_angle(self, axes):
  55. return (x, y), angle
  56. # Construct the ticks.
  57. def get_tick_transform(self, axes):
  58. return transform
  59. def get_tick_iterators(self, axes):
  60. # A pair of iterables (one for major ticks, one for minor ticks)
  61. # that yield (tick_position, tick_angle, tick_label).
  62. return iter_major, iter_minor
  63. """
  64. def __init__(self, nth_coord):
  65. self.nth_coord = nth_coord
  66. def update_lim(self, axes):
  67. pass
  68. def get_nth_coord(self):
  69. return self.nth_coord
  70. def _to_xy(self, values, const):
  71. """
  72. Create a (*values.shape, 2)-shape array representing (x, y) pairs.
  73. The other coordinate is filled with the constant *const*.
  74. Example::
  75. >>> self.nth_coord = 0
  76. >>> self._to_xy([1, 2, 3], const=0)
  77. array([[1, 0],
  78. [2, 0],
  79. [3, 0]])
  80. """
  81. if self.nth_coord == 0:
  82. return np.stack(np.broadcast_arrays(values, const), axis=-1)
  83. elif self.nth_coord == 1:
  84. return np.stack(np.broadcast_arrays(const, values), axis=-1)
  85. else:
  86. raise ValueError("Unexpected nth_coord")
  87. class _FixedAxisArtistHelperBase(_AxisArtistHelperBase):
  88. """Helper class for a fixed (in the axes coordinate) axis."""
  89. @_api.delete_parameter("3.9", "nth_coord")
  90. def __init__(self, loc, nth_coord=None):
  91. """``nth_coord = 0``: x-axis; ``nth_coord = 1``: y-axis."""
  92. super().__init__(_api.check_getitem(
  93. {"bottom": 0, "top": 0, "left": 1, "right": 1}, loc=loc))
  94. self._loc = loc
  95. self._pos = {"bottom": 0, "top": 1, "left": 0, "right": 1}[loc]
  96. # axis line in transAxes
  97. self._path = Path(self._to_xy((0, 1), const=self._pos))
  98. # LINE
  99. def get_line(self, axes):
  100. return self._path
  101. def get_line_transform(self, axes):
  102. return axes.transAxes
  103. # LABEL
  104. def get_axislabel_transform(self, axes):
  105. return axes.transAxes
  106. def get_axislabel_pos_angle(self, axes):
  107. """
  108. Return the label reference position in transAxes.
  109. get_label_transform() returns a transform of (transAxes+offset)
  110. """
  111. return dict(left=((0., 0.5), 90), # (position, angle_tangent)
  112. right=((1., 0.5), 90),
  113. bottom=((0.5, 0.), 0),
  114. top=((0.5, 1.), 0))[self._loc]
  115. # TICK
  116. def get_tick_transform(self, axes):
  117. return [axes.get_xaxis_transform(), axes.get_yaxis_transform()][self.nth_coord]
  118. class _FloatingAxisArtistHelperBase(_AxisArtistHelperBase):
  119. def __init__(self, nth_coord, value):
  120. self._value = value
  121. super().__init__(nth_coord)
  122. def get_line(self, axes):
  123. raise RuntimeError("get_line method should be defined by the derived class")
  124. class FixedAxisArtistHelperRectilinear(_FixedAxisArtistHelperBase):
  125. @_api.delete_parameter("3.9", "nth_coord")
  126. def __init__(self, axes, loc, nth_coord=None):
  127. """
  128. nth_coord = along which coordinate value varies
  129. in 2D, nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
  130. """
  131. super().__init__(loc)
  132. self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
  133. # TICK
  134. def get_tick_iterators(self, axes):
  135. """tick_loc, tick_angle, tick_label"""
  136. angle_normal, angle_tangent = {0: (90, 0), 1: (0, 90)}[self.nth_coord]
  137. major = self.axis.major
  138. major_locs = major.locator()
  139. major_labels = major.formatter.format_ticks(major_locs)
  140. minor = self.axis.minor
  141. minor_locs = minor.locator()
  142. minor_labels = minor.formatter.format_ticks(minor_locs)
  143. tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
  144. def _f(locs, labels):
  145. for loc, label in zip(locs, labels):
  146. c = self._to_xy(loc, const=self._pos)
  147. # check if the tick point is inside axes
  148. c2 = tick_to_axes.transform(c)
  149. if mpl.transforms._interval_contains_close((0, 1), c2[self.nth_coord]):
  150. yield c, angle_normal, angle_tangent, label
  151. return _f(major_locs, major_labels), _f(minor_locs, minor_labels)
  152. class FloatingAxisArtistHelperRectilinear(_FloatingAxisArtistHelperBase):
  153. def __init__(self, axes, nth_coord,
  154. passingthrough_point, axis_direction="bottom"):
  155. super().__init__(nth_coord, passingthrough_point)
  156. self._axis_direction = axis_direction
  157. self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
  158. def get_line(self, axes):
  159. fixed_coord = 1 - self.nth_coord
  160. data_to_axes = axes.transData - axes.transAxes
  161. p = data_to_axes.transform([self._value, self._value])
  162. return Path(self._to_xy((0, 1), const=p[fixed_coord]))
  163. def get_line_transform(self, axes):
  164. return axes.transAxes
  165. def get_axislabel_transform(self, axes):
  166. return axes.transAxes
  167. def get_axislabel_pos_angle(self, axes):
  168. """
  169. Return the label reference position in transAxes.
  170. get_label_transform() returns a transform of (transAxes+offset)
  171. """
  172. angle = [0, 90][self.nth_coord]
  173. fixed_coord = 1 - self.nth_coord
  174. data_to_axes = axes.transData - axes.transAxes
  175. p = data_to_axes.transform([self._value, self._value])
  176. verts = self._to_xy(0.5, const=p[fixed_coord])
  177. return (verts, angle) if 0 <= verts[fixed_coord] <= 1 else (None, None)
  178. def get_tick_transform(self, axes):
  179. return axes.transData
  180. def get_tick_iterators(self, axes):
  181. """tick_loc, tick_angle, tick_label"""
  182. angle_normal, angle_tangent = {0: (90, 0), 1: (0, 90)}[self.nth_coord]
  183. major = self.axis.major
  184. major_locs = major.locator()
  185. major_labels = major.formatter.format_ticks(major_locs)
  186. minor = self.axis.minor
  187. minor_locs = minor.locator()
  188. minor_labels = minor.formatter.format_ticks(minor_locs)
  189. data_to_axes = axes.transData - axes.transAxes
  190. def _f(locs, labels):
  191. for loc, label in zip(locs, labels):
  192. c = self._to_xy(loc, const=self._value)
  193. c1, c2 = data_to_axes.transform(c)
  194. if 0 <= c1 <= 1 and 0 <= c2 <= 1:
  195. yield c, angle_normal, angle_tangent, label
  196. return _f(major_locs, major_labels), _f(minor_locs, minor_labels)
  197. class AxisArtistHelper: # Backcompat.
  198. Fixed = _FixedAxisArtistHelperBase
  199. Floating = _FloatingAxisArtistHelperBase
  200. class AxisArtistHelperRectlinear: # Backcompat.
  201. Fixed = FixedAxisArtistHelperRectilinear
  202. Floating = FloatingAxisArtistHelperRectilinear
  203. class GridHelperBase:
  204. def __init__(self):
  205. self._old_limits = None
  206. super().__init__()
  207. def update_lim(self, axes):
  208. x1, x2 = axes.get_xlim()
  209. y1, y2 = axes.get_ylim()
  210. if self._old_limits != (x1, x2, y1, y2):
  211. self._update_grid(x1, y1, x2, y2)
  212. self._old_limits = (x1, x2, y1, y2)
  213. def _update_grid(self, x1, y1, x2, y2):
  214. """Cache relevant computations when the axes limits have changed."""
  215. def get_gridlines(self, which, axis):
  216. """
  217. Return list of grid lines as a list of paths (list of points).
  218. Parameters
  219. ----------
  220. which : {"both", "major", "minor"}
  221. axis : {"both", "x", "y"}
  222. """
  223. return []
  224. class GridHelperRectlinear(GridHelperBase):
  225. def __init__(self, axes):
  226. super().__init__()
  227. self.axes = axes
  228. @_api.delete_parameter(
  229. "3.9", "nth_coord", addendum="'nth_coord' is now inferred from 'loc'.")
  230. def new_fixed_axis(
  231. self, loc, nth_coord=None, axis_direction=None, offset=None, axes=None):
  232. if axes is None:
  233. _api.warn_external(
  234. "'new_fixed_axis' explicitly requires the axes keyword.")
  235. axes = self.axes
  236. if axis_direction is None:
  237. axis_direction = loc
  238. return AxisArtist(axes, FixedAxisArtistHelperRectilinear(axes, loc),
  239. offset=offset, axis_direction=axis_direction)
  240. def new_floating_axis(self, nth_coord, value, axis_direction="bottom", axes=None):
  241. if axes is None:
  242. _api.warn_external(
  243. "'new_floating_axis' explicitly requires the axes keyword.")
  244. axes = self.axes
  245. helper = FloatingAxisArtistHelperRectilinear(
  246. axes, nth_coord, value, axis_direction)
  247. axisline = AxisArtist(axes, helper, axis_direction=axis_direction)
  248. axisline.line.set_clip_on(True)
  249. axisline.line.set_clip_box(axisline.axes.bbox)
  250. return axisline
  251. def get_gridlines(self, which="major", axis="both"):
  252. """
  253. Return list of gridline coordinates in data coordinates.
  254. Parameters
  255. ----------
  256. which : {"both", "major", "minor"}
  257. axis : {"both", "x", "y"}
  258. """
  259. _api.check_in_list(["both", "major", "minor"], which=which)
  260. _api.check_in_list(["both", "x", "y"], axis=axis)
  261. gridlines = []
  262. if axis in ("both", "x"):
  263. locs = []
  264. y1, y2 = self.axes.get_ylim()
  265. if which in ("both", "major"):
  266. locs.extend(self.axes.xaxis.major.locator())
  267. if which in ("both", "minor"):
  268. locs.extend(self.axes.xaxis.minor.locator())
  269. gridlines.extend([[x, x], [y1, y2]] for x in locs)
  270. if axis in ("both", "y"):
  271. x1, x2 = self.axes.get_xlim()
  272. locs = []
  273. if self.axes.yaxis._major_tick_kw["gridOn"]:
  274. locs.extend(self.axes.yaxis.major.locator())
  275. if self.axes.yaxis._minor_tick_kw["gridOn"]:
  276. locs.extend(self.axes.yaxis.minor.locator())
  277. gridlines.extend([[x1, x2], [y, y]] for y in locs)
  278. return gridlines
  279. class Axes(maxes.Axes):
  280. def __init__(self, *args, grid_helper=None, **kwargs):
  281. self._axisline_on = True
  282. self._grid_helper = grid_helper if grid_helper else GridHelperRectlinear(self)
  283. super().__init__(*args, **kwargs)
  284. self.toggle_axisline(True)
  285. def toggle_axisline(self, b=None):
  286. if b is None:
  287. b = not self._axisline_on
  288. if b:
  289. self._axisline_on = True
  290. self.spines[:].set_visible(False)
  291. self.xaxis.set_visible(False)
  292. self.yaxis.set_visible(False)
  293. else:
  294. self._axisline_on = False
  295. self.spines[:].set_visible(True)
  296. self.xaxis.set_visible(True)
  297. self.yaxis.set_visible(True)
  298. @property
  299. def axis(self):
  300. return self._axislines
  301. def clear(self):
  302. # docstring inherited
  303. # Init gridlines before clear() as clear() calls grid().
  304. self.gridlines = gridlines = GridlinesCollection(
  305. [],
  306. colors=mpl.rcParams['grid.color'],
  307. linestyles=mpl.rcParams['grid.linestyle'],
  308. linewidths=mpl.rcParams['grid.linewidth'])
  309. self._set_artist_props(gridlines)
  310. gridlines.set_grid_helper(self.get_grid_helper())
  311. super().clear()
  312. # clip_path is set after Axes.clear(): that's when a patch is created.
  313. gridlines.set_clip_path(self.axes.patch)
  314. # Init axis artists.
  315. self._axislines = mpl_axes.Axes.AxisDict(self)
  316. new_fixed_axis = self.get_grid_helper().new_fixed_axis
  317. self._axislines.update({
  318. loc: new_fixed_axis(loc=loc, axes=self, axis_direction=loc)
  319. for loc in ["bottom", "top", "left", "right"]})
  320. for axisline in [self._axislines["top"], self._axislines["right"]]:
  321. axisline.label.set_visible(False)
  322. axisline.major_ticklabels.set_visible(False)
  323. axisline.minor_ticklabels.set_visible(False)
  324. def get_grid_helper(self):
  325. return self._grid_helper
  326. def grid(self, visible=None, which='major', axis="both", **kwargs):
  327. """
  328. Toggle the gridlines, and optionally set the properties of the lines.
  329. """
  330. # There are some discrepancies in the behavior of grid() between
  331. # axes_grid and Matplotlib, because axes_grid explicitly sets the
  332. # visibility of the gridlines.
  333. super().grid(visible, which=which, axis=axis, **kwargs)
  334. if not self._axisline_on:
  335. return
  336. if visible is None:
  337. visible = (self.axes.xaxis._minor_tick_kw["gridOn"]
  338. or self.axes.xaxis._major_tick_kw["gridOn"]
  339. or self.axes.yaxis._minor_tick_kw["gridOn"]
  340. or self.axes.yaxis._major_tick_kw["gridOn"])
  341. self.gridlines.set(which=which, axis=axis, visible=visible)
  342. self.gridlines.set(**kwargs)
  343. def get_children(self):
  344. if self._axisline_on:
  345. children = [*self._axislines.values(), self.gridlines]
  346. else:
  347. children = []
  348. children.extend(super().get_children())
  349. return children
  350. def new_fixed_axis(self, loc, offset=None):
  351. return self.get_grid_helper().new_fixed_axis(loc, offset=offset, axes=self)
  352. def new_floating_axis(self, nth_coord, value, axis_direction="bottom"):
  353. return self.get_grid_helper().new_floating_axis(
  354. nth_coord, value, axis_direction=axis_direction, axes=self)
  355. class AxesZero(Axes):
  356. def clear(self):
  357. super().clear()
  358. new_floating_axis = self.get_grid_helper().new_floating_axis
  359. self._axislines.update(
  360. xzero=new_floating_axis(
  361. nth_coord=0, value=0., axis_direction="bottom", axes=self),
  362. yzero=new_floating_axis(
  363. nth_coord=1, value=0., axis_direction="left", axes=self),
  364. )
  365. for k in ["xzero", "yzero"]:
  366. self._axislines[k].line.set_clip_path(self.patch)
  367. self._axislines[k].set_visible(False)
  368. Subplot = Axes
  369. SubplotZero = AxesZero