test_table.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import datetime
  2. from unittest.mock import Mock
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from matplotlib.path import Path
  6. from matplotlib.table import CustomCell, Table
  7. from matplotlib.testing.decorators import image_comparison, check_figures_equal
  8. from matplotlib.transforms import Bbox
  9. import matplotlib.units as munits
  10. def test_non_square():
  11. # Check that creating a non-square table works
  12. cellcolors = ['b', 'r']
  13. plt.table(cellColours=cellcolors)
  14. @image_comparison(['table_zorder.png'], remove_text=True)
  15. def test_zorder():
  16. data = [[66386, 174296],
  17. [58230, 381139]]
  18. colLabels = ('Freeze', 'Wind')
  19. rowLabels = ['%d year' % x for x in (100, 50)]
  20. cellText = []
  21. yoff = np.zeros(len(colLabels))
  22. for row in reversed(data):
  23. yoff += row
  24. cellText.append(['%1.1f' % (x/1000.0) for x in yoff])
  25. t = np.linspace(0, 2*np.pi, 100)
  26. plt.plot(t, np.cos(t), lw=4, zorder=2)
  27. plt.table(cellText=cellText,
  28. rowLabels=rowLabels,
  29. colLabels=colLabels,
  30. loc='center',
  31. zorder=-2,
  32. )
  33. plt.table(cellText=cellText,
  34. rowLabels=rowLabels,
  35. colLabels=colLabels,
  36. loc='upper center',
  37. zorder=4,
  38. )
  39. plt.yticks([])
  40. @image_comparison(['table_labels.png'])
  41. def test_label_colours():
  42. dim = 3
  43. c = np.linspace(0, 1, dim)
  44. colours = plt.cm.RdYlGn(c)
  45. cellText = [['1'] * dim] * dim
  46. fig = plt.figure()
  47. ax1 = fig.add_subplot(4, 1, 1)
  48. ax1.axis('off')
  49. ax1.table(cellText=cellText,
  50. rowColours=colours,
  51. loc='best')
  52. ax2 = fig.add_subplot(4, 1, 2)
  53. ax2.axis('off')
  54. ax2.table(cellText=cellText,
  55. rowColours=colours,
  56. rowLabels=['Header'] * dim,
  57. loc='best')
  58. ax3 = fig.add_subplot(4, 1, 3)
  59. ax3.axis('off')
  60. ax3.table(cellText=cellText,
  61. colColours=colours,
  62. loc='best')
  63. ax4 = fig.add_subplot(4, 1, 4)
  64. ax4.axis('off')
  65. ax4.table(cellText=cellText,
  66. colColours=colours,
  67. colLabels=['Header'] * dim,
  68. loc='best')
  69. @image_comparison(['table_cell_manipulation.png'], style='mpl20')
  70. def test_diff_cell_table(text_placeholders):
  71. cells = ('horizontal', 'vertical', 'open', 'closed', 'T', 'R', 'B', 'L')
  72. cellText = [['1'] * len(cells)] * 2
  73. colWidths = [0.1] * len(cells)
  74. _, axs = plt.subplots(nrows=len(cells), figsize=(4, len(cells)+1), layout='tight')
  75. for ax, cell in zip(axs, cells):
  76. ax.table(
  77. colWidths=colWidths,
  78. cellText=cellText,
  79. loc='center',
  80. edges=cell,
  81. )
  82. ax.axis('off')
  83. def test_customcell():
  84. types = ('horizontal', 'vertical', 'open', 'closed', 'T', 'R', 'B', 'L')
  85. codes = (
  86. (Path.MOVETO, Path.LINETO, Path.MOVETO, Path.LINETO, Path.MOVETO),
  87. (Path.MOVETO, Path.MOVETO, Path.LINETO, Path.MOVETO, Path.LINETO),
  88. (Path.MOVETO, Path.MOVETO, Path.MOVETO, Path.MOVETO, Path.MOVETO),
  89. (Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY),
  90. (Path.MOVETO, Path.MOVETO, Path.MOVETO, Path.LINETO, Path.MOVETO),
  91. (Path.MOVETO, Path.MOVETO, Path.LINETO, Path.MOVETO, Path.MOVETO),
  92. (Path.MOVETO, Path.LINETO, Path.MOVETO, Path.MOVETO, Path.MOVETO),
  93. (Path.MOVETO, Path.MOVETO, Path.MOVETO, Path.MOVETO, Path.LINETO),
  94. )
  95. for t, c in zip(types, codes):
  96. cell = CustomCell((0, 0), visible_edges=t, width=1, height=1)
  97. code = tuple(s for _, s in cell.get_path().iter_segments())
  98. assert c == code
  99. @image_comparison(['table_auto_column.png'])
  100. def test_auto_column():
  101. fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1)
  102. # iterable list input
  103. ax1.axis('off')
  104. tb1 = ax1.table(
  105. cellText=[['Fit Text', 2],
  106. ['very long long text, Longer text than default', 1]],
  107. rowLabels=["A", "B"],
  108. colLabels=["Col1", "Col2"],
  109. loc="center")
  110. tb1.auto_set_font_size(False)
  111. tb1.set_fontsize(12)
  112. tb1.auto_set_column_width([-1, 0, 1])
  113. # iterable tuple input
  114. ax2.axis('off')
  115. tb2 = ax2.table(
  116. cellText=[['Fit Text', 2],
  117. ['very long long text, Longer text than default', 1]],
  118. rowLabels=["A", "B"],
  119. colLabels=["Col1", "Col2"],
  120. loc="center")
  121. tb2.auto_set_font_size(False)
  122. tb2.set_fontsize(12)
  123. tb2.auto_set_column_width((-1, 0, 1))
  124. # 3 single inputs
  125. ax3.axis('off')
  126. tb3 = ax3.table(
  127. cellText=[['Fit Text', 2],
  128. ['very long long text, Longer text than default', 1]],
  129. rowLabels=["A", "B"],
  130. colLabels=["Col1", "Col2"],
  131. loc="center")
  132. tb3.auto_set_font_size(False)
  133. tb3.set_fontsize(12)
  134. tb3.auto_set_column_width(-1)
  135. tb3.auto_set_column_width(0)
  136. tb3.auto_set_column_width(1)
  137. # 4 this used to test non-integer iterable input, which did nothing, but only
  138. # remains to avoid re-generating the test image.
  139. ax4.axis('off')
  140. tb4 = ax4.table(
  141. cellText=[['Fit Text', 2],
  142. ['very long long text, Longer text than default', 1]],
  143. rowLabels=["A", "B"],
  144. colLabels=["Col1", "Col2"],
  145. loc="center")
  146. tb4.auto_set_font_size(False)
  147. tb4.set_fontsize(12)
  148. def test_table_cells():
  149. fig, ax = plt.subplots()
  150. table = Table(ax)
  151. cell = table.add_cell(1, 2, 1, 1)
  152. assert isinstance(cell, CustomCell)
  153. assert cell is table[1, 2]
  154. cell2 = CustomCell((0, 0), 1, 2, visible_edges=None)
  155. table[2, 1] = cell2
  156. assert table[2, 1] is cell2
  157. # make sure getitem support has not broken
  158. # properties and setp
  159. table.properties()
  160. plt.setp(table)
  161. @check_figures_equal(extensions=["png"])
  162. def test_table_bbox(fig_test, fig_ref):
  163. data = [[2, 3],
  164. [4, 5]]
  165. col_labels = ('Foo', 'Bar')
  166. row_labels = ('Ada', 'Bob')
  167. cell_text = [[f"{x}" for x in row] for row in data]
  168. ax_list = fig_test.subplots()
  169. ax_list.table(cellText=cell_text,
  170. rowLabels=row_labels,
  171. colLabels=col_labels,
  172. loc='center',
  173. bbox=[0.1, 0.2, 0.8, 0.6]
  174. )
  175. ax_bbox = fig_ref.subplots()
  176. ax_bbox.table(cellText=cell_text,
  177. rowLabels=row_labels,
  178. colLabels=col_labels,
  179. loc='center',
  180. bbox=Bbox.from_extents(0.1, 0.2, 0.9, 0.8)
  181. )
  182. @check_figures_equal(extensions=['png'])
  183. def test_table_unit(fig_test, fig_ref):
  184. # test that table doesn't participate in unit machinery, instead uses repr/str
  185. class FakeUnit:
  186. def __init__(self, thing):
  187. pass
  188. def __repr__(self):
  189. return "Hello"
  190. fake_convertor = munits.ConversionInterface()
  191. # v, u, a = value, unit, axis
  192. fake_convertor.convert = Mock(side_effect=lambda v, u, a: 0)
  193. # not used, here for completeness
  194. fake_convertor.default_units = Mock(side_effect=lambda v, a: None)
  195. fake_convertor.axisinfo = Mock(side_effect=lambda u, a: munits.AxisInfo())
  196. munits.registry[FakeUnit] = fake_convertor
  197. data = [[FakeUnit("yellow"), FakeUnit(42)],
  198. [FakeUnit(datetime.datetime(1968, 8, 1)), FakeUnit(True)]]
  199. fig_test.subplots().table(data)
  200. fig_ref.subplots().table([["Hello", "Hello"], ["Hello", "Hello"]])
  201. fig_test.canvas.draw()
  202. fake_convertor.convert.assert_not_called()
  203. munits.registry.pop(FakeUnit)
  204. assert not munits.registry.get_converter(FakeUnit)
  205. def test_table_dataframe(pd):
  206. # Test if Pandas Data Frame can be passed in cellText
  207. data = {
  208. 'Letter': ['A', 'B', 'C'],
  209. 'Number': [100, 200, 300]
  210. }
  211. df = pd.DataFrame(data)
  212. fig, ax = plt.subplots()
  213. table = ax.table(df, loc='center')
  214. for r, (index, row) in enumerate(df.iterrows()):
  215. for c, col in enumerate(df.columns if r == 0 else row.values):
  216. assert table[r if r == 0 else r+1, c].get_text().get_text() == str(col)
  217. def test_table_fontsize():
  218. # Test that the passed fontsize propagates to cells
  219. tableData = [['a', 1], ['b', 2]]
  220. fig, ax = plt.subplots()
  221. test_fontsize = 20
  222. t = ax.table(cellText=tableData, loc='top', fontsize=test_fontsize)
  223. cell_fontsize = t[(0, 0)].get_fontsize()
  224. assert cell_fontsize == test_fontsize, f"Actual:{test_fontsize},got:{cell_fontsize}"
  225. cell_fontsize = t[(1, 1)].get_fontsize()
  226. assert cell_fontsize == test_fontsize, f"Actual:{test_fontsize},got:{cell_fontsize}"