chat_parsing_utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. # Copyright 2026 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import json
  16. import re
  17. from transformers.utils import is_jmespath_available
  18. if is_jmespath_available():
  19. import jmespath
  20. else:
  21. jmespath = None
  22. def _gemma4_json_to_json(text: str) -> str:
  23. """Convert Gemma4 tool call format (unquoted keys, ``<|"|>`` string delimiters) to valid JSON."""
  24. strings = []
  25. def _capture(m):
  26. strings.append(m.group(1))
  27. return f"\x00{len(strings) - 1}\x00"
  28. # Grab the inside of gemma-quotes and store them for later
  29. text = re.sub(r'<\|"\|>(.*?)<\|"\|>', _capture, text, flags=re.DOTALL)
  30. # Add quotes to the bare keys elsewhere
  31. text = re.sub(r"(?<=[{,])(\w+):", r'"\1":', text)
  32. # Put the inside of the quotes back afterwards
  33. for i, s in enumerate(strings):
  34. text = text.replace(f"\x00{i}\x00", json.dumps(s))
  35. return text
  36. def _parse_re_match(node_match: re.Match) -> dict | str:
  37. # If the regex has named groups, return a dict of those groups
  38. if node_match.groupdict():
  39. return {key: val for key, val in node_match.groupdict().items() if val is not None}
  40. # Otherwise the regex must have exactly one unnamed group, and we return that
  41. else:
  42. groups = list(node_match.groups())
  43. if len(groups) > 1:
  44. raise ValueError(f"Regex has multiple unnamed groups!\nGroups: {groups}\n")
  45. elif len(groups) == 0:
  46. raise ValueError(f"Regex has no capture groups:\n\n{node_match.group(0)}")
  47. return groups[0]
  48. def recursive_parse(
  49. node_content: str | list | dict,
  50. node_schema: dict,
  51. ):
  52. """
  53. This function takes content and a JSON schema which includes
  54. regex extractors, and recursively parses the content. The output
  55. should be a data structure matching the schema.
  56. Args:
  57. node_content: The content corresponding to this node. Usually a string, but can be something else
  58. if the parent node has multiple capture groups or named groups. In that case,
  59. we generally pass the capture groups straight through to the children of this node
  60. and don't do any parsing at this level.
  61. node_schema: The schema node controlling the parsing.
  62. Returns:
  63. The parsed data structure for the current node.
  64. """
  65. # If the schema has a const, we just return that value and do absolutely nothing else
  66. if "const" in node_schema:
  67. return node_schema["const"]
  68. # If the node content is None, we return None. EZ.
  69. if node_content is None:
  70. return None
  71. # If not, we have to do a little parsing. First, set some vars and do basic validation
  72. node_type = node_schema.get("type")
  73. has_regex = (
  74. "x-regex" in node_schema
  75. or "x-regex-iterator" in node_schema
  76. or "x-regex-key-value" in node_schema
  77. or "x-regex-substitutions" in node_schema
  78. )
  79. if has_regex and not isinstance(node_content, str):
  80. raise TypeError(
  81. "Schema node got a non-string input, but has a regex for parsing or substitution.\n"
  82. f"Input: {node_content}\n"
  83. f"Schema: {node_schema}"
  84. )
  85. node_subs = node_schema.get("x-regex-substitutions", [])
  86. for node_sub in node_subs:
  87. node_content = re.sub(node_sub[0], node_sub[1], node_content, flags=re.DOTALL)
  88. node_regex = node_schema.get("x-regex")
  89. node_regex_iterator = node_schema.get("x-regex-iterator")
  90. node_regex_to_dict = node_schema.get("x-regex-key-value")
  91. if node_regex is not None:
  92. node_match = re.search(node_regex, node_content, flags=re.DOTALL)
  93. if not node_match:
  94. return None
  95. node_content = _parse_re_match(node_match)
  96. if node_regex_iterator is not None:
  97. if node_type != "array":
  98. raise TypeError(f"Schema node with type {node_type} cannot use x-regex-iterator.\nSchema: {node_schema}")
  99. # Note that this can be applied after a standard node-regex search
  100. node_content = [
  101. _parse_re_match(node_match)
  102. for node_match in re.finditer(node_regex_iterator, node_content, flags=re.DOTALL)
  103. ]
  104. if not node_content:
  105. return None
  106. if node_regex_to_dict is not None:
  107. if node_type != "object":
  108. raise TypeError(f"Schema node with type {node_type} cannot use x-regex-key-value.\nSchema: {node_schema}")
  109. # Note that this can be applied after a standard node-regex search
  110. output_content = {}
  111. for node_match in re.finditer(node_regex_to_dict, node_content, flags=re.DOTALL):
  112. match_groups = _parse_re_match(node_match)
  113. if not isinstance(match_groups, dict) or "key" not in match_groups or "value" not in match_groups:
  114. raise ValueError(
  115. f"Regex for x-regex-key-value must have named groups 'key' and 'value'.\n"
  116. f"Match groups: {match_groups}\n"
  117. f"Schema: {node_schema}"
  118. )
  119. output_content[match_groups["key"]] = match_groups["value"]
  120. node_content = output_content
  121. if not node_content:
  122. return None
  123. # Next, if the node has a parser, apply it. We do this after regexes so that the regex can extract
  124. # a substring to parse, if needed.
  125. if "x-parser" in node_schema:
  126. parser = node_schema["x-parser"]
  127. if parser == "gemma4-tool-call":
  128. if not isinstance(node_content, str):
  129. raise TypeError(
  130. f"Node has Gemma4 tool call parser but got non-string input: {node_content}\nSchema: {node_schema}"
  131. )
  132. node_content = _gemma4_json_to_json(node_content)
  133. parser = "json" # fall through to the JSON parser below - don't add an elif!
  134. if parser == "json":
  135. if not isinstance(node_content, str):
  136. raise TypeError(
  137. f"Node has JSON parser but got non-string input: {node_content}\nSchema: {node_schema}"
  138. )
  139. parser_args = node_schema.get("x-parser-args", {})
  140. transform = parser_args.get("transform")
  141. allow_non_json = parser_args.get("allow_non_json", False)
  142. try:
  143. parsed_json = json.loads(node_content)
  144. except json.JSONDecodeError as e:
  145. if allow_non_json:
  146. parsed_json = node_content
  147. else:
  148. raise ValueError(
  149. f"Node has JSON parser but could not parse its contents as JSON. You can use the `allow_non_json` parser arg for nodes which may contain JSON or string content.\n\nContent: {node_content}\n\nError: {e}"
  150. )
  151. if transform is not None:
  152. if jmespath is None:
  153. raise ImportError(
  154. "Chat response schema includes a jmespath transformation, but jmespath is not installed. You can install it with `pip install jmespath`."
  155. )
  156. parsed_json = jmespath.search(parser_args["transform"], parsed_json)
  157. node_content = parsed_json
  158. else:
  159. raise ValueError(f"Unknown parser {parser} for schema node: {node_schema}")
  160. # Finally, handle parsed content based on schema type and recurse if required
  161. if node_type == "object":
  162. parsed_schema = {}
  163. if isinstance(node_content, str):
  164. # This means we don't have a regex at this level, so all of our child nodes need to parse the whole
  165. # string themselves to extract their value.
  166. if "properties" not in node_schema:
  167. raise ValueError(
  168. f"Object node received string content but has no regex or parser to handle it.\n"
  169. f"Content: {node_content}\n"
  170. f"Schema: {node_schema}"
  171. )
  172. for key, child_node in node_schema["properties"].items():
  173. child_node_content = recursive_parse(node_content, node_schema["properties"][key])
  174. if child_node_content is not None:
  175. parsed_schema[key] = child_node_content
  176. elif isinstance(node_content, dict):
  177. for key, child_node in node_schema.get("properties", {}).items():
  178. if "const" in child_node:
  179. parsed_schema[key] = child_node["const"]
  180. elif key in node_content:
  181. parsed_schema[key] = recursive_parse(node_content[key], child_node)
  182. elif "default" in child_node:
  183. parsed_schema[key] = child_node["default"]
  184. additional_schema = node_schema.get("additionalProperties", True)
  185. # We want to check only for False values; {} is "falsy" but should pass through
  186. if additional_schema is not False:
  187. additional_schema = additional_schema if isinstance(additional_schema, dict) else {}
  188. for key, value in node_content.items():
  189. if key not in node_schema.get("properties", {}):
  190. parsed_schema[key] = recursive_parse(value, additional_schema)
  191. else:
  192. raise TypeError(f"Expected a dict or str for schema node with type object, got {node_content}")
  193. required = node_schema.get("required", [])
  194. missing = [key for key in required if key not in parsed_schema]
  195. if missing:
  196. input_preview = repr(node_content[:500]) if isinstance(node_content, str) else repr(node_content)
  197. raise ValueError(
  198. f"Required fields {missing} are missing from parsed output.\n"
  199. f"Parsed: {parsed_schema}\n"
  200. f"Input: {input_preview}"
  201. )
  202. return parsed_schema
  203. elif node_type == "array":
  204. if not node_content:
  205. return []
  206. parsed_schema = []
  207. if "items" in node_schema:
  208. if not isinstance(node_content, list):
  209. raise TypeError(f"Expected a list or regex for schema node with type array, got {node_content}")
  210. for item in node_content:
  211. parsed_schema.append(recursive_parse(item, node_schema["items"]))
  212. return parsed_schema
  213. elif "prefixItems" in node_schema:
  214. if not isinstance(node_content, list):
  215. if len(node_schema["prefixItems"]) == 1:
  216. # If there's only one prefix item, this is a single item array, we can just wrap the string
  217. node_content = [node_content]
  218. else:
  219. raise TypeError(f"Expected a list or regex for schema node with type array, got {node_content}")
  220. if len(node_content) != len(node_schema["prefixItems"]):
  221. raise ValueError(
  222. f"Array node has {len(node_content)} items, but schema only has "
  223. f"{len(node_schema['prefixItems'])} prefixItems defined.\n"
  224. f"Content: {node_content}\n"
  225. f"Schema: {node_schema}"
  226. )
  227. for item, item_schema in zip(node_content, node_schema["prefixItems"]):
  228. parsed_schema.append(recursive_parse(item, item_schema))
  229. return parsed_schema
  230. else:
  231. raise ValueError(f"Array node has no items or prefixItems schema defined.\nSchema: {node_schema}")
  232. elif node_type in ("string", "integer", "number", "boolean"):
  233. if node_type == "integer":
  234. if isinstance(node_content, int):
  235. return node_content
  236. if not isinstance(node_content, str):
  237. raise TypeError(
  238. f"Expected a string or int for schema node with type integer, got {type(node_content).__name__}: {node_content}"
  239. )
  240. try:
  241. return int(node_content)
  242. except ValueError:
  243. raise ValueError(
  244. f"Schema node has type 'integer', but the parsed string content is not a valid integer: {node_content!r}"
  245. )
  246. elif node_type == "number":
  247. if isinstance(node_content, (int, float)):
  248. return float(node_content)
  249. if not isinstance(node_content, str):
  250. raise TypeError(
  251. f"Expected a string or number for schema node with type number, got {type(node_content).__name__}: {node_content}"
  252. )
  253. try:
  254. return float(node_content)
  255. except ValueError:
  256. raise ValueError(
  257. f"Schema node has type 'number', but the parsed string content is not a valid number: {node_content!r}"
  258. )
  259. elif node_type == "boolean":
  260. if isinstance(node_content, bool):
  261. return node_content
  262. if not isinstance(node_content, str):
  263. raise TypeError(
  264. f"Expected a string or bool for schema node with type boolean, got {type(node_content).__name__}: {node_content}"
  265. )
  266. if node_content.lower() in ("true", "1"):
  267. return True
  268. elif node_content.lower() in ("false", "0"):
  269. return False
  270. else:
  271. raise ValueError(f"Invalid boolean value: {node_content}")
  272. else:
  273. # String type
  274. if not isinstance(node_content, str):
  275. raise TypeError(
  276. f"Expected a string for schema node with type string, got {type(node_content).__name__}: {node_content}"
  277. )
  278. return node_content
  279. elif node_type is None or node_type == "any":
  280. return node_content # Don't touch it
  281. else:
  282. raise TypeError(f"Unsupported schema type {node_type} for node: {node_content}")