grid_helper_curvelinear.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. """
  2. An experimental support for curvilinear grid.
  3. """
  4. import functools
  5. import numpy as np
  6. import matplotlib as mpl
  7. from matplotlib import _api
  8. from matplotlib.path import Path
  9. from matplotlib.transforms import Affine2D, IdentityTransform
  10. from .axislines import (
  11. _FixedAxisArtistHelperBase, _FloatingAxisArtistHelperBase, GridHelperBase)
  12. from .axis_artist import AxisArtist
  13. from .grid_finder import GridFinder
  14. def _value_and_jacobian(func, xs, ys, xlims, ylims):
  15. """
  16. Compute *func* and its derivatives along x and y at positions *xs*, *ys*,
  17. while ensuring that finite difference calculations don't try to evaluate
  18. values outside of *xlims*, *ylims*.
  19. """
  20. eps = np.finfo(float).eps ** (1/2) # see e.g. scipy.optimize.approx_fprime
  21. val = func(xs, ys)
  22. # Take the finite difference step in the direction where the bound is the
  23. # furthest; the step size is min of epsilon and distance to that bound.
  24. xlo, xhi = sorted(xlims)
  25. dxlo = xs - xlo
  26. dxhi = xhi - xs
  27. xeps = (np.take([-1, 1], dxhi >= dxlo)
  28. * np.minimum(eps, np.maximum(dxlo, dxhi)))
  29. val_dx = func(xs + xeps, ys)
  30. ylo, yhi = sorted(ylims)
  31. dylo = ys - ylo
  32. dyhi = yhi - ys
  33. yeps = (np.take([-1, 1], dyhi >= dylo)
  34. * np.minimum(eps, np.maximum(dylo, dyhi)))
  35. val_dy = func(xs, ys + yeps)
  36. return (val, (val_dx - val) / xeps, (val_dy - val) / yeps)
  37. class FixedAxisArtistHelper(_FixedAxisArtistHelperBase):
  38. """
  39. Helper class for a fixed axis.
  40. """
  41. def __init__(self, grid_helper, side, nth_coord_ticks=None):
  42. """
  43. nth_coord = along which coordinate value varies.
  44. nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
  45. """
  46. super().__init__(loc=side)
  47. self.grid_helper = grid_helper
  48. if nth_coord_ticks is None:
  49. nth_coord_ticks = self.nth_coord
  50. self.nth_coord_ticks = nth_coord_ticks
  51. self.side = side
  52. def update_lim(self, axes):
  53. self.grid_helper.update_lim(axes)
  54. def get_tick_transform(self, axes):
  55. return axes.transData
  56. def get_tick_iterators(self, axes):
  57. """tick_loc, tick_angle, tick_label"""
  58. v1, v2 = axes.get_ylim() if self.nth_coord == 0 else axes.get_xlim()
  59. if v1 > v2: # Inverted limits.
  60. side = {"left": "right", "right": "left",
  61. "top": "bottom", "bottom": "top"}[self.side]
  62. else:
  63. side = self.side
  64. angle_tangent = dict(left=90, right=90, bottom=0, top=0)[side]
  65. def iter_major():
  66. for nth_coord, show_labels in [
  67. (self.nth_coord_ticks, True), (1 - self.nth_coord_ticks, False)]:
  68. gi = self.grid_helper._grid_info[["lon", "lat"][nth_coord]]
  69. for tick in gi["ticks"][side]:
  70. yield (*tick["loc"], angle_tangent,
  71. (tick["label"] if show_labels else ""))
  72. return iter_major(), iter([])
  73. class FloatingAxisArtistHelper(_FloatingAxisArtistHelperBase):
  74. def __init__(self, grid_helper, nth_coord, value, axis_direction=None):
  75. """
  76. nth_coord = along which coordinate value varies.
  77. nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
  78. """
  79. super().__init__(nth_coord, value)
  80. self.value = value
  81. self.grid_helper = grid_helper
  82. self._extremes = -np.inf, np.inf
  83. self._line_num_points = 100 # number of points to create a line
  84. def set_extremes(self, e1, e2):
  85. if e1 is None:
  86. e1 = -np.inf
  87. if e2 is None:
  88. e2 = np.inf
  89. self._extremes = e1, e2
  90. def update_lim(self, axes):
  91. self.grid_helper.update_lim(axes)
  92. x1, x2 = axes.get_xlim()
  93. y1, y2 = axes.get_ylim()
  94. grid_finder = self.grid_helper.grid_finder
  95. extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
  96. x1, y1, x2, y2)
  97. lon_min, lon_max, lat_min, lat_max = extremes
  98. e_min, e_max = self._extremes # ranges of other coordinates
  99. if self.nth_coord == 0:
  100. lat_min = max(e_min, lat_min)
  101. lat_max = min(e_max, lat_max)
  102. elif self.nth_coord == 1:
  103. lon_min = max(e_min, lon_min)
  104. lon_max = min(e_max, lon_max)
  105. lon_levs, lon_n, lon_factor = \
  106. grid_finder.grid_locator1(lon_min, lon_max)
  107. lat_levs, lat_n, lat_factor = \
  108. grid_finder.grid_locator2(lat_min, lat_max)
  109. if self.nth_coord == 0:
  110. xx0 = np.full(self._line_num_points, self.value)
  111. yy0 = np.linspace(lat_min, lat_max, self._line_num_points)
  112. xx, yy = grid_finder.transform_xy(xx0, yy0)
  113. elif self.nth_coord == 1:
  114. xx0 = np.linspace(lon_min, lon_max, self._line_num_points)
  115. yy0 = np.full(self._line_num_points, self.value)
  116. xx, yy = grid_finder.transform_xy(xx0, yy0)
  117. self._grid_info = {
  118. "extremes": (lon_min, lon_max, lat_min, lat_max),
  119. "lon_info": (lon_levs, lon_n, np.asarray(lon_factor)),
  120. "lat_info": (lat_levs, lat_n, np.asarray(lat_factor)),
  121. "lon_labels": grid_finder._format_ticks(
  122. 1, "bottom", lon_factor, lon_levs),
  123. "lat_labels": grid_finder._format_ticks(
  124. 2, "bottom", lat_factor, lat_levs),
  125. "line_xy": (xx, yy),
  126. }
  127. def get_axislabel_transform(self, axes):
  128. return Affine2D() # axes.transData
  129. def get_axislabel_pos_angle(self, axes):
  130. def trf_xy(x, y):
  131. trf = self.grid_helper.grid_finder.get_transform() + axes.transData
  132. return trf.transform([x, y]).T
  133. xmin, xmax, ymin, ymax = self._grid_info["extremes"]
  134. if self.nth_coord == 0:
  135. xx0 = self.value
  136. yy0 = (ymin + ymax) / 2
  137. elif self.nth_coord == 1:
  138. xx0 = (xmin + xmax) / 2
  139. yy0 = self.value
  140. xy1, dxy1_dx, dxy1_dy = _value_and_jacobian(
  141. trf_xy, xx0, yy0, (xmin, xmax), (ymin, ymax))
  142. p = axes.transAxes.inverted().transform(xy1)
  143. if 0 <= p[0] <= 1 and 0 <= p[1] <= 1:
  144. d = [dxy1_dy, dxy1_dx][self.nth_coord]
  145. return xy1, np.rad2deg(np.arctan2(*d[::-1]))
  146. else:
  147. return None, None
  148. def get_tick_transform(self, axes):
  149. return IdentityTransform() # axes.transData
  150. def get_tick_iterators(self, axes):
  151. """tick_loc, tick_angle, tick_label, (optionally) tick_label"""
  152. lat_levs, lat_n, lat_factor = self._grid_info["lat_info"]
  153. yy0 = lat_levs / lat_factor
  154. lon_levs, lon_n, lon_factor = self._grid_info["lon_info"]
  155. xx0 = lon_levs / lon_factor
  156. e0, e1 = self._extremes
  157. def trf_xy(x, y):
  158. trf = self.grid_helper.grid_finder.get_transform() + axes.transData
  159. return trf.transform(np.column_stack(np.broadcast_arrays(x, y))).T
  160. # find angles
  161. if self.nth_coord == 0:
  162. mask = (e0 <= yy0) & (yy0 <= e1)
  163. (xx1, yy1), (dxx1, dyy1), (dxx2, dyy2) = _value_and_jacobian(
  164. trf_xy, self.value, yy0[mask], (-np.inf, np.inf), (e0, e1))
  165. labels = self._grid_info["lat_labels"]
  166. elif self.nth_coord == 1:
  167. mask = (e0 <= xx0) & (xx0 <= e1)
  168. (xx1, yy1), (dxx2, dyy2), (dxx1, dyy1) = _value_and_jacobian(
  169. trf_xy, xx0[mask], self.value, (-np.inf, np.inf), (e0, e1))
  170. labels = self._grid_info["lon_labels"]
  171. labels = [l for l, m in zip(labels, mask) if m]
  172. angle_normal = np.arctan2(dyy1, dxx1)
  173. angle_tangent = np.arctan2(dyy2, dxx2)
  174. mm = (dyy1 == 0) & (dxx1 == 0) # points with degenerate normal
  175. angle_normal[mm] = angle_tangent[mm] + np.pi / 2
  176. tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
  177. in_01 = functools.partial(
  178. mpl.transforms._interval_contains_close, (0, 1))
  179. def iter_major():
  180. for x, y, normal, tangent, lab \
  181. in zip(xx1, yy1, angle_normal, angle_tangent, labels):
  182. c2 = tick_to_axes.transform((x, y))
  183. if in_01(c2[0]) and in_01(c2[1]):
  184. yield [x, y], *np.rad2deg([normal, tangent]), lab
  185. return iter_major(), iter([])
  186. def get_line_transform(self, axes):
  187. return axes.transData
  188. def get_line(self, axes):
  189. self.update_lim(axes)
  190. x, y = self._grid_info["line_xy"]
  191. return Path(np.column_stack([x, y]))
  192. class GridHelperCurveLinear(GridHelperBase):
  193. def __init__(self, aux_trans,
  194. extreme_finder=None,
  195. grid_locator1=None,
  196. grid_locator2=None,
  197. tick_formatter1=None,
  198. tick_formatter2=None):
  199. """
  200. Parameters
  201. ----------
  202. aux_trans : `.Transform` or tuple[Callable, Callable]
  203. The transform from curved coordinates to rectilinear coordinate:
  204. either a `.Transform` instance (which provides also its inverse),
  205. or a pair of callables ``(trans, inv_trans)`` that define the
  206. transform and its inverse. The callables should have signature::
  207. x_rect, y_rect = trans(x_curved, y_curved)
  208. x_curved, y_curved = inv_trans(x_rect, y_rect)
  209. extreme_finder
  210. grid_locator1, grid_locator2
  211. Grid locators for each axis.
  212. tick_formatter1, tick_formatter2
  213. Tick formatters for each axis.
  214. """
  215. super().__init__()
  216. self._grid_info = None
  217. self.grid_finder = GridFinder(aux_trans,
  218. extreme_finder,
  219. grid_locator1,
  220. grid_locator2,
  221. tick_formatter1,
  222. tick_formatter2)
  223. def update_grid_finder(self, aux_trans=None, **kwargs):
  224. if aux_trans is not None:
  225. self.grid_finder.update_transform(aux_trans)
  226. self.grid_finder.update(**kwargs)
  227. self._old_limits = None # Force revalidation.
  228. @_api.make_keyword_only("3.9", "nth_coord")
  229. def new_fixed_axis(
  230. self, loc, nth_coord=None, axis_direction=None, offset=None, axes=None):
  231. if axes is None:
  232. axes = self.axes
  233. if axis_direction is None:
  234. axis_direction = loc
  235. helper = FixedAxisArtistHelper(self, loc, nth_coord_ticks=nth_coord)
  236. axisline = AxisArtist(axes, helper, axis_direction=axis_direction)
  237. # Why is clip not set on axisline, unlike in new_floating_axis or in
  238. # the floating_axig.GridHelperCurveLinear subclass?
  239. return axisline
  240. def new_floating_axis(self, nth_coord, value, axes=None, axis_direction="bottom"):
  241. if axes is None:
  242. axes = self.axes
  243. helper = FloatingAxisArtistHelper(
  244. self, nth_coord, value, axis_direction)
  245. axisline = AxisArtist(axes, helper)
  246. axisline.line.set_clip_on(True)
  247. axisline.line.set_clip_box(axisline.axes.bbox)
  248. # axisline.major_ticklabels.set_visible(True)
  249. # axisline.minor_ticklabels.set_visible(False)
  250. return axisline
  251. def _update_grid(self, x1, y1, x2, y2):
  252. self._grid_info = self.grid_finder.get_grid_info(x1, y1, x2, y2)
  253. def get_gridlines(self, which="major", axis="both"):
  254. grid_lines = []
  255. if axis in ["both", "x"]:
  256. for gl in self._grid_info["lon"]["lines"]:
  257. grid_lines.extend(gl)
  258. if axis in ["both", "y"]:
  259. for gl in self._grid_info["lat"]["lines"]:
  260. grid_lines.extend(gl)
  261. return grid_lines
  262. @_api.deprecated("3.9")
  263. def get_tick_iterator(self, nth_coord, axis_side, minor=False):
  264. angle_tangent = dict(left=90, right=90, bottom=0, top=0)[axis_side]
  265. lon_or_lat = ["lon", "lat"][nth_coord]
  266. if not minor: # major ticks
  267. for tick in self._grid_info[lon_or_lat]["ticks"][axis_side]:
  268. yield *tick["loc"], angle_tangent, tick["label"]
  269. else:
  270. for tick in self._grid_info[lon_or_lat]["ticks"][axis_side]:
  271. yield *tick["loc"], angle_tangent, ""