prepare.py 86 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import warnings
  4. from typing import Any
  5. import torch
  6. from torch._subclasses import FakeTensor
  7. from torch.ao.quantization import (
  8. ObserverBase,
  9. ObserverOrFakeQuantize,
  10. PlaceholderObserver,
  11. )
  12. from torch.ao.quantization.backend_config import (
  13. BackendConfig,
  14. DTypeConfig,
  15. get_native_backend_config,
  16. )
  17. from torch.ao.quantization.backend_config.utils import (
  18. get_fusion_pattern_to_root_node_getter,
  19. get_module_to_qat_module,
  20. get_pattern_to_dtype_configs,
  21. )
  22. from torch.ao.quantization.observer import _is_activation_post_process
  23. from torch.ao.quantization.qconfig import _is_reuse_input_qconfig, QConfigAny
  24. from torch.ao.quantization.qconfig_mapping import QConfigMapping
  25. from torch.ao.quantization.quantize import convert, propagate_qconfig_
  26. from torch.ao.quantization.utils import (
  27. _parent_name,
  28. get_qconfig_dtypes,
  29. get_swapped_custom_module_class,
  30. NodePattern,
  31. Pattern,
  32. )
  33. from torch.fx import GraphModule
  34. from torch.fx.graph import Graph, Node
  35. from torch.fx.node import Argument
  36. from ._equalize import is_equalization_observer, node_supports_equalization
  37. from .custom_config import PrepareCustomConfig, StandaloneModuleConfigEntry
  38. from .match_utils import _find_matches, _MatchResultWithQConfig
  39. from .pattern_utils import _sorted_patterns_dict
  40. from .qconfig_mapping_utils import (
  41. _generate_node_name_to_qconfig,
  42. _get_flattened_qconfig_dict,
  43. _update_qconfig_for_fusion,
  44. _update_qconfig_for_qat,
  45. )
  46. from .quantize_handler import (
  47. _default_root_node_getter,
  48. _get_pattern_to_quantize_handlers,
  49. QuantizeHandler,
  50. )
  51. from .utils import (
  52. _insert_dequant_stubs_for_custom_module_lstm_output,
  53. _is_custom_module_lstm,
  54. _maybe_get_custom_module_lstm_from_node_arg,
  55. _qconfig_satisfies_dtype_config_constraints,
  56. all_node_args_have_no_tensors,
  57. assert_and_get_unique_device,
  58. get_custom_module_class_keys,
  59. get_new_attr_name_with_prefix,
  60. get_non_observable_arg_indexes_and_types,
  61. node_arg_is_bias,
  62. node_arg_is_weight,
  63. NON_QUANTIZABLE_WEIGHT_OPS,
  64. ObservedGraphModuleAttrs,
  65. )
  66. __all__ = [
  67. "insert_observers_for_model",
  68. "prepare",
  69. "propagate_dtypes_for_known_nodes",
  70. ]
  71. # list of dtypes to not add observers to
  72. _DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None]
  73. _OBS_DTYPE_LIST = [
  74. torch.quint8,
  75. torch.qint8,
  76. torch.qint32,
  77. torch.float16,
  78. torch.uint8,
  79. torch.int8,
  80. torch.int16,
  81. torch.int32,
  82. torch.float8_e5m2,
  83. torch.float8_e4m3fn,
  84. ]
  85. _DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float)
  86. # note: the following default target dtype info dicts are temporary,
  87. # should be moved to the new programmable API class soon
  88. _DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = {
  89. "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation,
  90. "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation,
  91. }
  92. _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO = {
  93. "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation,
  94. "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation,
  95. }
  96. def _needs_obs_or_fq(
  97. prev_output_dtype: Any,
  98. prev_output_is_dynamic: bool,
  99. cur_target_dtype: Any,
  100. cur_target_is_dynamic: bool,
  101. reuse_input_obs_or_fq: bool,
  102. is_zeroth_arg: bool = False,
  103. ) -> bool:
  104. """
  105. note: we will treat "not specified" as torch.float for now
  106. utility function that checks if we should insert an observer or fake quant node
  107. base on the requested dtype for the nodes from user
  108. is_zeroth_arg: we only dynamically quantize the first arg of the node right now
  109. this should be removed when we enable configuring dynamic quantization
  110. for a specific argument, this can be removed if we deprecate fx graph mode
  111. quantization
  112. """
  113. # need to insert placeholder observer for dynamic quantization so that it can
  114. # be converted to choose_qparams -> q -> dq in convert step
  115. if cur_target_is_dynamic:
  116. if cur_target_dtype not in _OBS_DTYPE_LIST:
  117. raise AssertionError(
  118. f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}"
  119. )
  120. if prev_output_dtype in _DO_NOT_OBS_DTYPE_LIST:
  121. raise AssertionError(
  122. "prev_output_dtype must not be in _DO_NOT_OBS_DTYPE_LIST"
  123. )
  124. return is_zeroth_arg
  125. if reuse_input_obs_or_fq:
  126. return False
  127. # non dynamic quantization
  128. if cur_target_dtype in _OBS_DTYPE_LIST:
  129. return (
  130. prev_output_dtype in _OBS_DTYPE_LIST + [torch.float]
  131. and cur_target_dtype != prev_output_dtype
  132. )
  133. # lots of error checking are skipped here for now
  134. return False
  135. def _is_activation_post_process_node(
  136. node: Node, named_modules: dict[str, torch.nn.Module]
  137. ) -> bool:
  138. return (
  139. isinstance(node, torch.fx.Node)
  140. and node.op == "call_module"
  141. and _is_activation_post_process(named_modules[str(node.target)])
  142. )
  143. def _get_dtype_and_is_dynamic(
  144. obs_or_fq: ObserverOrFakeQuantize | None,
  145. ) -> tuple[torch.dtype | None, bool]:
  146. """Given a constructor for observer or fake quant module, returns
  147. a Tuple of dtype and is_dynamic
  148. """
  149. # TODO: instead of instantiating the instance, we can use inspect to get the default args
  150. if obs_or_fq is None:
  151. return None, False
  152. else:
  153. return obs_or_fq.dtype, getattr(obs_or_fq, "is_dynamic", False) # type: ignore[return-value]
  154. def _is_input_arg_dtype_supported_by_backend(
  155. arg: Argument,
  156. node: Node,
  157. qconfig: QConfigAny,
  158. dtype_config: DTypeConfig,
  159. backend_config: BackendConfig,
  160. ) -> bool:
  161. """Check if the configured qconfig for the argument
  162. is supported by the backend or not
  163. """
  164. if isinstance(arg, (list, tuple)): # noqa: UP038
  165. return all(
  166. _is_input_arg_dtype_supported_by_backend(
  167. # pyrefly: ignore [bad-argument-type]
  168. a,
  169. node,
  170. qconfig,
  171. dtype_config,
  172. backend_config,
  173. )
  174. for a in arg
  175. )
  176. if not isinstance(arg, Node):
  177. return True
  178. # TODO: support check for standalone module
  179. is_weight = node_arg_is_weight(node, arg)
  180. is_bias = node_arg_is_bias(node, arg)
  181. is_activation = not is_weight and not is_bias
  182. if is_activation:
  183. input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
  184. "input_act_obs_or_fq_ctr"
  185. )
  186. input_act_obs_or_fq = (
  187. input_act_obs_or_fq_ctr() if input_act_obs_or_fq_ctr else None
  188. )
  189. qconfig_dtype, qconfig_is_dynamic = _get_dtype_and_is_dynamic(
  190. input_act_obs_or_fq
  191. )
  192. # TODO(future PR): remove the cast to bool below after figuring
  193. # out why backend_config has is_dynamic set to None in some cases.
  194. return (dtype_config.input_dtype is None) or (
  195. dtype_config.input_dtype == qconfig_dtype
  196. and bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic)
  197. and _qconfig_satisfies_dtype_config_constraints(
  198. qconfig, dtype_config.input_dtype_with_constraints
  199. )
  200. )
  201. elif is_weight:
  202. # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
  203. weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
  204. "weight_obs_or_fq_ctr", None
  205. )
  206. weight_obs_or_fq = weight_obs_or_fq_ctr() if weight_obs_or_fq_ctr else None
  207. qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq)
  208. backend_config_weight_dtype = dtype_config.weight_dtype
  209. dtype_matches = qconfig_weight_dtype == backend_config_weight_dtype
  210. qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
  211. qconfig, dtype_config.weight_dtype_with_constraints, is_activation=False
  212. )
  213. return backend_config_weight_dtype is None or (
  214. dtype_matches and qconfig_satisfies_constraints
  215. )
  216. else: # bias
  217. # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
  218. bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
  219. "bias_obs_or_fq_ctr", None
  220. )
  221. bias_obs_or_fq = bias_obs_or_fq_ctr() if bias_obs_or_fq_ctr else None
  222. qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq)
  223. backend_config_bias_dtype = dtype_config.bias_dtype
  224. return (
  225. backend_config_bias_dtype is None
  226. or qconfig_bias_dtype == backend_config_bias_dtype
  227. )
  228. def _is_output_dtype_supported_by_backend(
  229. node: Node,
  230. qconfig: QConfigAny,
  231. dtype_config: DTypeConfig,
  232. ) -> bool:
  233. """Check if the configured qconfig for the output
  234. is supported by the backend or not
  235. """
  236. # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
  237. backend_config_output_dtype = dtype_config.output_dtype
  238. # TODO: we should check is_dynamic here as well, the code from _is_input_arg_dtype_supported_by_backend
  239. # from input activation check can be reused here
  240. qconfig_output_dtype = None
  241. output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
  242. "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
  243. )
  244. output_act_obs_or_fq = (
  245. output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
  246. )
  247. qconfig_output_dtype, qconfig_output_is_dynamic = _get_dtype_and_is_dynamic(
  248. output_act_obs_or_fq
  249. )
  250. # TODO: this is a hack because we can only specify one activation_obs_or_fq for
  251. # qconfig (qconfig.activation), and we are only supporting dynamically quantized
  252. # linear op which has fp32 output dtype, this should be removed if we generalize
  253. # the structure of qconfig in the future
  254. if qconfig_output_is_dynamic:
  255. qconfig_output_dtype = torch.float32
  256. dtype_matches = qconfig_output_dtype == backend_config_output_dtype
  257. qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
  258. qconfig, dtype_config.output_dtype_with_constraints
  259. )
  260. return backend_config_output_dtype is None or (
  261. dtype_matches and qconfig_satisfies_constraints
  262. )
  263. from typing import Annotated
  264. from torch.fx import Node
  265. EdgeOrNode = Annotated[tuple[Node, Node] | Node, None]
  266. def _is_observer_in_same_graph(
  267. node: Node,
  268. named_modules: dict[str, torch.nn.Module],
  269. obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
  270. is_qat,
  271. ):
  272. """Check if observer in same graph
  273. when the node output is not fp32 and input is 'placeholder'
  274. the input is assumed to be quantized, so it is observed
  275. in a different place rather than not observed.
  276. """
  277. node_output_dtype = _get_arg_target_dtype_as_output(
  278. node, named_modules, obs_or_fq_map, is_qat
  279. )
  280. if len(node.args) > 0 and isinstance(node.args[0], Node):
  281. if (
  282. node_output_dtype in [torch.quint8, torch.uint8]
  283. and node.args[0].op == "placeholder"
  284. ):
  285. return False
  286. return True
  287. def _is_pattern_dtype_config_and_qconfig_supported_by_backend(
  288. pattern: Pattern | None,
  289. matched_node_pattern: list[Node] | None,
  290. qconfig: QConfigAny,
  291. backend_config: BackendConfig,
  292. ) -> bool:
  293. """Check if the dtype configuration of a pattern is supported by
  294. the backend or not, and whether the qconfig satisfies constraints
  295. specified in the corresponding dtype config.
  296. """
  297. if backend_config is None or pattern is None:
  298. return True
  299. if matched_node_pattern is None or len(matched_node_pattern) < 1:
  300. raise AssertionError("matched_node_pattern must be non-empty")
  301. pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
  302. dtype_configs: list[DTypeConfig] = pattern_to_dtype_configs.get(pattern, [])
  303. pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
  304. root_node_getter = pattern_to_root_node_getter.get(
  305. pattern, _default_root_node_getter
  306. )
  307. root_node = root_node_getter(matched_node_pattern)
  308. input_node = root_node
  309. output_node = matched_node_pattern[0]
  310. for dtype_config in dtype_configs:
  311. # check if arg dtype are supported
  312. supported = True
  313. for arg in list(input_node.args) + list(input_node.kwargs.values()):
  314. supported = supported and _is_input_arg_dtype_supported_by_backend(
  315. arg, input_node, qconfig, dtype_config, backend_config
  316. )
  317. # check if output dtype is supported
  318. supported = supported and _is_output_dtype_supported_by_backend(
  319. output_node, qconfig, dtype_config
  320. )
  321. if supported:
  322. return True
  323. return False
  324. def _get_standalone_module_configs(
  325. node: Node,
  326. named_modules: dict[str, torch.nn.Module],
  327. prepare_custom_config: PrepareCustomConfig,
  328. parent_qconfig: QConfigAny,
  329. parent_backend_config: BackendConfig | None,
  330. ) -> tuple[QConfigMapping, tuple[Any, ...], PrepareCustomConfig, BackendConfig | None]:
  331. """
  332. Returns the standalone module QConfigMapping and PrepareCustomConfig
  333. for `node`, assuming that the module pointed to by `node` is
  334. a standalone modules.
  335. """
  336. module_name = str(node.target)
  337. module_type = type(named_modules[module_name]) # type: ignore[index]
  338. # name config has precedence over type config
  339. config_entry = StandaloneModuleConfigEntry(None, (), None, None)
  340. config_entry = prepare_custom_config.standalone_module_classes.get(
  341. module_type, config_entry
  342. )
  343. config_entry = prepare_custom_config.standalone_module_names.get(
  344. module_name, config_entry
  345. )
  346. # fallback to use parent module's qconfig if user didn't specify qconfig dict
  347. qconfig_mapping = config_entry.qconfig_mapping or QConfigMapping().set_global(
  348. parent_qconfig
  349. )
  350. example_inputs = config_entry.example_inputs
  351. prepare_custom_config = config_entry.prepare_custom_config or PrepareCustomConfig()
  352. backend_config = config_entry.backend_config or parent_backend_config
  353. return (qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
  354. def _qat_swap_modules(
  355. root: torch.nn.Module, module_to_qat_module: dict[Pattern, type[torch.nn.Module]]
  356. ) -> None:
  357. convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False)
  358. def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: set[str]):
  359. if isinstance(matched_node_pattern, Node):
  360. s.add(matched_node_pattern.name)
  361. elif isinstance(matched_node_pattern, (list, tuple)): # noqa: UP038
  362. for maybe_node in matched_node_pattern:
  363. _add_matched_node_name_to_set(maybe_node, s)
  364. def _insert_obs_or_fq(
  365. node: Node,
  366. obs_or_fq: ObserverOrFakeQuantize,
  367. model: torch.nn.Module,
  368. named_modules: dict[str, torch.nn.Module],
  369. graph: Graph,
  370. model_device: torch.device | None = None,
  371. ) -> Node:
  372. """
  373. Attaches `obs_or_fq` to `model`, and creates a node which calls
  374. `obs_or_fq` on the output of `node`.
  375. obs_or_fq: an instance of Observer or FakeQuantize module
  376. """
  377. if model_device is None:
  378. model_device = assert_and_get_unique_device(model)
  379. if model_device:
  380. obs_or_fq.to(model_device)
  381. # add obs_or_fq module as attribute
  382. if is_equalization_observer(obs_or_fq):
  383. prefix = node.name + "_equalization_process_"
  384. else:
  385. prefix = "activation_post_process_"
  386. get_new_obs_or_fq_name = get_new_attr_name_with_prefix(prefix)
  387. obs_or_fq_name = get_new_obs_or_fq_name(model)
  388. setattr(model, obs_or_fq_name, obs_or_fq)
  389. named_modules[obs_or_fq_name] = obs_or_fq
  390. with graph.inserting_after(node):
  391. new_obs = graph.create_node("call_module", obs_or_fq_name, (node,), {})
  392. return new_obs
  393. def _set_target_dtype_info_for_matched_node_pattern(
  394. matched_node_pattern: NodePattern,
  395. last_node: Node,
  396. qconfig: QConfigAny,
  397. qhandler: QuantizeHandler | None,
  398. backend_config: BackendConfig,
  399. named_modules: dict[str, torch.nn.Module],
  400. cache_for_no_tensor_check: dict[Node, bool],
  401. processed_nodes: set[Node],
  402. ) -> None:
  403. """Sets the target_dtype_info for each node in matched_node_pattern
  404. Note: processed_nodes is used to ensure we only process each node once
  405. """
  406. if isinstance(matched_node_pattern, (list, tuple)): # noqa: UP038
  407. for node_pattern in matched_node_pattern:
  408. _set_target_dtype_info_for_matched_node_pattern(
  409. node_pattern,
  410. last_node,
  411. qconfig,
  412. qhandler,
  413. backend_config,
  414. named_modules,
  415. cache_for_no_tensor_check,
  416. processed_nodes,
  417. )
  418. # set target_dtype_info if matched_node_pattern is a Node
  419. # other types of matched object, e.g. int, float literals, are ignored
  420. elif isinstance(matched_node_pattern, Node):
  421. # for pyre
  422. if not isinstance(matched_node_pattern, Node):
  423. raise AssertionError("matched_node_pattern must be a Node")
  424. node = matched_node_pattern
  425. if node in processed_nodes:
  426. return
  427. processed_nodes.add(node)
  428. if qconfig is None:
  429. return
  430. # TODO: refactor the following code in terms of apply a qconfig to a pattern
  431. # e.g. for a pattern with op1 -> op2 -> op3, and qconfig = QConfig(input_act=obs0, output_act=obs1)
  432. # we set the input_obs_or_fq_ctr for the arguments of op1 to based on qconfig.input_act,
  433. # and set output_obs_or_fq_ctr based on qconfig.output_act
  434. # this also requires we extend the structure of QConfig to support more fine
  435. # grained configurations
  436. target_dtype_info: dict[str, Any] = _get_target_activation_dtype_for_node(
  437. node,
  438. qconfig,
  439. qhandler,
  440. named_modules,
  441. backend_config,
  442. cache_for_no_tensor_check,
  443. )
  444. node.meta["target_dtype_info"] = target_dtype_info
  445. def _get_target_activation_dtype_for_node(
  446. node: Node,
  447. qconfig: QConfigAny,
  448. qhandler: QuantizeHandler | None,
  449. named_modules: dict[str, torch.nn.Module],
  450. backend_config: BackendConfig,
  451. cache_for_no_tensor_check: dict[Node, bool],
  452. ) -> dict[str, Any]:
  453. """
  454. For each op attribute in the op's input activation, output activation,
  455. weight, bias - returns the settings of dtype and is_dynamic we expect
  456. for the `quantize` call in the reference model representation, or None
  457. if there is no `quantize` call needed.
  458. For example, if we have a node corresponding to `op0` in
  459. x0 -> op0 -> x1
  460. And we want a reference quantized representation to be
  461. x0 -> quant_static -> dequant -> op0 -> quant_dynamic -> dequant -> x1
  462. Then this function will return
  463. {
  464. "input_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
  465. "output_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
  466. }
  467. TODO(future PR, if needed): explicitly spell out the non-Tensor
  468. dtypes.
  469. """
  470. args_have_no_tensors = all_node_args_have_no_tensors(
  471. node, named_modules, cache_for_no_tensor_check
  472. )
  473. if args_have_no_tensors:
  474. return {
  475. "input_act_obs_or_fq_ctr": None,
  476. "output_act_obs_or_fq_ctr": None,
  477. }
  478. # get qconfig to determine the eventual dtype of this node
  479. if qconfig is not None:
  480. act_dtype, weight_dtype, input_act_is_dynamic = get_qconfig_dtypes(qconfig)
  481. # Currently `QConfig` only has one `activation` field.
  482. # For static quantization, it is reused for both input
  483. # and output activation. For dynamic quantization, this
  484. # field is currently only used for the input activation,
  485. # with the output activation being in fp32.
  486. # In the future this may change as we add more fields
  487. # to the `QConfig` object.
  488. bias_dtype = (
  489. torch.float16
  490. if (
  491. act_dtype == torch.float16
  492. and weight_dtype == torch.float16
  493. and (not input_act_is_dynamic)
  494. )
  495. else torch.float
  496. )
  497. is_general_tensor_value_op = (
  498. qhandler is not None and qhandler.is_general_tensor_value_op()
  499. )
  500. _is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
  501. weight_index = None
  502. if (
  503. isinstance(node, Node)
  504. and node.op == "call_function"
  505. and node.target in backend_config._pattern_complex_format_to_config
  506. ):
  507. weight_index = backend_config._pattern_complex_format_to_config[
  508. node.target
  509. ]._input_type_to_index.get("weight")
  510. bias_index = None
  511. if (
  512. isinstance(node, Node)
  513. and node.op == "call_function"
  514. and node.target in backend_config._pattern_complex_format_to_config
  515. ):
  516. bias_index = backend_config._pattern_complex_format_to_config[
  517. node.target
  518. ]._input_type_to_index.get("bias")
  519. return {
  520. "input_act_obs_or_fq_ctr": qconfig.activation,
  521. "weight_obs_or_fq_ctr": qconfig.weight,
  522. "bias_obs_or_fq_ctr": PlaceholderObserver.with_args(dtype=bias_dtype),
  523. "weight_index": weight_index,
  524. "bias_index": bias_index,
  525. "output_act_obs_or_fq_ctr": qconfig.activation,
  526. "reuse_input_obs_or_fq": _is_reuse_input_qconfig(qconfig),
  527. "input_output_share_observers": is_general_tensor_value_op,
  528. "_is_standalone_module": _is_standalone_module,
  529. }
  530. return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
  531. def _get_output_act_obs_or_fq(
  532. arg: Node,
  533. named_modules: dict[str, torch.nn.Module],
  534. obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
  535. is_qat: bool,
  536. ) -> ObserverOrFakeQuantize | None:
  537. """Get the constructor for observer or fake quant object for
  538. the argument in the original graph as the output of previous node,
  539. skipping inserted observers
  540. We are assuming that the observers are inserted correctly, and the dtype for
  541. argument in quantized graph will match what is specified by the qconfig
  542. """
  543. if not isinstance(arg, Node):
  544. raise AssertionError("arg must be a Node")
  545. if "quantization_annotation" in arg.meta:
  546. raise NotImplementedError(
  547. "Please use torchao (https://github.com/pytorch/ao) for pt2e quantization flow"
  548. )
  549. # Custom module LSTM output is a tuple that we broke down into the internal nodes in order
  550. # to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
  551. # Since we modified the graph in this case, we must trace back from the args through
  552. # the specific nodes we added in order to reach the original LSTM node. Otherwise, we would
  553. # not be able to accurately detect whether this node is a consumer of custom module LSTM.
  554. custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(
  555. arg, named_modules
  556. )
  557. output_act_obs_or_fq_ctr = None
  558. if custom_module_lstm_node is not None:
  559. output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"][
  560. "output_act_obs_or_fq_ctr"
  561. ]
  562. output_act_obs_or_fq = (
  563. output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
  564. )
  565. elif _is_activation_post_process_node(arg, named_modules):
  566. observed_arg = arg.args[0]
  567. if not isinstance(observed_arg, Node):
  568. raise AssertionError("Currently we only support observing Node")
  569. if "quantization_annotation" in observed_arg.meta:
  570. raise NotImplementedError(
  571. "Please use torchao (https://github.com/pytorch/ao) for pt2e quantization flow"
  572. )
  573. if "target_dtype_info" not in observed_arg.meta:
  574. raise AssertionError("expected 'target_dtype_info' in observed_arg.meta")
  575. output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"][
  576. "output_act_obs_or_fq_ctr"
  577. ]
  578. output_act_obs_or_fq = (
  579. output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
  580. )
  581. else:
  582. if "target_dtype_info" in arg.meta:
  583. output_act_obs_or_fq_ctr = arg.meta["target_dtype_info"].get(
  584. "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
  585. )
  586. else:
  587. output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
  588. output_act_obs_or_fq = (
  589. output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
  590. )
  591. return output_act_obs_or_fq
  592. def _get_arg_target_dtype_as_output(
  593. arg: Node,
  594. named_modules: dict[str, torch.nn.Module],
  595. obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
  596. is_qat: bool,
  597. ) -> torch.dtype | None:
  598. arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(
  599. arg, named_modules, obs_or_fq_map, is_qat
  600. )
  601. arg_as_output_target_dtype, _ = _get_dtype_and_is_dynamic(
  602. arg_as_output_act_obs_or_fq
  603. )
  604. return arg_as_output_target_dtype
  605. def _get_arg_as_input_act_obs_or_fq(
  606. arg: Node,
  607. node: Node,
  608. named_modules: dict[str, torch.nn.Module],
  609. obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
  610. is_qat: bool,
  611. ) -> ObserverOrFakeQuantize | None:
  612. """Get the observer or fake quant constructor for the Argument `arg`, as input
  613. to Node `node`
  614. """
  615. if not isinstance(arg, Node):
  616. raise AssertionError("arg must be a Node")
  617. if "quantization_annotation" in node.meta:
  618. raise NotImplementedError(
  619. "Please use torchao (https://github.com/pytorch/ao) for pt2e quantization flow"
  620. )
  621. # we can remove the following path in the future if fx graph mode quantization is
  622. # no longer used
  623. is_weight = node_arg_is_weight(node, arg)
  624. is_bias = node_arg_is_bias(node, arg)
  625. is_activation = not is_weight and not is_bias
  626. obs_or_fq_ctr = None
  627. if is_activation:
  628. obs_or_fq_ctr = node.meta["target_dtype_info"].get(
  629. "input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
  630. )
  631. elif is_weight:
  632. if node.target not in NON_QUANTIZABLE_WEIGHT_OPS:
  633. obs_or_fq_ctr = node.meta["target_dtype_info"].get(
  634. "weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
  635. )
  636. else:
  637. obs_or_fq_ctr = node.meta["target_dtype_info"].get(
  638. "bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
  639. )
  640. return obs_or_fq_ctr() if obs_or_fq_ctr else None
  641. def _maybe_insert_input_observer_for_arg_or_kwarg(
  642. node: Node | Any,
  643. arg: Argument,
  644. qconfig: QConfigAny,
  645. model: torch.nn.Module,
  646. named_modules: dict[str, torch.nn.Module],
  647. graph: Graph,
  648. qhandler: QuantizeHandler | None,
  649. prepare_custom_config: PrepareCustomConfig,
  650. obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
  651. is_qat: bool,
  652. backend_config: BackendConfig | None = None,
  653. model_device: torch.device | None = None,
  654. ) -> Argument:
  655. """
  656. Given a `node` and an `arg`, inserts an input observer between
  657. `node` and `arg` if necessary.
  658. """
  659. # for ops such as torch.cat([x0, x1]),
  660. # traverse through the list
  661. if isinstance(arg, (list, tuple)): # noqa: UP038
  662. new_arg_to_return = []
  663. for inner_arg in arg:
  664. new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
  665. node,
  666. # pyrefly: ignore [bad-argument-type]
  667. inner_arg,
  668. qconfig,
  669. model,
  670. named_modules,
  671. graph,
  672. qhandler,
  673. prepare_custom_config,
  674. obs_or_fq_map,
  675. is_qat,
  676. backend_config,
  677. model_device,
  678. )
  679. new_arg_to_return.append(new_inner_arg)
  680. return type(arg)(new_arg_to_return)
  681. if not isinstance(arg, Node):
  682. return arg
  683. if not isinstance(arg, Node):
  684. raise AssertionError("arg must be a Node")
  685. # default (no observer)
  686. new_arg = arg
  687. is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
  688. # TODO: move this to a separate function
  689. if not is_standalone_module:
  690. # Note: qconfig can be None in this branch this we are getting act/fq from
  691. # node.meta now
  692. # regular flow for most nodes, except standalone modules
  693. if "quantization_annotation" in node.meta:
  694. raise NotImplementedError(
  695. "Please use torchao (https://github.com/pytorch/ao) for pt2e quantization flow"
  696. )
  697. if "target_dtype_info" not in node.meta:
  698. raise AssertionError("expected 'target_dtype_info' in node.meta")
  699. # TODO: we are assuming "target_dtype_info" exists here, maybe
  700. # a default value also need to be provided here
  701. target_dtype_info = node.meta["target_dtype_info"]
  702. # for nodes that doesn't have `reuse_input_obs_or_fq` configured,
  703. # we'll default to False, this makes configuring this field optional for users
  704. reuse_input_obs_or_fq = target_dtype_info.get("reuse_input_obs_or_fq", False)
  705. arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq(
  706. arg, node, named_modules, obs_or_fq_map, is_qat
  707. )
  708. (
  709. arg_as_input_target_dtype,
  710. arg_as_input_target_is_dynamic,
  711. ) = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq)
  712. arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(
  713. arg, named_modules, obs_or_fq_map, is_qat
  714. )
  715. (
  716. arg_as_output_target_dtype,
  717. arg_as_output_target_is_dynamic,
  718. ) = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq)
  719. needs_obs_or_fq = _needs_obs_or_fq(
  720. arg_as_output_target_dtype,
  721. arg_as_output_target_is_dynamic,
  722. arg_as_input_target_dtype,
  723. arg_as_input_target_is_dynamic,
  724. reuse_input_obs_or_fq,
  725. is_zeroth_arg=len(node.args) > 0 and arg is node.args[0],
  726. )
  727. else:
  728. if qconfig is None:
  729. raise AssertionError("qconfig must not be None")
  730. # custom flow for standalone modules
  731. _, _, sm_prepare_custom_config, _ = _get_standalone_module_configs(
  732. node, named_modules, prepare_custom_config, qconfig, backend_config
  733. )
  734. sm_input_quantized_idxs = sm_prepare_custom_config.input_quantized_indexes
  735. # for args, this is set to the index of the current arg
  736. # for kwargs, this is left at None
  737. cur_input_idx = None
  738. for arg_idx, arg_to_check in enumerate(node.args):
  739. if arg_to_check is arg:
  740. cur_input_idx = arg_idx
  741. break
  742. if cur_input_idx is None:
  743. needs_obs_or_fq = False
  744. else:
  745. arg_as_output_target_dtype = _get_arg_target_dtype_as_output(
  746. arg, named_modules, obs_or_fq_map, is_qat
  747. )
  748. arg_as_input_target_dtype = (
  749. torch.quint8
  750. if cur_input_idx in sm_input_quantized_idxs
  751. else torch.float
  752. )
  753. needs_obs_or_fq = (
  754. arg_as_output_target_dtype != arg_as_input_target_dtype
  755. ) and (arg_as_input_target_dtype != torch.float)
  756. act_post_process_ctr = qconfig.activation
  757. arg_as_input_act_obs_or_fq = (
  758. act_post_process_ctr() if act_post_process_ctr else None
  759. )
  760. if needs_obs_or_fq:
  761. existing_obs_node = None
  762. # Before using the new observer, check if an observer
  763. # of the correct type already exists. If it does, use it.
  764. # This prevents duplicate observer insertions if a node is
  765. # used by multiple nodes.
  766. # TODO: this is looking into how the value is used in the future
  767. # we should remove this
  768. # removing this means we insert one observer for each use, even if they
  769. # have the same dtype, we can have an extra pass that removes the extra observers
  770. for maybe_obs_node in arg.users:
  771. if maybe_obs_node.op == "call_module":
  772. maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
  773. if (
  774. type(maybe_obs_mod) is type(arg_as_input_act_obs_or_fq)
  775. and maybe_obs_mod.dtype == arg_as_input_target_dtype # type: ignore[possibly-undefined]
  776. ):
  777. arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment]
  778. existing_obs_node = maybe_obs_node
  779. break
  780. if arg_as_input_act_obs_or_fq is None:
  781. raise AssertionError("arg_as_input_act_obs_or_fq must not be None")
  782. obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
  783. if existing_obs_node is None:
  784. new_obs_node = _insert_obs_or_fq(
  785. arg,
  786. arg_as_input_act_obs_or_fq,
  787. model,
  788. named_modules,
  789. graph,
  790. model_device,
  791. )
  792. # override this arg to be the observed arg
  793. new_arg = new_obs_node
  794. else:
  795. new_arg = existing_obs_node
  796. return new_arg
  797. def _maybe_insert_input_observers_for_node(
  798. node: Node,
  799. qconfig: QConfigAny,
  800. model: torch.nn.Module,
  801. named_modules: dict[str, torch.nn.Module],
  802. graph: Graph,
  803. qhandler: QuantizeHandler | None,
  804. prepare_custom_config: PrepareCustomConfig,
  805. obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
  806. is_qat: bool,
  807. backend_config: BackendConfig | None = None,
  808. model_device: torch.device | None = None,
  809. ) -> None:
  810. """
  811. If needed, inserts observers to the input args and kwargs of `node`.
  812. Note: modifies `node` inplace.
  813. For example, if cur_node needs an observer after prev_node, we change from
  814. prev_node -> cur_node
  815. To
  816. prev_node -> obs -> cur_node
  817. Note: backend_config only needed for standalone_module node
  818. """
  819. # Look through every input arg. If that arg's target dtype does not
  820. # match the current node's target dtype, insert an observer.
  821. new_args = []
  822. for arg in node.args:
  823. new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
  824. node,
  825. arg,
  826. qconfig,
  827. model,
  828. named_modules,
  829. graph,
  830. qhandler,
  831. prepare_custom_config,
  832. obs_or_fq_map,
  833. is_qat,
  834. backend_config,
  835. model_device,
  836. )
  837. new_args.append(new_arg)
  838. new_kwargs = {}
  839. for k, kwarg in node.kwargs.items():
  840. new_kwarg = _maybe_insert_input_observer_for_arg_or_kwarg(
  841. node,
  842. kwarg,
  843. qconfig,
  844. model,
  845. named_modules,
  846. graph,
  847. qhandler,
  848. prepare_custom_config,
  849. obs_or_fq_map,
  850. is_qat,
  851. backend_config,
  852. model_device,
  853. )
  854. new_kwargs[k] = new_kwarg
  855. # assign the new args and kwargs to the node, inplace
  856. node.args = tuple(new_args)
  857. node.kwargs = new_kwargs
  858. def _maybe_insert_input_equalization_observers_for_node(
  859. node: Node,
  860. equalization_qconfig: Any,
  861. model: torch.nn.Module,
  862. named_modules: dict[str, torch.nn.Module],
  863. graph: Graph,
  864. is_branch: bool,
  865. ) -> None:
  866. """
  867. If `node` needs to be equalized, find the input/weight observers it needs in
  868. `equalization_qconfig`, creates them, and inserts it into `graph`.
  869. If `node` does not need an equalization observer, returns None.
  870. """
  871. if equalization_qconfig is None or not node_supports_equalization(
  872. node, named_modules
  873. ):
  874. return
  875. if is_branch:
  876. warnings.warn(
  877. f"Cannot equalize {node} because it is part of a branch.", stacklevel=2
  878. )
  879. return
  880. new_args = []
  881. for arg in node.args:
  882. if not isinstance(arg, Node) or node_arg_is_bias(node, arg):
  883. new_args.append(arg)
  884. continue
  885. is_weight = node_arg_is_weight(node, arg)
  886. act_eq_process_ctr = (
  887. equalization_qconfig.weight
  888. if is_weight
  889. else equalization_qconfig.input_activation
  890. )
  891. new_eq_obs_mod = act_eq_process_ctr()
  892. new_eq_obs_node = _insert_obs_or_fq(
  893. arg, new_eq_obs_mod, model, named_modules, graph
  894. )
  895. new_args.append(new_eq_obs_node)
  896. # assign the new args and kwargs to the node, inplace
  897. node.args = tuple(new_args)
  898. def _maybe_insert_output_observer_for_node(
  899. node: Node,
  900. model: torch.nn.Module,
  901. named_modules: dict[str, torch.nn.Module],
  902. graph: Graph,
  903. obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
  904. is_qat: bool,
  905. ) -> Node | None:
  906. """
  907. If `node` needs an output observer, creates it, inserts it into `graph`
  908. and returns it.
  909. If `node` does not need an output observer, returns None.
  910. Note: inserting dynamic quantization ops for output is not supported in fx graph mode
  911. quantization code path right now
  912. """
  913. if node.op == "output":
  914. raise AssertionError("observer insertion for outputs is handled elsewhere")
  915. is_standalone_module = False
  916. if "quantization_annotation" in node.meta:
  917. raise NotImplementedError(
  918. "Please use torchao (https://github.com/pytorch/ao) for pt2e quantization flow"
  919. )
  920. if "target_dtype_info" not in node.meta:
  921. raise AssertionError("expected 'target_dtype_info' in node.meta")
  922. is_standalone_module = node.meta["target_dtype_info"].get(
  923. "_is_standalone_module", False
  924. )
  925. output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
  926. "output_act_obs_or_fq_ctr"
  927. )
  928. output_act_obs_or_fq = (
  929. output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
  930. )
  931. target_dtype, target_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq)
  932. # uncomment after we support reuse_input_obs_or_fq properly by having separate
  933. # implementations for this key instead of reusing the input_output_share_observers
  934. # code
  935. # reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False)
  936. # for now we set this to False since reuse_input_obs_or_fq for
  937. # the output of a node is implementation in the same code path as observer sharing,
  938. # we should refactor this part to make it clearer in the future
  939. # and we would be able to read this from config directly
  940. reuse_input_obs_or_fq = False
  941. # Note: prev_output_dtype = torch.float and prev_output_is_dynamic=False
  942. # because the prev_output is the output of an fp32 op, although technically
  943. # we should get the dtype of the output from node.meta["val"] in the future
  944. # if we deprecate fx graph mode quantization
  945. needs_obs_or_fq = _needs_obs_or_fq(
  946. torch.float, False, target_dtype, target_is_dynamic, reuse_input_obs_or_fq
  947. )
  948. # currently the activation in QConfig(activation=...,) is for both input
  949. # and output, and when the activation is configured to be dynamic quantization
  950. # e.g. PlaceholderObserver(dtype=torch.quint8, is_dynamic=True, ...), it means
  951. # the input should by dynamically quantized, but output should not be quantized
  952. #
  953. # there is no way we can specify different observer/fq for input and output
  954. # activation through QConfig today, this limitation is lifted in the
  955. # quantizer/annotation API in pytorch 2.0 export quantization code path,
  956. # but since this code is reused, annotating output to be dynamically quantized
  957. # would not work either for that.
  958. # we can change QConfig to support input/output activation if we want
  959. # to remove the following check, or if we can deprecate fx graph mode quantization
  960. if target_is_dynamic:
  961. needs_obs_or_fq = False
  962. # we never insert observers to output of standalone module, we assume
  963. # if needed, they are inserted inside the standalone module
  964. needs_obs_or_fq = needs_obs_or_fq and (not is_standalone_module)
  965. if needs_obs_or_fq:
  966. obs_or_fq_map[node] = output_act_obs_or_fq
  967. return _insert_obs_or_fq(
  968. node, output_act_obs_or_fq, model, named_modules, graph
  969. )
  970. else:
  971. return None
  972. def _maybe_insert_observers_before_graph_output(
  973. graph_output_node: Node,
  974. model: torch.nn.Module,
  975. named_modules: dict[str, torch.nn.Module],
  976. graph: Graph,
  977. obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
  978. is_qat: bool,
  979. ) -> None:
  980. """
  981. If the output needs to be quantized and there are any nodes
  982. in the output which are not already observed, inserts observers
  983. for those nodes.
  984. """
  985. def _recursive_maybe_replace_node_with_obs(
  986. maybe_node: Argument,
  987. model: torch.nn.Module,
  988. named_modules: dict[str, torch.nn.Module],
  989. graph: Graph,
  990. ) -> Argument:
  991. """
  992. Navigate an arbitrary data structure of lists, tuples, dicts.
  993. For each container type, recurse on all inputs. Once any Node
  994. is found, insert an observer if needed and do not recurse further.
  995. For example, given a structure of
  996. {'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}}
  997. we recurse down to bar1 and bar3, observe them if necessary,
  998. and if we inserted an observer then replace the original node
  999. with its observer.
  1000. Returns the data structure with all nodes needing observation being
  1001. replaced by their observers.
  1002. """
  1003. if isinstance(maybe_node, Node):
  1004. # check dtype of this node
  1005. arg_as_output_target_dtype = _get_arg_target_dtype_as_output(
  1006. maybe_node, named_modules, obs_or_fq_map, is_qat
  1007. )
  1008. observer_mod = None
  1009. arg_as_input_target_dtype = torch.float
  1010. if "target_dtype_info" in maybe_node.meta:
  1011. observer_cls = maybe_node.meta["target_dtype_info"].get(
  1012. "input_act_obs_or_fq_ctr", None
  1013. )
  1014. if observer_cls is not None:
  1015. observer_mod = observer_cls()
  1016. arg_as_input_target_dtype = observer_mod.dtype
  1017. # TODO: this does not handle dynamic quantization yet
  1018. need_obs = (
  1019. arg_as_output_target_dtype != arg_as_input_target_dtype
  1020. and arg_as_input_target_dtype != torch.float
  1021. )
  1022. if need_obs:
  1023. if observer_mod is None:
  1024. raise AssertionError(
  1025. "observer_mod must not be None when need_obs is True"
  1026. )
  1027. # insert observer
  1028. observer_node = _insert_obs_or_fq(
  1029. maybe_node, observer_mod, model, named_modules, graph
  1030. )
  1031. return observer_node
  1032. else:
  1033. return maybe_node
  1034. elif isinstance(maybe_node, (list, tuple)): # noqa: UP038
  1035. results = [
  1036. _recursive_maybe_replace_node_with_obs(
  1037. # pyrefly: ignore [bad-argument-type]
  1038. inner_node,
  1039. model,
  1040. named_modules,
  1041. graph,
  1042. )
  1043. for inner_node in maybe_node
  1044. ]
  1045. if isinstance(maybe_node, list):
  1046. return results
  1047. else:
  1048. return tuple(results)
  1049. elif isinstance(maybe_node, dict):
  1050. results_dict = {}
  1051. for k, inner_v in maybe_node.items():
  1052. results_dict[k] = _recursive_maybe_replace_node_with_obs(
  1053. inner_v, model, named_modules, graph
  1054. )
  1055. return results_dict
  1056. elif maybe_node is None:
  1057. return None
  1058. else:
  1059. raise Exception( # noqa: TRY002
  1060. "Unhandled type for returned node:", maybe_node
  1061. )
  1062. new_args = [
  1063. _recursive_maybe_replace_node_with_obs(old_arg, model, named_modules, graph)
  1064. for old_arg in graph_output_node.args
  1065. ]
  1066. graph_output_node.args = tuple(new_args) # type: ignore[assignment]
  1067. def _maybe_propagate_dtype_for_node(
  1068. node: Node,
  1069. target_dtype: torch.dtype | type,
  1070. node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig],
  1071. ) -> None:
  1072. """
  1073. Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node`
  1074. is a general tensor shape op, also call this function recursively on
  1075. the first argument, to propagate the dtype to the caller.
  1076. """
  1077. node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr"] = None
  1078. node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] = None
  1079. # if this is a copy node, propagate to first arg
  1080. (
  1081. _root_node,
  1082. _,
  1083. _pattern,
  1084. qhandler,
  1085. _qconfig,
  1086. ) = node_name_to_match_result_with_qconfig.get(
  1087. node.name, (None, None, None, None, None)
  1088. )
  1089. # TODO: probably need to remove `is_general_tensor_value_op`
  1090. if qhandler is not None and qhandler.is_general_tensor_value_op():
  1091. prev_node = node.args[0]
  1092. if isinstance(prev_node, Node):
  1093. _maybe_propagate_dtype_for_node(
  1094. prev_node, target_dtype, node_name_to_match_result_with_qconfig
  1095. )
  1096. def propagate_dtypes_for_known_nodes(
  1097. graph: Graph,
  1098. node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig],
  1099. ) -> None:
  1100. """
  1101. Currently we assume that inputs to the graph are either `torch.float` or
  1102. `torch.quint8`, which is not always correct. For ops such as
  1103. `x.masked_fill(mask, value)`, we know that the dtype of `mask` is a
  1104. `BoolTensor`. Propagate this information throughout the graph.
  1105. Note: not all dtypes in the graph will be correct after this pass, but a
  1106. higher percentage of them will be correct. Hopefully in the future we can
  1107. replace this with a better way to reason about dtypes of tensors.
  1108. """
  1109. for node in graph.nodes:
  1110. non_observable_arg_dict = get_non_observable_arg_indexes_and_types(node)
  1111. for arg_type in non_observable_arg_dict:
  1112. non_observable_indices = non_observable_arg_dict[arg_type](node)
  1113. for index in non_observable_indices:
  1114. arg = node.args[index]
  1115. # when an argument is a tuple, it does not show up as another node so we need to go through
  1116. # all elements of the tuple manually
  1117. if isinstance(arg, (tuple, list)): # noqa: UP038
  1118. arg_list = list(arg)
  1119. else:
  1120. arg_list = [arg]
  1121. for cur_arg in arg_list:
  1122. # hard coded arguments show up but aren't `Node` typed and do not need dtype propagated
  1123. if isinstance(cur_arg, torch.fx.node.Node):
  1124. _maybe_propagate_dtype_for_node(
  1125. cur_arg, arg_type, node_name_to_match_result_with_qconfig
  1126. )
  1127. def _maybe_make_input_output_share_observers(
  1128. node: Node,
  1129. model: torch.nn.Module,
  1130. named_modules: dict[str, torch.nn.Module],
  1131. ) -> bool:
  1132. """
  1133. Ensures that we share an observer
  1134. for all input arguments as well as the output argument. In detail, given
  1135. a graph of
  1136. x0 -> obs0 -> op -> x2
  1137. /
  1138. x1 -> obs1 /
  1139. where node obs0 points to observer instance observer0,
  1140. obs1 points to observer1 and obs2 points to observer2, we make nodes obs1
  1141. and ob2 point to observer0.
  1142. Returns: whether the operation succeeded or not
  1143. """
  1144. first_arg = None
  1145. # find the first non-Tensor arg
  1146. for i in range(len(node.args)):
  1147. if isinstance(node.args[i], (Node, list, tuple)): # noqa: UP038
  1148. first_arg = node.args[i]
  1149. break
  1150. # if there is no non-Tensor arg, return directly
  1151. if first_arg is None:
  1152. return False
  1153. if isinstance(first_arg, (list, tuple)): # noqa: UP038
  1154. first_arg_arg = first_arg[0]
  1155. elif isinstance(first_arg, Node):
  1156. first_arg_arg = first_arg
  1157. else:
  1158. return False
  1159. # if we have a graph such as
  1160. # observed_node -> non_observed_node -> cat
  1161. # we need to navigate up to the first observer
  1162. iteration_guard = 0
  1163. # pyrefly: ignore [bad-argument-type]
  1164. while not _is_activation_post_process_node(first_arg_arg, named_modules):
  1165. if not isinstance(first_arg_arg, Node):
  1166. return False
  1167. # did not find an activation_post_process for the op
  1168. if first_arg_arg.op == "placeholder":
  1169. return False
  1170. # trace back the args until we found the first Tensor/Node
  1171. trace_back_node = None
  1172. for i in range(len(first_arg_arg.args)):
  1173. trace_back_node = first_arg_arg.args[i]
  1174. if isinstance(trace_back_node, Node):
  1175. break
  1176. if trace_back_node is None:
  1177. return False
  1178. first_arg_arg = trace_back_node
  1179. iteration_guard += 1
  1180. if iteration_guard > 10000:
  1181. raise AssertionError("Unable to find observer of previous node")
  1182. if not isinstance(first_arg_arg, Node):
  1183. raise AssertionError("first_arg_arg must be a Node")
  1184. target_to_use = first_arg_arg.target
  1185. if not isinstance(target_to_use, str):
  1186. raise AssertionError("target_to_use must be a string")
  1187. obs_mod_to_use = named_modules[target_to_use]
  1188. if isinstance(first_arg, (list, tuple)): # noqa: UP038
  1189. # set all other input observer nodes to use that module
  1190. for input_idx, input_arg in enumerate(first_arg):
  1191. if input_idx == 0:
  1192. continue
  1193. iteration_guard = 0
  1194. # pyrefly: ignore [bad-argument-type]
  1195. while not _is_activation_post_process_node(input_arg, named_modules):
  1196. # failed to trace back since no input arg for the current node
  1197. # pyrefly: ignore [missing-attribute]
  1198. if len(input_arg.args) < 1:
  1199. return False
  1200. # pyrefly: ignore [bad-index, unsupported-operation]
  1201. input_arg = input_arg.args[0]
  1202. iteration_guard += 1
  1203. if iteration_guard > 10000:
  1204. raise AssertionError("Unable to find observer of previous node")
  1205. # pyrefly: ignore [missing-attribute]
  1206. parent_name, name = _parent_name(input_arg.target)
  1207. setattr(named_modules[parent_name], name, obs_mod_to_use)
  1208. # set the output observer node to use that module
  1209. for output_obs_node in node.users:
  1210. if not _is_activation_post_process_node(output_obs_node, named_modules):
  1211. raise AssertionError(
  1212. "output_obs_node must be an activation post process node"
  1213. )
  1214. parent_name, name = _parent_name(output_obs_node.target)
  1215. setattr(named_modules[parent_name], name, obs_mod_to_use)
  1216. # TODO(future PR): delete the orphaned observer modules
  1217. return True
  1218. def _remove_output_observer(
  1219. node: Node, model: torch.nn.Module, named_modules: dict[str, torch.nn.Module]
  1220. ):
  1221. items = list(node.users.items())
  1222. for output_obs_node, _ in items:
  1223. if not _is_activation_post_process_node(output_obs_node, named_modules):
  1224. raise AssertionError(
  1225. "output_obs_node must be an activation post process node"
  1226. )
  1227. output_obs_node.replace_all_uses_with(node)
  1228. model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator]
  1229. def _swap_custom_module_to_observed(
  1230. node: Node,
  1231. qconfig: QConfigAny,
  1232. named_modules: dict[str, torch.nn.Module],
  1233. prepare_custom_config: PrepareCustomConfig,
  1234. ):
  1235. custom_module = named_modules[node.target] # type: ignore[index]
  1236. custom_module_class_mapping = prepare_custom_config.float_to_observed_mapping
  1237. observed_custom_module_class = get_swapped_custom_module_class(
  1238. custom_module, custom_module_class_mapping, qconfig
  1239. )
  1240. observed_custom_module = observed_custom_module_class.from_float(custom_module)
  1241. parent_name, name = _parent_name(node.target)
  1242. setattr(named_modules[parent_name], name, observed_custom_module)
  1243. def insert_observers_for_model(
  1244. model: GraphModule,
  1245. node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig],
  1246. node_name_to_qconfig: dict[str, QConfigAny],
  1247. prepare_custom_config: PrepareCustomConfig,
  1248. equalization_config_map: dict[str, Any],
  1249. backend_config: BackendConfig,
  1250. observed_node_names: set[str],
  1251. is_qat: bool,
  1252. ) -> Node | None:
  1253. """
  1254. Inserts observers, using the following high level algorithm:
  1255. For each node in the graph:
  1256. 1. determine the target dtype of this node in the quantized graph, and save
  1257. it for future steps
  1258. 2. determine the target dtype or all args and kwargs of this node
  1259. 3. if any arg or kwarg's target dtype does not match the current node's
  1260. dtype, insert an observer
  1261. 4. if the current node needs an output observer, insert it
  1262. For example:
  1263. - starting graph:
  1264. x0 -> linear -> x1
  1265. - observed graph after processing x0:
  1266. x0(fp32)
  1267. - observed graph after processing linear:
  1268. x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8)
  1269. - observed graph after processing x1:
  1270. x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1
  1271. After a node is processed, the naive observer placement is guaranteed to be
  1272. complete for that node and all of its predecessors. There can be future
  1273. passes which optimize the graph by deduplicating observers, etc.
  1274. """
  1275. # node.meta["target_dtype_info"] stores the target dtype information
  1276. # that's derived from qconfig for the Node, for example, if we have
  1277. # a conv2d node that has a qconfig
  1278. # qconfig = QConfig(activation=..., weight=...)
  1279. # # information for input and bias node omitted
  1280. # # for getattr node
  1281. # # weight = getattr(self, 'weight')
  1282. # weight.meta["target_dtype_info"] = {
  1283. # 'output_act_obs_or_fq_ctr': qconfig.weight,
  1284. # }
  1285. # # for conv2d node
  1286. # # conv2d = call_function[target=torch.nn.functional.conv2d](
  1287. # # args=(input, weight, bias))
  1288. # conv2d.meta["target_dtype_info"] = {
  1289. # 'input_act_obs_or_fq_ctr': qconfig.activation
  1290. # 'weight_obs_or_fq_ctr': qconfig.weight,
  1291. # 'bias_obs_or_fq_ctr': PlaceholderObserver.with_args(dtype=torch.float32),
  1292. # 'output_act_obs_or_fq_ctr': qconfig.activation,
  1293. # }
  1294. #
  1295. cache_for_no_tensor_check: dict[Node, bool] = {}
  1296. # first, populate the dtype map based only on qconfig and qhandler
  1297. # this assumes:
  1298. # graph inputs are fp32 by default, and int8 where overridden
  1299. # other nodes output dtype is specified by the qconfig
  1300. named_modules = dict(model.named_modules(remove_duplicate=False))
  1301. input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes
  1302. output_quantized_idxs: list[int] = prepare_custom_config.output_quantized_indexes
  1303. processed_nodes: set[Node] = set()
  1304. # initialize target_dtype_info
  1305. for node in model.graph.nodes:
  1306. node.meta["target_dtype_info"] = copy.copy(
  1307. _DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO
  1308. )
  1309. inputs_seen_counter = 0
  1310. outputs_seen_counter = 0
  1311. placeholder_node_to_input_index: dict[Node, int] = {}
  1312. # TODO: we probably don't need this counter since each graph will only have
  1313. # one output node?
  1314. output_node_to_output_index: dict[Node, int] = {}
  1315. for node in model.graph.nodes:
  1316. if node.op == "placeholder":
  1317. placeholder_node_to_input_index[node] = inputs_seen_counter
  1318. inputs_seen_counter += 1
  1319. if node.op == "output":
  1320. output_node_to_output_index[node] = outputs_seen_counter
  1321. outputs_seen_counter += 1
  1322. # Step 1, set the observer or fake quantize module constructor for each node in the
  1323. # matched_node_pattern
  1324. for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
  1325. (
  1326. last_node,
  1327. matched_node_pattern,
  1328. pattern,
  1329. qhandler,
  1330. qconfig,
  1331. ) = match_res_with_qconfig
  1332. if qhandler is None:
  1333. raise AssertionError("qhandler must not be None")
  1334. _set_target_dtype_info_for_matched_node_pattern(
  1335. matched_node_pattern,
  1336. last_node,
  1337. qconfig,
  1338. qhandler,
  1339. backend_config,
  1340. named_modules,
  1341. cache_for_no_tensor_check,
  1342. processed_nodes,
  1343. )
  1344. # Step 2. Special cases for some operators, we might be able to remove them
  1345. # in the future if we know dtype information of each node better
  1346. # Step 2.1. some settings are not based on patterns, we need to process each node
  1347. # instead
  1348. for node in model.graph.nodes:
  1349. if (
  1350. node.op == "placeholder"
  1351. and placeholder_node_to_input_index[node] in input_quantized_idxs
  1352. ):
  1353. # users are not supposed to call calculate_qparams on PlaceholderObserver, and
  1354. # this is OK because we are using this as a way to encode the dtypes of input
  1355. # tensor, we won't actually insert these observers in the graph and won't
  1356. # actually call calculate_qparams
  1357. node.meta["target_dtype_info"] = copy.copy(
  1358. _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO
  1359. )
  1360. elif node.op in ("call_module", "call_method", "call_function"):
  1361. args_have_no_tensors = all_node_args_have_no_tensors(
  1362. node, named_modules, cache_for_no_tensor_check
  1363. )
  1364. if args_have_no_tensors:
  1365. node.meta["target_dtype_info"] = {
  1366. "input_act_obs_or_fq_ctr": None,
  1367. "output_act_obs_or_fq_ctr": None,
  1368. }
  1369. elif (
  1370. node.op == "output"
  1371. and output_node_to_output_index[node] in output_quantized_idxs
  1372. ):
  1373. # TODO(future PR): update the output_quantized_idxs API to match
  1374. # arbitrary data structures. There is always a single output, and
  1375. # that output can have arbitrary nesting of values. List[int] is
  1376. # not the right data type for this.
  1377. # TODO(future PR): support more dtypes in model outputs, if necessary
  1378. node.meta["target_dtype_info"] = copy.copy(
  1379. _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO
  1380. )
  1381. # Step 2.2, for nodes with known input dtypes, propagate them throughout the
  1382. # graph. For example, if there is a call such as
  1383. # x1 = x0.masked_fill(mask, 1)
  1384. # we propagate the type of mask to be torch.bool
  1385. propagate_dtypes_for_known_nodes(
  1386. model.graph, node_name_to_match_result_with_qconfig
  1387. )
  1388. # Step 3, check if the requested target_dtype_info is supported by backend or not
  1389. # if not, we'll reset the target_dtye_info to use the default (float Tensor)
  1390. # reset the counters and set of processed_nodes
  1391. processed_nodes: set[Node] = set()
  1392. for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
  1393. (
  1394. last_node,
  1395. matched_node_pattern,
  1396. pattern,
  1397. qhandler,
  1398. qconfig,
  1399. ) = match_res_with_qconfig
  1400. is_supported_by_backend = (
  1401. _is_pattern_dtype_config_and_qconfig_supported_by_backend(
  1402. pattern, matched_node_pattern, qconfig, backend_config
  1403. )
  1404. )
  1405. if qhandler is None:
  1406. raise AssertionError("qhandler must not be None")
  1407. # get output_act_dtype so that we don't also reset the special typed nodes
  1408. # TODO: we might want to handle these more uniformly with the default path
  1409. # this can be improved if we can use node.meta["val"]
  1410. output_act_or_fq_ctr = node.meta["target_dtype_info"][
  1411. "output_act_obs_or_fq_ctr"
  1412. ]
  1413. output_act_or_fq = output_act_or_fq_ctr() if output_act_or_fq_ctr else None
  1414. output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_or_fq)
  1415. if not is_supported_by_backend and output_act_dtype not in [
  1416. None,
  1417. int,
  1418. float,
  1419. torch.bool,
  1420. ]:
  1421. # restore target_dtype_info to default if it is not supported by backend
  1422. _set_target_dtype_info_for_matched_node_pattern(
  1423. matched_node_pattern,
  1424. last_node,
  1425. torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig,
  1426. None,
  1427. backend_config,
  1428. named_modules,
  1429. cache_for_no_tensor_check,
  1430. processed_nodes,
  1431. )
  1432. # After this point, the current node and all of its arguments
  1433. # have a target_dtype_info assigned. Now, we insert observers for inputs
  1434. # of this node (if needed for this node), and the output of this node
  1435. # (if needed for this node).
  1436. # Since we are mutating the graph as we go, we iterate over the original
  1437. # nodes before observer insertion, instead of model.graph.nodes.
  1438. nodes_before_observation = list(model.graph.nodes)
  1439. # Avoid duplicates custom module swaps for multiple nodes with same target.
  1440. custom_module_names_already_swapped: set[str] = set()
  1441. # TODO: reuse placeholder_node_to_input_index and output_node_to_output_index
  1442. # reset inputs/outputs counters
  1443. inputs_seen_counter = 0
  1444. outputs_seen_counter = 0
  1445. results_node = None
  1446. obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
  1447. model_device = assert_and_get_unique_device(model)
  1448. # TODO: change this to insert obs/fq by pattern instead of by node
  1449. for node in nodes_before_observation:
  1450. if node.op == "placeholder":
  1451. # if a graph input is in fp32, it does not need observation
  1452. # if a graph input is in int8, we assume the observation happens
  1453. # outside of the graph, and no additional observation is needed
  1454. pass
  1455. elif node.op in ("call_module", "call_method", "call_function", "output"):
  1456. # check for matches
  1457. (
  1458. last_node,
  1459. matched_node_pattern,
  1460. pattern,
  1461. qhandler,
  1462. qconfig,
  1463. ) = node_name_to_match_result_with_qconfig.get( # type: ignore[assignment]
  1464. node.name, (None, None, None, None, None)
  1465. )
  1466. equalization_qconfig = equalization_config_map.get(node.name)
  1467. this_node_dtype_info = node.meta["target_dtype_info"]
  1468. if "val" in node.meta:
  1469. output_is_a_tensor = this_node_dtype_info is not None and isinstance(
  1470. node.meta["val"], FakeTensor
  1471. )
  1472. else:
  1473. output_is_a_tensor = this_node_dtype_info is not None
  1474. skip_inserting_observers = (
  1475. (qconfig is None) or not output_is_a_tensor
  1476. ) and (node.op != "output")
  1477. # TODO: take a closer look to see if we can remove this check
  1478. # right now it is here because of `observed_node_names`, we are using
  1479. # it as an indicator for swapping the modules to reference modules in
  1480. # convert
  1481. is_supported_by_backend = (
  1482. _is_pattern_dtype_config_and_qconfig_supported_by_backend(
  1483. pattern, matched_node_pattern, qconfig, backend_config
  1484. )
  1485. )
  1486. if not skip_inserting_observers and is_supported_by_backend:
  1487. named_modules = dict(model.named_modules(remove_duplicate=False))
  1488. if node.op != "output":
  1489. if matched_node_pattern is None:
  1490. raise AssertionError("matched_node_pattern must not be None")
  1491. # add matched nodes to the observed node name set
  1492. _add_matched_node_name_to_set(
  1493. matched_node_pattern, observed_node_names
  1494. )
  1495. # This is currently only used for equalization.
  1496. # Checks if the current node is in a branch in which the two
  1497. # first layers are both being quantized.
  1498. #
  1499. # ex. conv2
  1500. # /
  1501. # x -> conv1
  1502. #
  1503. # If this is the case, we will not apply equalization to the
  1504. # initial two layers.
  1505. is_quantized_branch = False
  1506. if (
  1507. len(node.args) > 0
  1508. and isinstance(node.args[0], Node)
  1509. and len(node.args[0].users) > 1
  1510. ):
  1511. for user in node.args[0].users:
  1512. # Checks if there exists another user being quantized
  1513. is_user_quantized = node_name_to_qconfig.get(
  1514. user.name
  1515. ) is not None or (
  1516. user.op == "call_module"
  1517. and isinstance(
  1518. named_modules[str(user.target)], ObserverBase
  1519. )
  1520. )
  1521. if user != node and is_user_quantized:
  1522. is_quantized_branch = True
  1523. pattern_to_root_node_getter = (
  1524. get_fusion_pattern_to_root_node_getter(backend_config)
  1525. )
  1526. root_node_getter = pattern_to_root_node_getter.get(
  1527. pattern, _default_root_node_getter
  1528. )
  1529. root_node = root_node_getter(matched_node_pattern)
  1530. is_input_node_of_the_pattern = node is root_node
  1531. if is_input_node_of_the_pattern:
  1532. # this modifies node inplace
  1533. _maybe_insert_input_observers_for_node(
  1534. node,
  1535. qconfig,
  1536. model,
  1537. named_modules,
  1538. model.graph,
  1539. qhandler,
  1540. prepare_custom_config,
  1541. obs_or_fq_map,
  1542. is_qat,
  1543. backend_config,
  1544. model_device,
  1545. )
  1546. # insert equalization input observers if needed
  1547. _maybe_insert_input_equalization_observers_for_node(
  1548. node,
  1549. equalization_qconfig,
  1550. model,
  1551. named_modules,
  1552. model.graph,
  1553. is_quantized_branch,
  1554. )
  1555. is_last_node_of_pattern = node is last_node
  1556. input_output_share_observers = node.meta["target_dtype_info"].get(
  1557. "input_output_share_observers", False
  1558. )
  1559. reuse_input_obs_or_fq = node.meta["target_dtype_info"].get(
  1560. "reuse_input_obs_or_fq", False
  1561. )
  1562. if is_last_node_of_pattern:
  1563. if _is_custom_module_lstm(
  1564. # pyrefly: ignore [bad-argument-type]
  1565. node,
  1566. named_modules,
  1567. qconfig,
  1568. qhandler,
  1569. ):
  1570. # Currently custom module outputs are assumed to be already quantized,
  1571. # so we need to insert a DeQuantStub after the output. For custom module
  1572. # LSTM specifically, the outputs are also a nested tuple, so we must first
  1573. # break down the tuple to insert DeQuantStubs after the internal nodes.
  1574. # TODO: This currently diverges from how custom modules are handled today,
  1575. # where we insert observers after the output instead of DeQuantStubs, and
  1576. # replace these observers with "dequantize" nodes during convert. Conceptually,
  1577. # these output observers are the same as DeQuantStubs. In the future, we
  1578. # should resolve this inconsistency by inserting DeQuantStubs for all custom
  1579. # modules, not just for LSTM.
  1580. _insert_dequant_stubs_for_custom_module_lstm_output(
  1581. # pyrefly: ignore [bad-argument-type]
  1582. node,
  1583. model,
  1584. named_modules,
  1585. model.graph,
  1586. )
  1587. # pyrefly: ignore [missing-attribute]
  1588. if node.target not in custom_module_names_already_swapped:
  1589. # pyrefly: ignore [bad-argument-type]
  1590. custom_module_names_already_swapped.add(node.target)
  1591. _swap_custom_module_to_observed(
  1592. # pyrefly: ignore [bad-argument-type]
  1593. node,
  1594. qconfig,
  1595. named_modules,
  1596. prepare_custom_config,
  1597. )
  1598. else:
  1599. # this returns the new observer node if it was needed
  1600. maybe_output_obs_node = (
  1601. _maybe_insert_output_observer_for_node(
  1602. # pyrefly: ignore [bad-argument-type]
  1603. node,
  1604. model,
  1605. named_modules,
  1606. model.graph,
  1607. obs_or_fq_map,
  1608. is_qat,
  1609. )
  1610. )
  1611. if maybe_output_obs_node is not None:
  1612. # Update users of original node to use the output observer
  1613. # instead. For example, change
  1614. #
  1615. # next_node
  1616. # /
  1617. # cur_node -> obs
  1618. #
  1619. # to
  1620. #
  1621. # next_node
  1622. # /
  1623. # cur_node -> obs
  1624. #
  1625. # We need to save orig users before updating uses because
  1626. # the list of users will change as we update uses
  1627. # pyrefly: ignore [missing-attribute]
  1628. orig_users = list(node.users.keys())
  1629. for user_node in orig_users:
  1630. if user_node is maybe_output_obs_node:
  1631. continue
  1632. user_node.replace_input_with(
  1633. node, maybe_output_obs_node
  1634. )
  1635. _is_observer_in_same_graph_ = (
  1636. _is_observer_in_same_graph(
  1637. # pyrefly: ignore [bad-argument-type]
  1638. node,
  1639. named_modules,
  1640. obs_or_fq_map,
  1641. is_qat,
  1642. )
  1643. )
  1644. # for ops whose inputs and outputs share observer/fqs, we modify the graph
  1645. # to make all inputs and outputs use the first input's
  1646. # observer/fq
  1647. if (
  1648. input_output_share_observers
  1649. and _is_observer_in_same_graph_
  1650. ) or reuse_input_obs_or_fq:
  1651. if not _maybe_make_input_output_share_observers(
  1652. # pyrefly: ignore [bad-argument-type]
  1653. node,
  1654. model,
  1655. named_modules,
  1656. ):
  1657. _remove_output_observer(
  1658. # pyrefly: ignore [bad-argument-type]
  1659. node,
  1660. model,
  1661. named_modules,
  1662. )
  1663. if qhandler is not None and qhandler.is_custom_module():
  1664. if (
  1665. # pyrefly: ignore [missing-attribute]
  1666. node.target
  1667. not in custom_module_names_already_swapped
  1668. ):
  1669. custom_module_names_already_swapped.add(
  1670. # pyrefly: ignore [bad-argument-type]
  1671. node.target
  1672. )
  1673. _swap_custom_module_to_observed(
  1674. # pyrefly: ignore [bad-argument-type]
  1675. node,
  1676. qconfig,
  1677. named_modules,
  1678. prepare_custom_config,
  1679. )
  1680. else: # output
  1681. _maybe_insert_observers_before_graph_output(
  1682. node, model, named_modules, model.graph, obs_or_fq_map, is_qat
  1683. )
  1684. #
  1685. # After this point, the current node has input and output observers
  1686. # that it needs for itself inserted.
  1687. #
  1688. # increment the counters, so future inputs and outputs are assigned
  1689. # correct dtypes
  1690. if node.op == "placeholder":
  1691. inputs_seen_counter += 1
  1692. elif node.op == "output":
  1693. outputs_seen_counter += 1
  1694. results_node = node
  1695. return results_node
  1696. def _run_prepare_fx_on_standalone_modules(
  1697. model: torch.nn.Module,
  1698. is_qat: bool,
  1699. named_modules: dict[str, torch.nn.Module],
  1700. node_name_to_match_result_with_qconfig: Any,
  1701. prepare_custom_config: PrepareCustomConfig,
  1702. backend_config: BackendConfig,
  1703. ) -> None:
  1704. """
  1705. Runs prepare_fx on each standalone module. Note: this does
  1706. not modify the graph, it just replaces the unobserved modules with
  1707. their observed versions.
  1708. """
  1709. for (
  1710. root_node,
  1711. _,
  1712. _pattern,
  1713. qhandler,
  1714. qconfig,
  1715. ) in node_name_to_match_result_with_qconfig.values():
  1716. if qhandler is None:
  1717. continue
  1718. elif not qhandler.is_standalone_module():
  1719. continue
  1720. (
  1721. sm_qconfig_mapping,
  1722. sm_example_inputs,
  1723. sm_prepare_custom_config,
  1724. sm_backend_config,
  1725. ) = _get_standalone_module_configs(
  1726. root_node, named_modules, prepare_custom_config, qconfig, backend_config
  1727. )
  1728. standalone_module = named_modules[root_node.target]
  1729. prepare = torch.ao.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore[attr-defined]
  1730. observed_standalone_module = prepare(
  1731. standalone_module,
  1732. sm_qconfig_mapping,
  1733. is_qat,
  1734. example_inputs=sm_example_inputs,
  1735. prepare_custom_config=sm_prepare_custom_config,
  1736. backend_config=sm_backend_config,
  1737. )
  1738. parent_name, name = _parent_name(root_node.target)
  1739. setattr(named_modules[parent_name], name, observed_standalone_module)
  1740. named_modules[root_node.target] = observed_standalone_module
  1741. def _save_state(
  1742. observed: GraphModule,
  1743. node_name_to_qconfig: dict[str, QConfigAny],
  1744. node_name_to_scope: dict[str, tuple[str, type]],
  1745. prepare_custom_config: PrepareCustomConfig,
  1746. equalization_node_name_to_qconfig: dict[str, Any],
  1747. qconfig_mapping: QConfigMapping,
  1748. is_qat: bool,
  1749. observed_node_names: set[str],
  1750. ) -> None:
  1751. observed.meta["_observed_graph_module_attrs"] = ObservedGraphModuleAttrs(
  1752. node_name_to_qconfig=node_name_to_qconfig,
  1753. node_name_to_scope=node_name_to_scope,
  1754. prepare_custom_config=prepare_custom_config,
  1755. equalization_node_name_to_qconfig=equalization_node_name_to_qconfig,
  1756. qconfig_mapping=qconfig_mapping,
  1757. is_qat=is_qat,
  1758. observed_node_names=observed_node_names,
  1759. )
  1760. def prepare(
  1761. model: GraphModule,
  1762. qconfig_mapping: QConfigMapping | dict[str, Any],
  1763. is_qat: bool,
  1764. node_name_to_scope: dict[str, tuple[str, type]],
  1765. example_inputs: tuple[Any, ...],
  1766. prepare_custom_config: PrepareCustomConfig | dict[str, Any] | None = None,
  1767. _equalization_config: QConfigMapping | dict[str, Any] | None = None,
  1768. backend_config: BackendConfig | dict[str, Any] | None = None,
  1769. is_standalone_module: bool = False,
  1770. ) -> GraphModule:
  1771. """standalone_module means it a submodule that is not inlined in
  1772. parent module, and will be quantized separately as one unit.
  1773. How the standalone module is observed is specified by `input_quantized_idxs` and
  1774. `output_quantized_idxs` in the prepare_custom_config for the standalone module
  1775. Args:
  1776. node_name_to_scope: mapping from node name to the scope of the module which contains the node.
  1777. The scope is a tuple of fully qualified path of the module and the type of the module
  1778. Returns:
  1779. model(GraphModule): prepared standalone module
  1780. attributes related to standalone module
  1781. in model.meta["_observed_graph_module_attrs"]:
  1782. is_observed_standalone_module (bool): boolean value that shows whether the
  1783. current model is a observed standalone module or not
  1784. standalone_module_input_quantized_idxs(List[Int]): a list of
  1785. indexes for the graph input that is expected to be quantized,
  1786. same as input_quantized_idxs configuration provided
  1787. for the standalone module
  1788. standalone_module_output_quantized_idxs(List[Int]): a list of
  1789. indices for the graph output that is quantized
  1790. same as input_quantized_idxs configuration provided
  1791. for the standalone module
  1792. """
  1793. if prepare_custom_config is None:
  1794. prepare_custom_config = PrepareCustomConfig()
  1795. if _equalization_config is None:
  1796. _equalization_config = QConfigMapping()
  1797. if isinstance(qconfig_mapping, dict):
  1798. warnings.warn(
  1799. "Passing a QConfig dictionary to prepare is deprecated and will not be supported "
  1800. "in a future version. Please pass in a QConfigMapping instead.",
  1801. FutureWarning,
  1802. stacklevel=2,
  1803. )
  1804. qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping)
  1805. if isinstance(_equalization_config, dict):
  1806. warnings.warn(
  1807. "Passing a QConfig dictionary to prepare for equalization is deprecated and will not "
  1808. "be supported in a future version. Please pass in a QConfigMapping instead.",
  1809. FutureWarning,
  1810. stacklevel=2,
  1811. )
  1812. _equalization_config = QConfigMapping.from_dict(_equalization_config)
  1813. if isinstance(prepare_custom_config, dict):
  1814. warnings.warn(
  1815. "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
  1816. "in a future version. Please pass in a PrepareCustomConfig instead.",
  1817. FutureWarning,
  1818. stacklevel=2,
  1819. )
  1820. prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
  1821. if isinstance(backend_config, dict):
  1822. warnings.warn(
  1823. "Passing a backend_config_dict to prepare is deprecated and will not be supported "
  1824. "in a future version. Please pass in a BackendConfig instead.",
  1825. FutureWarning,
  1826. stacklevel=2,
  1827. )
  1828. backend_config = BackendConfig.from_dict(backend_config)
  1829. if not isinstance(qconfig_mapping, QConfigMapping):
  1830. raise AssertionError("qconfig_mapping must be a QConfigMapping")
  1831. if not isinstance(_equalization_config, QConfigMapping):
  1832. raise AssertionError("_equalization_config must be a QConfigMapping")
  1833. qconfig_mapping = copy.deepcopy(qconfig_mapping)
  1834. _equalization_config = copy.deepcopy(_equalization_config)
  1835. # mapping from a tuple of nodes in reverse order to uninitialized
  1836. # QuantizeHandler subclass. For example,
  1837. # {
  1838. # # match a single node
  1839. # (<class 'torch.nn.modules.conv.Conv3d'>:
  1840. # <class 'torch.ao.quantization.fx.quantize.ConvRelu'>),
  1841. # # match multiple nodes in reverse order
  1842. # ((<function relu at 0x7f766a7360d0>, <built-in function add>):
  1843. # <class 'torch.ao.quantization.fx.quantize.Add'>),
  1844. # }
  1845. pattern_to_quantize_handler: dict[Pattern, QuantizeHandler] = {}
  1846. if backend_config is None:
  1847. backend_config = get_native_backend_config()
  1848. pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config)
  1849. pattern_to_quantize_handler = _sorted_patterns_dict(pattern_to_quantize_handler)
  1850. root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
  1851. # pyrefly: ignore [bad-argument-type]
  1852. _update_qconfig_for_fusion(model, qconfig_mapping)
  1853. # pyrefly: ignore [bad-argument-type]
  1854. _update_qconfig_for_fusion(model, _equalization_config)
  1855. # pyrefly: ignore [bad-argument-type]
  1856. flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
  1857. # TODO: support regex as well
  1858. propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
  1859. if is_qat:
  1860. module_to_qat_module = get_module_to_qat_module(backend_config)
  1861. _qat_swap_modules(model, module_to_qat_module)
  1862. # pyrefly: ignore [bad-argument-type]
  1863. _update_qconfig_for_qat(qconfig_mapping, backend_config)
  1864. # mapping from fully qualified module name to module instance
  1865. # for example,
  1866. # {
  1867. # '': Model(...),
  1868. # 'linear': Linear(...),
  1869. # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
  1870. # }
  1871. named_modules = dict(model.named_modules(remove_duplicate=False))
  1872. # fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches
  1873. equalization_node_name_to_qconfig = _generate_node_name_to_qconfig(
  1874. model,
  1875. named_modules,
  1876. model.graph,
  1877. # pyrefly: ignore [bad-argument-type]
  1878. _equalization_config,
  1879. node_name_to_scope,
  1880. )
  1881. node_name_to_qconfig = _generate_node_name_to_qconfig(
  1882. model,
  1883. named_modules,
  1884. model.graph,
  1885. # pyrefly: ignore [bad-argument-type]
  1886. qconfig_mapping,
  1887. node_name_to_scope,
  1888. )
  1889. # match the patterns that will get quantized
  1890. standalone_module_names = list(prepare_custom_config.standalone_module_names.keys())
  1891. standalone_module_classes = list(
  1892. prepare_custom_config.standalone_module_classes.keys()
  1893. )
  1894. custom_module_classes = get_custom_module_class_keys(
  1895. prepare_custom_config.float_to_observed_mapping
  1896. )
  1897. matches_without_qconfig = _find_matches(
  1898. model.graph,
  1899. named_modules,
  1900. pattern_to_quantize_handler,
  1901. root_node_getter_mapping,
  1902. standalone_module_names,
  1903. standalone_module_classes,
  1904. custom_module_classes,
  1905. )
  1906. # map qconfig instances to matches
  1907. node_name_to_match_result_with_qconfig = {}
  1908. for node_name, match_without_qconfig in matches_without_qconfig.items():
  1909. match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name])
  1910. node_name_to_match_result_with_qconfig[node_name] = match_with_qconfig
  1911. _run_prepare_fx_on_standalone_modules(
  1912. model,
  1913. is_qat,
  1914. named_modules,
  1915. node_name_to_match_result_with_qconfig,
  1916. prepare_custom_config,
  1917. backend_config,
  1918. )
  1919. # record names for the set of observed node, so that in convert step
  1920. # we know whether we need to convert a floating point module to reference
  1921. # quantized module or not
  1922. observed_node_names: set[str] = set()
  1923. result_node = insert_observers_for_model(
  1924. model,
  1925. node_name_to_match_result_with_qconfig,
  1926. node_name_to_qconfig,
  1927. prepare_custom_config,
  1928. equalization_node_name_to_qconfig,
  1929. backend_config,
  1930. observed_node_names,
  1931. is_qat,
  1932. )
  1933. model = GraphModule(model, model.graph)
  1934. _save_state(
  1935. model,
  1936. node_name_to_qconfig,
  1937. node_name_to_scope,
  1938. prepare_custom_config,
  1939. equalization_node_name_to_qconfig,
  1940. # pyrefly: ignore [bad-argument-type]
  1941. qconfig_mapping,
  1942. is_qat,
  1943. observed_node_names,
  1944. )
  1945. if is_standalone_module:
  1946. if result_node is None:
  1947. raise AssertionError("result_node must not be None for standalone modules")
  1948. if not isinstance(result_node.args[0], Node):
  1949. raise AssertionError(
  1950. "standalone module only supports returning simple value currently (not tuple, dict etc.)"
  1951. )
  1952. # these inputs are observed in parent
  1953. # converting List[int] to Tensor since module attribute is
  1954. # Union[Tensor, Module]
  1955. input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes
  1956. output_quantized_idxs: list[int] = (
  1957. prepare_custom_config.output_quantized_indexes
  1958. )
  1959. observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
  1960. # inplace modification
  1961. observed_graph_module_attrs.is_observed_standalone_module = True
  1962. observed_graph_module_attrs.standalone_module_input_quantized_idxs = (
  1963. input_quantized_idxs
  1964. )
  1965. observed_graph_module_attrs.standalone_module_output_quantized_idxs = (
  1966. output_quantized_idxs
  1967. )
  1968. return model