| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618 |
- # Copyright 2024 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import inspect
- import json
- import re
- import types
- from collections.abc import Callable
- from contextlib import contextmanager
- from copy import deepcopy
- from datetime import datetime
- from functools import lru_cache
- from inspect import isfunction
- from typing import Any, Literal, Union, get_args, get_origin, get_type_hints, no_type_check
- from packaging import version
- from . import logging
- from .import_utils import is_jinja_available, is_torch_available, is_vision_available
- logger = logging.get_logger(__name__)
- if is_jinja_available():
- import jinja2
- import jinja2.exceptions
- import jinja2.ext
- import jinja2.meta
- import jinja2.nodes
- import jinja2.runtime
- from jinja2.ext import Extension
- from jinja2.sandbox import ImmutableSandboxedEnvironment
- else:
- jinja2 = None
- if is_vision_available():
- from PIL.Image import Image
- ChatType = list[dict[str, Any]]
- BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
- # Extracts the initial segment of the docstring, containing the function description
- description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL)
- # Extracts the Args: block from the docstring
- args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL)
- # Splits the Args: block into individual arguments
- args_split_re = re.compile(
- r"""
- (?:^|\n) # Match the start of the args block, or a newline
- \s*(\w+):\s* # Capture the argument name and strip spacing
- (.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing
- (?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block
- """,
- re.DOTALL | re.VERBOSE,
- )
- # Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc!
- returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL)
- class TypeHintParsingException(Exception):
- """Exception raised for errors in parsing type hints to generate JSON schemas"""
- class DocstringParsingException(Exception):
- """Exception raised for errors in parsing docstrings to generate JSON schemas"""
- def _get_json_schema_type(param_type: type) -> dict[str, str]:
- type_mapping = {
- int: {"type": "integer"},
- float: {"type": "number"},
- str: {"type": "string"},
- bool: {"type": "boolean"},
- type(None): {"type": "null"},
- Any: {},
- }
- if is_vision_available():
- type_mapping[Image] = {"type": "image"}
- if is_torch_available():
- import torch
- type_mapping[torch.Tensor] = {"type": "audio"}
- return type_mapping.get(param_type, {"type": "object"})
- def _parse_type_hint(hint: str) -> dict:
- origin = get_origin(hint)
- args = get_args(hint)
- if origin is None:
- try:
- return _get_json_schema_type(hint)
- except KeyError:
- raise TypeHintParsingException(
- "Couldn't parse this type hint, likely due to a custom class or object: ", hint
- )
- elif origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
- # Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
- subtypes = [_parse_type_hint(t) for t in args if t is not type(None)]
- if len(subtypes) == 1:
- # A single non-null type can be expressed directly
- return_dict = subtypes[0]
- elif all("type" in subtype and isinstance(subtype["type"], str) for subtype in subtypes):
- # A union of basic types can be expressed as a list in the schema
- return_dict = {"type": sorted([subtype["type"] for subtype in subtypes])}
- else:
- # A union of more complex types requires "anyOf"
- return_dict = {"anyOf": subtypes}
- if type(None) in args:
- return_dict["nullable"] = True
- return return_dict
- elif origin is Literal and len(args) > 0:
- LITERAL_TYPES = (int, float, str, bool, type(None))
- args_types = []
- for arg in args:
- if type(arg) not in LITERAL_TYPES:
- raise TypeHintParsingException("Only the valid python literals can be listed in typing.Literal.")
- arg_type = _get_json_schema_type(type(arg)).get("type")
- if arg_type is not None and arg_type not in args_types:
- args_types.append(arg_type)
- return {
- "type": args_types.pop() if len(args_types) == 1 else list(args_types),
- "enum": list(args),
- }
- elif origin is list:
- if not args:
- return {"type": "array"}
- else:
- # Lists can only have a single type argument, so recurse into it
- return {"type": "array", "items": _parse_type_hint(args[0])}
- elif origin is tuple:
- if not args:
- return {"type": "array"}
- if len(args) == 1:
- raise TypeHintParsingException(
- f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which "
- "we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain "
- "more than one element, we recommend "
- "using a list[] type instead, or if it really is a single element, remove the tuple[] wrapper and just "
- "pass the element directly."
- )
- if ... in args:
- raise TypeHintParsingException(
- "Conversion of '...' is not supported in Tuple type hints. "
- "Use list[] types for variable-length"
- " inputs instead."
- )
- return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]}
- elif origin is dict:
- # The JSON equivalent to a dict is 'object', which mandates that all keys are strings
- # However, we can specify the type of the dict values with "additionalProperties"
- out = {"type": "object"}
- if len(args) == 2:
- out["additionalProperties"] = _parse_type_hint(args[1])
- return out
- raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
- def _convert_type_hints_to_json_schema(func: Callable) -> dict:
- type_hints = get_type_hints(func)
- signature = inspect.signature(func)
- func_name = getattr(func, "__name__", "operation")
- # For methods, we need to ignore the first "self" or "cls" parameter. Here we assume that if the first parameter
- # is named "self" or "cls" and has no type hint, it is an implicit receiver argument.
- first_param_name = next(iter(signature.parameters), None)
- if (
- first_param_name in {"self", "cls"}
- and signature.parameters[first_param_name].annotation == inspect.Parameter.empty
- ):
- implicit_arg_name = first_param_name
- else:
- implicit_arg_name = None
- required = []
- for param_name, param in signature.parameters.items():
- if param_name == implicit_arg_name:
- continue
- if param.annotation == inspect.Parameter.empty:
- raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func_name}")
- if param.default == inspect.Parameter.empty:
- required.append(param_name)
- properties = {}
- for param_name, param_type in type_hints.items():
- if param_name == implicit_arg_name:
- continue
- properties[param_name] = _parse_type_hint(param_type)
- schema = {"type": "object", "properties": properties}
- if required:
- schema["required"] = required
- return schema
- def parse_google_format_docstring(docstring: str) -> tuple[str | None, dict | None, str | None]:
- """
- Parses a Google-style docstring to extract the function description,
- argument descriptions, and return description.
- Args:
- docstring (str): The docstring to parse.
- Returns:
- The function description, arguments, and return description.
- """
- # Extract the sections
- description_match = description_re.search(docstring)
- args_match = args_re.search(docstring)
- returns_match = returns_re.search(docstring)
- # Clean and store the sections
- description = description_match.group(1).strip() if description_match else None
- docstring_args = args_match.group(1).strip() if args_match else None
- returns = returns_match.group(1).strip() if returns_match else None
- # Parsing the arguments into a dictionary
- if docstring_args is not None:
- docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines
- matches = args_split_re.findall(docstring_args)
- args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches}
- else:
- args_dict = {}
- return description, args_dict, returns
- def get_json_schema(func: Callable) -> dict:
- """
- This function generates a JSON schema for a given function, based on its docstring and type hints. This is
- mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
- the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires
- that the function has a docstring, and that each argument has a description in the docstring, in the standard
- Google docstring format shown below. It also requires that all user-facing arguments have valid Python type hints.
- When passing methods, implicit receiver arguments (`self` or `cls`) are ignored.
- Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is
- optional because most chat templates ignore the return value of the function.
- Args:
- func: The function to generate a JSON schema for.
- Returns:
- A dictionary containing the JSON schema for the function.
- Examples:
- ```python
- >>> def multiply(x: float, y: float):
- >>> '''
- >>> A function that multiplies two numbers
- >>>
- >>> Args:
- >>> x: The first number to multiply
- >>> y: The second number to multiply
- >>> '''
- >>> return x * y
- >>>
- >>> print(get_json_schema(multiply))
- {
- "name": "multiply",
- "description": "A function that multiplies two numbers",
- "parameters": {
- "type": "object",
- "properties": {
- "x": {"type": "number", "description": "The first number to multiply"},
- "y": {"type": "number", "description": "The second number to multiply"}
- },
- "required": ["x", "y"]
- }
- }
- ```
- The general use for these schemas is that they are used to generate tool descriptions for chat templates that
- support them, like so:
- ```python
- >>> from transformers import AutoTokenizer
- >>> from transformers.utils import get_json_schema
- >>>
- >>> def multiply(x: float, y: float):
- >>> '''
- >>> A function that multiplies two numbers
- >>>
- >>> Args:
- >>> x: The first number to multiply
- >>> y: The second number to multiply
- >>> return x * y
- >>> '''
- >>>
- >>> multiply_schema = get_json_schema(multiply)
- >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
- >>> messages = [{"role": "user", "content": "What is 179 x 4571?"}]
- >>> formatted_chat = tokenizer.apply_chat_template(
- >>> messages,
- >>> tools=[multiply_schema],
- >>> chat_template="tool_use",
- >>> return_dict=True,
- >>> return_tensors="pt",
- >>> add_generation_prompt=True
- >>> )
- >>> # The formatted chat can now be passed to model.generate()
- ```
- Each argument description can also have an optional `(choices: ...)` block at the end, such as
- `(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will
- only be parsed correctly if it is at the end of the line:
- ```python
- >>> def drink_beverage(beverage: str):
- >>> '''
- >>> A function that drinks a beverage
- >>>
- >>> Args:
- >>> beverage: The beverage to drink (choices: ["tea", "coffee"])
- >>> '''
- >>> pass
- >>>
- >>> print(get_json_schema(drink_beverage))
- ```
- {
- 'name': 'drink_beverage',
- 'description': 'A function that drinks a beverage',
- 'parameters': {
- 'type': 'object',
- 'properties': {
- 'beverage': {
- 'type': 'string',
- 'enum': ['tea', 'coffee'],
- 'description': 'The beverage to drink'
- }
- },
- 'required': ['beverage']
- }
- }
- """
- doc = inspect.getdoc(func)
- func_name = getattr(func, "__name__", "operation")
- if not doc:
- raise DocstringParsingException(f"Cannot generate JSON schema for {func_name} because it has no docstring!")
- doc = doc.strip()
- main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc)
- json_schema = _convert_type_hints_to_json_schema(func)
- if (return_dict := json_schema["properties"].pop("return", None)) is not None:
- if return_doc is not None: # We allow a missing return docstring since most templates ignore it
- return_dict["description"] = return_doc
- for arg, schema in json_schema["properties"].items():
- if arg not in param_descriptions:
- raise DocstringParsingException(
- f"Cannot generate JSON schema for {func_name} because the docstring has no description for the argument '{arg}'"
- )
- desc = param_descriptions[arg]
- enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE)
- if enum_choices:
- schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))]
- desc = enum_choices.string[: enum_choices.start()].strip()
- schema["description"] = desc
- output = {"name": func_name, "description": main_doc, "parameters": json_schema}
- if return_dict is not None:
- output["return"] = return_dict
- return {"type": "function", "function": output}
- @lru_cache
- @no_type_check
- def _get_template_variables(chat_template: str) -> frozenset[str]:
- """Return the set of undeclared variables referenced by a chat template.
- Uses ``jinja2.meta.find_undeclared_variables`` so that callers can
- automatically distinguish template-level kwargs from processor kwargs
- without maintaining a manual allowlist. Needed only to support BC as we
- allowed all `kwargs` to be merged into one in the past
- """
- compiled = _compile_jinja_template(chat_template)
- ast = compiled.environment.parse(chat_template)
- return frozenset(jinja2.meta.find_undeclared_variables(ast))
- def _render_with_assistant_indices(
- compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
- ):
- rendered_blocks = []
- generation_indices = []
- with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
- for block in compiled_template.generate(
- messages=messages,
- tools=tools,
- documents=documents,
- add_generation_prompt=add_generation_prompt,
- **template_kwargs,
- ):
- rendered_blocks.append(block)
- rendered_chat = "".join(rendered_blocks)
- return rendered_chat, generation_indices
- @lru_cache
- def _compile_jinja_template(chat_template):
- return _cached_compile_jinja_template(chat_template)
- @no_type_check
- def _cached_compile_jinja_template(chat_template):
- if not is_jinja_available():
- raise ImportError(
- "apply_chat_template requires jinja2 to be installed. Please install it using `pip install jinja2`."
- )
- class AssistantTracker(Extension):
- # This extension is used to track the indices of assistant-generated tokens in the rendered chat
- tags = {"generation"}
- def __init__(self, environment: ImmutableSandboxedEnvironment):
- # The class is only initiated by jinja.
- super().__init__(environment)
- environment.extend(activate_tracker=self.activate_tracker)
- self._rendered_blocks = None
- self._generation_indices = None
- def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
- lineno = next(parser.stream).lineno
- body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
- return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)
- @jinja2.pass_eval_context
- def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
- rv = caller()
- if self.is_active():
- # Only track generation indices if the tracker is active
- start_index = len("".join(self._rendered_blocks))
- end_index = start_index + len(rv)
- self._generation_indices.append((start_index, end_index))
- return rv
- def is_active(self) -> bool:
- return self._rendered_blocks is not None or self._generation_indices is not None
- @contextmanager
- def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
- try:
- if self.is_active():
- raise ValueError("AssistantTracker should not be reused before closed")
- self._rendered_blocks = rendered_blocks
- self._generation_indices = generation_indices
- yield
- finally:
- self._rendered_blocks = None
- self._generation_indices = None
- if version.parse(jinja2.__version__) < version.parse("3.1.0"):
- raise ImportError(
- f"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is {jinja2.__version__}."
- )
- def raise_exception(message):
- raise jinja2.exceptions.TemplateError(message)
- def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
- # We override the built-in tojson filter because Jinja's default filter escapes HTML characters
- # We also expose some options like custom indents and separators
- return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
- def strftime_now(format):
- return datetime.now().strftime(format)
- jinja_env = ImmutableSandboxedEnvironment(
- trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols]
- )
- jinja_env.filters["tojson"] = tojson
- jinja_env.globals["raise_exception"] = raise_exception
- jinja_env.globals["strftime_now"] = strftime_now
- return jinja_env.from_string(chat_template)
- def render_jinja_template(
- conversations: list[ChatType],
- tools: list[dict | Callable] | None = None,
- documents: ChatType | None = None,
- chat_template: str | None = None,
- return_assistant_tokens_mask: bool = False,
- continue_final_message: bool = False,
- add_generation_prompt: bool = False,
- **kwargs,
- ) -> str:
- if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template):
- logger.warning_once(
- "return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword."
- )
- # Compilation function uses a cache to avoid recompiling the same template
- compiled_template = _compile_jinja_template(chat_template)
- # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
- if tools is not None:
- tool_schemas = []
- for tool in tools:
- if isinstance(tool, dict):
- tool_schemas.append(tool)
- elif isfunction(tool) or inspect.ismethod(tool):
- tool_schemas.append(get_json_schema(tool))
- else:
- raise ValueError(
- "Tools should either be a JSON schema, or a callable function with type hints "
- "and a docstring suitable for auto-conversion to a schema."
- )
- else:
- tool_schemas = None
- if documents is not None:
- for document in documents:
- if not isinstance(document, dict):
- raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")
- rendered = []
- all_generation_indices = []
- continue_final_message_tag = "CONTINUE_FINAL_MESSAGE_TAG "
- for chat in conversations:
- if hasattr(chat, "messages"):
- # Indicates it's a Conversation object
- chat = chat.messages
- if continue_final_message:
- chat = deepcopy(chat)
- final_message = chat[-1]["content"]
- if isinstance(final_message, (list, tuple)):
- for content_block in reversed(final_message):
- if "text" in content_block:
- # Pick the last text block in the message (the first one we hit while iterating in reverse)
- final_message = content_block["text"]
- content_block["text"] = content_block["text"] + continue_final_message_tag
- break
- else:
- raise ValueError(
- "continue_final_message is set but we could not find any text to continue in the final message!"
- )
- else:
- chat[-1]["content"] = chat[-1]["content"] + continue_final_message_tag
- if return_assistant_tokens_mask:
- rendered_chat, generation_indices = _render_with_assistant_indices(
- compiled_template=compiled_template,
- messages=chat,
- tools=tool_schemas,
- documents=documents,
- add_generation_prompt=add_generation_prompt,
- **kwargs,
- )
- all_generation_indices.append(generation_indices)
- else:
- rendered_chat = compiled_template.render(
- messages=chat,
- tools=tool_schemas,
- documents=documents,
- add_generation_prompt=add_generation_prompt,
- **kwargs,
- )
- if continue_final_message:
- if (final_message.strip() not in rendered_chat) or (
- continue_final_message_tag.strip() not in rendered_chat
- ):
- raise ValueError(
- "continue_final_message is set but the final message does not appear in the chat after "
- "applying the chat template! This can happen if the chat template deletes portions of "
- "the final message. Please verify the chat template and final message in your chat to "
- "ensure they are compatible."
- )
- tag_loc = rendered_chat.rindex(continue_final_message_tag.strip())
- if rendered_chat[tag_loc : tag_loc + len(continue_final_message_tag)] == continue_final_message_tag:
- # The template preserves spacing, so things are simple
- rendered_chat = rendered_chat[:tag_loc]
- else:
- # The message has trailing spacing that was trimmed, so we must be more cautious
- rendered_chat = rendered_chat[:tag_loc].rstrip()
- rendered.append(rendered_chat)
- return rendered, all_generation_indices
- def is_valid_message(message):
- """
- Check that input is a valid message in a chat, namely a dict with "role" and "content" keys.
- """
- if not isinstance(message, dict):
- return False
- if not ("role" in message and "content" in message):
- return False
- return True
- class Chat:
- """This class is intended to just be used internally for pipelines and not exposed to users. We convert chats
- to this format because the rest of the pipeline code tends to assume that lists of messages are
- actually a batch of samples rather than messages in the same conversation."""
- def __init__(self, messages: dict):
- for message in messages:
- if not is_valid_message(message):
- raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
- self.messages = messages
|