visualizer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. import itertools
  2. import os
  3. import re
  4. from string import Template
  5. from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple
  6. from tokenizers import Encoding, Tokenizer
  7. dirname = os.path.dirname(__file__)
  8. css_filename = os.path.join(dirname, "visualizer-styles.css")
  9. with open(css_filename) as f:
  10. css = f.read()
  11. class Annotation:
  12. start: int
  13. end: int
  14. label: str
  15. def __init__(self, start: int, end: int, label: str):
  16. self.start = start
  17. self.end = end
  18. self.label = label
  19. AnnotationList = List[Annotation]
  20. PartialIntList = List[Optional[int]]
  21. class CharStateKey(NamedTuple):
  22. token_ix: Optional[int]
  23. anno_ix: Optional[int]
  24. class CharState:
  25. char_ix: Optional[int]
  26. def __init__(self, char_ix):
  27. self.char_ix = char_ix
  28. self.anno_ix: Optional[int] = None
  29. self.tokens: List[int] = []
  30. @property
  31. def token_ix(self):
  32. return self.tokens[0] if len(self.tokens) > 0 else None
  33. @property
  34. def is_multitoken(self):
  35. """
  36. BPE tokenizers can output more than one token for a char
  37. """
  38. return len(self.tokens) > 1
  39. def partition_key(self) -> CharStateKey:
  40. return CharStateKey(
  41. token_ix=self.token_ix,
  42. anno_ix=self.anno_ix,
  43. )
  44. class Aligned:
  45. pass
  46. class EncodingVisualizer:
  47. """
  48. Build an EncodingVisualizer
  49. Args:
  50. tokenizer (:class:`~tokenizers.Tokenizer`):
  51. A tokenizer instance
  52. default_to_notebook (:obj:`bool`):
  53. Whether to render html output in a notebook by default
  54. annotation_converter (:obj:`Callable`, `optional`):
  55. An optional (lambda) function that takes an annotation in any format and returns
  56. an Annotation object
  57. """
  58. unk_token_regex = re.compile("(.{1}\b)?(unk|oov)(\b.{1})?", flags=re.IGNORECASE)
  59. def __init__(
  60. self,
  61. tokenizer: Tokenizer,
  62. default_to_notebook: bool = True,
  63. annotation_converter: Optional[Callable[[Any], Annotation]] = None,
  64. ):
  65. if default_to_notebook:
  66. try:
  67. from IPython.core.display import HTML, display # type: ignore[attr-defined]
  68. except ImportError:
  69. raise Exception(
  70. """We couldn't import IPython utils for html display.
  71. Are you running in a notebook?
  72. You can also pass `default_to_notebook=False` to get back raw HTML
  73. """
  74. )
  75. self.tokenizer = tokenizer
  76. self.default_to_notebook = default_to_notebook
  77. self.annotation_coverter = annotation_converter
  78. pass
  79. def __call__(
  80. self,
  81. text: str,
  82. annotations: Optional[List[Any]] = None,
  83. default_to_notebook: Optional[bool] = None,
  84. ) -> Optional[str]:
  85. """
  86. Build a visualization of the given text
  87. Args:
  88. text (:obj:`str`):
  89. The text to tokenize
  90. annotations (:obj:`List[Annotation]`, `optional`):
  91. An optional list of annotations of the text. The can either be an annotation class
  92. or anything else if you instantiated the visualizer with a converter function
  93. default_to_notebook (:obj:`bool`, `optional`, defaults to `False`):
  94. If True, will render the html in a notebook. Otherwise returns an html string.
  95. Returns:
  96. The HTML string if default_to_notebook is False, otherwise (default) returns None and
  97. renders the HTML in the notebook
  98. """
  99. final_default_to_notebook = self.default_to_notebook
  100. if default_to_notebook is not None:
  101. final_default_to_notebook = default_to_notebook
  102. if final_default_to_notebook:
  103. try:
  104. from IPython.core.display import HTML, display # type: ignore[attr-defined]
  105. except ImportError:
  106. raise Exception(
  107. """We couldn't import IPython utils for html display.
  108. Are you running in a notebook?"""
  109. )
  110. if annotations is None:
  111. annotations = []
  112. if self.annotation_coverter is not None:
  113. annotations = list(map(self.annotation_coverter, annotations))
  114. encoding = self.tokenizer.encode(text)
  115. html = EncodingVisualizer.__make_html(text, encoding, annotations)
  116. if final_default_to_notebook:
  117. display(HTML(html))
  118. else:
  119. return html
  120. @staticmethod
  121. def calculate_label_colors(annotations: AnnotationList) -> Dict[str, str]:
  122. """
  123. Generates a color palette for all the labels in a given set of annotations
  124. Args:
  125. annotations (:obj:`Annotation`):
  126. A list of annotations
  127. Returns:
  128. :obj:`dict`: A dictionary mapping labels to colors in HSL format
  129. """
  130. if len(annotations) == 0:
  131. return {}
  132. labels = set(map(lambda x: x.label, annotations))
  133. num_labels = len(labels)
  134. h_step = int(255 / num_labels)
  135. if h_step < 20:
  136. h_step = 20
  137. s = 32
  138. l = 64 # noqa: E741
  139. h = 10
  140. colors = {}
  141. for label in sorted(labels): # sort so we always get the same colors for a given set of labels
  142. colors[label] = f"hsl({h},{s}%,{l}%)"
  143. h += h_step
  144. return colors
  145. @staticmethod
  146. def consecutive_chars_to_html(
  147. consecutive_chars_list: List[CharState],
  148. text: str,
  149. encoding: Encoding,
  150. ):
  151. """
  152. Converts a list of "consecutive chars" into a single HTML element.
  153. Chars are consecutive if they fall under the same word, token and annotation.
  154. The CharState class is a named tuple with a "partition_key" method that makes it easy to
  155. compare if two chars are consecutive.
  156. Args:
  157. consecutive_chars_list (:obj:`List[CharState]`):
  158. A list of CharStates that have been grouped together
  159. text (:obj:`str`):
  160. The original text being processed
  161. encoding (:class:`~tokenizers.Encoding`):
  162. The encoding returned from the tokenizer
  163. Returns:
  164. :obj:`str`: The HTML span for a set of consecutive chars
  165. """
  166. first = consecutive_chars_list[0]
  167. if first.char_ix is None:
  168. # its a special token
  169. stoken = encoding.tokens[first.token_ix]
  170. # special tokens are represented as empty spans. We use the data attribute and css
  171. # magic to display it
  172. return f'<span class="special-token" data-stoken={stoken}></span>'
  173. # We're not in a special token so this group has a start and end.
  174. last = consecutive_chars_list[-1]
  175. assert first.char_ix is not None
  176. assert last.char_ix is not None
  177. start = first.char_ix
  178. end = last.char_ix + 1
  179. span_text = text[start:end]
  180. css_classes = [] # What css classes will we apply on the resulting span
  181. data_items = {} # What data attributes will we apply on the result span
  182. if first.token_ix is not None:
  183. # We can either be in a token or not (e.g. in white space)
  184. css_classes.append("token")
  185. if first.is_multitoken:
  186. css_classes.append("multi-token")
  187. if first.token_ix % 2:
  188. # We use this to color alternating tokens.
  189. # A token might be split by an annotation that ends in the middle of it, so this
  190. # lets us visually indicate a consecutive token despite its possible splitting in
  191. # the html markup
  192. css_classes.append("odd-token")
  193. else:
  194. # Like above, but a different color so we can see the tokens alternate
  195. css_classes.append("even-token")
  196. if EncodingVisualizer.unk_token_regex.search(encoding.tokens[first.token_ix]) is not None:
  197. # This is a special token that is in the text. probably UNK
  198. css_classes.append("special-token")
  199. # TODO is this the right name for the data attribute ?
  200. data_items["stok"] = encoding.tokens[first.token_ix]
  201. else:
  202. # In this case we are looking at a group/single char that is not tokenized.
  203. # e.g. white space
  204. css_classes.append("non-token")
  205. css = f'''class="{" ".join(css_classes)}"'''
  206. data = ""
  207. for key, val in data_items.items():
  208. data += f' data-{key}="{val}"'
  209. return f"<span {css} {data} >{span_text}</span>"
  210. @staticmethod
  211. def __make_html(text: str, encoding: Encoding, annotations: AnnotationList) -> str:
  212. char_states = EncodingVisualizer.__make_char_states(text, encoding, annotations)
  213. current_consecutive_chars = [char_states[0]]
  214. prev_anno_ix = char_states[0].anno_ix
  215. spans = []
  216. label_colors_dict = EncodingVisualizer.calculate_label_colors(annotations)
  217. cur_anno_ix = char_states[0].anno_ix
  218. if cur_anno_ix is not None:
  219. # If we started in an annotation make a span for it
  220. anno = annotations[cur_anno_ix]
  221. label = anno.label
  222. color = label_colors_dict[label]
  223. spans.append(f'<span class="annotation" style="color:{color}" data-label="{label}">')
  224. for cs in char_states[1:]:
  225. cur_anno_ix = cs.anno_ix
  226. if cur_anno_ix != prev_anno_ix:
  227. # If we've transitioned in or out of an annotation
  228. spans.append(
  229. # Create a span from the current consecutive characters
  230. EncodingVisualizer.consecutive_chars_to_html(
  231. current_consecutive_chars,
  232. text=text,
  233. encoding=encoding,
  234. )
  235. )
  236. current_consecutive_chars = [cs]
  237. if prev_anno_ix is not None:
  238. # if we transitioned out of an annotation close it's span
  239. spans.append("</span>")
  240. if cur_anno_ix is not None:
  241. # If we entered a new annotation make a span for it
  242. anno = annotations[cur_anno_ix]
  243. label = anno.label
  244. color = label_colors_dict[label]
  245. spans.append(f'<span class="annotation" style="color:{color}" data-label="{label}">')
  246. prev_anno_ix = cur_anno_ix
  247. if cs.partition_key() == current_consecutive_chars[0].partition_key():
  248. # If the current charchter is in the same "group" as the previous one
  249. current_consecutive_chars.append(cs)
  250. else:
  251. # Otherwise we make a span for the previous group
  252. spans.append(
  253. EncodingVisualizer.consecutive_chars_to_html(
  254. current_consecutive_chars,
  255. text=text,
  256. encoding=encoding,
  257. )
  258. )
  259. # An reset the consecutive_char_list to form a new group
  260. current_consecutive_chars = [cs]
  261. # All that's left is to fill out the final span
  262. # TODO I think there is an edge case here where an annotation's span might not close
  263. spans.append(
  264. EncodingVisualizer.consecutive_chars_to_html(
  265. current_consecutive_chars,
  266. text=text,
  267. encoding=encoding,
  268. )
  269. )
  270. res = HTMLBody(spans) # Send the list of spans to the body of our html
  271. return res
  272. @staticmethod
  273. def __make_anno_map(text: str, annotations: AnnotationList) -> PartialIntList:
  274. """
  275. Args:
  276. text (:obj:`str`):
  277. The raw text we want to align to
  278. annotations (:obj:`AnnotationList`):
  279. A (possibly empty) list of annotations
  280. Returns:
  281. A list of length len(text) whose entry at index i is None if there is no annotation on
  282. character i or k, the index of the annotation that covers index i where k is with
  283. respect to the list of annotations
  284. """
  285. annotation_map = [None] * len(text)
  286. for anno_ix, a in enumerate(annotations):
  287. for i in range(a.start, a.end):
  288. annotation_map[i] = anno_ix
  289. return annotation_map
  290. @staticmethod
  291. def __make_char_states(text: str, encoding: Encoding, annotations: AnnotationList) -> List[CharState]:
  292. """
  293. For each character in the original text, we emit a tuple representing it's "state":
  294. * which token_ix it corresponds to
  295. * which word_ix it corresponds to
  296. * which annotation_ix it corresponds to
  297. Args:
  298. text (:obj:`str`):
  299. The raw text we want to align to
  300. annotations (:obj:`List[Annotation]`):
  301. A (possibly empty) list of annotations
  302. encoding: (:class:`~tokenizers.Encoding`):
  303. The encoding returned from the tokenizer
  304. Returns:
  305. :obj:`List[CharState]`: A list of CharStates, indicating for each char in the text what
  306. it's state is
  307. """
  308. annotation_map = EncodingVisualizer.__make_anno_map(text, annotations)
  309. # Todo make this a dataclass or named tuple
  310. char_states: List[CharState] = [CharState(char_ix) for char_ix in range(len(text))]
  311. for token_ix, token in enumerate(encoding.tokens):
  312. offsets = encoding.token_to_chars(token_ix)
  313. if offsets is not None:
  314. start, end = offsets
  315. for i in range(start, end):
  316. char_states[i].tokens.append(token_ix)
  317. for char_ix, anno_ix in enumerate(annotation_map):
  318. char_states[char_ix].anno_ix = anno_ix
  319. return char_states
  320. def HTMLBody(children: List[str], css_styles=css) -> str:
  321. """
  322. Generates the full html with css from a list of html spans
  323. Args:
  324. children (:obj:`List[str]`):
  325. A list of strings, assumed to be html elements
  326. css_styles (:obj:`str`, `optional`):
  327. Optional alternative implementation of the css
  328. Returns:
  329. :obj:`str`: An HTML string with style markup
  330. """
  331. children_text = "".join(children)
  332. return f"""
  333. <html>
  334. <head>
  335. <style>
  336. {css_styles}
  337. </style>
  338. </head>
  339. <body>
  340. <div class="tokenized-text" dir=auto>
  341. {children_text}
  342. </div>
  343. </body>
  344. </html>
  345. """