| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859 |
- # mypy: allow-untyped-defs
- """
- Utils shared by different modes of quantization (eager/graph)
- """
- import functools
- import warnings
- from collections import OrderedDict
- from collections.abc import Callable
- from inspect import getfullargspec, signature
- from typing import Any
- from typing_extensions import TypeAliasType
- import torch
- from torch.ao.quantization.quant_type import QuantType
- from torch.fx import Node
- from torch.nn.utils.parametrize import is_parametrized
- NodePattern = TypeAliasType(
- "NodePattern", tuple[Node, Node] | tuple[Node, tuple[Node, Node]] | Any
- )
- # This is the Quantizer class instance from torch/quantization/fx/quantize.py.
- # Define separately to prevent circular imports.
- # TODO(future PR): improve this.
- # make this public once fixed (can't be public as is because setting the module directly
- # doesn't work)
- QuantizerCls = Any
- # Type for fusion patterns, it can be more complicated than the following actually,
- # see pattern.md for docs
- # TODO: not sure if typing supports recursive data types
- Pattern = TypeAliasType(
- "Pattern",
- Callable
- | tuple[Callable, Callable]
- | tuple[Callable, tuple[Callable, Callable]]
- | Any,
- )
- # TODO: maybe rename this to MatchInputNode
- class MatchAllNode:
- """A node pattern that matches all nodes, used in defining
- fusion patterns in FX Graph Mode Quantization
- """
- module_type_list = {
- torch.nn.ReLU,
- torch.nn.ReLU6,
- torch.nn.AdaptiveAvgPool1d,
- torch.nn.AdaptiveAvgPool2d,
- torch.nn.AdaptiveAvgPool3d,
- torch.nn.AvgPool1d,
- torch.nn.AvgPool2d,
- torch.nn.AvgPool3d,
- torch.nn.MaxPool1d,
- torch.nn.MaxPool2d,
- torch.nn.MaxPool3d,
- torch.nn.Identity,
- torch.nn.Hardsigmoid,
- torch.nn.Sigmoid,
- torch.nn.Tanh,
- }
- func_list = {
- torch.nn.functional.adaptive_avg_pool1d,
- torch.nn.functional.adaptive_avg_pool2d,
- torch.nn.functional.adaptive_avg_pool3d,
- torch.nn.functional.elu,
- torch.nn.functional.hardswish,
- torch.nn.functional.instance_norm,
- torch.nn.functional.layer_norm,
- torch.nn.functional.leaky_relu,
- torch.nn.functional.silu,
- torch.nn.functional.mish,
- torch.nn.functional.dropout,
- torch.nn.functional.max_pool1d,
- torch.nn.functional.max_pool2d,
- torch.nn.functional.max_pool3d,
- torch.nn.functional.relu,
- torch.nn.functional.hardtanh,
- torch.nn.functional.hardtanh_,
- torch.nn.functional.hardsigmoid,
- torch.nn.functional.sigmoid,
- torch.transpose,
- torch.repeat_interleave,
- torch.sigmoid,
- torch.squeeze,
- torch.stack,
- torch.sum,
- torch.tanh,
- torch.unsqueeze,
- torch.cat,
- }
- method_list = {
- torch.mean,
- "relu",
- "relu_",
- "contiguous",
- "detach",
- "detach_",
- "hardsigmoid",
- "hardsigmoid_",
- "permute",
- "repeat",
- "repeat_interleave",
- "reshape",
- "resize_",
- "shape",
- "sigmoid",
- "sigmoid_",
- "size",
- "squeeze",
- "squeeze_",
- "tanh",
- "tanh_",
- "transpose",
- "unsqueeze",
- "unsqueeze_",
- "view",
- }
- # TODO: not used now, remove
- def check_node(node, modules):
- # TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py
- is_call_function = node.op == "call_function" and node.target in func_list
- is_call_method = node.op == "call_method" and node.target in method_list
- is_call_module = (
- node.op == "call_module" and type(modules[str(node.target)]) in module_type_list
- )
- return is_call_function, is_call_method, is_call_module
- def get_combined_dict(default_dict, additional_dict):
- """
- Combines two dictionaries.
- This function takes two dictionaries as input and returns a new dictionary
- that contains all the key-value pairs from both input dictionaries.
- If there are any duplicate keys in the `additional_dict`, the values
- from the `additional_dict` will overwrite those in the `default_dict`.
- Args:
- default_dict (dict): The main dictionary that will be used as the base
- additional_dict (dict): The dictionary used to update `default_dict`
- Returns:
- dict: The resulting dictionary
- Example:
- >>> x = dict(a=1, b=1)
- >>> y = dict(b=2, c=3)
- >>> get_combined_dict(x, y)
- {'a': 1, 'b': 2, 'c': 3}
- """
- d = default_dict.copy()
- d.update(additional_dict)
- return d
- def is_per_tensor(qscheme):
- return qscheme == torch.per_tensor_affine or qscheme == torch.per_tensor_symmetric
- def is_per_channel(qscheme):
- return qscheme in [
- torch.per_channel_affine,
- torch.per_channel_affine_float_qparams,
- torch.per_channel_symmetric,
- ]
- def getattr_from_fqn(obj: Any, fqn: str) -> Any:
- """
- Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz.
- """
- return functools.reduce(getattr, fqn.split("."), obj)
- def to_underlying_dtype(qdtype):
- DTYPE_MAPPING = {
- torch.quint8: torch.uint8,
- torch.qint8: torch.int8,
- torch.qint32: torch.int32,
- torch.quint4x2: torch.uint8,
- torch.quint2x4: torch.uint8,
- torch.uint8: torch.uint8,
- torch.int8: torch.int8,
- torch.uint16: torch.uint16,
- torch.int16: torch.int16,
- torch.int32: torch.int32,
- torch.float8_e5m2: torch.float8_e5m2,
- torch.float8_e4m3fn: torch.float8_e4m3fn,
- }
- if qdtype not in DTYPE_MAPPING:
- raise AssertionError("Unsupported dtype: " + str(qdtype))
- return DTYPE_MAPPING[qdtype]
- def get_qparam_dict(observer_or_fake_quant):
- from torch.ao.quantization.observer import PlaceholderObserver
- qscheme = getattr(observer_or_fake_quant, "qscheme", None)
- dtype = observer_or_fake_quant.dtype
- qparams = {"qscheme": qscheme, "dtype": dtype}
- if not qscheme or isinstance(observer_or_fake_quant, PlaceholderObserver):
- return {"qscheme": None, "dtype": dtype}
- if is_per_tensor(qscheme):
- qscheme = torch.per_tensor_affine
- elif is_per_channel(qscheme):
- # change symmetric to affine since we do not have symmetric
- # quantized Tensor
- if qscheme == torch.per_channel_symmetric:
- qscheme = torch.per_channel_affine
- qparams["axis"] = observer_or_fake_quant.ch_axis
- else:
- raise RuntimeError(f"Unrecognized qscheme: {qscheme}")
- # update qscheme, since we don't have symmetric quant qscheme
- # in quantized Tensor
- qparams["qscheme"] = qscheme
- scale, zero_point = observer_or_fake_quant.calculate_qparams()
- qparams["scale"] = scale
- qparams["zero_point"] = zero_point
- if hasattr(observer_or_fake_quant, "quant_min"):
- qparams["quant_min"] = observer_or_fake_quant.quant_min
- if hasattr(observer_or_fake_quant, "quant_max"):
- qparams["quant_max"] = observer_or_fake_quant.quant_max
- return qparams
- def get_swapped_custom_module_class(
- custom_module, custom_module_class_mapping, qconfig
- ):
- """Get the observed/quantized custom module class that we need
- to swap `custom_module` to
- Input:
- custom_module: input, can be an instance of either a float or observed custom module
- custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
- qconfig: qconfig configured for the custom module
- Output:
- corresponding observed/quantized custom module class for input custom module instance
- """
- quant_type = get_quant_type(qconfig)
- class_mapping = custom_module_class_mapping.get(quant_type, {})
- if type(custom_module) not in class_mapping:
- raise AssertionError(
- "did not find corresponding observed "
- f"module class for {type(custom_module)} in mapping: {class_mapping}"
- )
- return class_mapping[type(custom_module)]
- def activation_dtype(qconfig):
- if qconfig is None:
- raise AssertionError("qconfig must be provided to determine activation dtype")
- activation = qconfig.activation()
- return activation.dtype
- def weight_dtype(qconfig):
- if qconfig is None:
- raise AssertionError("qconfig must be provided to determine weight dtype")
- weight = qconfig.weight()
- return weight.dtype
- def activation_is_statically_quantized(qconfig):
- """Given a qconfig, decide if the activation needs to be
- quantized or not, this includes quantizing to quint8, qint8 and qint32 and float16
- """
- return activation_dtype(qconfig) in [
- torch.quint8,
- torch.qint8,
- torch.qint32,
- torch.float16,
- torch.uint8,
- torch.int8,
- torch.int16,
- torch.int32,
- torch.float8_e5m2,
- torch.float8_e4m3fn,
- ] and (not activation_is_dynamically_quantized(qconfig))
- def activation_is_dynamically_quantized(qconfig):
- """Given a qconfig, decide if the activation needs to be
- dynamically quantized or not, this includes dynamically quantizing to
- quint8, qint8 and float16
- """
- _activation_dtype, _, activation_is_dynamic = get_qconfig_dtypes(qconfig)
- return activation_is_dynamic
- def activation_is_int8_quantized(qconfig):
- """Given a qconfig, decide if the activation needs to be
- quantized to int8 or not, this includes quantizing to quint8, qint8
- """
- return activation_dtype(qconfig) in [
- torch.quint8,
- torch.qint8,
- torch.uint8,
- torch.int8,
- ]
- def activation_is_int32_quantized(qconfig):
- """Given a qconfig, decide if the activation needs to be
- quantized to int32 or not
- """
- return activation_dtype(qconfig) in [torch.qint32, torch.int32]
- def weight_is_quantized(qconfig):
- """Given a qconfig, decide if the weight needs to be
- quantized or not
- """
- return weight_dtype(qconfig) in [
- torch.quint8,
- torch.qint8,
- torch.float16,
- torch.quint4x2,
- torch.uint8,
- torch.int8,
- torch.int16,
- torch.int32,
- torch.float8_e5m2,
- torch.float8_e4m3fn,
- ]
- def weight_is_statically_quantized(qconfig):
- """Given a qconfig, decide if the weight needs to be statically
- quantized or not
- """
- return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.uint8, torch.int8]
- def op_is_int8_dynamically_quantized(qconfig) -> bool:
- """Given a qconfig, returns True if this op is using int8 dynamic
- quantization
- """
- activation_dtype, weight_dtype, activation_is_dynamic = get_qconfig_dtypes(qconfig)
- return (
- activation_dtype in [torch.quint8, torch.uint8]
- and
- # for now, the lines below assume fbgemm or qnnpack
- weight_dtype in [torch.qint8, torch.int8]
- and activation_is_dynamic
- )
- def get_qconfig_dtypes(qconfig):
- r"""returns the qconfig tuple for qconfig:
- (activation_dtype, weight_dtype, activation_is_dynamic)
- """
- if qconfig is None:
- raise AssertionError("qconfig must be provided to extract dtypes")
- activation = qconfig.activation()
- weight = qconfig.weight()
- act_is_dynamic = getattr(activation, "is_dynamic", False)
- return (activation.dtype, weight.dtype, act_is_dynamic)
- def get_quant_type(qconfig):
- if qconfig is None:
- raise AssertionError("qconfig must be provided to determine quant type")
- activation = qconfig.activation()
- weight = qconfig.weight()
- static_dtypes = [
- torch.quint8,
- torch.qint8,
- torch.quint4x2,
- torch.qint32,
- torch.uint8,
- torch.int8,
- torch.int16,
- torch.int32,
- torch.float8_e5m2,
- torch.float8_e4m3fn,
- ]
- if weight.dtype in static_dtypes:
- if hasattr(activation, "is_dynamic") and activation.is_dynamic:
- return QuantType.DYNAMIC
- elif activation.dtype in static_dtypes:
- return QuantType.STATIC
- else:
- return QuantType.WEIGHT_ONLY
- if weight.dtype == torch.float16:
- if hasattr(activation, "is_dynamic") and activation.is_dynamic:
- return QuantType.DYNAMIC
- elif activation.dtype == torch.float16:
- return QuantType.STATIC
- raise Exception( # noqa: TRY002
- f"Unrecognized dtype combination in get_quant_type: activation({activation.dtype}),"
- f"weight({weight.dtype})"
- )
- def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
- """Checks if the given minimum and maximum values are valid, meaning that
- they exist and the min value is less than the max value.
- """
- if min_val.numel() == 0 or max_val.numel() == 0:
- warnings.warn(
- "must run observer before calling calculate_qparams. "
- + "Returning default values.",
- stacklevel=2,
- )
- return False
- if min_val.dim() == 0 or max_val.dim() == 0:
- if min_val == float("inf") and max_val == float("-inf"):
- warnings.warn(
- "must run observer before calling calculate_qparams. "
- + "Returning default values.",
- stacklevel=2,
- )
- return False
- if min_val > max_val:
- raise AssertionError(f"min {min_val} should be less than max {max_val}")
- else:
- if torch.any(min_val > max_val):
- raise AssertionError(f"min {min_val} should be less than max {max_val}")
- return True
- def calculate_qmin_qmax(
- quant_min: int,
- quant_max: int,
- has_customized_qrange: bool,
- dtype: torch.dtype,
- reduce_range: bool,
- ) -> tuple[int, int]:
- r"""Calculates actual qmin and qmax based on the quantization range,
- observer datatype and if range is reduced.
- """
- # TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted.
- if has_customized_qrange:
- # This initialization here is to be resolve TorchScript compilation issues and allow
- # using of refinement to decouple initial_qmin and initial_qmax from quantization range.
- # The actual values of initial_qmin and initial_qmax will be reset below.
- if dtype in [torch.qint32, torch.int32]:
- initial_quant_min, initial_quant_max = 0, 2**32 - 1
- else:
- initial_quant_min, initial_quant_max = 0, 255
- # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the
- # attribute from Optional valid integers for use, based on TorchScript's requirements.
- custom_quant_min, custom_quant_max = quant_min, quant_max
- if custom_quant_min is not None and custom_quant_max is not None:
- initial_quant_min, initial_quant_max = (
- custom_quant_min,
- custom_quant_max,
- )
- qrange_len = initial_quant_max - initial_quant_min + 1
- if dtype in [torch.qint8, torch.int8]:
- if not (0 < qrange_len <= 256):
- raise AssertionError(
- "quantization range should be positive and not exceed the maximum bit range (=256)."
- )
- elif dtype in [torch.qint32, torch.int32]:
- if not (0 < qrange_len <= 2**32):
- raise AssertionError(
- "quantization range should be positive and not exceed the maximum bit range (=4294967296)."
- )
- if reduce_range:
- quant_min, quant_max = quant_min // 2, quant_max // 2
- else:
- # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used.
- if dtype in [torch.qint8, torch.int8]:
- if reduce_range:
- quant_min, quant_max = -64, 63
- else:
- quant_min, quant_max = -128, 127
- elif dtype in [torch.quint8, torch.uint8]:
- if reduce_range:
- quant_min, quant_max = 0, 127
- else:
- quant_min, quant_max = 0, 255
- elif dtype in [torch.qint32, torch.int32]:
- quant_min, quant_max = -1 * (2**31), (2**31) - 1
- elif dtype == torch.uint16:
- quant_min, quant_max = 0, 2**16 - 1
- elif dtype == torch.int16:
- quant_min, quant_max = -(2**15), 2**15 - 1
- else:
- quant_min, quant_max = 0, 15
- return quant_min, quant_max
- def _parent_name(target):
- """
- Turn 'foo.bar' into ['foo', 'bar']
- """
- r = target.rsplit(".", 1)
- if len(r) == 1:
- return "", r[0]
- else:
- return r[0], r[1]
- def has_no_children_ignoring_parametrizations(module):
- """
- Checks if module._modules is empty or
- if module is a parametrization, checks that module._modules only has
- the 'parametrizations' module
- """
- if len(module._modules) == 0:
- return True
- elif is_parametrized(module):
- return len(module._modules) == 1 and "parametrizations" in module._modules
- else:
- return False
- def _get_path_of_module(
- root: torch.nn.Module, submodule: torch.nn.Module
- ) -> str | None:
- """Get the path (fully qualified name) of a submodule
- Example::
- >> class M(torch.nn.Module):
- def __init__(self) -> None:
- self.linear = torch.nn.Linear(5, 5)
- def forward(self, x):
- return self.linear(x)
- >> m = M()
- >> l = m.linear
- >> _get_path_of_module(m, l)
- "linear"
- """
- for n, p in root.named_modules():
- if submodule is p:
- return n
- return None
- def _get_signature_locals(f: Callable, loc: dict[str, Any]) -> dict[str, Any]:
- """Get local keyword arguments
- Example::
- >> def f(self, a, b=9):
- pass
- >> loc = {"a": 6, "c": 7}
- >> _get_signature_locals(f, loc)
- {"a": 6}
- """
- return {k: v for k, v in loc.items() if k in signature(f).parameters}
- def _get_default_kwargs(f: Callable) -> "OrderedDict[str, Any]":
- """Get all default keyword arguments from function signature
- Example::
- >> def f(self, a, b=9):
- pass
- >> _get_default_kwargs(f)
- {"b": 9}
- """
- kwargs = {}
- for name, param in signature(f).parameters.items():
- if param.default is not param.empty:
- kwargs[name] = param.default
- elif param.kind is param.VAR_POSITIONAL:
- kwargs[name] = ()
- elif param.kind is param.VAR_KEYWORD:
- kwargs[name] = {}
- return OrderedDict(kwargs)
- def _normalize_kwargs(func: Callable, loc: dict[str, Any]) -> "OrderedDict[str, Any]":
- """Given a function and local function arguments, normalize the keyword
- arguments by filling in default arguments from function signature
- Example::
- >> def f(self, key1=3, key2=3):
- pass
- >> loc = {"key2": 6}
- >> _normalize_kwargs(f, loc)
- {"key1": 3, "key2": 6}
- """
- default_kwargs = _get_default_kwargs(func)
- local_kwargs = _get_signature_locals(func, loc)
- normalized_kwargs = default_kwargs.copy()
- for attr, val in local_kwargs.items():
- if attr in normalized_kwargs:
- # override the default keyword arguments
- normalized_kwargs[attr] = val
- return normalized_kwargs
- def validate_qmin_qmax(quant_min: int, quant_max: int) -> None:
- r"""Validates that the user-specified quantization range is properly initialized
- and within the given bound supported by the observer dtype.
- To accommodate lower-bit quantization with respect to the existing torch.qint8 and
- torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
- in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
- values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
- fake quantization. These estimates are compared against parameters learned through backpropagation.
- The related literatures for scale and zero point via backpropagation are as follows:
- Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
- Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
- """
- # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
- # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
- if not (quant_min <= 0 <= quant_max):
- raise AssertionError("Used-specified quantization range must include 0.")
- if quant_min >= quant_max:
- raise AssertionError(
- "qmin must be strictly less than qmax for user-specified quantization range."
- )
- # Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme
- # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer
- # to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikely to change
- # (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. TODO(jakeszwe, jerryzh168)
- def determine_qparams(
- min_val: torch.Tensor,
- max_val: torch.Tensor,
- quant_min: int,
- quant_max: int,
- dtype: torch.dtype,
- eps: torch.Tensor,
- has_customized_qrange: bool,
- qscheme: torch.qscheme = torch.per_tensor_affine,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- r"""Calculates the quantization parameters, given min and max
- value tensors. Works for both per tensor and per channel cases
- Args:
- min_val: Minimum values per channel
- max_val: Maximum values per channel
- Returns:
- scales: Scales tensor of shape (#channels,)
- zero_points: Zero points tensor of shape (#channels,)
- """
- if not check_min_max_valid(min_val, max_val):
- return torch.tensor([1.0], device=min_val.device.type), torch.tensor(
- [0], device=min_val.device.type
- )
- min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
- max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
- device = min_val_neg.device
- scale = torch.ones(min_val_neg.size(), dtype=torch.double, device=device)
- zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
- eps = eps.to(device)
- if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric:
- max_val_pos = torch.max(-min_val_neg, max_val_pos)
- scale = max_val_pos / (float(quant_max - quant_min) / 2)
- scale = torch.max(scale, eps)
- if dtype in [torch.uint8, torch.quint8]:
- if has_customized_qrange:
- # When customized quantization range is used, down-rounded midpoint of the range is chosen.
- zero_point = zero_point.new_full(
- zero_point.size(), (quant_min + quant_max) // 2
- )
- else:
- zero_point = zero_point.new_full(zero_point.size(), 128)
- elif qscheme == torch.per_channel_affine_float_qparams:
- scale = (max_val - min_val) / float(quant_max - quant_min)
- scale = torch.where(scale > eps, scale, torch.ones_like(scale))
- # We use the quantize function
- # xq = Round(Xf * inv_scale + zero_point),
- # setting zero_point to (-1 * min *inv_scale) we get
- # Xq = Round((Xf - min) * inv_scale)
- zero_point = -1 * min_val / scale
- else:
- scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
- scale = torch.max(scale, eps)
- zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
- zero_point = torch.clamp(zero_point, quant_min, quant_max)
- # For scalar values, cast them to Tensors of size 1 to keep the shape
- # consistent with default values in FakeQuantize.
- if len(scale.shape) == 0:
- # TODO: switch to scale.item() after adding JIT support
- scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
- if len(zero_point.shape) == 0:
- # TODO: switch to zero_point.item() after adding JIT support
- zero_point = torch.tensor(
- [int(zero_point)], dtype=zero_point.dtype, device=device
- )
- if qscheme == torch.per_channel_affine_float_qparams:
- zero_point = torch.tensor(
- [float(zero_point)], dtype=zero_point.dtype, device=device
- )
- return scale.to(torch.double), zero_point.to(torch.int64)
- def _get_num_pos_args(f: Callable) -> int:
- """Get number of positional args for a function
- Example::
- >> def f(self, key1=3, key2=3):
- pass
- >> _get_num_pos_args(f)
- 3
- """
- return len(getfullargspec(f).args)
- def get_fqn_to_example_inputs(
- model: torch.nn.Module, example_inputs: tuple[Any, ...]
- ) -> dict[str, tuple[Any, ...]]:
- """Given a model and its example inputs, return a dictionary from
- fully qualified name of submodules to example_inputs for that submodule,
- e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,),
- "sub.linear1": (tensor4,), ...}
- Used to make quantizing submodules easier now that FX Graph Mode Quantization requires
- example inputs.
- Also works for keyword arguments with default values, we would flatten keyword
- arguments as positional arguments and fill in the missing keyword args with default
- values, e.g. if we have a forward function:
- def forward(self, x, key1=3, key2=3):
- ...
- and we call it with self.submodule(x, key2=6)
- we'll get example_inputs: (x, 3, 6)
- user can also override `key1` with positional arguments as well:
- for self.submodule(x, 5, key2=6)
- we'll get: (x, 5, 6)
- variable positional arguments and variable positional keyword arguments in forward
- function are not supported currently, so please make sure no submodules is using
- them.
- """
- root = model
- fqn_to_example_inputs = {}
- def _patched_module_call(self, *args, **kwargs):
- submodule_example_inputs = list(args).copy()
- normalized_kwargs = _normalize_kwargs(self.forward, kwargs)
- # minus 1 to skipping counting `self`
- num_args = _get_num_pos_args(self.forward) - 1
- num_to_pop = num_args - len(submodule_example_inputs)
- while num_to_pop and normalized_kwargs:
- normalized_kwargs.popitem(last=False)
- num_to_pop -= 1
- submodule_example_inputs.extend(normalized_kwargs.values())
- submodule_example_inputs_tuple = tuple(submodule_example_inputs)
- fqn = _get_path_of_module(root, self)
- if fqn is not None:
- fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple
- return orig_module_call(self, *args, **kwargs)
- orig_module_call = torch.nn.Module.__call__
- torch.nn.Module.__call__ = _patched_module_call # type: ignore[method-assign]
- try:
- model(*example_inputs)
- finally:
- # restore the module call even if there is an exception
- torch.nn.Module.__call__ = orig_module_call # type: ignore[method-assign]
- return fqn_to_example_inputs
- def _assert_and_get_unique_device(module: torch.nn.Module) -> Any:
- """
- Returns the unique device for a module, or None if no device is found.
- Throws an error if multiple devices are detected.
- """
- devices = {p.device for p in module.parameters()} | {
- p.device for p in module.buffers()
- }
- """
- As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564
- """
- if {torch.device("cpu"), torch.device("meta")} == devices:
- warnings.warn(
- "Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'.",
- stacklevel=2,
- )
- devices = {torch.device("cpu")}
- ""
- if len(devices) > 1:
- raise AssertionError(
- "prepare only works with cpu or single-device CUDA modules, "
- f"but got devices {devices}"
- )
- device = next(iter(devices)) if len(devices) > 0 else None
- return device
- DEPRECATION_WARNING = (
- "torch.ao.quantization is deprecated and will be removed in 2.10. \n"
- "For migrations of users: \n"
- "1. Eager mode quantization (torch.ao.quantization.quantize, "
- "torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode "
- "quantize_ API instead \n"
- "2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,"
- "torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization "
- "API instead (prepare_pt2e, convert_pt2e) \n"
- "3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) \n"
- "see https://github.com/pytorch/ao/issues/2259 for more details"
- )
- __all__ = [
- "NodePattern",
- "Pattern",
- "MatchAllNode",
- "check_node",
- "get_combined_dict",
- "is_per_tensor",
- "is_per_channel",
- "getattr_from_fqn",
- "get_qparam_dict",
- "get_swapped_custom_module_class",
- "activation_dtype",
- "weight_dtype",
- "activation_is_statically_quantized",
- "activation_is_dynamically_quantized",
- "activation_is_int8_quantized",
- "activation_is_int32_quantized",
- "weight_is_quantized",
- "weight_is_statically_quantized",
- "op_is_int8_dynamically_quantized",
- "get_qconfig_dtypes",
- "get_quant_type",
- "check_min_max_valid",
- "calculate_qmin_qmax",
- "has_no_children_ignoring_parametrizations",
- "get_fqn_to_example_inputs",
- "to_underlying_dtype",
- "determine_qparams",
- "validate_qmin_qmax",
- "DEPRECATION_WARNING",
- ]
|