figureoptions.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. # Copyright © 2009 Pierre Raybaut
  2. # Licensed under the terms of the MIT License
  3. # see the Matplotlib licenses directory for a copy of the license
  4. """Module that provides a GUI-based editor for Matplotlib's figure options."""
  5. from itertools import chain
  6. from matplotlib import cbook, cm, colors as mcolors, markers, image as mimage
  7. from matplotlib.backends.qt_compat import QtGui
  8. from matplotlib.backends.qt_editor import _formlayout
  9. from matplotlib.dates import DateConverter, num2date
  10. LINESTYLES = {'-': 'Solid',
  11. '--': 'Dashed',
  12. '-.': 'DashDot',
  13. ':': 'Dotted',
  14. 'None': 'None',
  15. }
  16. DRAWSTYLES = {
  17. 'default': 'Default',
  18. 'steps-pre': 'Steps (Pre)', 'steps': 'Steps (Pre)',
  19. 'steps-mid': 'Steps (Mid)',
  20. 'steps-post': 'Steps (Post)'}
  21. MARKERS = markers.MarkerStyle.markers
  22. def figure_edit(axes, parent=None):
  23. """Edit matplotlib figure options"""
  24. sep = (None, None) # separator
  25. # Get / General
  26. def convert_limits(lim, converter):
  27. """Convert axis limits for correct input editors."""
  28. if isinstance(converter, DateConverter):
  29. return map(num2date, lim)
  30. # Cast to builtin floats as they have nicer reprs.
  31. return map(float, lim)
  32. axis_map = axes._axis_map
  33. axis_limits = {
  34. name: tuple(convert_limits(
  35. getattr(axes, f'get_{name}lim')(), axis.get_converter()
  36. ))
  37. for name, axis in axis_map.items()
  38. }
  39. general = [
  40. ('Title', axes.get_title()),
  41. sep,
  42. *chain.from_iterable([
  43. (
  44. (None, f"<b>{name.title()}-Axis</b>"),
  45. ('Min', axis_limits[name][0]),
  46. ('Max', axis_limits[name][1]),
  47. ('Label', axis.label.get_text()),
  48. ('Scale', [axis.get_scale(),
  49. 'linear', 'log', 'symlog', 'logit']),
  50. sep,
  51. )
  52. for name, axis in axis_map.items()
  53. ]),
  54. ('(Re-)Generate automatic legend', False),
  55. ]
  56. # Save the converter and unit data
  57. axis_converter = {
  58. name: axis.get_converter()
  59. for name, axis in axis_map.items()
  60. }
  61. axis_units = {
  62. name: axis.get_units()
  63. for name, axis in axis_map.items()
  64. }
  65. # Get / Curves
  66. labeled_lines = []
  67. for line in axes.get_lines():
  68. label = line.get_label()
  69. if label == '_nolegend_':
  70. continue
  71. labeled_lines.append((label, line))
  72. curves = []
  73. def prepare_data(d, init):
  74. """
  75. Prepare entry for FormLayout.
  76. *d* is a mapping of shorthands to style names (a single style may
  77. have multiple shorthands, in particular the shorthands `None`,
  78. `"None"`, `"none"` and `""` are synonyms); *init* is one shorthand
  79. of the initial style.
  80. This function returns an list suitable for initializing a
  81. FormLayout combobox, namely `[initial_name, (shorthand,
  82. style_name), (shorthand, style_name), ...]`.
  83. """
  84. if init not in d:
  85. d = {**d, init: str(init)}
  86. # Drop duplicate shorthands from dict (by overwriting them during
  87. # the dict comprehension).
  88. name2short = {name: short for short, name in d.items()}
  89. # Convert back to {shorthand: name}.
  90. short2name = {short: name for name, short in name2short.items()}
  91. # Find the kept shorthand for the style specified by init.
  92. canonical_init = name2short[d[init]]
  93. # Sort by representation and prepend the initial value.
  94. return ([canonical_init] +
  95. sorted(short2name.items(),
  96. key=lambda short_and_name: short_and_name[1]))
  97. for label, line in labeled_lines:
  98. color = mcolors.to_hex(
  99. mcolors.to_rgba(line.get_color(), line.get_alpha()),
  100. keep_alpha=True)
  101. ec = mcolors.to_hex(
  102. mcolors.to_rgba(line.get_markeredgecolor(), line.get_alpha()),
  103. keep_alpha=True)
  104. fc = mcolors.to_hex(
  105. mcolors.to_rgba(line.get_markerfacecolor(), line.get_alpha()),
  106. keep_alpha=True)
  107. curvedata = [
  108. ('Label', label),
  109. sep,
  110. (None, '<b>Line</b>'),
  111. ('Line style', prepare_data(LINESTYLES, line.get_linestyle())),
  112. ('Draw style', prepare_data(DRAWSTYLES, line.get_drawstyle())),
  113. ('Width', line.get_linewidth()),
  114. ('Color (RGBA)', color),
  115. sep,
  116. (None, '<b>Marker</b>'),
  117. ('Style', prepare_data(MARKERS, line.get_marker())),
  118. ('Size', line.get_markersize()),
  119. ('Face color (RGBA)', fc),
  120. ('Edge color (RGBA)', ec)]
  121. curves.append([curvedata, label, ""])
  122. # Is there a curve displayed?
  123. has_curve = bool(curves)
  124. # Get ScalarMappables.
  125. labeled_mappables = []
  126. for mappable in [*axes.images, *axes.collections]:
  127. label = mappable.get_label()
  128. if label == '_nolegend_' or mappable.get_array() is None:
  129. continue
  130. labeled_mappables.append((label, mappable))
  131. mappables = []
  132. cmaps = [(cmap, name) for name, cmap in sorted(cm._colormaps.items())]
  133. for label, mappable in labeled_mappables:
  134. cmap = mappable.get_cmap()
  135. if cmap.name not in cm._colormaps:
  136. cmaps = [(cmap, cmap.name), *cmaps]
  137. low, high = mappable.get_clim()
  138. mappabledata = [
  139. ('Label', label),
  140. ('Colormap', [cmap.name] + cmaps),
  141. ('Min. value', low),
  142. ('Max. value', high),
  143. ]
  144. if hasattr(mappable, "get_interpolation"): # Images.
  145. interpolations = [
  146. (name, name) for name in sorted(mimage.interpolations_names)]
  147. mappabledata.append((
  148. 'Interpolation',
  149. [mappable.get_interpolation(), *interpolations]))
  150. interpolation_stages = ['data', 'rgba', 'auto']
  151. mappabledata.append((
  152. 'Interpolation stage',
  153. [mappable.get_interpolation_stage(), *interpolation_stages]))
  154. mappables.append([mappabledata, label, ""])
  155. # Is there a scalarmappable displayed?
  156. has_sm = bool(mappables)
  157. datalist = [(general, "Axes", "")]
  158. if curves:
  159. datalist.append((curves, "Curves", ""))
  160. if mappables:
  161. datalist.append((mappables, "Images, etc.", ""))
  162. def apply_callback(data):
  163. """A callback to apply changes."""
  164. orig_limits = {
  165. name: getattr(axes, f"get_{name}lim")()
  166. for name in axis_map
  167. }
  168. general = data.pop(0)
  169. curves = data.pop(0) if has_curve else []
  170. mappables = data.pop(0) if has_sm else []
  171. if data:
  172. raise ValueError("Unexpected field")
  173. title = general.pop(0)
  174. axes.set_title(title)
  175. generate_legend = general.pop()
  176. for i, (name, axis) in enumerate(axis_map.items()):
  177. axis_min = general[4*i]
  178. axis_max = general[4*i + 1]
  179. axis_label = general[4*i + 2]
  180. axis_scale = general[4*i + 3]
  181. if axis.get_scale() != axis_scale:
  182. getattr(axes, f"set_{name}scale")(axis_scale)
  183. axis._set_lim(axis_min, axis_max, auto=False)
  184. axis.set_label_text(axis_label)
  185. # Restore the unit data
  186. axis._set_converter(axis_converter[name])
  187. axis.set_units(axis_units[name])
  188. # Set / Curves
  189. for index, curve in enumerate(curves):
  190. line = labeled_lines[index][1]
  191. (label, linestyle, drawstyle, linewidth, color, marker, markersize,
  192. markerfacecolor, markeredgecolor) = curve
  193. line.set_label(label)
  194. line.set_linestyle(linestyle)
  195. line.set_drawstyle(drawstyle)
  196. line.set_linewidth(linewidth)
  197. rgba = mcolors.to_rgba(color)
  198. line.set_alpha(None)
  199. line.set_color(rgba)
  200. if marker != 'none':
  201. line.set_marker(marker)
  202. line.set_markersize(markersize)
  203. line.set_markerfacecolor(markerfacecolor)
  204. line.set_markeredgecolor(markeredgecolor)
  205. # Set ScalarMappables.
  206. for index, mappable_settings in enumerate(mappables):
  207. mappable = labeled_mappables[index][1]
  208. if len(mappable_settings) == 6:
  209. label, cmap, low, high, interpolation, interpolation_stage = \
  210. mappable_settings
  211. mappable.set_interpolation(interpolation)
  212. mappable.set_interpolation_stage(interpolation_stage)
  213. elif len(mappable_settings) == 4:
  214. label, cmap, low, high = mappable_settings
  215. mappable.set_label(label)
  216. mappable.set_cmap(cmap)
  217. mappable.set_clim(*sorted([low, high]))
  218. # re-generate legend, if checkbox is checked
  219. if generate_legend:
  220. draggable = None
  221. ncols = 1
  222. if axes.legend_ is not None:
  223. old_legend = axes.get_legend()
  224. draggable = old_legend._draggable is not None
  225. ncols = old_legend._ncols
  226. new_legend = axes.legend(ncols=ncols)
  227. if new_legend:
  228. new_legend.set_draggable(draggable)
  229. # Redraw
  230. figure = axes.get_figure()
  231. figure.canvas.draw()
  232. for name in axis_map:
  233. if getattr(axes, f"get_{name}lim")() != orig_limits[name]:
  234. figure.canvas.toolbar.push_current()
  235. break
  236. _formlayout.fedit(
  237. datalist, title="Figure options", parent=parent,
  238. icon=QtGui.QIcon(
  239. str(cbook._get_data_path('images', 'qt4_editor_options.svg'))),
  240. apply=apply_callback)