test_units.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. from datetime import datetime, timezone, timedelta
  2. import platform
  3. from unittest.mock import MagicMock
  4. import matplotlib.pyplot as plt
  5. from matplotlib.testing.decorators import check_figures_equal, image_comparison
  6. import matplotlib.patches as mpatches
  7. import matplotlib.units as munits
  8. from matplotlib.category import StrCategoryConverter, UnitData
  9. from matplotlib.dates import DateConverter
  10. import numpy as np
  11. import pytest
  12. # Basic class that wraps numpy array and has units
  13. class Quantity:
  14. def __init__(self, data, units):
  15. self.magnitude = data
  16. self.units = units
  17. def to(self, new_units):
  18. factors = {('hours', 'seconds'): 3600, ('minutes', 'hours'): 1 / 60,
  19. ('minutes', 'seconds'): 60, ('feet', 'miles'): 1 / 5280.,
  20. ('feet', 'inches'): 12, ('miles', 'inches'): 12 * 5280}
  21. if self.units != new_units:
  22. mult = factors[self.units, new_units]
  23. return Quantity(mult * self.magnitude, new_units)
  24. else:
  25. return Quantity(self.magnitude, self.units)
  26. def __copy__(self):
  27. return Quantity(self.magnitude, self.units)
  28. def __getattr__(self, attr):
  29. return getattr(self.magnitude, attr)
  30. def __getitem__(self, item):
  31. if np.iterable(self.magnitude):
  32. return Quantity(self.magnitude[item], self.units)
  33. else:
  34. return Quantity(self.magnitude, self.units)
  35. def __array__(self):
  36. return np.asarray(self.magnitude)
  37. @pytest.fixture
  38. def quantity_converter():
  39. # Create an instance of the conversion interface and
  40. # mock so we can check methods called
  41. qc = munits.ConversionInterface()
  42. def convert(value, unit, axis):
  43. if hasattr(value, 'units'):
  44. return value.to(unit).magnitude
  45. elif np.iterable(value):
  46. try:
  47. return [v.to(unit).magnitude for v in value]
  48. except AttributeError:
  49. return [Quantity(v, axis.get_units()).to(unit).magnitude
  50. for v in value]
  51. else:
  52. return Quantity(value, axis.get_units()).to(unit).magnitude
  53. def default_units(value, axis):
  54. if hasattr(value, 'units'):
  55. return value.units
  56. elif np.iterable(value):
  57. for v in value:
  58. if hasattr(v, 'units'):
  59. return v.units
  60. return None
  61. qc.convert = MagicMock(side_effect=convert)
  62. qc.axisinfo = MagicMock(side_effect=lambda u, a:
  63. munits.AxisInfo(label=u, default_limits=(0, 100)))
  64. qc.default_units = MagicMock(side_effect=default_units)
  65. return qc
  66. # Tests that the conversion machinery works properly for classes that
  67. # work as a facade over numpy arrays (like pint)
  68. @image_comparison(['plot_pint.png'], style='mpl20',
  69. tol=0 if platform.machine() == 'x86_64' else 0.03)
  70. def test_numpy_facade(quantity_converter):
  71. # use former defaults to match existing baseline image
  72. plt.rcParams['axes.formatter.limits'] = -7, 7
  73. # Register the class
  74. munits.registry[Quantity] = quantity_converter
  75. # Simple test
  76. y = Quantity(np.linspace(0, 30), 'miles')
  77. x = Quantity(np.linspace(0, 5), 'hours')
  78. fig, ax = plt.subplots()
  79. fig.subplots_adjust(left=0.15) # Make space for label
  80. ax.plot(x, y, 'tab:blue')
  81. ax.axhline(Quantity(26400, 'feet'), color='tab:red')
  82. ax.axvline(Quantity(120, 'minutes'), color='tab:green')
  83. ax.yaxis.set_units('inches')
  84. ax.xaxis.set_units('seconds')
  85. assert quantity_converter.convert.called
  86. assert quantity_converter.axisinfo.called
  87. assert quantity_converter.default_units.called
  88. # Tests gh-8908
  89. @image_comparison(['plot_masked_units.png'], remove_text=True, style='mpl20',
  90. tol=0 if platform.machine() == 'x86_64' else 0.02)
  91. def test_plot_masked_units():
  92. data = np.linspace(-5, 5)
  93. data_masked = np.ma.array(data, mask=(data > -2) & (data < 2))
  94. data_masked_units = Quantity(data_masked, 'meters')
  95. fig, ax = plt.subplots()
  96. ax.plot(data_masked_units)
  97. def test_empty_set_limits_with_units(quantity_converter):
  98. # Register the class
  99. munits.registry[Quantity] = quantity_converter
  100. fig, ax = plt.subplots()
  101. ax.set_xlim(Quantity(-1, 'meters'), Quantity(6, 'meters'))
  102. ax.set_ylim(Quantity(-1, 'hours'), Quantity(16, 'hours'))
  103. @image_comparison(['jpl_bar_units.png'],
  104. savefig_kwarg={'dpi': 120}, style='mpl20')
  105. def test_jpl_bar_units():
  106. import matplotlib.testing.jpl_units as units
  107. units.register()
  108. day = units.Duration("ET", 24.0 * 60.0 * 60.0)
  109. x = [0 * units.km, 1 * units.km, 2 * units.km]
  110. w = [1 * day, 2 * day, 3 * day]
  111. b = units.Epoch("ET", dt=datetime(2009, 4, 26))
  112. fig, ax = plt.subplots()
  113. ax.bar(x, w, bottom=b)
  114. ax.set_ylim([b - 1 * day, b + w[-1] + (1.001) * day])
  115. @image_comparison(['jpl_barh_units.png'],
  116. savefig_kwarg={'dpi': 120}, style='mpl20')
  117. def test_jpl_barh_units():
  118. import matplotlib.testing.jpl_units as units
  119. units.register()
  120. day = units.Duration("ET", 24.0 * 60.0 * 60.0)
  121. x = [0 * units.km, 1 * units.km, 2 * units.km]
  122. w = [1 * day, 2 * day, 3 * day]
  123. b = units.Epoch("ET", dt=datetime(2009, 4, 26))
  124. fig, ax = plt.subplots()
  125. ax.barh(x, w, left=b)
  126. ax.set_xlim([b - 1 * day, b + w[-1] + (1.001) * day])
  127. def test_jpl_datetime_units_consistent():
  128. import matplotlib.testing.jpl_units as units
  129. units.register()
  130. dt = datetime(2009, 4, 26)
  131. jpl = units.Epoch("ET", dt=dt)
  132. dt_conv = munits.registry.get_converter(dt).convert(dt, None, None)
  133. jpl_conv = munits.registry.get_converter(jpl).convert(jpl, None, None)
  134. assert dt_conv == jpl_conv
  135. def test_empty_arrays():
  136. # Check that plotting an empty array with a dtype works
  137. plt.scatter(np.array([], dtype='datetime64[ns]'), np.array([]))
  138. def test_scatter_element0_masked():
  139. times = np.arange('2005-02', '2005-03', dtype='datetime64[D]')
  140. y = np.arange(len(times), dtype=float)
  141. y[0] = np.nan
  142. fig, ax = plt.subplots()
  143. ax.scatter(times, y)
  144. fig.canvas.draw()
  145. def test_errorbar_mixed_units():
  146. x = np.arange(10)
  147. y = [datetime(2020, 5, i * 2 + 1) for i in x]
  148. fig, ax = plt.subplots()
  149. ax.errorbar(x, y, timedelta(days=0.5))
  150. fig.canvas.draw()
  151. @check_figures_equal(extensions=["png"])
  152. def test_subclass(fig_test, fig_ref):
  153. class subdate(datetime):
  154. pass
  155. fig_test.subplots().plot(subdate(2000, 1, 1), 0, "o")
  156. fig_ref.subplots().plot(datetime(2000, 1, 1), 0, "o")
  157. def test_shared_axis_quantity(quantity_converter):
  158. munits.registry[Quantity] = quantity_converter
  159. x = Quantity(np.linspace(0, 1, 10), "hours")
  160. y1 = Quantity(np.linspace(1, 2, 10), "feet")
  161. y2 = Quantity(np.linspace(3, 4, 10), "feet")
  162. fig, (ax1, ax2) = plt.subplots(2, 1, sharex='all', sharey='all')
  163. ax1.plot(x, y1)
  164. ax2.plot(x, y2)
  165. assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "hours"
  166. assert ax2.yaxis.get_units() == ax2.yaxis.get_units() == "feet"
  167. ax1.xaxis.set_units("seconds")
  168. ax2.yaxis.set_units("inches")
  169. assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "seconds"
  170. assert ax1.yaxis.get_units() == ax2.yaxis.get_units() == "inches"
  171. def test_shared_axis_datetime():
  172. # datetime uses dates.DateConverter
  173. y1 = [datetime(2020, i, 1, tzinfo=timezone.utc) for i in range(1, 13)]
  174. y2 = [datetime(2021, i, 1, tzinfo=timezone.utc) for i in range(1, 13)]
  175. fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
  176. ax1.plot(y1)
  177. ax2.plot(y2)
  178. ax1.yaxis.set_units(timezone(timedelta(hours=5)))
  179. assert ax2.yaxis.units == timezone(timedelta(hours=5))
  180. def test_shared_axis_categorical():
  181. # str uses category.StrCategoryConverter
  182. d1 = {"a": 1, "b": 2}
  183. d2 = {"a": 3, "b": 4}
  184. fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, sharey=True)
  185. ax1.plot(d1.keys(), d1.values())
  186. ax2.plot(d2.keys(), d2.values())
  187. ax1.xaxis.set_units(UnitData(["c", "d"]))
  188. assert "c" in ax2.xaxis.get_units()._mapping.keys()
  189. def test_explicit_converter():
  190. d1 = {"a": 1, "b": 2}
  191. str_cat_converter = StrCategoryConverter()
  192. str_cat_converter_2 = StrCategoryConverter()
  193. date_converter = DateConverter()
  194. # Explicit is set
  195. fig1, ax1 = plt.subplots()
  196. ax1.xaxis.set_converter(str_cat_converter)
  197. assert ax1.xaxis.get_converter() == str_cat_converter
  198. # Explicit not overridden by implicit
  199. ax1.plot(d1.keys(), d1.values())
  200. assert ax1.xaxis.get_converter() == str_cat_converter
  201. # No error when called twice with equivalent input
  202. ax1.xaxis.set_converter(str_cat_converter)
  203. # Error when explicit called twice
  204. with pytest.raises(RuntimeError):
  205. ax1.xaxis.set_converter(str_cat_converter_2)
  206. fig2, ax2 = plt.subplots()
  207. ax2.plot(d1.keys(), d1.values())
  208. # No error when equivalent type is used
  209. ax2.xaxis.set_converter(str_cat_converter)
  210. fig3, ax3 = plt.subplots()
  211. ax3.plot(d1.keys(), d1.values())
  212. # Warn when implicit overridden
  213. with pytest.warns():
  214. ax3.xaxis.set_converter(date_converter)
  215. def test_empty_default_limits(quantity_converter):
  216. munits.registry[Quantity] = quantity_converter
  217. fig, ax1 = plt.subplots()
  218. ax1.xaxis.update_units(Quantity([10], "miles"))
  219. fig.draw_without_rendering()
  220. assert ax1.get_xlim() == (0, 100)
  221. ax1.yaxis.update_units(Quantity([10], "miles"))
  222. fig.draw_without_rendering()
  223. assert ax1.get_ylim() == (0, 100)
  224. fig, ax = plt.subplots()
  225. ax.axhline(30)
  226. ax.plot(Quantity(np.arange(0, 3), "miles"),
  227. Quantity(np.arange(0, 6, 2), "feet"))
  228. fig.draw_without_rendering()
  229. assert ax.get_xlim() == (0, 2)
  230. assert ax.get_ylim() == (0, 30)
  231. fig, ax = plt.subplots()
  232. ax.axvline(30)
  233. ax.plot(Quantity(np.arange(0, 3), "miles"),
  234. Quantity(np.arange(0, 6, 2), "feet"))
  235. fig.draw_without_rendering()
  236. assert ax.get_xlim() == (0, 30)
  237. assert ax.get_ylim() == (0, 4)
  238. fig, ax = plt.subplots()
  239. ax.xaxis.update_units(Quantity([10], "miles"))
  240. ax.axhline(30)
  241. fig.draw_without_rendering()
  242. assert ax.get_xlim() == (0, 100)
  243. assert ax.get_ylim() == (28.5, 31.5)
  244. fig, ax = plt.subplots()
  245. ax.yaxis.update_units(Quantity([10], "miles"))
  246. ax.axvline(30)
  247. fig.draw_without_rendering()
  248. assert ax.get_ylim() == (0, 100)
  249. assert ax.get_xlim() == (28.5, 31.5)
  250. # test array-like objects...
  251. class Kernel:
  252. def __init__(self, array):
  253. self._array = np.asanyarray(array)
  254. def __array__(self, dtype=None, copy=None):
  255. if dtype is not None and dtype != self._array.dtype:
  256. if copy is not None and not copy:
  257. raise ValueError(
  258. f"Converting array from {self._array.dtype} to "
  259. f"{dtype} requires a copy"
  260. )
  261. arr = np.asarray(self._array, dtype=dtype)
  262. return (arr if not copy else np.copy(arr))
  263. @property
  264. def shape(self):
  265. return self._array.shape
  266. def test_plot_kernel():
  267. # just a smoketest that fail
  268. kernel = Kernel([1, 2, 3, 4, 5])
  269. plt.plot(kernel)
  270. def test_connection_patch_units(pd):
  271. # tests that this doesn't raise an error
  272. fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(10, 5))
  273. x = pd.Timestamp('2017-01-01T12')
  274. ax1.axvline(x)
  275. y = "test test"
  276. ax2.axhline(y)
  277. arr = mpatches.ConnectionPatch((x, 0), (0, y),
  278. coordsA='data', coordsB='data',
  279. axesA=ax1, axesB=ax2)
  280. fig.add_artist(arr)
  281. fig.draw_without_rendering()