chat_template_utils.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. import json
  16. import re
  17. import types
  18. from collections.abc import Callable
  19. from contextlib import contextmanager
  20. from copy import deepcopy
  21. from datetime import datetime
  22. from functools import lru_cache
  23. from inspect import isfunction
  24. from typing import Any, Literal, Union, get_args, get_origin, get_type_hints, no_type_check
  25. from packaging import version
  26. from . import logging
  27. from .import_utils import is_jinja_available, is_torch_available, is_vision_available
  28. logger = logging.get_logger(__name__)
  29. if is_jinja_available():
  30. import jinja2
  31. import jinja2.exceptions
  32. import jinja2.ext
  33. import jinja2.meta
  34. import jinja2.nodes
  35. import jinja2.runtime
  36. from jinja2.ext import Extension
  37. from jinja2.sandbox import ImmutableSandboxedEnvironment
  38. else:
  39. jinja2 = None
  40. if is_vision_available():
  41. from PIL.Image import Image
  42. ChatType = list[dict[str, Any]]
  43. BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
  44. # Extracts the initial segment of the docstring, containing the function description
  45. description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL)
  46. # Extracts the Args: block from the docstring
  47. args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL)
  48. # Splits the Args: block into individual arguments
  49. args_split_re = re.compile(
  50. r"""
  51. (?:^|\n) # Match the start of the args block, or a newline
  52. \s*(\w+):\s* # Capture the argument name and strip spacing
  53. (.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing
  54. (?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block
  55. """,
  56. re.DOTALL | re.VERBOSE,
  57. )
  58. # Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc!
  59. returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL)
  60. class TypeHintParsingException(Exception):
  61. """Exception raised for errors in parsing type hints to generate JSON schemas"""
  62. class DocstringParsingException(Exception):
  63. """Exception raised for errors in parsing docstrings to generate JSON schemas"""
  64. def _get_json_schema_type(param_type: type) -> dict[str, str]:
  65. type_mapping = {
  66. int: {"type": "integer"},
  67. float: {"type": "number"},
  68. str: {"type": "string"},
  69. bool: {"type": "boolean"},
  70. type(None): {"type": "null"},
  71. Any: {},
  72. }
  73. if is_vision_available():
  74. type_mapping[Image] = {"type": "image"}
  75. if is_torch_available():
  76. import torch
  77. type_mapping[torch.Tensor] = {"type": "audio"}
  78. return type_mapping.get(param_type, {"type": "object"})
  79. def _parse_type_hint(hint: str) -> dict:
  80. origin = get_origin(hint)
  81. args = get_args(hint)
  82. if origin is None:
  83. try:
  84. return _get_json_schema_type(hint)
  85. except KeyError:
  86. raise TypeHintParsingException(
  87. "Couldn't parse this type hint, likely due to a custom class or object: ", hint
  88. )
  89. elif origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
  90. # Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
  91. subtypes = [_parse_type_hint(t) for t in args if t is not type(None)]
  92. if len(subtypes) == 1:
  93. # A single non-null type can be expressed directly
  94. return_dict = subtypes[0]
  95. elif all("type" in subtype and isinstance(subtype["type"], str) for subtype in subtypes):
  96. # A union of basic types can be expressed as a list in the schema
  97. return_dict = {"type": sorted([subtype["type"] for subtype in subtypes])}
  98. else:
  99. # A union of more complex types requires "anyOf"
  100. return_dict = {"anyOf": subtypes}
  101. if type(None) in args:
  102. return_dict["nullable"] = True
  103. return return_dict
  104. elif origin is Literal and len(args) > 0:
  105. LITERAL_TYPES = (int, float, str, bool, type(None))
  106. args_types = []
  107. for arg in args:
  108. if type(arg) not in LITERAL_TYPES:
  109. raise TypeHintParsingException("Only the valid python literals can be listed in typing.Literal.")
  110. arg_type = _get_json_schema_type(type(arg)).get("type")
  111. if arg_type is not None and arg_type not in args_types:
  112. args_types.append(arg_type)
  113. return {
  114. "type": args_types.pop() if len(args_types) == 1 else list(args_types),
  115. "enum": list(args),
  116. }
  117. elif origin is list:
  118. if not args:
  119. return {"type": "array"}
  120. else:
  121. # Lists can only have a single type argument, so recurse into it
  122. return {"type": "array", "items": _parse_type_hint(args[0])}
  123. elif origin is tuple:
  124. if not args:
  125. return {"type": "array"}
  126. if len(args) == 1:
  127. raise TypeHintParsingException(
  128. f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which "
  129. "we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain "
  130. "more than one element, we recommend "
  131. "using a list[] type instead, or if it really is a single element, remove the tuple[] wrapper and just "
  132. "pass the element directly."
  133. )
  134. if ... in args:
  135. raise TypeHintParsingException(
  136. "Conversion of '...' is not supported in Tuple type hints. "
  137. "Use list[] types for variable-length"
  138. " inputs instead."
  139. )
  140. return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]}
  141. elif origin is dict:
  142. # The JSON equivalent to a dict is 'object', which mandates that all keys are strings
  143. # However, we can specify the type of the dict values with "additionalProperties"
  144. out = {"type": "object"}
  145. if len(args) == 2:
  146. out["additionalProperties"] = _parse_type_hint(args[1])
  147. return out
  148. raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
  149. def _convert_type_hints_to_json_schema(func: Callable) -> dict:
  150. type_hints = get_type_hints(func)
  151. signature = inspect.signature(func)
  152. func_name = getattr(func, "__name__", "operation")
  153. # For methods, we need to ignore the first "self" or "cls" parameter. Here we assume that if the first parameter
  154. # is named "self" or "cls" and has no type hint, it is an implicit receiver argument.
  155. first_param_name = next(iter(signature.parameters), None)
  156. if (
  157. first_param_name in {"self", "cls"}
  158. and signature.parameters[first_param_name].annotation == inspect.Parameter.empty
  159. ):
  160. implicit_arg_name = first_param_name
  161. else:
  162. implicit_arg_name = None
  163. required = []
  164. for param_name, param in signature.parameters.items():
  165. if param_name == implicit_arg_name:
  166. continue
  167. if param.annotation == inspect.Parameter.empty:
  168. raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func_name}")
  169. if param.default == inspect.Parameter.empty:
  170. required.append(param_name)
  171. properties = {}
  172. for param_name, param_type in type_hints.items():
  173. if param_name == implicit_arg_name:
  174. continue
  175. properties[param_name] = _parse_type_hint(param_type)
  176. schema = {"type": "object", "properties": properties}
  177. if required:
  178. schema["required"] = required
  179. return schema
  180. def parse_google_format_docstring(docstring: str) -> tuple[str | None, dict | None, str | None]:
  181. """
  182. Parses a Google-style docstring to extract the function description,
  183. argument descriptions, and return description.
  184. Args:
  185. docstring (str): The docstring to parse.
  186. Returns:
  187. The function description, arguments, and return description.
  188. """
  189. # Extract the sections
  190. description_match = description_re.search(docstring)
  191. args_match = args_re.search(docstring)
  192. returns_match = returns_re.search(docstring)
  193. # Clean and store the sections
  194. description = description_match.group(1).strip() if description_match else None
  195. docstring_args = args_match.group(1).strip() if args_match else None
  196. returns = returns_match.group(1).strip() if returns_match else None
  197. # Parsing the arguments into a dictionary
  198. if docstring_args is not None:
  199. docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines
  200. matches = args_split_re.findall(docstring_args)
  201. args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches}
  202. else:
  203. args_dict = {}
  204. return description, args_dict, returns
  205. def get_json_schema(func: Callable) -> dict:
  206. """
  207. This function generates a JSON schema for a given function, based on its docstring and type hints. This is
  208. mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
  209. the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires
  210. that the function has a docstring, and that each argument has a description in the docstring, in the standard
  211. Google docstring format shown below. It also requires that all user-facing arguments have valid Python type hints.
  212. When passing methods, implicit receiver arguments (`self` or `cls`) are ignored.
  213. Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is
  214. optional because most chat templates ignore the return value of the function.
  215. Args:
  216. func: The function to generate a JSON schema for.
  217. Returns:
  218. A dictionary containing the JSON schema for the function.
  219. Examples:
  220. ```python
  221. >>> def multiply(x: float, y: float):
  222. >>> '''
  223. >>> A function that multiplies two numbers
  224. >>>
  225. >>> Args:
  226. >>> x: The first number to multiply
  227. >>> y: The second number to multiply
  228. >>> '''
  229. >>> return x * y
  230. >>>
  231. >>> print(get_json_schema(multiply))
  232. {
  233. "name": "multiply",
  234. "description": "A function that multiplies two numbers",
  235. "parameters": {
  236. "type": "object",
  237. "properties": {
  238. "x": {"type": "number", "description": "The first number to multiply"},
  239. "y": {"type": "number", "description": "The second number to multiply"}
  240. },
  241. "required": ["x", "y"]
  242. }
  243. }
  244. ```
  245. The general use for these schemas is that they are used to generate tool descriptions for chat templates that
  246. support them, like so:
  247. ```python
  248. >>> from transformers import AutoTokenizer
  249. >>> from transformers.utils import get_json_schema
  250. >>>
  251. >>> def multiply(x: float, y: float):
  252. >>> '''
  253. >>> A function that multiplies two numbers
  254. >>>
  255. >>> Args:
  256. >>> x: The first number to multiply
  257. >>> y: The second number to multiply
  258. >>> return x * y
  259. >>> '''
  260. >>>
  261. >>> multiply_schema = get_json_schema(multiply)
  262. >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
  263. >>> messages = [{"role": "user", "content": "What is 179 x 4571?"}]
  264. >>> formatted_chat = tokenizer.apply_chat_template(
  265. >>> messages,
  266. >>> tools=[multiply_schema],
  267. >>> chat_template="tool_use",
  268. >>> return_dict=True,
  269. >>> return_tensors="pt",
  270. >>> add_generation_prompt=True
  271. >>> )
  272. >>> # The formatted chat can now be passed to model.generate()
  273. ```
  274. Each argument description can also have an optional `(choices: ...)` block at the end, such as
  275. `(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will
  276. only be parsed correctly if it is at the end of the line:
  277. ```python
  278. >>> def drink_beverage(beverage: str):
  279. >>> '''
  280. >>> A function that drinks a beverage
  281. >>>
  282. >>> Args:
  283. >>> beverage: The beverage to drink (choices: ["tea", "coffee"])
  284. >>> '''
  285. >>> pass
  286. >>>
  287. >>> print(get_json_schema(drink_beverage))
  288. ```
  289. {
  290. 'name': 'drink_beverage',
  291. 'description': 'A function that drinks a beverage',
  292. 'parameters': {
  293. 'type': 'object',
  294. 'properties': {
  295. 'beverage': {
  296. 'type': 'string',
  297. 'enum': ['tea', 'coffee'],
  298. 'description': 'The beverage to drink'
  299. }
  300. },
  301. 'required': ['beverage']
  302. }
  303. }
  304. """
  305. doc = inspect.getdoc(func)
  306. func_name = getattr(func, "__name__", "operation")
  307. if not doc:
  308. raise DocstringParsingException(f"Cannot generate JSON schema for {func_name} because it has no docstring!")
  309. doc = doc.strip()
  310. main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc)
  311. json_schema = _convert_type_hints_to_json_schema(func)
  312. if (return_dict := json_schema["properties"].pop("return", None)) is not None:
  313. if return_doc is not None: # We allow a missing return docstring since most templates ignore it
  314. return_dict["description"] = return_doc
  315. for arg, schema in json_schema["properties"].items():
  316. if arg not in param_descriptions:
  317. raise DocstringParsingException(
  318. f"Cannot generate JSON schema for {func_name} because the docstring has no description for the argument '{arg}'"
  319. )
  320. desc = param_descriptions[arg]
  321. enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE)
  322. if enum_choices:
  323. schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))]
  324. desc = enum_choices.string[: enum_choices.start()].strip()
  325. schema["description"] = desc
  326. output = {"name": func_name, "description": main_doc, "parameters": json_schema}
  327. if return_dict is not None:
  328. output["return"] = return_dict
  329. return {"type": "function", "function": output}
  330. @lru_cache
  331. @no_type_check
  332. def _get_template_variables(chat_template: str) -> frozenset[str]:
  333. """Return the set of undeclared variables referenced by a chat template.
  334. Uses ``jinja2.meta.find_undeclared_variables`` so that callers can
  335. automatically distinguish template-level kwargs from processor kwargs
  336. without maintaining a manual allowlist. Needed only to support BC as we
  337. allowed all `kwargs` to be merged into one in the past
  338. """
  339. compiled = _compile_jinja_template(chat_template)
  340. ast = compiled.environment.parse(chat_template)
  341. return frozenset(jinja2.meta.find_undeclared_variables(ast))
  342. def _render_with_assistant_indices(
  343. compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
  344. ):
  345. rendered_blocks = []
  346. generation_indices = []
  347. with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
  348. for block in compiled_template.generate(
  349. messages=messages,
  350. tools=tools,
  351. documents=documents,
  352. add_generation_prompt=add_generation_prompt,
  353. **template_kwargs,
  354. ):
  355. rendered_blocks.append(block)
  356. rendered_chat = "".join(rendered_blocks)
  357. return rendered_chat, generation_indices
  358. @lru_cache
  359. def _compile_jinja_template(chat_template):
  360. return _cached_compile_jinja_template(chat_template)
  361. @no_type_check
  362. def _cached_compile_jinja_template(chat_template):
  363. if not is_jinja_available():
  364. raise ImportError(
  365. "apply_chat_template requires jinja2 to be installed. Please install it using `pip install jinja2`."
  366. )
  367. class AssistantTracker(Extension):
  368. # This extension is used to track the indices of assistant-generated tokens in the rendered chat
  369. tags = {"generation"}
  370. def __init__(self, environment: ImmutableSandboxedEnvironment):
  371. # The class is only initiated by jinja.
  372. super().__init__(environment)
  373. environment.extend(activate_tracker=self.activate_tracker)
  374. self._rendered_blocks = None
  375. self._generation_indices = None
  376. def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
  377. lineno = next(parser.stream).lineno
  378. body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
  379. return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)
  380. @jinja2.pass_eval_context
  381. def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
  382. rv = caller()
  383. if self.is_active():
  384. # Only track generation indices if the tracker is active
  385. start_index = len("".join(self._rendered_blocks))
  386. end_index = start_index + len(rv)
  387. self._generation_indices.append((start_index, end_index))
  388. return rv
  389. def is_active(self) -> bool:
  390. return self._rendered_blocks is not None or self._generation_indices is not None
  391. @contextmanager
  392. def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
  393. try:
  394. if self.is_active():
  395. raise ValueError("AssistantTracker should not be reused before closed")
  396. self._rendered_blocks = rendered_blocks
  397. self._generation_indices = generation_indices
  398. yield
  399. finally:
  400. self._rendered_blocks = None
  401. self._generation_indices = None
  402. if version.parse(jinja2.__version__) < version.parse("3.1.0"):
  403. raise ImportError(
  404. f"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is {jinja2.__version__}."
  405. )
  406. def raise_exception(message):
  407. raise jinja2.exceptions.TemplateError(message)
  408. def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
  409. # We override the built-in tojson filter because Jinja's default filter escapes HTML characters
  410. # We also expose some options like custom indents and separators
  411. return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
  412. def strftime_now(format):
  413. return datetime.now().strftime(format)
  414. jinja_env = ImmutableSandboxedEnvironment(
  415. trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols]
  416. )
  417. jinja_env.filters["tojson"] = tojson
  418. jinja_env.globals["raise_exception"] = raise_exception
  419. jinja_env.globals["strftime_now"] = strftime_now
  420. return jinja_env.from_string(chat_template)
  421. def render_jinja_template(
  422. conversations: list[ChatType],
  423. tools: list[dict | Callable] | None = None,
  424. documents: ChatType | None = None,
  425. chat_template: str | None = None,
  426. return_assistant_tokens_mask: bool = False,
  427. continue_final_message: bool = False,
  428. add_generation_prompt: bool = False,
  429. **kwargs,
  430. ) -> str:
  431. if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template):
  432. logger.warning_once(
  433. "return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword."
  434. )
  435. # Compilation function uses a cache to avoid recompiling the same template
  436. compiled_template = _compile_jinja_template(chat_template)
  437. # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
  438. if tools is not None:
  439. tool_schemas = []
  440. for tool in tools:
  441. if isinstance(tool, dict):
  442. tool_schemas.append(tool)
  443. elif isfunction(tool) or inspect.ismethod(tool):
  444. tool_schemas.append(get_json_schema(tool))
  445. else:
  446. raise ValueError(
  447. "Tools should either be a JSON schema, or a callable function with type hints "
  448. "and a docstring suitable for auto-conversion to a schema."
  449. )
  450. else:
  451. tool_schemas = None
  452. if documents is not None:
  453. for document in documents:
  454. if not isinstance(document, dict):
  455. raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")
  456. rendered = []
  457. all_generation_indices = []
  458. continue_final_message_tag = "CONTINUE_FINAL_MESSAGE_TAG "
  459. for chat in conversations:
  460. if hasattr(chat, "messages"):
  461. # Indicates it's a Conversation object
  462. chat = chat.messages
  463. if continue_final_message:
  464. chat = deepcopy(chat)
  465. final_message = chat[-1]["content"]
  466. if isinstance(final_message, (list, tuple)):
  467. for content_block in reversed(final_message):
  468. if "text" in content_block:
  469. # Pick the last text block in the message (the first one we hit while iterating in reverse)
  470. final_message = content_block["text"]
  471. content_block["text"] = content_block["text"] + continue_final_message_tag
  472. break
  473. else:
  474. raise ValueError(
  475. "continue_final_message is set but we could not find any text to continue in the final message!"
  476. )
  477. else:
  478. chat[-1]["content"] = chat[-1]["content"] + continue_final_message_tag
  479. if return_assistant_tokens_mask:
  480. rendered_chat, generation_indices = _render_with_assistant_indices(
  481. compiled_template=compiled_template,
  482. messages=chat,
  483. tools=tool_schemas,
  484. documents=documents,
  485. add_generation_prompt=add_generation_prompt,
  486. **kwargs,
  487. )
  488. all_generation_indices.append(generation_indices)
  489. else:
  490. rendered_chat = compiled_template.render(
  491. messages=chat,
  492. tools=tool_schemas,
  493. documents=documents,
  494. add_generation_prompt=add_generation_prompt,
  495. **kwargs,
  496. )
  497. if continue_final_message:
  498. if (final_message.strip() not in rendered_chat) or (
  499. continue_final_message_tag.strip() not in rendered_chat
  500. ):
  501. raise ValueError(
  502. "continue_final_message is set but the final message does not appear in the chat after "
  503. "applying the chat template! This can happen if the chat template deletes portions of "
  504. "the final message. Please verify the chat template and final message in your chat to "
  505. "ensure they are compatible."
  506. )
  507. tag_loc = rendered_chat.rindex(continue_final_message_tag.strip())
  508. if rendered_chat[tag_loc : tag_loc + len(continue_final_message_tag)] == continue_final_message_tag:
  509. # The template preserves spacing, so things are simple
  510. rendered_chat = rendered_chat[:tag_loc]
  511. else:
  512. # The message has trailing spacing that was trimmed, so we must be more cautious
  513. rendered_chat = rendered_chat[:tag_loc].rstrip()
  514. rendered.append(rendered_chat)
  515. return rendered, all_generation_indices
  516. def is_valid_message(message):
  517. """
  518. Check that input is a valid message in a chat, namely a dict with "role" and "content" keys.
  519. """
  520. if not isinstance(message, dict):
  521. return False
  522. if not ("role" in message and "content" in message):
  523. return False
  524. return True
  525. class Chat:
  526. """This class is intended to just be used internally for pipelines and not exposed to users. We convert chats
  527. to this format because the rest of the pipeline code tends to assume that lists of messages are
  528. actually a batch of samples rather than messages in the same conversation."""
  529. def __init__(self, messages: dict):
  530. for message in messages:
  531. if not is_valid_message(message):
  532. raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
  533. self.messages = messages