category.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. """
  2. Plotting of string "category" data: ``plot(['d', 'f', 'a'], [1, 2, 3])`` will
  3. plot three points with x-axis values of 'd', 'f', 'a'.
  4. See :doc:`/gallery/lines_bars_and_markers/categorical_variables` for an
  5. example.
  6. The module uses Matplotlib's `matplotlib.units` mechanism to convert from
  7. strings to integers and provides a tick locator, a tick formatter, and the
  8. `.UnitData` class that creates and stores the string-to-integer mapping.
  9. """
  10. from collections import OrderedDict
  11. import dateutil.parser
  12. import itertools
  13. import logging
  14. import numpy as np
  15. from matplotlib import _api, cbook, ticker, units
  16. _log = logging.getLogger(__name__)
  17. class StrCategoryConverter(units.ConversionInterface):
  18. @staticmethod
  19. def convert(value, unit, axis):
  20. """
  21. Convert strings in *value* to floats using mapping information stored
  22. in the *unit* object.
  23. Parameters
  24. ----------
  25. value : str or iterable
  26. Value or list of values to be converted.
  27. unit : `.UnitData`
  28. An object mapping strings to integers.
  29. axis : `~matplotlib.axis.Axis`
  30. The axis on which the converted value is plotted.
  31. .. note:: *axis* is unused.
  32. Returns
  33. -------
  34. float or `~numpy.ndarray` of float
  35. """
  36. if unit is None:
  37. raise ValueError(
  38. 'Missing category information for StrCategoryConverter; '
  39. 'this might be caused by unintendedly mixing categorical and '
  40. 'numeric data')
  41. StrCategoryConverter._validate_unit(unit)
  42. # dtype = object preserves numerical pass throughs
  43. values = np.atleast_1d(np.array(value, dtype=object))
  44. # force an update so it also does type checking
  45. unit.update(values)
  46. s = np.vectorize(unit._mapping.__getitem__, otypes=[float])(values)
  47. return s if not cbook.is_scalar_or_string(value) else s[0]
  48. @staticmethod
  49. def axisinfo(unit, axis):
  50. """
  51. Set the default axis ticks and labels.
  52. Parameters
  53. ----------
  54. unit : `.UnitData`
  55. object string unit information for value
  56. axis : `~matplotlib.axis.Axis`
  57. axis for which information is being set
  58. .. note:: *axis* is not used
  59. Returns
  60. -------
  61. `~matplotlib.units.AxisInfo`
  62. Information to support default tick labeling
  63. """
  64. StrCategoryConverter._validate_unit(unit)
  65. # locator and formatter take mapping dict because
  66. # args need to be pass by reference for updates
  67. majloc = StrCategoryLocator(unit._mapping)
  68. majfmt = StrCategoryFormatter(unit._mapping)
  69. return units.AxisInfo(majloc=majloc, majfmt=majfmt)
  70. @staticmethod
  71. def default_units(data, axis):
  72. """
  73. Set and update the `~matplotlib.axis.Axis` units.
  74. Parameters
  75. ----------
  76. data : str or iterable of str
  77. axis : `~matplotlib.axis.Axis`
  78. axis on which the data is plotted
  79. Returns
  80. -------
  81. `.UnitData`
  82. object storing string to integer mapping
  83. """
  84. # the conversion call stack is default_units -> axis_info -> convert
  85. if axis.units is None:
  86. axis.set_units(UnitData(data))
  87. else:
  88. axis.units.update(data)
  89. return axis.units
  90. @staticmethod
  91. def _validate_unit(unit):
  92. if not hasattr(unit, '_mapping'):
  93. raise ValueError(
  94. f'Provided unit "{unit}" is not valid for a categorical '
  95. 'converter, as it does not have a _mapping attribute.')
  96. class StrCategoryLocator(ticker.Locator):
  97. """Tick at every integer mapping of the string data."""
  98. def __init__(self, units_mapping):
  99. """
  100. Parameters
  101. ----------
  102. units_mapping : dict
  103. Mapping of category names (str) to indices (int).
  104. """
  105. self._units = units_mapping
  106. def __call__(self):
  107. # docstring inherited
  108. return list(self._units.values())
  109. def tick_values(self, vmin, vmax):
  110. # docstring inherited
  111. return self()
  112. class StrCategoryFormatter(ticker.Formatter):
  113. """String representation of the data at every tick."""
  114. def __init__(self, units_mapping):
  115. """
  116. Parameters
  117. ----------
  118. units_mapping : dict
  119. Mapping of category names (str) to indices (int).
  120. """
  121. self._units = units_mapping
  122. def __call__(self, x, pos=None):
  123. # docstring inherited
  124. return self.format_ticks([x])[0]
  125. def format_ticks(self, values):
  126. # docstring inherited
  127. r_mapping = {v: self._text(k) for k, v in self._units.items()}
  128. return [r_mapping.get(round(val), '') for val in values]
  129. @staticmethod
  130. def _text(value):
  131. """Convert text values into utf-8 or ascii strings."""
  132. if isinstance(value, bytes):
  133. value = value.decode(encoding='utf-8')
  134. elif not isinstance(value, str):
  135. value = str(value)
  136. return value
  137. class UnitData:
  138. def __init__(self, data=None):
  139. """
  140. Create mapping between unique categorical values and integer ids.
  141. Parameters
  142. ----------
  143. data : iterable
  144. sequence of string values
  145. """
  146. self._mapping = OrderedDict()
  147. self._counter = itertools.count()
  148. if data is not None:
  149. self.update(data)
  150. @staticmethod
  151. def _str_is_convertible(val):
  152. """
  153. Helper method to check whether a string can be parsed as float or date.
  154. """
  155. try:
  156. float(val)
  157. except ValueError:
  158. try:
  159. dateutil.parser.parse(val)
  160. except (ValueError, TypeError):
  161. # TypeError if dateutil >= 2.8.1 else ValueError
  162. return False
  163. return True
  164. def update(self, data):
  165. """
  166. Map new values to integer identifiers.
  167. Parameters
  168. ----------
  169. data : iterable of str or bytes
  170. Raises
  171. ------
  172. TypeError
  173. If elements in *data* are neither str nor bytes.
  174. """
  175. data = np.atleast_1d(np.array(data, dtype=object))
  176. # check if convertible to number:
  177. convertible = True
  178. for val in OrderedDict.fromkeys(data):
  179. # OrderedDict just iterates over unique values in data.
  180. _api.check_isinstance((str, bytes), value=val)
  181. if convertible:
  182. # this will only be called so long as convertible is True.
  183. convertible = self._str_is_convertible(val)
  184. if val not in self._mapping:
  185. self._mapping[val] = next(self._counter)
  186. if data.size and convertible:
  187. _log.info('Using categorical units to plot a list of strings '
  188. 'that are all parsable as floats or dates. If these '
  189. 'strings should be plotted as numbers, cast to the '
  190. 'appropriate data type before plotting.')
  191. # Register the converter with Matplotlib's unit framework
  192. # Intentionally set to a single instance
  193. units.registry[str] = \
  194. units.registry[np.str_] = \
  195. units.registry[bytes] = \
  196. units.registry[np.bytes_] = StrCategoryConverter()