test_collections.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393
  1. from datetime import datetime
  2. import io
  3. import itertools
  4. import platform
  5. import re
  6. from types import SimpleNamespace
  7. import numpy as np
  8. from numpy.testing import assert_array_equal, assert_array_almost_equal
  9. import pytest
  10. import matplotlib as mpl
  11. import matplotlib.pyplot as plt
  12. import matplotlib.collections as mcollections
  13. import matplotlib.colors as mcolors
  14. import matplotlib.path as mpath
  15. import matplotlib.transforms as mtransforms
  16. from matplotlib.collections import (Collection, LineCollection,
  17. EventCollection, PolyCollection)
  18. from matplotlib.collections import FillBetweenPolyCollection
  19. from matplotlib.testing.decorators import check_figures_equal, image_comparison
  20. @pytest.fixture(params=["pcolormesh", "pcolor"])
  21. def pcfunc(request):
  22. return request.param
  23. def generate_EventCollection_plot():
  24. """Generate the initial collection and plot it."""
  25. positions = np.array([0., 1., 2., 3., 5., 8., 13., 21.])
  26. extra_positions = np.array([34., 55., 89.])
  27. orientation = 'horizontal'
  28. lineoffset = 1
  29. linelength = .5
  30. linewidth = 2
  31. color = [1, 0, 0, 1]
  32. linestyle = 'solid'
  33. antialiased = True
  34. coll = EventCollection(positions,
  35. orientation=orientation,
  36. lineoffset=lineoffset,
  37. linelength=linelength,
  38. linewidth=linewidth,
  39. color=color,
  40. linestyle=linestyle,
  41. antialiased=antialiased
  42. )
  43. fig, ax = plt.subplots()
  44. ax.add_collection(coll)
  45. ax.set_title('EventCollection: default')
  46. props = {'positions': positions,
  47. 'extra_positions': extra_positions,
  48. 'orientation': orientation,
  49. 'lineoffset': lineoffset,
  50. 'linelength': linelength,
  51. 'linewidth': linewidth,
  52. 'color': color,
  53. 'linestyle': linestyle,
  54. 'antialiased': antialiased
  55. }
  56. ax.set_xlim(-1, 22)
  57. ax.set_ylim(0, 2)
  58. return ax, coll, props
  59. @image_comparison(['EventCollection_plot__default.png'])
  60. def test__EventCollection__get_props():
  61. _, coll, props = generate_EventCollection_plot()
  62. # check that the default segments have the correct coordinates
  63. check_segments(coll,
  64. props['positions'],
  65. props['linelength'],
  66. props['lineoffset'],
  67. props['orientation'])
  68. # check that the default positions match the input positions
  69. np.testing.assert_array_equal(props['positions'], coll.get_positions())
  70. # check that the default orientation matches the input orientation
  71. assert props['orientation'] == coll.get_orientation()
  72. # check that the default orientation matches the input orientation
  73. assert coll.is_horizontal()
  74. # check that the default linelength matches the input linelength
  75. assert props['linelength'] == coll.get_linelength()
  76. # check that the default lineoffset matches the input lineoffset
  77. assert props['lineoffset'] == coll.get_lineoffset()
  78. # check that the default linestyle matches the input linestyle
  79. assert coll.get_linestyle() == [(0, None)]
  80. # check that the default color matches the input color
  81. for color in [coll.get_color(), *coll.get_colors()]:
  82. np.testing.assert_array_equal(color, props['color'])
  83. @image_comparison(['EventCollection_plot__set_positions.png'])
  84. def test__EventCollection__set_positions():
  85. splt, coll, props = generate_EventCollection_plot()
  86. new_positions = np.hstack([props['positions'], props['extra_positions']])
  87. coll.set_positions(new_positions)
  88. np.testing.assert_array_equal(new_positions, coll.get_positions())
  89. check_segments(coll, new_positions,
  90. props['linelength'],
  91. props['lineoffset'],
  92. props['orientation'])
  93. splt.set_title('EventCollection: set_positions')
  94. splt.set_xlim(-1, 90)
  95. @image_comparison(['EventCollection_plot__add_positions.png'])
  96. def test__EventCollection__add_positions():
  97. splt, coll, props = generate_EventCollection_plot()
  98. new_positions = np.hstack([props['positions'],
  99. props['extra_positions'][0]])
  100. coll.switch_orientation() # Test adding in the vertical orientation, too.
  101. coll.add_positions(props['extra_positions'][0])
  102. coll.switch_orientation()
  103. np.testing.assert_array_equal(new_positions, coll.get_positions())
  104. check_segments(coll,
  105. new_positions,
  106. props['linelength'],
  107. props['lineoffset'],
  108. props['orientation'])
  109. splt.set_title('EventCollection: add_positions')
  110. splt.set_xlim(-1, 35)
  111. @image_comparison(['EventCollection_plot__append_positions.png'])
  112. def test__EventCollection__append_positions():
  113. splt, coll, props = generate_EventCollection_plot()
  114. new_positions = np.hstack([props['positions'],
  115. props['extra_positions'][2]])
  116. coll.append_positions(props['extra_positions'][2])
  117. np.testing.assert_array_equal(new_positions, coll.get_positions())
  118. check_segments(coll,
  119. new_positions,
  120. props['linelength'],
  121. props['lineoffset'],
  122. props['orientation'])
  123. splt.set_title('EventCollection: append_positions')
  124. splt.set_xlim(-1, 90)
  125. @image_comparison(['EventCollection_plot__extend_positions.png'])
  126. def test__EventCollection__extend_positions():
  127. splt, coll, props = generate_EventCollection_plot()
  128. new_positions = np.hstack([props['positions'],
  129. props['extra_positions'][1:]])
  130. coll.extend_positions(props['extra_positions'][1:])
  131. np.testing.assert_array_equal(new_positions, coll.get_positions())
  132. check_segments(coll,
  133. new_positions,
  134. props['linelength'],
  135. props['lineoffset'],
  136. props['orientation'])
  137. splt.set_title('EventCollection: extend_positions')
  138. splt.set_xlim(-1, 90)
  139. @image_comparison(['EventCollection_plot__switch_orientation.png'])
  140. def test__EventCollection__switch_orientation():
  141. splt, coll, props = generate_EventCollection_plot()
  142. new_orientation = 'vertical'
  143. coll.switch_orientation()
  144. assert new_orientation == coll.get_orientation()
  145. assert not coll.is_horizontal()
  146. new_positions = coll.get_positions()
  147. check_segments(coll,
  148. new_positions,
  149. props['linelength'],
  150. props['lineoffset'], new_orientation)
  151. splt.set_title('EventCollection: switch_orientation')
  152. splt.set_ylim(-1, 22)
  153. splt.set_xlim(0, 2)
  154. @image_comparison(['EventCollection_plot__switch_orientation__2x.png'])
  155. def test__EventCollection__switch_orientation_2x():
  156. """
  157. Check that calling switch_orientation twice sets the orientation back to
  158. the default.
  159. """
  160. splt, coll, props = generate_EventCollection_plot()
  161. coll.switch_orientation()
  162. coll.switch_orientation()
  163. new_positions = coll.get_positions()
  164. assert props['orientation'] == coll.get_orientation()
  165. assert coll.is_horizontal()
  166. np.testing.assert_array_equal(props['positions'], new_positions)
  167. check_segments(coll,
  168. new_positions,
  169. props['linelength'],
  170. props['lineoffset'],
  171. props['orientation'])
  172. splt.set_title('EventCollection: switch_orientation 2x')
  173. @image_comparison(['EventCollection_plot__set_orientation.png'])
  174. def test__EventCollection__set_orientation():
  175. splt, coll, props = generate_EventCollection_plot()
  176. new_orientation = 'vertical'
  177. coll.set_orientation(new_orientation)
  178. assert new_orientation == coll.get_orientation()
  179. assert not coll.is_horizontal()
  180. check_segments(coll,
  181. props['positions'],
  182. props['linelength'],
  183. props['lineoffset'],
  184. new_orientation)
  185. splt.set_title('EventCollection: set_orientation')
  186. splt.set_ylim(-1, 22)
  187. splt.set_xlim(0, 2)
  188. @image_comparison(['EventCollection_plot__set_linelength.png'])
  189. def test__EventCollection__set_linelength():
  190. splt, coll, props = generate_EventCollection_plot()
  191. new_linelength = 15
  192. coll.set_linelength(new_linelength)
  193. assert new_linelength == coll.get_linelength()
  194. check_segments(coll,
  195. props['positions'],
  196. new_linelength,
  197. props['lineoffset'],
  198. props['orientation'])
  199. splt.set_title('EventCollection: set_linelength')
  200. splt.set_ylim(-20, 20)
  201. @image_comparison(['EventCollection_plot__set_lineoffset.png'])
  202. def test__EventCollection__set_lineoffset():
  203. splt, coll, props = generate_EventCollection_plot()
  204. new_lineoffset = -5.
  205. coll.set_lineoffset(new_lineoffset)
  206. assert new_lineoffset == coll.get_lineoffset()
  207. check_segments(coll,
  208. props['positions'],
  209. props['linelength'],
  210. new_lineoffset,
  211. props['orientation'])
  212. splt.set_title('EventCollection: set_lineoffset')
  213. splt.set_ylim(-6, -4)
  214. @image_comparison([
  215. 'EventCollection_plot__set_linestyle.png',
  216. 'EventCollection_plot__set_linestyle.png',
  217. 'EventCollection_plot__set_linewidth.png',
  218. ])
  219. def test__EventCollection__set_prop():
  220. for prop, value, expected in [
  221. ('linestyle', 'dashed', [(0, (6.0, 6.0))]),
  222. ('linestyle', (0, (6., 6.)), [(0, (6.0, 6.0))]),
  223. ('linewidth', 5, 5),
  224. ]:
  225. splt, coll, _ = generate_EventCollection_plot()
  226. coll.set(**{prop: value})
  227. assert plt.getp(coll, prop) == expected
  228. splt.set_title(f'EventCollection: set_{prop}')
  229. @image_comparison(['EventCollection_plot__set_color.png'])
  230. def test__EventCollection__set_color():
  231. splt, coll, _ = generate_EventCollection_plot()
  232. new_color = np.array([0, 1, 1, 1])
  233. coll.set_color(new_color)
  234. for color in [coll.get_color(), *coll.get_colors()]:
  235. np.testing.assert_array_equal(color, new_color)
  236. splt.set_title('EventCollection: set_color')
  237. def check_segments(coll, positions, linelength, lineoffset, orientation):
  238. """
  239. Test helper checking that all values in the segment are correct, given a
  240. particular set of inputs.
  241. """
  242. segments = coll.get_segments()
  243. if (orientation.lower() == 'horizontal'
  244. or orientation.lower() == 'none' or orientation is None):
  245. # if horizontal, the position in is in the y-axis
  246. pos1 = 1
  247. pos2 = 0
  248. elif orientation.lower() == 'vertical':
  249. # if vertical, the position in is in the x-axis
  250. pos1 = 0
  251. pos2 = 1
  252. else:
  253. raise ValueError("orientation must be 'horizontal' or 'vertical'")
  254. # test to make sure each segment is correct
  255. for i, segment in enumerate(segments):
  256. assert segment[0, pos1] == lineoffset + linelength / 2
  257. assert segment[1, pos1] == lineoffset - linelength / 2
  258. assert segment[0, pos2] == positions[i]
  259. assert segment[1, pos2] == positions[i]
  260. def test_collection_norm_autoscale():
  261. # norm should be autoscaled when array is set, not deferred to draw time
  262. lines = np.arange(24).reshape((4, 3, 2))
  263. coll = mcollections.LineCollection(lines, array=np.arange(4))
  264. assert coll.norm(2) == 2 / 3
  265. # setting a new array shouldn't update the already scaled limits
  266. coll.set_array(np.arange(4) + 5)
  267. assert coll.norm(2) == 2 / 3
  268. def test_null_collection_datalim():
  269. col = mcollections.PathCollection([])
  270. col_data_lim = col.get_datalim(mtransforms.IdentityTransform())
  271. assert_array_equal(col_data_lim.get_points(),
  272. mtransforms.Bbox.null().get_points())
  273. def test_no_offsets_datalim():
  274. # A collection with no offsets and a non transData
  275. # transform should return a null bbox
  276. ax = plt.axes()
  277. coll = mcollections.PathCollection([mpath.Path([(0, 0), (1, 0)])])
  278. ax.add_collection(coll)
  279. coll_data_lim = coll.get_datalim(mtransforms.IdentityTransform())
  280. assert_array_equal(coll_data_lim.get_points(),
  281. mtransforms.Bbox.null().get_points())
  282. def test_add_collection():
  283. # Test if data limits are unchanged by adding an empty collection.
  284. # GitHub issue #1490, pull #1497.
  285. plt.figure()
  286. ax = plt.axes()
  287. ax.scatter([0, 1], [0, 1])
  288. bounds = ax.dataLim.bounds
  289. ax.scatter([], [])
  290. assert ax.dataLim.bounds == bounds
  291. @mpl.style.context('mpl20')
  292. @check_figures_equal(extensions=['png'])
  293. def test_collection_log_datalim(fig_test, fig_ref):
  294. # Data limits should respect the minimum x/y when using log scale.
  295. x_vals = [4.38462e-6, 5.54929e-6, 7.02332e-6, 8.88889e-6, 1.12500e-5,
  296. 1.42383e-5, 1.80203e-5, 2.28070e-5, 2.88651e-5, 3.65324e-5,
  297. 4.62363e-5, 5.85178e-5, 7.40616e-5, 9.37342e-5, 1.18632e-4]
  298. y_vals = [0.0, 0.1, 0.182, 0.332, 0.604, 1.1, 2.0, 3.64, 6.64, 12.1, 22.0,
  299. 39.6, 71.3]
  300. x, y = np.meshgrid(x_vals, y_vals)
  301. x = x.flatten()
  302. y = y.flatten()
  303. ax_test = fig_test.subplots()
  304. ax_test.set_xscale('log')
  305. ax_test.set_yscale('log')
  306. ax_test.margins = 0
  307. ax_test.scatter(x, y)
  308. ax_ref = fig_ref.subplots()
  309. ax_ref.set_xscale('log')
  310. ax_ref.set_yscale('log')
  311. ax_ref.plot(x, y, marker="o", ls="")
  312. def test_quiver_limits():
  313. ax = plt.axes()
  314. x, y = np.arange(8), np.arange(10)
  315. u = v = np.linspace(0, 10, 80).reshape(10, 8)
  316. q = plt.quiver(x, y, u, v)
  317. assert q.get_datalim(ax.transData).bounds == (0., 0., 7., 9.)
  318. plt.figure()
  319. ax = plt.axes()
  320. x = np.linspace(-5, 10, 20)
  321. y = np.linspace(-2, 4, 10)
  322. y, x = np.meshgrid(y, x)
  323. trans = mtransforms.Affine2D().translate(25, 32) + ax.transData
  324. plt.quiver(x, y, np.sin(x), np.cos(y), transform=trans)
  325. assert ax.dataLim.bounds == (20.0, 30.0, 15.0, 6.0)
  326. def test_barb_limits():
  327. ax = plt.axes()
  328. x = np.linspace(-5, 10, 20)
  329. y = np.linspace(-2, 4, 10)
  330. y, x = np.meshgrid(y, x)
  331. trans = mtransforms.Affine2D().translate(25, 32) + ax.transData
  332. plt.barbs(x, y, np.sin(x), np.cos(y), transform=trans)
  333. # The calculated bounds are approximately the bounds of the original data,
  334. # this is because the entire path is taken into account when updating the
  335. # datalim.
  336. assert_array_almost_equal(ax.dataLim.bounds, (20, 30, 15, 6),
  337. decimal=1)
  338. @image_comparison(['EllipseCollection_test_image.png'], remove_text=True,
  339. tol=0 if platform.machine() == 'x86_64' else 0.021)
  340. def test_EllipseCollection():
  341. # Test basic functionality
  342. fig, ax = plt.subplots()
  343. x = np.arange(4)
  344. y = np.arange(3)
  345. X, Y = np.meshgrid(x, y)
  346. XY = np.vstack((X.ravel(), Y.ravel())).T
  347. ww = X / x[-1]
  348. hh = Y / y[-1]
  349. aa = np.ones_like(ww) * 20 # first axis is 20 degrees CCW from x axis
  350. ec = mcollections.EllipseCollection(
  351. ww, hh, aa, units='x', offsets=XY, offset_transform=ax.transData,
  352. facecolors='none')
  353. ax.add_collection(ec)
  354. ax.autoscale_view()
  355. def test_EllipseCollection_setter_getter():
  356. # Test widths, heights and angle setter
  357. rng = np.random.default_rng(0)
  358. widths = (2, )
  359. heights = (3, )
  360. angles = (45, )
  361. offsets = rng.random((10, 2)) * 10
  362. fig, ax = plt.subplots()
  363. ec = mcollections.EllipseCollection(
  364. widths=widths,
  365. heights=heights,
  366. angles=angles,
  367. offsets=offsets,
  368. units='x',
  369. offset_transform=ax.transData,
  370. )
  371. assert_array_almost_equal(ec._widths, np.array(widths).ravel() * 0.5)
  372. assert_array_almost_equal(ec._heights, np.array(heights).ravel() * 0.5)
  373. assert_array_almost_equal(ec._angles, np.deg2rad(angles).ravel())
  374. assert_array_almost_equal(ec.get_widths(), widths)
  375. assert_array_almost_equal(ec.get_heights(), heights)
  376. assert_array_almost_equal(ec.get_angles(), angles)
  377. ax.add_collection(ec)
  378. ax.set_xlim(-2, 12)
  379. ax.set_ylim(-2, 12)
  380. new_widths = rng.random((10, 2)) * 2
  381. new_heights = rng.random((10, 2)) * 3
  382. new_angles = rng.random((10, 2)) * 180
  383. ec.set(widths=new_widths, heights=new_heights, angles=new_angles)
  384. assert_array_almost_equal(ec.get_widths(), new_widths.ravel())
  385. assert_array_almost_equal(ec.get_heights(), new_heights.ravel())
  386. assert_array_almost_equal(ec.get_angles(), new_angles.ravel())
  387. @image_comparison(['polycollection_close.png'], remove_text=True, style='mpl20')
  388. def test_polycollection_close():
  389. from mpl_toolkits.mplot3d import Axes3D # type: ignore[import]
  390. plt.rcParams['axes3d.automargin'] = True
  391. vertsQuad = [
  392. [[0., 0.], [0., 1.], [1., 1.], [1., 0.]],
  393. [[0., 1.], [2., 3.], [2., 2.], [1., 1.]],
  394. [[2., 2.], [2., 3.], [4., 1.], [3., 1.]],
  395. [[3., 0.], [3., 1.], [4., 1.], [4., 0.]]]
  396. fig = plt.figure()
  397. ax = fig.add_axes(Axes3D(fig))
  398. colors = ['r', 'g', 'b', 'y', 'k']
  399. zpos = list(range(5))
  400. poly = mcollections.PolyCollection(
  401. vertsQuad * len(zpos), linewidth=0.25)
  402. poly.set_alpha(0.7)
  403. # need to have a z-value for *each* polygon = element!
  404. zs = []
  405. cs = []
  406. for z, c in zip(zpos, colors):
  407. zs.extend([z] * len(vertsQuad))
  408. cs.extend([c] * len(vertsQuad))
  409. poly.set_color(cs)
  410. ax.add_collection3d(poly, zs=zs, zdir='y')
  411. # axis limit settings:
  412. ax.set_xlim3d(0, 4)
  413. ax.set_zlim3d(0, 3)
  414. ax.set_ylim3d(0, 4)
  415. @image_comparison(['regularpolycollection_rotate.png'], remove_text=True)
  416. def test_regularpolycollection_rotate():
  417. xx, yy = np.mgrid[:10, :10]
  418. xy_points = np.transpose([xx.flatten(), yy.flatten()])
  419. rotations = np.linspace(0, 2*np.pi, len(xy_points))
  420. fig, ax = plt.subplots()
  421. for xy, alpha in zip(xy_points, rotations):
  422. col = mcollections.RegularPolyCollection(
  423. 4, sizes=(100,), rotation=alpha,
  424. offsets=[xy], offset_transform=ax.transData)
  425. ax.add_collection(col, autolim=True)
  426. ax.autoscale_view()
  427. @image_comparison(['regularpolycollection_scale.png'], remove_text=True)
  428. def test_regularpolycollection_scale():
  429. # See issue #3860
  430. class SquareCollection(mcollections.RegularPolyCollection):
  431. def __init__(self, **kwargs):
  432. super().__init__(4, rotation=np.pi/4., **kwargs)
  433. def get_transform(self):
  434. """Return transform scaling circle areas to data space."""
  435. ax = self.axes
  436. pts2pixels = 72.0 / ax.get_figure(root=True).dpi
  437. scale_x = pts2pixels * ax.bbox.width / ax.viewLim.width
  438. scale_y = pts2pixels * ax.bbox.height / ax.viewLim.height
  439. return mtransforms.Affine2D().scale(scale_x, scale_y)
  440. fig, ax = plt.subplots()
  441. xy = [(0, 0)]
  442. # Unit square has a half-diagonal of `1/sqrt(2)`, so `pi * r**2` equals...
  443. circle_areas = [np.pi / 2]
  444. squares = SquareCollection(
  445. sizes=circle_areas, offsets=xy, offset_transform=ax.transData)
  446. ax.add_collection(squares, autolim=True)
  447. ax.axis([-1, 1, -1, 1])
  448. def test_picking():
  449. fig, ax = plt.subplots()
  450. col = ax.scatter([0], [0], [1000], picker=True)
  451. fig.savefig(io.BytesIO(), dpi=fig.dpi)
  452. mouse_event = SimpleNamespace(x=325, y=240)
  453. found, indices = col.contains(mouse_event)
  454. assert found
  455. assert_array_equal(indices['ind'], [0])
  456. def test_quadmesh_contains():
  457. x = np.arange(4)
  458. X = x[:, None] * x[None, :]
  459. fig, ax = plt.subplots()
  460. mesh = ax.pcolormesh(X)
  461. fig.draw_without_rendering()
  462. xdata, ydata = 0.5, 0.5
  463. x, y = mesh.get_transform().transform((xdata, ydata))
  464. mouse_event = SimpleNamespace(xdata=xdata, ydata=ydata, x=x, y=y)
  465. found, indices = mesh.contains(mouse_event)
  466. assert found
  467. assert_array_equal(indices['ind'], [0])
  468. xdata, ydata = 1.5, 1.5
  469. x, y = mesh.get_transform().transform((xdata, ydata))
  470. mouse_event = SimpleNamespace(xdata=xdata, ydata=ydata, x=x, y=y)
  471. found, indices = mesh.contains(mouse_event)
  472. assert found
  473. assert_array_equal(indices['ind'], [5])
  474. def test_quadmesh_contains_concave():
  475. # Test a concave polygon, V-like shape
  476. x = [[0, -1], [1, 0]]
  477. y = [[0, 1], [1, -1]]
  478. fig, ax = plt.subplots()
  479. mesh = ax.pcolormesh(x, y, [[0]])
  480. fig.draw_without_rendering()
  481. # xdata, ydata, expected
  482. points = [(-0.5, 0.25, True), # left wing
  483. (0, 0.25, False), # between the two wings
  484. (0.5, 0.25, True), # right wing
  485. (0, -0.25, True), # main body
  486. ]
  487. for point in points:
  488. xdata, ydata, expected = point
  489. x, y = mesh.get_transform().transform((xdata, ydata))
  490. mouse_event = SimpleNamespace(xdata=xdata, ydata=ydata, x=x, y=y)
  491. found, indices = mesh.contains(mouse_event)
  492. assert found is expected
  493. def test_quadmesh_cursor_data():
  494. x = np.arange(4)
  495. X = x[:, None] * x[None, :]
  496. fig, ax = plt.subplots()
  497. mesh = ax.pcolormesh(X)
  498. # Empty array data
  499. mesh._A = None
  500. fig.draw_without_rendering()
  501. xdata, ydata = 0.5, 0.5
  502. x, y = mesh.get_transform().transform((xdata, ydata))
  503. mouse_event = SimpleNamespace(xdata=xdata, ydata=ydata, x=x, y=y)
  504. # Empty collection should return None
  505. assert mesh.get_cursor_data(mouse_event) is None
  506. # Now test adding the array data, to make sure we do get a value
  507. mesh.set_array(np.ones(X.shape))
  508. assert_array_equal(mesh.get_cursor_data(mouse_event), [1])
  509. def test_quadmesh_cursor_data_multiple_points():
  510. x = [1, 2, 1, 2]
  511. fig, ax = plt.subplots()
  512. mesh = ax.pcolormesh(x, x, np.ones((3, 3)))
  513. fig.draw_without_rendering()
  514. xdata, ydata = 1.5, 1.5
  515. x, y = mesh.get_transform().transform((xdata, ydata))
  516. mouse_event = SimpleNamespace(xdata=xdata, ydata=ydata, x=x, y=y)
  517. # All quads are covering the same square
  518. assert_array_equal(mesh.get_cursor_data(mouse_event), np.ones(9))
  519. def test_linestyle_single_dashes():
  520. plt.scatter([0, 1, 2], [0, 1, 2], linestyle=(0., [2., 2.]))
  521. plt.draw()
  522. @image_comparison(['size_in_xy.png'], remove_text=True)
  523. def test_size_in_xy():
  524. fig, ax = plt.subplots()
  525. widths, heights, angles = (10, 10), 10, 0
  526. widths = 10, 10
  527. coords = [(10, 10), (15, 15)]
  528. e = mcollections.EllipseCollection(
  529. widths, heights, angles, units='xy',
  530. offsets=coords, offset_transform=ax.transData)
  531. ax.add_collection(e)
  532. ax.set_xlim(0, 30)
  533. ax.set_ylim(0, 30)
  534. def test_pandas_indexing(pd):
  535. # Should not fail break when faced with a
  536. # non-zero indexed series
  537. index = [11, 12, 13]
  538. ec = fc = pd.Series(['red', 'blue', 'green'], index=index)
  539. lw = pd.Series([1, 2, 3], index=index)
  540. ls = pd.Series(['solid', 'dashed', 'dashdot'], index=index)
  541. aa = pd.Series([True, False, True], index=index)
  542. Collection(edgecolors=ec)
  543. Collection(facecolors=fc)
  544. Collection(linewidths=lw)
  545. Collection(linestyles=ls)
  546. Collection(antialiaseds=aa)
  547. @mpl.style.context('default')
  548. def test_lslw_bcast():
  549. col = mcollections.PathCollection([])
  550. col.set_linestyles(['-', '-'])
  551. col.set_linewidths([1, 2, 3])
  552. assert col.get_linestyles() == [(0, None)] * 6
  553. assert col.get_linewidths() == [1, 2, 3] * 2
  554. col.set_linestyles(['-', '-', '-'])
  555. assert col.get_linestyles() == [(0, None)] * 3
  556. assert (col.get_linewidths() == [1, 2, 3]).all()
  557. def test_set_wrong_linestyle():
  558. c = Collection()
  559. with pytest.raises(ValueError, match="Do not know how to convert 'fuzzy'"):
  560. c.set_linestyle('fuzzy')
  561. @mpl.style.context('default')
  562. def test_capstyle():
  563. col = mcollections.PathCollection([])
  564. assert col.get_capstyle() is None
  565. col = mcollections.PathCollection([], capstyle='round')
  566. assert col.get_capstyle() == 'round'
  567. col.set_capstyle('butt')
  568. assert col.get_capstyle() == 'butt'
  569. @mpl.style.context('default')
  570. def test_joinstyle():
  571. col = mcollections.PathCollection([])
  572. assert col.get_joinstyle() is None
  573. col = mcollections.PathCollection([], joinstyle='round')
  574. assert col.get_joinstyle() == 'round'
  575. col.set_joinstyle('miter')
  576. assert col.get_joinstyle() == 'miter'
  577. @image_comparison(['cap_and_joinstyle.png'])
  578. def test_cap_and_joinstyle_image():
  579. fig, ax = plt.subplots()
  580. ax.set_xlim([-0.5, 1.5])
  581. ax.set_ylim([-0.5, 2.5])
  582. x = np.array([0.0, 1.0, 0.5])
  583. ys = np.array([[0.0], [0.5], [1.0]]) + np.array([[0.0, 0.0, 1.0]])
  584. segs = np.zeros((3, 3, 2))
  585. segs[:, :, 0] = x
  586. segs[:, :, 1] = ys
  587. line_segments = LineCollection(segs, linewidth=[10, 15, 20])
  588. line_segments.set_capstyle("round")
  589. line_segments.set_joinstyle("miter")
  590. ax.add_collection(line_segments)
  591. ax.set_title('Line collection with customized caps and joinstyle')
  592. @image_comparison(['scatter_post_alpha.png'],
  593. remove_text=True, style='default')
  594. def test_scatter_post_alpha():
  595. fig, ax = plt.subplots()
  596. sc = ax.scatter(range(5), range(5), c=range(5))
  597. sc.set_alpha(.1)
  598. def test_scatter_alpha_array():
  599. x = np.arange(5)
  600. alpha = x / 5
  601. # With colormapping.
  602. fig, (ax0, ax1) = plt.subplots(2)
  603. sc0 = ax0.scatter(x, x, c=x, alpha=alpha)
  604. sc1 = ax1.scatter(x, x, c=x)
  605. sc1.set_alpha(alpha)
  606. plt.draw()
  607. assert_array_equal(sc0.get_facecolors()[:, -1], alpha)
  608. assert_array_equal(sc1.get_facecolors()[:, -1], alpha)
  609. # Without colormapping.
  610. fig, (ax0, ax1) = plt.subplots(2)
  611. sc0 = ax0.scatter(x, x, color=['r', 'g', 'b', 'c', 'm'], alpha=alpha)
  612. sc1 = ax1.scatter(x, x, color='r', alpha=alpha)
  613. plt.draw()
  614. assert_array_equal(sc0.get_facecolors()[:, -1], alpha)
  615. assert_array_equal(sc1.get_facecolors()[:, -1], alpha)
  616. # Without colormapping, and set alpha afterward.
  617. fig, (ax0, ax1) = plt.subplots(2)
  618. sc0 = ax0.scatter(x, x, color=['r', 'g', 'b', 'c', 'm'])
  619. sc0.set_alpha(alpha)
  620. sc1 = ax1.scatter(x, x, color='r')
  621. sc1.set_alpha(alpha)
  622. plt.draw()
  623. assert_array_equal(sc0.get_facecolors()[:, -1], alpha)
  624. assert_array_equal(sc1.get_facecolors()[:, -1], alpha)
  625. def test_pathcollection_legend_elements():
  626. np.random.seed(19680801)
  627. x, y = np.random.rand(2, 10)
  628. y = np.random.rand(10)
  629. c = np.random.randint(0, 5, size=10)
  630. s = np.random.randint(10, 300, size=10)
  631. fig, ax = plt.subplots()
  632. sc = ax.scatter(x, y, c=c, s=s, cmap="jet", marker="o", linewidths=0)
  633. h, l = sc.legend_elements(fmt="{x:g}")
  634. assert len(h) == 5
  635. assert l == ["0", "1", "2", "3", "4"]
  636. colors = np.array([line.get_color() for line in h])
  637. colors2 = sc.cmap(np.arange(5)/4)
  638. assert_array_equal(colors, colors2)
  639. l1 = ax.legend(h, l, loc=1)
  640. h2, lab2 = sc.legend_elements(num=9)
  641. assert len(h2) == 9
  642. l2 = ax.legend(h2, lab2, loc=2)
  643. h, l = sc.legend_elements(prop="sizes", alpha=0.5, color="red")
  644. assert all(line.get_alpha() == 0.5 for line in h)
  645. assert all(line.get_markerfacecolor() == "red" for line in h)
  646. l3 = ax.legend(h, l, loc=4)
  647. h, l = sc.legend_elements(prop="sizes", num=4, fmt="{x:.2f}",
  648. func=lambda x: 2*x)
  649. actsizes = [line.get_markersize() for line in h]
  650. labeledsizes = np.sqrt(np.array(l, float) / 2)
  651. assert_array_almost_equal(actsizes, labeledsizes)
  652. l4 = ax.legend(h, l, loc=3)
  653. loc = mpl.ticker.MaxNLocator(nbins=9, min_n_ticks=9-1,
  654. steps=[1, 2, 2.5, 3, 5, 6, 8, 10])
  655. h5, lab5 = sc.legend_elements(num=loc)
  656. assert len(h2) == len(h5)
  657. levels = [-1, 0, 55.4, 260]
  658. h6, lab6 = sc.legend_elements(num=levels, prop="sizes", fmt="{x:g}")
  659. assert [float(l) for l in lab6] == levels[2:]
  660. for l in [l1, l2, l3, l4]:
  661. ax.add_artist(l)
  662. fig.canvas.draw()
  663. def test_EventCollection_nosort():
  664. # Check that EventCollection doesn't modify input in place
  665. arr = np.array([3, 2, 1, 10])
  666. coll = EventCollection(arr)
  667. np.testing.assert_array_equal(arr, np.array([3, 2, 1, 10]))
  668. def test_collection_set_verts_array():
  669. verts = np.arange(80, dtype=np.double).reshape(10, 4, 2)
  670. col_arr = PolyCollection(verts)
  671. col_list = PolyCollection(list(verts))
  672. assert len(col_arr._paths) == len(col_list._paths)
  673. for ap, lp in zip(col_arr._paths, col_list._paths):
  674. assert np.array_equal(ap._vertices, lp._vertices)
  675. assert np.array_equal(ap._codes, lp._codes)
  676. verts_tuple = np.empty(10, dtype=object)
  677. verts_tuple[:] = [tuple(tuple(y) for y in x) for x in verts]
  678. col_arr_tuple = PolyCollection(verts_tuple)
  679. assert len(col_arr._paths) == len(col_arr_tuple._paths)
  680. for ap, atp in zip(col_arr._paths, col_arr_tuple._paths):
  681. assert np.array_equal(ap._vertices, atp._vertices)
  682. assert np.array_equal(ap._codes, atp._codes)
  683. @check_figures_equal(extensions=["png"])
  684. @pytest.mark.parametrize("kwargs", [{}, {"step": "pre"}])
  685. def test_fill_between_poly_collection_set_data(fig_test, fig_ref, kwargs):
  686. t = np.linspace(0, 16)
  687. f1 = np.sin(t)
  688. f2 = f1 + 0.2
  689. fig_ref.subplots().fill_between(t, f1, f2, **kwargs)
  690. coll = fig_test.subplots().fill_between(t, -1, 1.2, **kwargs)
  691. coll.set_data(t, f1, f2)
  692. @pytest.mark.parametrize(("t_direction", "f1", "shape", "where", "msg"), [
  693. ("z", None, None, None, r"t_direction must be 'x' or 'y', got 'z'"),
  694. ("x", None, (-1, 1), None, r"'x' is not 1-dimensional"),
  695. ("x", None, None, [False] * 3, r"where size \(3\) does not match 'x' size \(\d+\)"),
  696. ("y", [1, 2], None, None, r"'y' has size \d+, but 'x1' has an unequal size of \d+"),
  697. ])
  698. def test_fill_between_poly_collection_raise(t_direction, f1, shape, where, msg):
  699. t = np.linspace(0, 16)
  700. f1 = np.sin(t) if f1 is None else np.asarray(f1)
  701. f2 = f1 + 0.2
  702. if shape:
  703. t = t.reshape(*shape)
  704. with pytest.raises(ValueError, match=msg):
  705. FillBetweenPolyCollection(t_direction, t, f1, f2, where=where)
  706. def test_collection_set_array():
  707. vals = [*range(10)]
  708. # Test set_array with list
  709. c = Collection()
  710. c.set_array(vals)
  711. # Test set_array with wrong dtype
  712. with pytest.raises(TypeError, match="^Image data of dtype"):
  713. c.set_array("wrong_input")
  714. # Test if array kwarg is copied
  715. vals[5] = 45
  716. assert np.not_equal(vals, c.get_array()).any()
  717. def test_blended_collection_autolim():
  718. f, ax = plt.subplots()
  719. # sample data to give initial data limits
  720. ax.plot([2, 3, 4], [0.4, 0.6, 0.5])
  721. np.testing.assert_allclose((ax.dataLim.xmin, ax.dataLim.xmax), (2, 4))
  722. data_ymin, data_ymax = ax.dataLim.ymin, ax.dataLim.ymax
  723. # LineCollection with vertical lines spanning the Axes vertical, using transAxes
  724. x = [1, 2, 3, 4, 5]
  725. vertical_lines = [np.array([[xi, 0], [xi, 1]]) for xi in x]
  726. trans = mtransforms.blended_transform_factory(ax.transData, ax.transAxes)
  727. ax.add_collection(LineCollection(vertical_lines, transform=trans))
  728. # check that the x data limits are updated to include the LineCollection
  729. np.testing.assert_allclose((ax.dataLim.xmin, ax.dataLim.xmax), (1, 5))
  730. # check that the y data limits are not updated (because they are not transData)
  731. np.testing.assert_allclose((ax.dataLim.ymin, ax.dataLim.ymax),
  732. (data_ymin, data_ymax))
  733. def test_singleton_autolim():
  734. fig, ax = plt.subplots()
  735. ax.scatter(0, 0)
  736. np.testing.assert_allclose(ax.get_ylim(), [-0.06, 0.06])
  737. np.testing.assert_allclose(ax.get_xlim(), [-0.06, 0.06])
  738. @pytest.mark.parametrize("transform, expected", [
  739. ("transData", (-0.5, 3.5)),
  740. ("transAxes", (2.8, 3.2)),
  741. ])
  742. def test_autolim_with_zeros(transform, expected):
  743. # 1) Test that a scatter at (0, 0) data coordinates contributes to
  744. # autoscaling even though any(offsets) would be False in that situation.
  745. # 2) Test that specifying transAxes for the transform does not contribute
  746. # to the autoscaling.
  747. fig, ax = plt.subplots()
  748. ax.scatter(0, 0, transform=getattr(ax, transform))
  749. ax.scatter(3, 3)
  750. np.testing.assert_allclose(ax.get_ylim(), expected)
  751. np.testing.assert_allclose(ax.get_xlim(), expected)
  752. def test_quadmesh_set_array_validation(pcfunc):
  753. x = np.arange(11)
  754. y = np.arange(8)
  755. z = np.random.random((7, 10))
  756. fig, ax = plt.subplots()
  757. coll = getattr(ax, pcfunc)(x, y, z)
  758. with pytest.raises(ValueError, match=re.escape(
  759. "For X (11) and Y (8) with flat shading, A should have shape "
  760. "(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (10, 7)")):
  761. coll.set_array(z.reshape(10, 7))
  762. z = np.arange(54).reshape((6, 9))
  763. with pytest.raises(ValueError, match=re.escape(
  764. "For X (11) and Y (8) with flat shading, A should have shape "
  765. "(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (6, 9)")):
  766. coll.set_array(z)
  767. with pytest.raises(ValueError, match=re.escape(
  768. "For X (11) and Y (8) with flat shading, A should have shape "
  769. "(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (54,)")):
  770. coll.set_array(z.ravel())
  771. # RGB(A) tests
  772. z = np.ones((9, 6, 3)) # RGB with wrong X/Y dims
  773. with pytest.raises(ValueError, match=re.escape(
  774. "For X (11) and Y (8) with flat shading, A should have shape "
  775. "(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (9, 6, 3)")):
  776. coll.set_array(z)
  777. z = np.ones((9, 6, 4)) # RGBA with wrong X/Y dims
  778. with pytest.raises(ValueError, match=re.escape(
  779. "For X (11) and Y (8) with flat shading, A should have shape "
  780. "(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (9, 6, 4)")):
  781. coll.set_array(z)
  782. z = np.ones((7, 10, 2)) # Right X/Y dims, bad 3rd dim
  783. with pytest.raises(ValueError, match=re.escape(
  784. "For X (11) and Y (8) with flat shading, A should have shape "
  785. "(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (7, 10, 2)")):
  786. coll.set_array(z)
  787. x = np.arange(10)
  788. y = np.arange(7)
  789. z = np.random.random((7, 10))
  790. fig, ax = plt.subplots()
  791. coll = ax.pcolormesh(x, y, z, shading='gouraud')
  792. def test_polyquadmesh_masked_vertices_array():
  793. xx, yy = np.meshgrid([0, 1, 2], [0, 1, 2, 3])
  794. # 2 x 3 mesh data
  795. zz = (xx*yy)[:-1, :-1]
  796. quadmesh = plt.pcolormesh(xx, yy, zz)
  797. quadmesh.update_scalarmappable()
  798. quadmesh_fc = quadmesh.get_facecolor()[1:, :]
  799. # Mask the origin vertex in x
  800. xx = np.ma.masked_where((xx == 0) & (yy == 0), xx)
  801. polymesh = plt.pcolor(xx, yy, zz)
  802. polymesh.update_scalarmappable()
  803. # One cell should be left out
  804. assert len(polymesh.get_paths()) == 5
  805. # Poly version should have the same facecolors as the end of the quadmesh
  806. assert_array_equal(quadmesh_fc, polymesh.get_facecolor())
  807. # Mask the origin vertex in y
  808. yy = np.ma.masked_where((xx == 0) & (yy == 0), yy)
  809. polymesh = plt.pcolor(xx, yy, zz)
  810. polymesh.update_scalarmappable()
  811. # One cell should be left out
  812. assert len(polymesh.get_paths()) == 5
  813. # Poly version should have the same facecolors as the end of the quadmesh
  814. assert_array_equal(quadmesh_fc, polymesh.get_facecolor())
  815. # Mask the origin cell data
  816. zz = np.ma.masked_where((xx[:-1, :-1] == 0) & (yy[:-1, :-1] == 0), zz)
  817. polymesh = plt.pcolor(zz)
  818. polymesh.update_scalarmappable()
  819. # One cell should be left out
  820. assert len(polymesh.get_paths()) == 5
  821. # Poly version should have the same facecolors as the end of the quadmesh
  822. assert_array_equal(quadmesh_fc, polymesh.get_facecolor())
  823. # We should also be able to call set_array with a new mask and get
  824. # updated polys
  825. # Remove mask, should add all polys back
  826. zz = np.arange(6).reshape((3, 2))
  827. polymesh.set_array(zz)
  828. polymesh.update_scalarmappable()
  829. assert len(polymesh.get_paths()) == 6
  830. # Add mask should remove polys
  831. zz = np.ma.masked_less(zz, 2)
  832. polymesh.set_array(zz)
  833. polymesh.update_scalarmappable()
  834. assert len(polymesh.get_paths()) == 4
  835. def test_quadmesh_get_coordinates(pcfunc):
  836. x = [0, 1, 2]
  837. y = [2, 4, 6]
  838. z = np.ones(shape=(2, 2))
  839. xx, yy = np.meshgrid(x, y)
  840. coll = getattr(plt, pcfunc)(xx, yy, z)
  841. # shape (3, 3, 2)
  842. coords = np.stack([xx.T, yy.T]).T
  843. assert_array_equal(coll.get_coordinates(), coords)
  844. def test_quadmesh_set_array():
  845. x = np.arange(4)
  846. y = np.arange(4)
  847. z = np.arange(9).reshape((3, 3))
  848. fig, ax = plt.subplots()
  849. coll = ax.pcolormesh(x, y, np.ones(z.shape))
  850. # Test that the collection is able to update with a 2d array
  851. coll.set_array(z)
  852. fig.canvas.draw()
  853. assert np.array_equal(coll.get_array(), z)
  854. # Check that pre-flattened arrays work too
  855. coll.set_array(np.ones(9))
  856. fig.canvas.draw()
  857. assert np.array_equal(coll.get_array(), np.ones(9))
  858. z = np.arange(16).reshape((4, 4))
  859. fig, ax = plt.subplots()
  860. coll = ax.pcolormesh(x, y, np.ones(z.shape), shading='gouraud')
  861. # Test that the collection is able to update with a 2d array
  862. coll.set_array(z)
  863. fig.canvas.draw()
  864. assert np.array_equal(coll.get_array(), z)
  865. # Check that pre-flattened arrays work too
  866. coll.set_array(np.ones(16))
  867. fig.canvas.draw()
  868. assert np.array_equal(coll.get_array(), np.ones(16))
  869. def test_quadmesh_vmin_vmax(pcfunc):
  870. # test when vmin/vmax on the norm changes, the quadmesh gets updated
  871. fig, ax = plt.subplots()
  872. cmap = mpl.colormaps['plasma']
  873. norm = mpl.colors.Normalize(vmin=0, vmax=1)
  874. coll = getattr(ax, pcfunc)([[1]], cmap=cmap, norm=norm)
  875. fig.canvas.draw()
  876. assert np.array_equal(coll.get_facecolors()[0, :], cmap(norm(1)))
  877. # Change the vmin/vmax of the norm so that the color is from
  878. # the bottom of the colormap now
  879. norm.vmin, norm.vmax = 1, 2
  880. fig.canvas.draw()
  881. assert np.array_equal(coll.get_facecolors()[0, :], cmap(norm(1)))
  882. def test_quadmesh_alpha_array(pcfunc):
  883. x = np.arange(4)
  884. y = np.arange(4)
  885. z = np.arange(9).reshape((3, 3))
  886. alpha = z / z.max()
  887. alpha_flat = alpha.ravel()
  888. # Provide 2-D alpha:
  889. fig, (ax0, ax1) = plt.subplots(2)
  890. coll1 = getattr(ax0, pcfunc)(x, y, z, alpha=alpha)
  891. coll2 = getattr(ax0, pcfunc)(x, y, z)
  892. coll2.set_alpha(alpha)
  893. plt.draw()
  894. assert_array_equal(coll1.get_facecolors()[:, -1], alpha_flat)
  895. assert_array_equal(coll2.get_facecolors()[:, -1], alpha_flat)
  896. # Or provide 1-D alpha:
  897. fig, (ax0, ax1) = plt.subplots(2)
  898. coll1 = getattr(ax0, pcfunc)(x, y, z, alpha=alpha)
  899. coll2 = getattr(ax1, pcfunc)(x, y, z)
  900. coll2.set_alpha(alpha)
  901. plt.draw()
  902. assert_array_equal(coll1.get_facecolors()[:, -1], alpha_flat)
  903. assert_array_equal(coll2.get_facecolors()[:, -1], alpha_flat)
  904. def test_alpha_validation(pcfunc):
  905. # Most of the relevant testing is in test_artist and test_colors.
  906. fig, ax = plt.subplots()
  907. pc = getattr(ax, pcfunc)(np.arange(12).reshape((3, 4)))
  908. with pytest.raises(ValueError, match="^Data array shape"):
  909. pc.set_alpha([0.5, 0.6])
  910. pc.update_scalarmappable()
  911. def test_legend_inverse_size_label_relationship():
  912. """
  913. Ensure legend markers scale appropriately when label and size are
  914. inversely related.
  915. Here label = 5 / size
  916. """
  917. np.random.seed(19680801)
  918. X = np.random.random(50)
  919. Y = np.random.random(50)
  920. C = 1 - np.random.random(50)
  921. S = 5 / C
  922. legend_sizes = [0.2, 0.4, 0.6, 0.8]
  923. fig, ax = plt.subplots()
  924. sc = ax.scatter(X, Y, s=S)
  925. handles, labels = sc.legend_elements(
  926. prop='sizes', num=legend_sizes, func=lambda s: 5 / s
  927. )
  928. # Convert markersize scale to 's' scale
  929. handle_sizes = [x.get_markersize() for x in handles]
  930. handle_sizes = [5 / x**2 for x in handle_sizes]
  931. assert_array_almost_equal(handle_sizes, legend_sizes, decimal=1)
  932. @mpl.style.context('default')
  933. def test_color_logic(pcfunc):
  934. pcfunc = getattr(plt, pcfunc)
  935. z = np.arange(12).reshape(3, 4)
  936. # Explicitly set an edgecolor.
  937. pc = pcfunc(z, edgecolors='red', facecolors='none')
  938. pc.update_scalarmappable() # This is called in draw().
  939. # Define 2 reference "colors" here for multiple use.
  940. face_default = mcolors.to_rgba_array(pc._get_default_facecolor())
  941. mapped = pc.get_cmap()(pc.norm(z.ravel()))
  942. # GitHub issue #1302:
  943. assert mcolors.same_color(pc.get_edgecolor(), 'red')
  944. # Check setting attributes after initialization:
  945. pc = pcfunc(z)
  946. pc.set_facecolor('none')
  947. pc.set_edgecolor('red')
  948. pc.update_scalarmappable()
  949. assert mcolors.same_color(pc.get_facecolor(), 'none')
  950. assert mcolors.same_color(pc.get_edgecolor(), [[1, 0, 0, 1]])
  951. pc.set_alpha(0.5)
  952. pc.update_scalarmappable()
  953. assert mcolors.same_color(pc.get_edgecolor(), [[1, 0, 0, 0.5]])
  954. pc.set_alpha(None) # restore default alpha
  955. pc.update_scalarmappable()
  956. assert mcolors.same_color(pc.get_edgecolor(), [[1, 0, 0, 1]])
  957. # Reset edgecolor to default.
  958. pc.set_edgecolor(None)
  959. pc.update_scalarmappable()
  960. assert np.array_equal(pc.get_edgecolor(), mapped)
  961. pc.set_facecolor(None) # restore default for facecolor
  962. pc.update_scalarmappable()
  963. assert np.array_equal(pc.get_facecolor(), mapped)
  964. assert mcolors.same_color(pc.get_edgecolor(), 'none')
  965. # Turn off colormapping entirely:
  966. pc.set_array(None)
  967. pc.update_scalarmappable()
  968. assert mcolors.same_color(pc.get_edgecolor(), 'none')
  969. assert mcolors.same_color(pc.get_facecolor(), face_default) # not mapped
  970. # Turn it back on by restoring the array (must be 1D!):
  971. pc.set_array(z)
  972. pc.update_scalarmappable()
  973. assert np.array_equal(pc.get_facecolor(), mapped)
  974. assert mcolors.same_color(pc.get_edgecolor(), 'none')
  975. # Give color via tuple rather than string.
  976. pc = pcfunc(z, edgecolors=(1, 0, 0), facecolors=(0, 1, 0))
  977. pc.update_scalarmappable()
  978. assert np.array_equal(pc.get_facecolor(), mapped)
  979. assert mcolors.same_color(pc.get_edgecolor(), [[1, 0, 0, 1]])
  980. # Provide an RGB array; mapping overrides it.
  981. pc = pcfunc(z, edgecolors=(1, 0, 0), facecolors=np.ones((12, 3)))
  982. pc.update_scalarmappable()
  983. assert np.array_equal(pc.get_facecolor(), mapped)
  984. assert mcolors.same_color(pc.get_edgecolor(), [[1, 0, 0, 1]])
  985. # Turn off the mapping.
  986. pc.set_array(None)
  987. pc.update_scalarmappable()
  988. assert mcolors.same_color(pc.get_facecolor(), np.ones((12, 3)))
  989. assert mcolors.same_color(pc.get_edgecolor(), [[1, 0, 0, 1]])
  990. # And an RGBA array.
  991. pc = pcfunc(z, edgecolors=(1, 0, 0), facecolors=np.ones((12, 4)))
  992. pc.update_scalarmappable()
  993. assert np.array_equal(pc.get_facecolor(), mapped)
  994. assert mcolors.same_color(pc.get_edgecolor(), [[1, 0, 0, 1]])
  995. # Turn off the mapping.
  996. pc.set_array(None)
  997. pc.update_scalarmappable()
  998. assert mcolors.same_color(pc.get_facecolor(), np.ones((12, 4)))
  999. assert mcolors.same_color(pc.get_edgecolor(), [[1, 0, 0, 1]])
  1000. def test_LineCollection_args():
  1001. lc = LineCollection(None, linewidth=2.2, edgecolor='r',
  1002. zorder=3, facecolors=[0, 1, 0, 1])
  1003. assert lc.get_linewidth()[0] == 2.2
  1004. assert mcolors.same_color(lc.get_edgecolor(), 'r')
  1005. assert lc.get_zorder() == 3
  1006. assert mcolors.same_color(lc.get_facecolor(), [[0, 1, 0, 1]])
  1007. # To avoid breaking mplot3d, LineCollection internally sets the facecolor
  1008. # kwarg if it has not been specified. Hence we need the following test
  1009. # for LineCollection._set_default().
  1010. lc = LineCollection(None, facecolor=None)
  1011. assert mcolors.same_color(lc.get_facecolor(), 'none')
  1012. def test_array_dimensions(pcfunc):
  1013. # Make sure we can set the 1D, 2D, and 3D array shapes
  1014. z = np.arange(12).reshape(3, 4)
  1015. pc = getattr(plt, pcfunc)(z)
  1016. # 1D
  1017. pc.set_array(z.ravel())
  1018. pc.update_scalarmappable()
  1019. # 2D
  1020. pc.set_array(z)
  1021. pc.update_scalarmappable()
  1022. # 3D RGB is OK as well
  1023. z = np.arange(36, dtype=np.uint8).reshape(3, 4, 3)
  1024. pc.set_array(z)
  1025. pc.update_scalarmappable()
  1026. def test_get_segments():
  1027. segments = np.tile(np.linspace(0, 1, 256), (2, 1)).T
  1028. lc = LineCollection([segments])
  1029. readback, = lc.get_segments()
  1030. # these should comeback un-changed!
  1031. assert np.all(segments == readback)
  1032. def test_set_offsets_late():
  1033. identity = mtransforms.IdentityTransform()
  1034. sizes = [2]
  1035. null = mcollections.CircleCollection(sizes=sizes)
  1036. init = mcollections.CircleCollection(sizes=sizes, offsets=(10, 10))
  1037. late = mcollections.CircleCollection(sizes=sizes)
  1038. late.set_offsets((10, 10))
  1039. # Bbox.__eq__ doesn't compare bounds
  1040. null_bounds = null.get_datalim(identity).bounds
  1041. init_bounds = init.get_datalim(identity).bounds
  1042. late_bounds = late.get_datalim(identity).bounds
  1043. # offsets and transform are applied when set after initialization
  1044. assert null_bounds != init_bounds
  1045. assert init_bounds == late_bounds
  1046. def test_set_offset_transform():
  1047. skew = mtransforms.Affine2D().skew(2, 2)
  1048. init = mcollections.Collection(offset_transform=skew)
  1049. late = mcollections.Collection()
  1050. late.set_offset_transform(skew)
  1051. assert skew == init.get_offset_transform() == late.get_offset_transform()
  1052. def test_set_offset_units():
  1053. # passing the offsets in initially (i.e. via scatter)
  1054. # should yield the same results as `set_offsets`
  1055. x = np.linspace(0, 10, 5)
  1056. y = np.sin(x)
  1057. d = x * np.timedelta64(24, 'h') + np.datetime64('2021-11-29')
  1058. sc = plt.scatter(d, y)
  1059. off0 = sc.get_offsets()
  1060. sc.set_offsets(list(zip(d, y)))
  1061. np.testing.assert_allclose(off0, sc.get_offsets())
  1062. # try the other way around
  1063. fig, ax = plt.subplots()
  1064. sc = ax.scatter(y, d)
  1065. off0 = sc.get_offsets()
  1066. sc.set_offsets(list(zip(y, d)))
  1067. np.testing.assert_allclose(off0, sc.get_offsets())
  1068. @image_comparison(baseline_images=["test_check_masked_offsets"],
  1069. extensions=["png"], remove_text=True, style="mpl20")
  1070. def test_check_masked_offsets():
  1071. # Check if masked data is respected by scatter
  1072. # Ref: Issue #24545
  1073. unmasked_x = [
  1074. datetime(2022, 12, 15, 4, 49, 52),
  1075. datetime(2022, 12, 15, 4, 49, 53),
  1076. datetime(2022, 12, 15, 4, 49, 54),
  1077. datetime(2022, 12, 15, 4, 49, 55),
  1078. datetime(2022, 12, 15, 4, 49, 56),
  1079. ]
  1080. masked_y = np.ma.array([1, 2, 3, 4, 5], mask=[0, 1, 1, 0, 0])
  1081. fig, ax = plt.subplots()
  1082. ax.scatter(unmasked_x, masked_y)
  1083. @check_figures_equal(extensions=["png"])
  1084. def test_masked_set_offsets(fig_ref, fig_test):
  1085. x = np.ma.array([1, 2, 3, 4, 5], mask=[0, 0, 1, 1, 0])
  1086. y = np.arange(1, 6)
  1087. ax_test = fig_test.add_subplot()
  1088. scat = ax_test.scatter(x, y)
  1089. scat.set_offsets(np.ma.column_stack([x, y]))
  1090. ax_test.set_xticks([])
  1091. ax_test.set_yticks([])
  1092. ax_ref = fig_ref.add_subplot()
  1093. ax_ref.scatter([1, 2, 5], [1, 2, 5])
  1094. ax_ref.set_xticks([])
  1095. ax_ref.set_yticks([])
  1096. def test_check_offsets_dtype():
  1097. # Check that setting offsets doesn't change dtype
  1098. x = np.ma.array([1, 2, 3, 4, 5], mask=[0, 0, 1, 1, 0])
  1099. y = np.arange(1, 6)
  1100. fig, ax = plt.subplots()
  1101. scat = ax.scatter(x, y)
  1102. masked_offsets = np.ma.column_stack([x, y])
  1103. scat.set_offsets(masked_offsets)
  1104. assert isinstance(scat.get_offsets(), type(masked_offsets))
  1105. unmasked_offsets = np.column_stack([x, y])
  1106. scat.set_offsets(unmasked_offsets)
  1107. assert isinstance(scat.get_offsets(), type(unmasked_offsets))
  1108. @pytest.mark.parametrize('gapcolor', ['orange', ['r', 'k']])
  1109. @check_figures_equal(extensions=['png'])
  1110. def test_striped_lines(fig_test, fig_ref, gapcolor):
  1111. ax_test = fig_test.add_subplot(111)
  1112. ax_ref = fig_ref.add_subplot(111)
  1113. for ax in [ax_test, ax_ref]:
  1114. ax.set_xlim(0, 6)
  1115. ax.set_ylim(0, 1)
  1116. x = range(1, 6)
  1117. linestyles = [':', '-', '--']
  1118. ax_test.vlines(x, 0, 1, linewidth=20, linestyle=linestyles, gapcolor=gapcolor,
  1119. alpha=0.5)
  1120. if isinstance(gapcolor, str):
  1121. gapcolor = [gapcolor]
  1122. for x, gcol, ls in zip(x, itertools.cycle(gapcolor),
  1123. itertools.cycle(linestyles)):
  1124. ax_ref.axvline(x, 0, 1, linewidth=20, linestyle=ls, gapcolor=gcol, alpha=0.5)
  1125. @check_figures_equal(extensions=['png', 'pdf', 'svg', 'eps'])
  1126. def test_hatch_linewidth(fig_test, fig_ref):
  1127. ax_test = fig_test.add_subplot()
  1128. ax_ref = fig_ref.add_subplot()
  1129. lw = 2.0
  1130. polygons = [
  1131. [(0.1, 0.1), (0.1, 0.4), (0.4, 0.4), (0.4, 0.1)],
  1132. [(0.6, 0.6), (0.6, 0.9), (0.9, 0.9), (0.9, 0.6)],
  1133. ]
  1134. ref = PolyCollection(polygons, hatch="x")
  1135. ref.set_hatch_linewidth(lw)
  1136. with mpl.rc_context({"hatch.linewidth": lw}):
  1137. test = PolyCollection(polygons, hatch="x")
  1138. ax_ref.add_collection(ref)
  1139. ax_test.add_collection(test)
  1140. assert test.get_hatch_linewidth() == ref.get_hatch_linewidth() == lw