| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197 |
- # mypy: allow-untyped-defs
- import copy
- import warnings
- from typing import Any
- import torch
- from torch._subclasses import FakeTensor
- from torch.ao.quantization import (
- ObserverBase,
- ObserverOrFakeQuantize,
- PlaceholderObserver,
- )
- from torch.ao.quantization.backend_config import (
- BackendConfig,
- DTypeConfig,
- get_native_backend_config,
- )
- from torch.ao.quantization.backend_config.utils import (
- get_fusion_pattern_to_root_node_getter,
- get_module_to_qat_module,
- get_pattern_to_dtype_configs,
- )
- from torch.ao.quantization.observer import _is_activation_post_process
- from torch.ao.quantization.qconfig import _is_reuse_input_qconfig, QConfigAny
- from torch.ao.quantization.qconfig_mapping import QConfigMapping
- from torch.ao.quantization.quantize import convert, propagate_qconfig_
- from torch.ao.quantization.utils import (
- _parent_name,
- get_qconfig_dtypes,
- get_swapped_custom_module_class,
- NodePattern,
- Pattern,
- )
- from torch.fx import GraphModule
- from torch.fx.graph import Graph, Node
- from torch.fx.node import Argument
- from ._equalize import is_equalization_observer, node_supports_equalization
- from .custom_config import PrepareCustomConfig, StandaloneModuleConfigEntry
- from .match_utils import _find_matches, _MatchResultWithQConfig
- from .pattern_utils import _sorted_patterns_dict
- from .qconfig_mapping_utils import (
- _generate_node_name_to_qconfig,
- _get_flattened_qconfig_dict,
- _update_qconfig_for_fusion,
- _update_qconfig_for_qat,
- )
- from .quantize_handler import (
- _default_root_node_getter,
- _get_pattern_to_quantize_handlers,
- QuantizeHandler,
- )
- from .utils import (
- _insert_dequant_stubs_for_custom_module_lstm_output,
- _is_custom_module_lstm,
- _maybe_get_custom_module_lstm_from_node_arg,
- _qconfig_satisfies_dtype_config_constraints,
- all_node_args_have_no_tensors,
- assert_and_get_unique_device,
- get_custom_module_class_keys,
- get_new_attr_name_with_prefix,
- get_non_observable_arg_indexes_and_types,
- node_arg_is_bias,
- node_arg_is_weight,
- NON_QUANTIZABLE_WEIGHT_OPS,
- ObservedGraphModuleAttrs,
- )
- __all__ = [
- "insert_observers_for_model",
- "prepare",
- "propagate_dtypes_for_known_nodes",
- ]
- # list of dtypes to not add observers to
- _DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None]
- _OBS_DTYPE_LIST = [
- torch.quint8,
- torch.qint8,
- torch.qint32,
- torch.float16,
- torch.uint8,
- torch.int8,
- torch.int16,
- torch.int32,
- torch.float8_e5m2,
- torch.float8_e4m3fn,
- ]
- _DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float)
- # note: the following default target dtype info dicts are temporary,
- # should be moved to the new programmable API class soon
- _DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = {
- "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation,
- "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation,
- }
- _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO = {
- "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation,
- "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation,
- }
- def _needs_obs_or_fq(
- prev_output_dtype: Any,
- prev_output_is_dynamic: bool,
- cur_target_dtype: Any,
- cur_target_is_dynamic: bool,
- reuse_input_obs_or_fq: bool,
- is_zeroth_arg: bool = False,
- ) -> bool:
- """
- note: we will treat "not specified" as torch.float for now
- utility function that checks if we should insert an observer or fake quant node
- base on the requested dtype for the nodes from user
- is_zeroth_arg: we only dynamically quantize the first arg of the node right now
- this should be removed when we enable configuring dynamic quantization
- for a specific argument, this can be removed if we deprecate fx graph mode
- quantization
- """
- # need to insert placeholder observer for dynamic quantization so that it can
- # be converted to choose_qparams -> q -> dq in convert step
- if cur_target_is_dynamic:
- if cur_target_dtype not in _OBS_DTYPE_LIST:
- raise AssertionError(
- f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}"
- )
- if prev_output_dtype in _DO_NOT_OBS_DTYPE_LIST:
- raise AssertionError(
- "prev_output_dtype must not be in _DO_NOT_OBS_DTYPE_LIST"
- )
- return is_zeroth_arg
- if reuse_input_obs_or_fq:
- return False
- # non dynamic quantization
- if cur_target_dtype in _OBS_DTYPE_LIST:
- return (
- prev_output_dtype in _OBS_DTYPE_LIST + [torch.float]
- and cur_target_dtype != prev_output_dtype
- )
- # lots of error checking are skipped here for now
- return False
- def _is_activation_post_process_node(
- node: Node, named_modules: dict[str, torch.nn.Module]
- ) -> bool:
- return (
- isinstance(node, torch.fx.Node)
- and node.op == "call_module"
- and _is_activation_post_process(named_modules[str(node.target)])
- )
- def _get_dtype_and_is_dynamic(
- obs_or_fq: ObserverOrFakeQuantize | None,
- ) -> tuple[torch.dtype | None, bool]:
- """Given a constructor for observer or fake quant module, returns
- a Tuple of dtype and is_dynamic
- """
- # TODO: instead of instantiating the instance, we can use inspect to get the default args
- if obs_or_fq is None:
- return None, False
- else:
- return obs_or_fq.dtype, getattr(obs_or_fq, "is_dynamic", False) # type: ignore[return-value]
- def _is_input_arg_dtype_supported_by_backend(
- arg: Argument,
- node: Node,
- qconfig: QConfigAny,
- dtype_config: DTypeConfig,
- backend_config: BackendConfig,
- ) -> bool:
- """Check if the configured qconfig for the argument
- is supported by the backend or not
- """
- if isinstance(arg, (list, tuple)): # noqa: UP038
- return all(
- _is_input_arg_dtype_supported_by_backend(
- # pyrefly: ignore [bad-argument-type]
- a,
- node,
- qconfig,
- dtype_config,
- backend_config,
- )
- for a in arg
- )
- if not isinstance(arg, Node):
- return True
- # TODO: support check for standalone module
- is_weight = node_arg_is_weight(node, arg)
- is_bias = node_arg_is_bias(node, arg)
- is_activation = not is_weight and not is_bias
- if is_activation:
- input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
- "input_act_obs_or_fq_ctr"
- )
- input_act_obs_or_fq = (
- input_act_obs_or_fq_ctr() if input_act_obs_or_fq_ctr else None
- )
- qconfig_dtype, qconfig_is_dynamic = _get_dtype_and_is_dynamic(
- input_act_obs_or_fq
- )
- # TODO(future PR): remove the cast to bool below after figuring
- # out why backend_config has is_dynamic set to None in some cases.
- return (dtype_config.input_dtype is None) or (
- dtype_config.input_dtype == qconfig_dtype
- and bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic)
- and _qconfig_satisfies_dtype_config_constraints(
- qconfig, dtype_config.input_dtype_with_constraints
- )
- )
- elif is_weight:
- # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
- weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
- "weight_obs_or_fq_ctr", None
- )
- weight_obs_or_fq = weight_obs_or_fq_ctr() if weight_obs_or_fq_ctr else None
- qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq)
- backend_config_weight_dtype = dtype_config.weight_dtype
- dtype_matches = qconfig_weight_dtype == backend_config_weight_dtype
- qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
- qconfig, dtype_config.weight_dtype_with_constraints, is_activation=False
- )
- return backend_config_weight_dtype is None or (
- dtype_matches and qconfig_satisfies_constraints
- )
- else: # bias
- # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
- bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
- "bias_obs_or_fq_ctr", None
- )
- bias_obs_or_fq = bias_obs_or_fq_ctr() if bias_obs_or_fq_ctr else None
- qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq)
- backend_config_bias_dtype = dtype_config.bias_dtype
- return (
- backend_config_bias_dtype is None
- or qconfig_bias_dtype == backend_config_bias_dtype
- )
- def _is_output_dtype_supported_by_backend(
- node: Node,
- qconfig: QConfigAny,
- dtype_config: DTypeConfig,
- ) -> bool:
- """Check if the configured qconfig for the output
- is supported by the backend or not
- """
- # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
- backend_config_output_dtype = dtype_config.output_dtype
- # TODO: we should check is_dynamic here as well, the code from _is_input_arg_dtype_supported_by_backend
- # from input activation check can be reused here
- qconfig_output_dtype = None
- output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
- "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
- )
- output_act_obs_or_fq = (
- output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
- )
- qconfig_output_dtype, qconfig_output_is_dynamic = _get_dtype_and_is_dynamic(
- output_act_obs_or_fq
- )
- # TODO: this is a hack because we can only specify one activation_obs_or_fq for
- # qconfig (qconfig.activation), and we are only supporting dynamically quantized
- # linear op which has fp32 output dtype, this should be removed if we generalize
- # the structure of qconfig in the future
- if qconfig_output_is_dynamic:
- qconfig_output_dtype = torch.float32
- dtype_matches = qconfig_output_dtype == backend_config_output_dtype
- qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
- qconfig, dtype_config.output_dtype_with_constraints
- )
- return backend_config_output_dtype is None or (
- dtype_matches and qconfig_satisfies_constraints
- )
- from typing import Annotated
- from torch.fx import Node
- EdgeOrNode = Annotated[tuple[Node, Node] | Node, None]
- def _is_observer_in_same_graph(
- node: Node,
- named_modules: dict[str, torch.nn.Module],
- obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
- is_qat,
- ):
- """Check if observer in same graph
- when the node output is not fp32 and input is 'placeholder'
- the input is assumed to be quantized, so it is observed
- in a different place rather than not observed.
- """
- node_output_dtype = _get_arg_target_dtype_as_output(
- node, named_modules, obs_or_fq_map, is_qat
- )
- if len(node.args) > 0 and isinstance(node.args[0], Node):
- if (
- node_output_dtype in [torch.quint8, torch.uint8]
- and node.args[0].op == "placeholder"
- ):
- return False
- return True
- def _is_pattern_dtype_config_and_qconfig_supported_by_backend(
- pattern: Pattern | None,
- matched_node_pattern: list[Node] | None,
- qconfig: QConfigAny,
- backend_config: BackendConfig,
- ) -> bool:
- """Check if the dtype configuration of a pattern is supported by
- the backend or not, and whether the qconfig satisfies constraints
- specified in the corresponding dtype config.
- """
- if backend_config is None or pattern is None:
- return True
- if matched_node_pattern is None or len(matched_node_pattern) < 1:
- raise AssertionError("matched_node_pattern must be non-empty")
- pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
- dtype_configs: list[DTypeConfig] = pattern_to_dtype_configs.get(pattern, [])
- pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
- root_node_getter = pattern_to_root_node_getter.get(
- pattern, _default_root_node_getter
- )
- root_node = root_node_getter(matched_node_pattern)
- input_node = root_node
- output_node = matched_node_pattern[0]
- for dtype_config in dtype_configs:
- # check if arg dtype are supported
- supported = True
- for arg in list(input_node.args) + list(input_node.kwargs.values()):
- supported = supported and _is_input_arg_dtype_supported_by_backend(
- arg, input_node, qconfig, dtype_config, backend_config
- )
- # check if output dtype is supported
- supported = supported and _is_output_dtype_supported_by_backend(
- output_node, qconfig, dtype_config
- )
- if supported:
- return True
- return False
- def _get_standalone_module_configs(
- node: Node,
- named_modules: dict[str, torch.nn.Module],
- prepare_custom_config: PrepareCustomConfig,
- parent_qconfig: QConfigAny,
- parent_backend_config: BackendConfig | None,
- ) -> tuple[QConfigMapping, tuple[Any, ...], PrepareCustomConfig, BackendConfig | None]:
- """
- Returns the standalone module QConfigMapping and PrepareCustomConfig
- for `node`, assuming that the module pointed to by `node` is
- a standalone modules.
- """
- module_name = str(node.target)
- module_type = type(named_modules[module_name]) # type: ignore[index]
- # name config has precedence over type config
- config_entry = StandaloneModuleConfigEntry(None, (), None, None)
- config_entry = prepare_custom_config.standalone_module_classes.get(
- module_type, config_entry
- )
- config_entry = prepare_custom_config.standalone_module_names.get(
- module_name, config_entry
- )
- # fallback to use parent module's qconfig if user didn't specify qconfig dict
- qconfig_mapping = config_entry.qconfig_mapping or QConfigMapping().set_global(
- parent_qconfig
- )
- example_inputs = config_entry.example_inputs
- prepare_custom_config = config_entry.prepare_custom_config or PrepareCustomConfig()
- backend_config = config_entry.backend_config or parent_backend_config
- return (qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
- def _qat_swap_modules(
- root: torch.nn.Module, module_to_qat_module: dict[Pattern, type[torch.nn.Module]]
- ) -> None:
- convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False)
- def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: set[str]):
- if isinstance(matched_node_pattern, Node):
- s.add(matched_node_pattern.name)
- elif isinstance(matched_node_pattern, (list, tuple)): # noqa: UP038
- for maybe_node in matched_node_pattern:
- _add_matched_node_name_to_set(maybe_node, s)
- def _insert_obs_or_fq(
- node: Node,
- obs_or_fq: ObserverOrFakeQuantize,
- model: torch.nn.Module,
- named_modules: dict[str, torch.nn.Module],
- graph: Graph,
- model_device: torch.device | None = None,
- ) -> Node:
- """
- Attaches `obs_or_fq` to `model`, and creates a node which calls
- `obs_or_fq` on the output of `node`.
- obs_or_fq: an instance of Observer or FakeQuantize module
- """
- if model_device is None:
- model_device = assert_and_get_unique_device(model)
- if model_device:
- obs_or_fq.to(model_device)
- # add obs_or_fq module as attribute
- if is_equalization_observer(obs_or_fq):
- prefix = node.name + "_equalization_process_"
- else:
- prefix = "activation_post_process_"
- get_new_obs_or_fq_name = get_new_attr_name_with_prefix(prefix)
- obs_or_fq_name = get_new_obs_or_fq_name(model)
- setattr(model, obs_or_fq_name, obs_or_fq)
- named_modules[obs_or_fq_name] = obs_or_fq
- with graph.inserting_after(node):
- new_obs = graph.create_node("call_module", obs_or_fq_name, (node,), {})
- return new_obs
- def _set_target_dtype_info_for_matched_node_pattern(
- matched_node_pattern: NodePattern,
- last_node: Node,
- qconfig: QConfigAny,
- qhandler: QuantizeHandler | None,
- backend_config: BackendConfig,
- named_modules: dict[str, torch.nn.Module],
- cache_for_no_tensor_check: dict[Node, bool],
- processed_nodes: set[Node],
- ) -> None:
- """Sets the target_dtype_info for each node in matched_node_pattern
- Note: processed_nodes is used to ensure we only process each node once
- """
- if isinstance(matched_node_pattern, (list, tuple)): # noqa: UP038
- for node_pattern in matched_node_pattern:
- _set_target_dtype_info_for_matched_node_pattern(
- node_pattern,
- last_node,
- qconfig,
- qhandler,
- backend_config,
- named_modules,
- cache_for_no_tensor_check,
- processed_nodes,
- )
- # set target_dtype_info if matched_node_pattern is a Node
- # other types of matched object, e.g. int, float literals, are ignored
- elif isinstance(matched_node_pattern, Node):
- # for pyre
- if not isinstance(matched_node_pattern, Node):
- raise AssertionError("matched_node_pattern must be a Node")
- node = matched_node_pattern
- if node in processed_nodes:
- return
- processed_nodes.add(node)
- if qconfig is None:
- return
- # TODO: refactor the following code in terms of apply a qconfig to a pattern
- # e.g. for a pattern with op1 -> op2 -> op3, and qconfig = QConfig(input_act=obs0, output_act=obs1)
- # we set the input_obs_or_fq_ctr for the arguments of op1 to based on qconfig.input_act,
- # and set output_obs_or_fq_ctr based on qconfig.output_act
- # this also requires we extend the structure of QConfig to support more fine
- # grained configurations
- target_dtype_info: dict[str, Any] = _get_target_activation_dtype_for_node(
- node,
- qconfig,
- qhandler,
- named_modules,
- backend_config,
- cache_for_no_tensor_check,
- )
- node.meta["target_dtype_info"] = target_dtype_info
- def _get_target_activation_dtype_for_node(
- node: Node,
- qconfig: QConfigAny,
- qhandler: QuantizeHandler | None,
- named_modules: dict[str, torch.nn.Module],
- backend_config: BackendConfig,
- cache_for_no_tensor_check: dict[Node, bool],
- ) -> dict[str, Any]:
- """
- For each op attribute in the op's input activation, output activation,
- weight, bias - returns the settings of dtype and is_dynamic we expect
- for the `quantize` call in the reference model representation, or None
- if there is no `quantize` call needed.
- For example, if we have a node corresponding to `op0` in
- x0 -> op0 -> x1
- And we want a reference quantized representation to be
- x0 -> quant_static -> dequant -> op0 -> quant_dynamic -> dequant -> x1
- Then this function will return
- {
- "input_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
- "output_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
- }
- TODO(future PR, if needed): explicitly spell out the non-Tensor
- dtypes.
- """
- args_have_no_tensors = all_node_args_have_no_tensors(
- node, named_modules, cache_for_no_tensor_check
- )
- if args_have_no_tensors:
- return {
- "input_act_obs_or_fq_ctr": None,
- "output_act_obs_or_fq_ctr": None,
- }
- # get qconfig to determine the eventual dtype of this node
- if qconfig is not None:
- act_dtype, weight_dtype, input_act_is_dynamic = get_qconfig_dtypes(qconfig)
- # Currently `QConfig` only has one `activation` field.
- # For static quantization, it is reused for both input
- # and output activation. For dynamic quantization, this
- # field is currently only used for the input activation,
- # with the output activation being in fp32.
- # In the future this may change as we add more fields
- # to the `QConfig` object.
- bias_dtype = (
- torch.float16
- if (
- act_dtype == torch.float16
- and weight_dtype == torch.float16
- and (not input_act_is_dynamic)
- )
- else torch.float
- )
- is_general_tensor_value_op = (
- qhandler is not None and qhandler.is_general_tensor_value_op()
- )
- _is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
- weight_index = None
- if (
- isinstance(node, Node)
- and node.op == "call_function"
- and node.target in backend_config._pattern_complex_format_to_config
- ):
- weight_index = backend_config._pattern_complex_format_to_config[
- node.target
- ]._input_type_to_index.get("weight")
- bias_index = None
- if (
- isinstance(node, Node)
- and node.op == "call_function"
- and node.target in backend_config._pattern_complex_format_to_config
- ):
- bias_index = backend_config._pattern_complex_format_to_config[
- node.target
- ]._input_type_to_index.get("bias")
- return {
- "input_act_obs_or_fq_ctr": qconfig.activation,
- "weight_obs_or_fq_ctr": qconfig.weight,
- "bias_obs_or_fq_ctr": PlaceholderObserver.with_args(dtype=bias_dtype),
- "weight_index": weight_index,
- "bias_index": bias_index,
- "output_act_obs_or_fq_ctr": qconfig.activation,
- "reuse_input_obs_or_fq": _is_reuse_input_qconfig(qconfig),
- "input_output_share_observers": is_general_tensor_value_op,
- "_is_standalone_module": _is_standalone_module,
- }
- return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
- def _get_output_act_obs_or_fq(
- arg: Node,
- named_modules: dict[str, torch.nn.Module],
- obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
- is_qat: bool,
- ) -> ObserverOrFakeQuantize | None:
- """Get the constructor for observer or fake quant object for
- the argument in the original graph as the output of previous node,
- skipping inserted observers
- We are assuming that the observers are inserted correctly, and the dtype for
- argument in quantized graph will match what is specified by the qconfig
- """
- if not isinstance(arg, Node):
- raise AssertionError("arg must be a Node")
- if "quantization_annotation" in arg.meta:
- raise NotImplementedError(
- "Please use torchao (https://github.com/pytorch/ao) for pt2e quantization flow"
- )
- # Custom module LSTM output is a tuple that we broke down into the internal nodes in order
- # to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
- # Since we modified the graph in this case, we must trace back from the args through
- # the specific nodes we added in order to reach the original LSTM node. Otherwise, we would
- # not be able to accurately detect whether this node is a consumer of custom module LSTM.
- custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(
- arg, named_modules
- )
- output_act_obs_or_fq_ctr = None
- if custom_module_lstm_node is not None:
- output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"][
- "output_act_obs_or_fq_ctr"
- ]
- output_act_obs_or_fq = (
- output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
- )
- elif _is_activation_post_process_node(arg, named_modules):
- observed_arg = arg.args[0]
- if not isinstance(observed_arg, Node):
- raise AssertionError("Currently we only support observing Node")
- if "quantization_annotation" in observed_arg.meta:
- raise NotImplementedError(
- "Please use torchao (https://github.com/pytorch/ao) for pt2e quantization flow"
- )
- if "target_dtype_info" not in observed_arg.meta:
- raise AssertionError("expected 'target_dtype_info' in observed_arg.meta")
- output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"][
- "output_act_obs_or_fq_ctr"
- ]
- output_act_obs_or_fq = (
- output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
- )
- else:
- if "target_dtype_info" in arg.meta:
- output_act_obs_or_fq_ctr = arg.meta["target_dtype_info"].get(
- "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
- )
- else:
- output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
- output_act_obs_or_fq = (
- output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
- )
- return output_act_obs_or_fq
- def _get_arg_target_dtype_as_output(
- arg: Node,
- named_modules: dict[str, torch.nn.Module],
- obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
- is_qat: bool,
- ) -> torch.dtype | None:
- arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(
- arg, named_modules, obs_or_fq_map, is_qat
- )
- arg_as_output_target_dtype, _ = _get_dtype_and_is_dynamic(
- arg_as_output_act_obs_or_fq
- )
- return arg_as_output_target_dtype
- def _get_arg_as_input_act_obs_or_fq(
- arg: Node,
- node: Node,
- named_modules: dict[str, torch.nn.Module],
- obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
- is_qat: bool,
- ) -> ObserverOrFakeQuantize | None:
- """Get the observer or fake quant constructor for the Argument `arg`, as input
- to Node `node`
- """
- if not isinstance(arg, Node):
- raise AssertionError("arg must be a Node")
- if "quantization_annotation" in node.meta:
- raise NotImplementedError(
- "Please use torchao (https://github.com/pytorch/ao) for pt2e quantization flow"
- )
- # we can remove the following path in the future if fx graph mode quantization is
- # no longer used
- is_weight = node_arg_is_weight(node, arg)
- is_bias = node_arg_is_bias(node, arg)
- is_activation = not is_weight and not is_bias
- obs_or_fq_ctr = None
- if is_activation:
- obs_or_fq_ctr = node.meta["target_dtype_info"].get(
- "input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
- )
- elif is_weight:
- if node.target not in NON_QUANTIZABLE_WEIGHT_OPS:
- obs_or_fq_ctr = node.meta["target_dtype_info"].get(
- "weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
- )
- else:
- obs_or_fq_ctr = node.meta["target_dtype_info"].get(
- "bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
- )
- return obs_or_fq_ctr() if obs_or_fq_ctr else None
- def _maybe_insert_input_observer_for_arg_or_kwarg(
- node: Node | Any,
- arg: Argument,
- qconfig: QConfigAny,
- model: torch.nn.Module,
- named_modules: dict[str, torch.nn.Module],
- graph: Graph,
- qhandler: QuantizeHandler | None,
- prepare_custom_config: PrepareCustomConfig,
- obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
- is_qat: bool,
- backend_config: BackendConfig | None = None,
- model_device: torch.device | None = None,
- ) -> Argument:
- """
- Given a `node` and an `arg`, inserts an input observer between
- `node` and `arg` if necessary.
- """
- # for ops such as torch.cat([x0, x1]),
- # traverse through the list
- if isinstance(arg, (list, tuple)): # noqa: UP038
- new_arg_to_return = []
- for inner_arg in arg:
- new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
- node,
- # pyrefly: ignore [bad-argument-type]
- inner_arg,
- qconfig,
- model,
- named_modules,
- graph,
- qhandler,
- prepare_custom_config,
- obs_or_fq_map,
- is_qat,
- backend_config,
- model_device,
- )
- new_arg_to_return.append(new_inner_arg)
- return type(arg)(new_arg_to_return)
- if not isinstance(arg, Node):
- return arg
- if not isinstance(arg, Node):
- raise AssertionError("arg must be a Node")
- # default (no observer)
- new_arg = arg
- is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
- # TODO: move this to a separate function
- if not is_standalone_module:
- # Note: qconfig can be None in this branch this we are getting act/fq from
- # node.meta now
- # regular flow for most nodes, except standalone modules
- if "quantization_annotation" in node.meta:
- raise NotImplementedError(
- "Please use torchao (https://github.com/pytorch/ao) for pt2e quantization flow"
- )
- if "target_dtype_info" not in node.meta:
- raise AssertionError("expected 'target_dtype_info' in node.meta")
- # TODO: we are assuming "target_dtype_info" exists here, maybe
- # a default value also need to be provided here
- target_dtype_info = node.meta["target_dtype_info"]
- # for nodes that doesn't have `reuse_input_obs_or_fq` configured,
- # we'll default to False, this makes configuring this field optional for users
- reuse_input_obs_or_fq = target_dtype_info.get("reuse_input_obs_or_fq", False)
- arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq(
- arg, node, named_modules, obs_or_fq_map, is_qat
- )
- (
- arg_as_input_target_dtype,
- arg_as_input_target_is_dynamic,
- ) = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq)
- arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(
- arg, named_modules, obs_or_fq_map, is_qat
- )
- (
- arg_as_output_target_dtype,
- arg_as_output_target_is_dynamic,
- ) = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq)
- needs_obs_or_fq = _needs_obs_or_fq(
- arg_as_output_target_dtype,
- arg_as_output_target_is_dynamic,
- arg_as_input_target_dtype,
- arg_as_input_target_is_dynamic,
- reuse_input_obs_or_fq,
- is_zeroth_arg=len(node.args) > 0 and arg is node.args[0],
- )
- else:
- if qconfig is None:
- raise AssertionError("qconfig must not be None")
- # custom flow for standalone modules
- _, _, sm_prepare_custom_config, _ = _get_standalone_module_configs(
- node, named_modules, prepare_custom_config, qconfig, backend_config
- )
- sm_input_quantized_idxs = sm_prepare_custom_config.input_quantized_indexes
- # for args, this is set to the index of the current arg
- # for kwargs, this is left at None
- cur_input_idx = None
- for arg_idx, arg_to_check in enumerate(node.args):
- if arg_to_check is arg:
- cur_input_idx = arg_idx
- break
- if cur_input_idx is None:
- needs_obs_or_fq = False
- else:
- arg_as_output_target_dtype = _get_arg_target_dtype_as_output(
- arg, named_modules, obs_or_fq_map, is_qat
- )
- arg_as_input_target_dtype = (
- torch.quint8
- if cur_input_idx in sm_input_quantized_idxs
- else torch.float
- )
- needs_obs_or_fq = (
- arg_as_output_target_dtype != arg_as_input_target_dtype
- ) and (arg_as_input_target_dtype != torch.float)
- act_post_process_ctr = qconfig.activation
- arg_as_input_act_obs_or_fq = (
- act_post_process_ctr() if act_post_process_ctr else None
- )
- if needs_obs_or_fq:
- existing_obs_node = None
- # Before using the new observer, check if an observer
- # of the correct type already exists. If it does, use it.
- # This prevents duplicate observer insertions if a node is
- # used by multiple nodes.
- # TODO: this is looking into how the value is used in the future
- # we should remove this
- # removing this means we insert one observer for each use, even if they
- # have the same dtype, we can have an extra pass that removes the extra observers
- for maybe_obs_node in arg.users:
- if maybe_obs_node.op == "call_module":
- maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
- if (
- type(maybe_obs_mod) is type(arg_as_input_act_obs_or_fq)
- and maybe_obs_mod.dtype == arg_as_input_target_dtype # type: ignore[possibly-undefined]
- ):
- arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment]
- existing_obs_node = maybe_obs_node
- break
- if arg_as_input_act_obs_or_fq is None:
- raise AssertionError("arg_as_input_act_obs_or_fq must not be None")
- obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
- if existing_obs_node is None:
- new_obs_node = _insert_obs_or_fq(
- arg,
- arg_as_input_act_obs_or_fq,
- model,
- named_modules,
- graph,
- model_device,
- )
- # override this arg to be the observed arg
- new_arg = new_obs_node
- else:
- new_arg = existing_obs_node
- return new_arg
- def _maybe_insert_input_observers_for_node(
- node: Node,
- qconfig: QConfigAny,
- model: torch.nn.Module,
- named_modules: dict[str, torch.nn.Module],
- graph: Graph,
- qhandler: QuantizeHandler | None,
- prepare_custom_config: PrepareCustomConfig,
- obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
- is_qat: bool,
- backend_config: BackendConfig | None = None,
- model_device: torch.device | None = None,
- ) -> None:
- """
- If needed, inserts observers to the input args and kwargs of `node`.
- Note: modifies `node` inplace.
- For example, if cur_node needs an observer after prev_node, we change from
- prev_node -> cur_node
- To
- prev_node -> obs -> cur_node
- Note: backend_config only needed for standalone_module node
- """
- # Look through every input arg. If that arg's target dtype does not
- # match the current node's target dtype, insert an observer.
- new_args = []
- for arg in node.args:
- new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
- node,
- arg,
- qconfig,
- model,
- named_modules,
- graph,
- qhandler,
- prepare_custom_config,
- obs_or_fq_map,
- is_qat,
- backend_config,
- model_device,
- )
- new_args.append(new_arg)
- new_kwargs = {}
- for k, kwarg in node.kwargs.items():
- new_kwarg = _maybe_insert_input_observer_for_arg_or_kwarg(
- node,
- kwarg,
- qconfig,
- model,
- named_modules,
- graph,
- qhandler,
- prepare_custom_config,
- obs_or_fq_map,
- is_qat,
- backend_config,
- model_device,
- )
- new_kwargs[k] = new_kwarg
- # assign the new args and kwargs to the node, inplace
- node.args = tuple(new_args)
- node.kwargs = new_kwargs
- def _maybe_insert_input_equalization_observers_for_node(
- node: Node,
- equalization_qconfig: Any,
- model: torch.nn.Module,
- named_modules: dict[str, torch.nn.Module],
- graph: Graph,
- is_branch: bool,
- ) -> None:
- """
- If `node` needs to be equalized, find the input/weight observers it needs in
- `equalization_qconfig`, creates them, and inserts it into `graph`.
- If `node` does not need an equalization observer, returns None.
- """
- if equalization_qconfig is None or not node_supports_equalization(
- node, named_modules
- ):
- return
- if is_branch:
- warnings.warn(
- f"Cannot equalize {node} because it is part of a branch.", stacklevel=2
- )
- return
- new_args = []
- for arg in node.args:
- if not isinstance(arg, Node) or node_arg_is_bias(node, arg):
- new_args.append(arg)
- continue
- is_weight = node_arg_is_weight(node, arg)
- act_eq_process_ctr = (
- equalization_qconfig.weight
- if is_weight
- else equalization_qconfig.input_activation
- )
- new_eq_obs_mod = act_eq_process_ctr()
- new_eq_obs_node = _insert_obs_or_fq(
- arg, new_eq_obs_mod, model, named_modules, graph
- )
- new_args.append(new_eq_obs_node)
- # assign the new args and kwargs to the node, inplace
- node.args = tuple(new_args)
- def _maybe_insert_output_observer_for_node(
- node: Node,
- model: torch.nn.Module,
- named_modules: dict[str, torch.nn.Module],
- graph: Graph,
- obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
- is_qat: bool,
- ) -> Node | None:
- """
- If `node` needs an output observer, creates it, inserts it into `graph`
- and returns it.
- If `node` does not need an output observer, returns None.
- Note: inserting dynamic quantization ops for output is not supported in fx graph mode
- quantization code path right now
- """
- if node.op == "output":
- raise AssertionError("observer insertion for outputs is handled elsewhere")
- is_standalone_module = False
- if "quantization_annotation" in node.meta:
- raise NotImplementedError(
- "Please use torchao (https://github.com/pytorch/ao) for pt2e quantization flow"
- )
- if "target_dtype_info" not in node.meta:
- raise AssertionError("expected 'target_dtype_info' in node.meta")
- is_standalone_module = node.meta["target_dtype_info"].get(
- "_is_standalone_module", False
- )
- output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
- "output_act_obs_or_fq_ctr"
- )
- output_act_obs_or_fq = (
- output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
- )
- target_dtype, target_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq)
- # uncomment after we support reuse_input_obs_or_fq properly by having separate
- # implementations for this key instead of reusing the input_output_share_observers
- # code
- # reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False)
- # for now we set this to False since reuse_input_obs_or_fq for
- # the output of a node is implementation in the same code path as observer sharing,
- # we should refactor this part to make it clearer in the future
- # and we would be able to read this from config directly
- reuse_input_obs_or_fq = False
- # Note: prev_output_dtype = torch.float and prev_output_is_dynamic=False
- # because the prev_output is the output of an fp32 op, although technically
- # we should get the dtype of the output from node.meta["val"] in the future
- # if we deprecate fx graph mode quantization
- needs_obs_or_fq = _needs_obs_or_fq(
- torch.float, False, target_dtype, target_is_dynamic, reuse_input_obs_or_fq
- )
- # currently the activation in QConfig(activation=...,) is for both input
- # and output, and when the activation is configured to be dynamic quantization
- # e.g. PlaceholderObserver(dtype=torch.quint8, is_dynamic=True, ...), it means
- # the input should by dynamically quantized, but output should not be quantized
- #
- # there is no way we can specify different observer/fq for input and output
- # activation through QConfig today, this limitation is lifted in the
- # quantizer/annotation API in pytorch 2.0 export quantization code path,
- # but since this code is reused, annotating output to be dynamically quantized
- # would not work either for that.
- # we can change QConfig to support input/output activation if we want
- # to remove the following check, or if we can deprecate fx graph mode quantization
- if target_is_dynamic:
- needs_obs_or_fq = False
- # we never insert observers to output of standalone module, we assume
- # if needed, they are inserted inside the standalone module
- needs_obs_or_fq = needs_obs_or_fq and (not is_standalone_module)
- if needs_obs_or_fq:
- obs_or_fq_map[node] = output_act_obs_or_fq
- return _insert_obs_or_fq(
- node, output_act_obs_or_fq, model, named_modules, graph
- )
- else:
- return None
- def _maybe_insert_observers_before_graph_output(
- graph_output_node: Node,
- model: torch.nn.Module,
- named_modules: dict[str, torch.nn.Module],
- graph: Graph,
- obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
- is_qat: bool,
- ) -> None:
- """
- If the output needs to be quantized and there are any nodes
- in the output which are not already observed, inserts observers
- for those nodes.
- """
- def _recursive_maybe_replace_node_with_obs(
- maybe_node: Argument,
- model: torch.nn.Module,
- named_modules: dict[str, torch.nn.Module],
- graph: Graph,
- ) -> Argument:
- """
- Navigate an arbitrary data structure of lists, tuples, dicts.
- For each container type, recurse on all inputs. Once any Node
- is found, insert an observer if needed and do not recurse further.
- For example, given a structure of
- {'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}}
- we recurse down to bar1 and bar3, observe them if necessary,
- and if we inserted an observer then replace the original node
- with its observer.
- Returns the data structure with all nodes needing observation being
- replaced by their observers.
- """
- if isinstance(maybe_node, Node):
- # check dtype of this node
- arg_as_output_target_dtype = _get_arg_target_dtype_as_output(
- maybe_node, named_modules, obs_or_fq_map, is_qat
- )
- observer_mod = None
- arg_as_input_target_dtype = torch.float
- if "target_dtype_info" in maybe_node.meta:
- observer_cls = maybe_node.meta["target_dtype_info"].get(
- "input_act_obs_or_fq_ctr", None
- )
- if observer_cls is not None:
- observer_mod = observer_cls()
- arg_as_input_target_dtype = observer_mod.dtype
- # TODO: this does not handle dynamic quantization yet
- need_obs = (
- arg_as_output_target_dtype != arg_as_input_target_dtype
- and arg_as_input_target_dtype != torch.float
- )
- if need_obs:
- if observer_mod is None:
- raise AssertionError(
- "observer_mod must not be None when need_obs is True"
- )
- # insert observer
- observer_node = _insert_obs_or_fq(
- maybe_node, observer_mod, model, named_modules, graph
- )
- return observer_node
- else:
- return maybe_node
- elif isinstance(maybe_node, (list, tuple)): # noqa: UP038
- results = [
- _recursive_maybe_replace_node_with_obs(
- # pyrefly: ignore [bad-argument-type]
- inner_node,
- model,
- named_modules,
- graph,
- )
- for inner_node in maybe_node
- ]
- if isinstance(maybe_node, list):
- return results
- else:
- return tuple(results)
- elif isinstance(maybe_node, dict):
- results_dict = {}
- for k, inner_v in maybe_node.items():
- results_dict[k] = _recursive_maybe_replace_node_with_obs(
- inner_v, model, named_modules, graph
- )
- return results_dict
- elif maybe_node is None:
- return None
- else:
- raise Exception( # noqa: TRY002
- "Unhandled type for returned node:", maybe_node
- )
- new_args = [
- _recursive_maybe_replace_node_with_obs(old_arg, model, named_modules, graph)
- for old_arg in graph_output_node.args
- ]
- graph_output_node.args = tuple(new_args) # type: ignore[assignment]
- def _maybe_propagate_dtype_for_node(
- node: Node,
- target_dtype: torch.dtype | type,
- node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig],
- ) -> None:
- """
- Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node`
- is a general tensor shape op, also call this function recursively on
- the first argument, to propagate the dtype to the caller.
- """
- node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr"] = None
- node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] = None
- # if this is a copy node, propagate to first arg
- (
- _root_node,
- _,
- _pattern,
- qhandler,
- _qconfig,
- ) = node_name_to_match_result_with_qconfig.get(
- node.name, (None, None, None, None, None)
- )
- # TODO: probably need to remove `is_general_tensor_value_op`
- if qhandler is not None and qhandler.is_general_tensor_value_op():
- prev_node = node.args[0]
- if isinstance(prev_node, Node):
- _maybe_propagate_dtype_for_node(
- prev_node, target_dtype, node_name_to_match_result_with_qconfig
- )
- def propagate_dtypes_for_known_nodes(
- graph: Graph,
- node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig],
- ) -> None:
- """
- Currently we assume that inputs to the graph are either `torch.float` or
- `torch.quint8`, which is not always correct. For ops such as
- `x.masked_fill(mask, value)`, we know that the dtype of `mask` is a
- `BoolTensor`. Propagate this information throughout the graph.
- Note: not all dtypes in the graph will be correct after this pass, but a
- higher percentage of them will be correct. Hopefully in the future we can
- replace this with a better way to reason about dtypes of tensors.
- """
- for node in graph.nodes:
- non_observable_arg_dict = get_non_observable_arg_indexes_and_types(node)
- for arg_type in non_observable_arg_dict:
- non_observable_indices = non_observable_arg_dict[arg_type](node)
- for index in non_observable_indices:
- arg = node.args[index]
- # when an argument is a tuple, it does not show up as another node so we need to go through
- # all elements of the tuple manually
- if isinstance(arg, (tuple, list)): # noqa: UP038
- arg_list = list(arg)
- else:
- arg_list = [arg]
- for cur_arg in arg_list:
- # hard coded arguments show up but aren't `Node` typed and do not need dtype propagated
- if isinstance(cur_arg, torch.fx.node.Node):
- _maybe_propagate_dtype_for_node(
- cur_arg, arg_type, node_name_to_match_result_with_qconfig
- )
- def _maybe_make_input_output_share_observers(
- node: Node,
- model: torch.nn.Module,
- named_modules: dict[str, torch.nn.Module],
- ) -> bool:
- """
- Ensures that we share an observer
- for all input arguments as well as the output argument. In detail, given
- a graph of
- x0 -> obs0 -> op -> x2
- /
- x1 -> obs1 /
- where node obs0 points to observer instance observer0,
- obs1 points to observer1 and obs2 points to observer2, we make nodes obs1
- and ob2 point to observer0.
- Returns: whether the operation succeeded or not
- """
- first_arg = None
- # find the first non-Tensor arg
- for i in range(len(node.args)):
- if isinstance(node.args[i], (Node, list, tuple)): # noqa: UP038
- first_arg = node.args[i]
- break
- # if there is no non-Tensor arg, return directly
- if first_arg is None:
- return False
- if isinstance(first_arg, (list, tuple)): # noqa: UP038
- first_arg_arg = first_arg[0]
- elif isinstance(first_arg, Node):
- first_arg_arg = first_arg
- else:
- return False
- # if we have a graph such as
- # observed_node -> non_observed_node -> cat
- # we need to navigate up to the first observer
- iteration_guard = 0
- # pyrefly: ignore [bad-argument-type]
- while not _is_activation_post_process_node(first_arg_arg, named_modules):
- if not isinstance(first_arg_arg, Node):
- return False
- # did not find an activation_post_process for the op
- if first_arg_arg.op == "placeholder":
- return False
- # trace back the args until we found the first Tensor/Node
- trace_back_node = None
- for i in range(len(first_arg_arg.args)):
- trace_back_node = first_arg_arg.args[i]
- if isinstance(trace_back_node, Node):
- break
- if trace_back_node is None:
- return False
- first_arg_arg = trace_back_node
- iteration_guard += 1
- if iteration_guard > 10000:
- raise AssertionError("Unable to find observer of previous node")
- if not isinstance(first_arg_arg, Node):
- raise AssertionError("first_arg_arg must be a Node")
- target_to_use = first_arg_arg.target
- if not isinstance(target_to_use, str):
- raise AssertionError("target_to_use must be a string")
- obs_mod_to_use = named_modules[target_to_use]
- if isinstance(first_arg, (list, tuple)): # noqa: UP038
- # set all other input observer nodes to use that module
- for input_idx, input_arg in enumerate(first_arg):
- if input_idx == 0:
- continue
- iteration_guard = 0
- # pyrefly: ignore [bad-argument-type]
- while not _is_activation_post_process_node(input_arg, named_modules):
- # failed to trace back since no input arg for the current node
- # pyrefly: ignore [missing-attribute]
- if len(input_arg.args) < 1:
- return False
- # pyrefly: ignore [bad-index, unsupported-operation]
- input_arg = input_arg.args[0]
- iteration_guard += 1
- if iteration_guard > 10000:
- raise AssertionError("Unable to find observer of previous node")
- # pyrefly: ignore [missing-attribute]
- parent_name, name = _parent_name(input_arg.target)
- setattr(named_modules[parent_name], name, obs_mod_to_use)
- # set the output observer node to use that module
- for output_obs_node in node.users:
- if not _is_activation_post_process_node(output_obs_node, named_modules):
- raise AssertionError(
- "output_obs_node must be an activation post process node"
- )
- parent_name, name = _parent_name(output_obs_node.target)
- setattr(named_modules[parent_name], name, obs_mod_to_use)
- # TODO(future PR): delete the orphaned observer modules
- return True
- def _remove_output_observer(
- node: Node, model: torch.nn.Module, named_modules: dict[str, torch.nn.Module]
- ):
- items = list(node.users.items())
- for output_obs_node, _ in items:
- if not _is_activation_post_process_node(output_obs_node, named_modules):
- raise AssertionError(
- "output_obs_node must be an activation post process node"
- )
- output_obs_node.replace_all_uses_with(node)
- model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator]
- def _swap_custom_module_to_observed(
- node: Node,
- qconfig: QConfigAny,
- named_modules: dict[str, torch.nn.Module],
- prepare_custom_config: PrepareCustomConfig,
- ):
- custom_module = named_modules[node.target] # type: ignore[index]
- custom_module_class_mapping = prepare_custom_config.float_to_observed_mapping
- observed_custom_module_class = get_swapped_custom_module_class(
- custom_module, custom_module_class_mapping, qconfig
- )
- observed_custom_module = observed_custom_module_class.from_float(custom_module)
- parent_name, name = _parent_name(node.target)
- setattr(named_modules[parent_name], name, observed_custom_module)
- def insert_observers_for_model(
- model: GraphModule,
- node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig],
- node_name_to_qconfig: dict[str, QConfigAny],
- prepare_custom_config: PrepareCustomConfig,
- equalization_config_map: dict[str, Any],
- backend_config: BackendConfig,
- observed_node_names: set[str],
- is_qat: bool,
- ) -> Node | None:
- """
- Inserts observers, using the following high level algorithm:
- For each node in the graph:
- 1. determine the target dtype of this node in the quantized graph, and save
- it for future steps
- 2. determine the target dtype or all args and kwargs of this node
- 3. if any arg or kwarg's target dtype does not match the current node's
- dtype, insert an observer
- 4. if the current node needs an output observer, insert it
- For example:
- - starting graph:
- x0 -> linear -> x1
- - observed graph after processing x0:
- x0(fp32)
- - observed graph after processing linear:
- x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8)
- - observed graph after processing x1:
- x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1
- After a node is processed, the naive observer placement is guaranteed to be
- complete for that node and all of its predecessors. There can be future
- passes which optimize the graph by deduplicating observers, etc.
- """
- # node.meta["target_dtype_info"] stores the target dtype information
- # that's derived from qconfig for the Node, for example, if we have
- # a conv2d node that has a qconfig
- # qconfig = QConfig(activation=..., weight=...)
- # # information for input and bias node omitted
- # # for getattr node
- # # weight = getattr(self, 'weight')
- # weight.meta["target_dtype_info"] = {
- # 'output_act_obs_or_fq_ctr': qconfig.weight,
- # }
- # # for conv2d node
- # # conv2d = call_function[target=torch.nn.functional.conv2d](
- # # args=(input, weight, bias))
- # conv2d.meta["target_dtype_info"] = {
- # 'input_act_obs_or_fq_ctr': qconfig.activation
- # 'weight_obs_or_fq_ctr': qconfig.weight,
- # 'bias_obs_or_fq_ctr': PlaceholderObserver.with_args(dtype=torch.float32),
- # 'output_act_obs_or_fq_ctr': qconfig.activation,
- # }
- #
- cache_for_no_tensor_check: dict[Node, bool] = {}
- # first, populate the dtype map based only on qconfig and qhandler
- # this assumes:
- # graph inputs are fp32 by default, and int8 where overridden
- # other nodes output dtype is specified by the qconfig
- named_modules = dict(model.named_modules(remove_duplicate=False))
- input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes
- output_quantized_idxs: list[int] = prepare_custom_config.output_quantized_indexes
- processed_nodes: set[Node] = set()
- # initialize target_dtype_info
- for node in model.graph.nodes:
- node.meta["target_dtype_info"] = copy.copy(
- _DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO
- )
- inputs_seen_counter = 0
- outputs_seen_counter = 0
- placeholder_node_to_input_index: dict[Node, int] = {}
- # TODO: we probably don't need this counter since each graph will only have
- # one output node?
- output_node_to_output_index: dict[Node, int] = {}
- for node in model.graph.nodes:
- if node.op == "placeholder":
- placeholder_node_to_input_index[node] = inputs_seen_counter
- inputs_seen_counter += 1
- if node.op == "output":
- output_node_to_output_index[node] = outputs_seen_counter
- outputs_seen_counter += 1
- # Step 1, set the observer or fake quantize module constructor for each node in the
- # matched_node_pattern
- for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
- (
- last_node,
- matched_node_pattern,
- pattern,
- qhandler,
- qconfig,
- ) = match_res_with_qconfig
- if qhandler is None:
- raise AssertionError("qhandler must not be None")
- _set_target_dtype_info_for_matched_node_pattern(
- matched_node_pattern,
- last_node,
- qconfig,
- qhandler,
- backend_config,
- named_modules,
- cache_for_no_tensor_check,
- processed_nodes,
- )
- # Step 2. Special cases for some operators, we might be able to remove them
- # in the future if we know dtype information of each node better
- # Step 2.1. some settings are not based on patterns, we need to process each node
- # instead
- for node in model.graph.nodes:
- if (
- node.op == "placeholder"
- and placeholder_node_to_input_index[node] in input_quantized_idxs
- ):
- # users are not supposed to call calculate_qparams on PlaceholderObserver, and
- # this is OK because we are using this as a way to encode the dtypes of input
- # tensor, we won't actually insert these observers in the graph and won't
- # actually call calculate_qparams
- node.meta["target_dtype_info"] = copy.copy(
- _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO
- )
- elif node.op in ("call_module", "call_method", "call_function"):
- args_have_no_tensors = all_node_args_have_no_tensors(
- node, named_modules, cache_for_no_tensor_check
- )
- if args_have_no_tensors:
- node.meta["target_dtype_info"] = {
- "input_act_obs_or_fq_ctr": None,
- "output_act_obs_or_fq_ctr": None,
- }
- elif (
- node.op == "output"
- and output_node_to_output_index[node] in output_quantized_idxs
- ):
- # TODO(future PR): update the output_quantized_idxs API to match
- # arbitrary data structures. There is always a single output, and
- # that output can have arbitrary nesting of values. List[int] is
- # not the right data type for this.
- # TODO(future PR): support more dtypes in model outputs, if necessary
- node.meta["target_dtype_info"] = copy.copy(
- _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO
- )
- # Step 2.2, for nodes with known input dtypes, propagate them throughout the
- # graph. For example, if there is a call such as
- # x1 = x0.masked_fill(mask, 1)
- # we propagate the type of mask to be torch.bool
- propagate_dtypes_for_known_nodes(
- model.graph, node_name_to_match_result_with_qconfig
- )
- # Step 3, check if the requested target_dtype_info is supported by backend or not
- # if not, we'll reset the target_dtye_info to use the default (float Tensor)
- # reset the counters and set of processed_nodes
- processed_nodes: set[Node] = set()
- for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
- (
- last_node,
- matched_node_pattern,
- pattern,
- qhandler,
- qconfig,
- ) = match_res_with_qconfig
- is_supported_by_backend = (
- _is_pattern_dtype_config_and_qconfig_supported_by_backend(
- pattern, matched_node_pattern, qconfig, backend_config
- )
- )
- if qhandler is None:
- raise AssertionError("qhandler must not be None")
- # get output_act_dtype so that we don't also reset the special typed nodes
- # TODO: we might want to handle these more uniformly with the default path
- # this can be improved if we can use node.meta["val"]
- output_act_or_fq_ctr = node.meta["target_dtype_info"][
- "output_act_obs_or_fq_ctr"
- ]
- output_act_or_fq = output_act_or_fq_ctr() if output_act_or_fq_ctr else None
- output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_or_fq)
- if not is_supported_by_backend and output_act_dtype not in [
- None,
- int,
- float,
- torch.bool,
- ]:
- # restore target_dtype_info to default if it is not supported by backend
- _set_target_dtype_info_for_matched_node_pattern(
- matched_node_pattern,
- last_node,
- torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig,
- None,
- backend_config,
- named_modules,
- cache_for_no_tensor_check,
- processed_nodes,
- )
- # After this point, the current node and all of its arguments
- # have a target_dtype_info assigned. Now, we insert observers for inputs
- # of this node (if needed for this node), and the output of this node
- # (if needed for this node).
- # Since we are mutating the graph as we go, we iterate over the original
- # nodes before observer insertion, instead of model.graph.nodes.
- nodes_before_observation = list(model.graph.nodes)
- # Avoid duplicates custom module swaps for multiple nodes with same target.
- custom_module_names_already_swapped: set[str] = set()
- # TODO: reuse placeholder_node_to_input_index and output_node_to_output_index
- # reset inputs/outputs counters
- inputs_seen_counter = 0
- outputs_seen_counter = 0
- results_node = None
- obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
- model_device = assert_and_get_unique_device(model)
- # TODO: change this to insert obs/fq by pattern instead of by node
- for node in nodes_before_observation:
- if node.op == "placeholder":
- # if a graph input is in fp32, it does not need observation
- # if a graph input is in int8, we assume the observation happens
- # outside of the graph, and no additional observation is needed
- pass
- elif node.op in ("call_module", "call_method", "call_function", "output"):
- # check for matches
- (
- last_node,
- matched_node_pattern,
- pattern,
- qhandler,
- qconfig,
- ) = node_name_to_match_result_with_qconfig.get( # type: ignore[assignment]
- node.name, (None, None, None, None, None)
- )
- equalization_qconfig = equalization_config_map.get(node.name)
- this_node_dtype_info = node.meta["target_dtype_info"]
- if "val" in node.meta:
- output_is_a_tensor = this_node_dtype_info is not None and isinstance(
- node.meta["val"], FakeTensor
- )
- else:
- output_is_a_tensor = this_node_dtype_info is not None
- skip_inserting_observers = (
- (qconfig is None) or not output_is_a_tensor
- ) and (node.op != "output")
- # TODO: take a closer look to see if we can remove this check
- # right now it is here because of `observed_node_names`, we are using
- # it as an indicator for swapping the modules to reference modules in
- # convert
- is_supported_by_backend = (
- _is_pattern_dtype_config_and_qconfig_supported_by_backend(
- pattern, matched_node_pattern, qconfig, backend_config
- )
- )
- if not skip_inserting_observers and is_supported_by_backend:
- named_modules = dict(model.named_modules(remove_duplicate=False))
- if node.op != "output":
- if matched_node_pattern is None:
- raise AssertionError("matched_node_pattern must not be None")
- # add matched nodes to the observed node name set
- _add_matched_node_name_to_set(
- matched_node_pattern, observed_node_names
- )
- # This is currently only used for equalization.
- # Checks if the current node is in a branch in which the two
- # first layers are both being quantized.
- #
- # ex. conv2
- # /
- # x -> conv1
- #
- # If this is the case, we will not apply equalization to the
- # initial two layers.
- is_quantized_branch = False
- if (
- len(node.args) > 0
- and isinstance(node.args[0], Node)
- and len(node.args[0].users) > 1
- ):
- for user in node.args[0].users:
- # Checks if there exists another user being quantized
- is_user_quantized = node_name_to_qconfig.get(
- user.name
- ) is not None or (
- user.op == "call_module"
- and isinstance(
- named_modules[str(user.target)], ObserverBase
- )
- )
- if user != node and is_user_quantized:
- is_quantized_branch = True
- pattern_to_root_node_getter = (
- get_fusion_pattern_to_root_node_getter(backend_config)
- )
- root_node_getter = pattern_to_root_node_getter.get(
- pattern, _default_root_node_getter
- )
- root_node = root_node_getter(matched_node_pattern)
- is_input_node_of_the_pattern = node is root_node
- if is_input_node_of_the_pattern:
- # this modifies node inplace
- _maybe_insert_input_observers_for_node(
- node,
- qconfig,
- model,
- named_modules,
- model.graph,
- qhandler,
- prepare_custom_config,
- obs_or_fq_map,
- is_qat,
- backend_config,
- model_device,
- )
- # insert equalization input observers if needed
- _maybe_insert_input_equalization_observers_for_node(
- node,
- equalization_qconfig,
- model,
- named_modules,
- model.graph,
- is_quantized_branch,
- )
- is_last_node_of_pattern = node is last_node
- input_output_share_observers = node.meta["target_dtype_info"].get(
- "input_output_share_observers", False
- )
- reuse_input_obs_or_fq = node.meta["target_dtype_info"].get(
- "reuse_input_obs_or_fq", False
- )
- if is_last_node_of_pattern:
- if _is_custom_module_lstm(
- # pyrefly: ignore [bad-argument-type]
- node,
- named_modules,
- qconfig,
- qhandler,
- ):
- # Currently custom module outputs are assumed to be already quantized,
- # so we need to insert a DeQuantStub after the output. For custom module
- # LSTM specifically, the outputs are also a nested tuple, so we must first
- # break down the tuple to insert DeQuantStubs after the internal nodes.
- # TODO: This currently diverges from how custom modules are handled today,
- # where we insert observers after the output instead of DeQuantStubs, and
- # replace these observers with "dequantize" nodes during convert. Conceptually,
- # these output observers are the same as DeQuantStubs. In the future, we
- # should resolve this inconsistency by inserting DeQuantStubs for all custom
- # modules, not just for LSTM.
- _insert_dequant_stubs_for_custom_module_lstm_output(
- # pyrefly: ignore [bad-argument-type]
- node,
- model,
- named_modules,
- model.graph,
- )
- # pyrefly: ignore [missing-attribute]
- if node.target not in custom_module_names_already_swapped:
- # pyrefly: ignore [bad-argument-type]
- custom_module_names_already_swapped.add(node.target)
- _swap_custom_module_to_observed(
- # pyrefly: ignore [bad-argument-type]
- node,
- qconfig,
- named_modules,
- prepare_custom_config,
- )
- else:
- # this returns the new observer node if it was needed
- maybe_output_obs_node = (
- _maybe_insert_output_observer_for_node(
- # pyrefly: ignore [bad-argument-type]
- node,
- model,
- named_modules,
- model.graph,
- obs_or_fq_map,
- is_qat,
- )
- )
- if maybe_output_obs_node is not None:
- # Update users of original node to use the output observer
- # instead. For example, change
- #
- # next_node
- # /
- # cur_node -> obs
- #
- # to
- #
- # next_node
- # /
- # cur_node -> obs
- #
- # We need to save orig users before updating uses because
- # the list of users will change as we update uses
- # pyrefly: ignore [missing-attribute]
- orig_users = list(node.users.keys())
- for user_node in orig_users:
- if user_node is maybe_output_obs_node:
- continue
- user_node.replace_input_with(
- node, maybe_output_obs_node
- )
- _is_observer_in_same_graph_ = (
- _is_observer_in_same_graph(
- # pyrefly: ignore [bad-argument-type]
- node,
- named_modules,
- obs_or_fq_map,
- is_qat,
- )
- )
- # for ops whose inputs and outputs share observer/fqs, we modify the graph
- # to make all inputs and outputs use the first input's
- # observer/fq
- if (
- input_output_share_observers
- and _is_observer_in_same_graph_
- ) or reuse_input_obs_or_fq:
- if not _maybe_make_input_output_share_observers(
- # pyrefly: ignore [bad-argument-type]
- node,
- model,
- named_modules,
- ):
- _remove_output_observer(
- # pyrefly: ignore [bad-argument-type]
- node,
- model,
- named_modules,
- )
- if qhandler is not None and qhandler.is_custom_module():
- if (
- # pyrefly: ignore [missing-attribute]
- node.target
- not in custom_module_names_already_swapped
- ):
- custom_module_names_already_swapped.add(
- # pyrefly: ignore [bad-argument-type]
- node.target
- )
- _swap_custom_module_to_observed(
- # pyrefly: ignore [bad-argument-type]
- node,
- qconfig,
- named_modules,
- prepare_custom_config,
- )
- else: # output
- _maybe_insert_observers_before_graph_output(
- node, model, named_modules, model.graph, obs_or_fq_map, is_qat
- )
- #
- # After this point, the current node has input and output observers
- # that it needs for itself inserted.
- #
- # increment the counters, so future inputs and outputs are assigned
- # correct dtypes
- if node.op == "placeholder":
- inputs_seen_counter += 1
- elif node.op == "output":
- outputs_seen_counter += 1
- results_node = node
- return results_node
- def _run_prepare_fx_on_standalone_modules(
- model: torch.nn.Module,
- is_qat: bool,
- named_modules: dict[str, torch.nn.Module],
- node_name_to_match_result_with_qconfig: Any,
- prepare_custom_config: PrepareCustomConfig,
- backend_config: BackendConfig,
- ) -> None:
- """
- Runs prepare_fx on each standalone module. Note: this does
- not modify the graph, it just replaces the unobserved modules with
- their observed versions.
- """
- for (
- root_node,
- _,
- _pattern,
- qhandler,
- qconfig,
- ) in node_name_to_match_result_with_qconfig.values():
- if qhandler is None:
- continue
- elif not qhandler.is_standalone_module():
- continue
- (
- sm_qconfig_mapping,
- sm_example_inputs,
- sm_prepare_custom_config,
- sm_backend_config,
- ) = _get_standalone_module_configs(
- root_node, named_modules, prepare_custom_config, qconfig, backend_config
- )
- standalone_module = named_modules[root_node.target]
- prepare = torch.ao.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore[attr-defined]
- observed_standalone_module = prepare(
- standalone_module,
- sm_qconfig_mapping,
- is_qat,
- example_inputs=sm_example_inputs,
- prepare_custom_config=sm_prepare_custom_config,
- backend_config=sm_backend_config,
- )
- parent_name, name = _parent_name(root_node.target)
- setattr(named_modules[parent_name], name, observed_standalone_module)
- named_modules[root_node.target] = observed_standalone_module
- def _save_state(
- observed: GraphModule,
- node_name_to_qconfig: dict[str, QConfigAny],
- node_name_to_scope: dict[str, tuple[str, type]],
- prepare_custom_config: PrepareCustomConfig,
- equalization_node_name_to_qconfig: dict[str, Any],
- qconfig_mapping: QConfigMapping,
- is_qat: bool,
- observed_node_names: set[str],
- ) -> None:
- observed.meta["_observed_graph_module_attrs"] = ObservedGraphModuleAttrs(
- node_name_to_qconfig=node_name_to_qconfig,
- node_name_to_scope=node_name_to_scope,
- prepare_custom_config=prepare_custom_config,
- equalization_node_name_to_qconfig=equalization_node_name_to_qconfig,
- qconfig_mapping=qconfig_mapping,
- is_qat=is_qat,
- observed_node_names=observed_node_names,
- )
- def prepare(
- model: GraphModule,
- qconfig_mapping: QConfigMapping | dict[str, Any],
- is_qat: bool,
- node_name_to_scope: dict[str, tuple[str, type]],
- example_inputs: tuple[Any, ...],
- prepare_custom_config: PrepareCustomConfig | dict[str, Any] | None = None,
- _equalization_config: QConfigMapping | dict[str, Any] | None = None,
- backend_config: BackendConfig | dict[str, Any] | None = None,
- is_standalone_module: bool = False,
- ) -> GraphModule:
- """standalone_module means it a submodule that is not inlined in
- parent module, and will be quantized separately as one unit.
- How the standalone module is observed is specified by `input_quantized_idxs` and
- `output_quantized_idxs` in the prepare_custom_config for the standalone module
- Args:
- node_name_to_scope: mapping from node name to the scope of the module which contains the node.
- The scope is a tuple of fully qualified path of the module and the type of the module
- Returns:
- model(GraphModule): prepared standalone module
- attributes related to standalone module
- in model.meta["_observed_graph_module_attrs"]:
- is_observed_standalone_module (bool): boolean value that shows whether the
- current model is a observed standalone module or not
- standalone_module_input_quantized_idxs(List[Int]): a list of
- indexes for the graph input that is expected to be quantized,
- same as input_quantized_idxs configuration provided
- for the standalone module
- standalone_module_output_quantized_idxs(List[Int]): a list of
- indices for the graph output that is quantized
- same as input_quantized_idxs configuration provided
- for the standalone module
- """
- if prepare_custom_config is None:
- prepare_custom_config = PrepareCustomConfig()
- if _equalization_config is None:
- _equalization_config = QConfigMapping()
- if isinstance(qconfig_mapping, dict):
- warnings.warn(
- "Passing a QConfig dictionary to prepare is deprecated and will not be supported "
- "in a future version. Please pass in a QConfigMapping instead.",
- FutureWarning,
- stacklevel=2,
- )
- qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping)
- if isinstance(_equalization_config, dict):
- warnings.warn(
- "Passing a QConfig dictionary to prepare for equalization is deprecated and will not "
- "be supported in a future version. Please pass in a QConfigMapping instead.",
- FutureWarning,
- stacklevel=2,
- )
- _equalization_config = QConfigMapping.from_dict(_equalization_config)
- if isinstance(prepare_custom_config, dict):
- warnings.warn(
- "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
- "in a future version. Please pass in a PrepareCustomConfig instead.",
- FutureWarning,
- stacklevel=2,
- )
- prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
- if isinstance(backend_config, dict):
- warnings.warn(
- "Passing a backend_config_dict to prepare is deprecated and will not be supported "
- "in a future version. Please pass in a BackendConfig instead.",
- FutureWarning,
- stacklevel=2,
- )
- backend_config = BackendConfig.from_dict(backend_config)
- if not isinstance(qconfig_mapping, QConfigMapping):
- raise AssertionError("qconfig_mapping must be a QConfigMapping")
- if not isinstance(_equalization_config, QConfigMapping):
- raise AssertionError("_equalization_config must be a QConfigMapping")
- qconfig_mapping = copy.deepcopy(qconfig_mapping)
- _equalization_config = copy.deepcopy(_equalization_config)
- # mapping from a tuple of nodes in reverse order to uninitialized
- # QuantizeHandler subclass. For example,
- # {
- # # match a single node
- # (<class 'torch.nn.modules.conv.Conv3d'>:
- # <class 'torch.ao.quantization.fx.quantize.ConvRelu'>),
- # # match multiple nodes in reverse order
- # ((<function relu at 0x7f766a7360d0>, <built-in function add>):
- # <class 'torch.ao.quantization.fx.quantize.Add'>),
- # }
- pattern_to_quantize_handler: dict[Pattern, QuantizeHandler] = {}
- if backend_config is None:
- backend_config = get_native_backend_config()
- pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config)
- pattern_to_quantize_handler = _sorted_patterns_dict(pattern_to_quantize_handler)
- root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
- # pyrefly: ignore [bad-argument-type]
- _update_qconfig_for_fusion(model, qconfig_mapping)
- # pyrefly: ignore [bad-argument-type]
- _update_qconfig_for_fusion(model, _equalization_config)
- # pyrefly: ignore [bad-argument-type]
- flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
- # TODO: support regex as well
- propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
- if is_qat:
- module_to_qat_module = get_module_to_qat_module(backend_config)
- _qat_swap_modules(model, module_to_qat_module)
- # pyrefly: ignore [bad-argument-type]
- _update_qconfig_for_qat(qconfig_mapping, backend_config)
- # mapping from fully qualified module name to module instance
- # for example,
- # {
- # '': Model(...),
- # 'linear': Linear(...),
- # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
- # }
- named_modules = dict(model.named_modules(remove_duplicate=False))
- # fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches
- equalization_node_name_to_qconfig = _generate_node_name_to_qconfig(
- model,
- named_modules,
- model.graph,
- # pyrefly: ignore [bad-argument-type]
- _equalization_config,
- node_name_to_scope,
- )
- node_name_to_qconfig = _generate_node_name_to_qconfig(
- model,
- named_modules,
- model.graph,
- # pyrefly: ignore [bad-argument-type]
- qconfig_mapping,
- node_name_to_scope,
- )
- # match the patterns that will get quantized
- standalone_module_names = list(prepare_custom_config.standalone_module_names.keys())
- standalone_module_classes = list(
- prepare_custom_config.standalone_module_classes.keys()
- )
- custom_module_classes = get_custom_module_class_keys(
- prepare_custom_config.float_to_observed_mapping
- )
- matches_without_qconfig = _find_matches(
- model.graph,
- named_modules,
- pattern_to_quantize_handler,
- root_node_getter_mapping,
- standalone_module_names,
- standalone_module_classes,
- custom_module_classes,
- )
- # map qconfig instances to matches
- node_name_to_match_result_with_qconfig = {}
- for node_name, match_without_qconfig in matches_without_qconfig.items():
- match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name])
- node_name_to_match_result_with_qconfig[node_name] = match_with_qconfig
- _run_prepare_fx_on_standalone_modules(
- model,
- is_qat,
- named_modules,
- node_name_to_match_result_with_qconfig,
- prepare_custom_config,
- backend_config,
- )
- # record names for the set of observed node, so that in convert step
- # we know whether we need to convert a floating point module to reference
- # quantized module or not
- observed_node_names: set[str] = set()
- result_node = insert_observers_for_model(
- model,
- node_name_to_match_result_with_qconfig,
- node_name_to_qconfig,
- prepare_custom_config,
- equalization_node_name_to_qconfig,
- backend_config,
- observed_node_names,
- is_qat,
- )
- model = GraphModule(model, model.graph)
- _save_state(
- model,
- node_name_to_qconfig,
- node_name_to_scope,
- prepare_custom_config,
- equalization_node_name_to_qconfig,
- # pyrefly: ignore [bad-argument-type]
- qconfig_mapping,
- is_qat,
- observed_node_names,
- )
- if is_standalone_module:
- if result_node is None:
- raise AssertionError("result_node must not be None for standalone modules")
- if not isinstance(result_node.args[0], Node):
- raise AssertionError(
- "standalone module only supports returning simple value currently (not tuple, dict etc.)"
- )
- # these inputs are observed in parent
- # converting List[int] to Tensor since module attribute is
- # Union[Tensor, Module]
- input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes
- output_quantized_idxs: list[int] = (
- prepare_custom_config.output_quantized_indexes
- )
- observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
- # inplace modification
- observed_graph_module_attrs.is_observed_standalone_module = True
- observed_graph_module_attrs.standalone_module_input_quantized_idxs = (
- input_quantized_idxs
- )
- observed_graph_module_attrs.standalone_module_output_quantized_idxs = (
- output_quantized_idxs
- )
- return model
|