converter.py 64 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622
  1. # mypy: allow-untyped-defs
  2. import builtins
  3. import logging
  4. import operator
  5. import typing
  6. import warnings
  7. from collections.abc import Callable, Sequence
  8. from contextlib import contextmanager
  9. from typing import Any, Optional, Union
  10. import torch
  11. import torch.export._trace
  12. from torch import _C
  13. from torch._export.passes.replace_quantized_ops_with_standard_ops_pass import (
  14. replace_quantized_ops_with_standard_ops,
  15. )
  16. from torch.export.dynamic_shapes import _tree_map_with_path, Dim
  17. from torch.export.exported_program import ExportedProgram
  18. from torch.export.graph_signature import (
  19. ConstantArgument,
  20. CustomObjArgument,
  21. InputKind,
  22. InputSpec,
  23. OutputKind,
  24. OutputSpec,
  25. TensorArgument,
  26. )
  27. from torch.fx import subgraph_rewriter
  28. log = logging.getLogger(__name__)
  29. def _get_param_count_list(method_graph, args_params):
  30. param_count_list = []
  31. for input_, arg_params_ in zip(method_graph.inputs(), args_params):
  32. if "PackedParams" in str(input_.type()):
  33. in_vars, _ = torch.jit._flatten(arg_params_)
  34. param_count_list.append(len(in_vars))
  35. else:
  36. param_count_list.append(arg_params_ is not None)
  37. return param_count_list
  38. def _trace_and_get_graph_from_model(model, args):
  39. # A basic sanity check: make sure the state_dict keys are the same
  40. # before and after running the model. Fail fast!
  41. orig_state_dict_keys = torch.jit._unique_state_dict(model).keys()
  42. # Disable Autocast cache because it replaces kernel's weight and bias
  43. # by (undesired) constants.
  44. # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665
  45. prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
  46. torch.set_autocast_cache_enabled(False)
  47. trace_graph, torch_out, _inputs_states = torch.jit._get_trace_graph(
  48. model,
  49. args,
  50. strict=False,
  51. _force_outplace=False,
  52. _return_inputs_states=True,
  53. )
  54. torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
  55. if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys():
  56. raise RuntimeError(
  57. "state_dict changed after running the tracer; "
  58. "something weird is happening in your model!"
  59. )
  60. return trace_graph, torch_out
  61. def _create_jit_graph(
  62. model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any]
  63. ) -> tuple[torch.Graph, list["_C.IValue"], Any, Optional[torch.ScriptModule]]:
  64. if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
  65. flattened_args = tuple(torch.jit._flatten(tuple(args))[0])
  66. torch_out = None
  67. if isinstance(model, torch.jit.ScriptModule):
  68. try:
  69. graph = model.forward.graph # type: ignore[attr-defined]
  70. except AttributeError as e:
  71. raise RuntimeError("'forward' method must be a script method") from e
  72. _C._jit_pass_onnx_function_substitution(graph)
  73. freezed_module = _C._freeze_module(
  74. typing.cast(_C.ScriptModule, model._c), preserveParameters=True
  75. )
  76. module, params = _C._jit_onnx_list_model_parameters(freezed_module)
  77. method_graph = module._get_method("forward").graph
  78. args_params = tuple(args) + tuple(params)
  79. param_count_list = _get_param_count_list(method_graph, args_params)
  80. in_vars, _ = torch.jit._flatten(args_params)
  81. graph = _C._propagate_and_assign_input_shapes(
  82. method_graph, tuple(in_vars), param_count_list, False, False
  83. )
  84. return graph, params, torch_out, module
  85. # torch.jit.ScriptFunction
  86. params = []
  87. graph = model.graph
  88. _C._jit_pass_onnx_function_substitution(graph)
  89. param_count_list = _get_param_count_list(graph, args)
  90. graph = _C._propagate_and_assign_input_shapes(
  91. graph, flattened_args, param_count_list, False, False
  92. )
  93. return graph, params, torch_out, None
  94. graph, torch_out = _trace_and_get_graph_from_model(model, args)
  95. _C._jit_pass_onnx_lint(graph)
  96. state_dict = torch.jit._unique_state_dict(model)
  97. params = list(state_dict.values())
  98. graph_inputs = list(graph.inputs())
  99. user_input_num = len(graph_inputs) - len(state_dict)
  100. param_names = list(state_dict.keys())
  101. for i, inp in enumerate(graph_inputs):
  102. if i >= user_input_num:
  103. inp.setDebugName(param_names[i - user_input_num])
  104. _C._jit_pass_onnx_function_substitution(graph)
  105. return graph, params, torch_out, None
  106. def list_add(a, b):
  107. return a + b
  108. def list_append(container, element):
  109. return container + [element]
  110. def execute_subgraph_from_prim_loop(
  111. subgraph, iter_idx, len_loop_local_arguments, *args, **kwargs
  112. ):
  113. """
  114. subgraph: GraphModule from sub-block.
  115. iter_idx: The index of interaction.
  116. len_loop_local_arguments: The number of loop local arguments in args.
  117. """
  118. # Loop local variables. TS graph create those as inputs because their values
  119. # are updated inside the loop.
  120. loop_local_args = args[:len_loop_local_arguments]
  121. # Global variables that are not passed in as inputs to the loop sub-blocks
  122. # but are directly used. Most of time, their values are not updated, but
  123. # the only exception is when there are some operations that perform inplace
  124. # updates.
  125. global_args = args[len_loop_local_arguments:]
  126. return subgraph(*global_args, iter_idx, *loop_local_args, **kwargs)
  127. def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule):
  128. def pattern(im, dim, scale):
  129. sym_size_int = torch.ops.aten.sym_size.int(im, dim)
  130. scalar_tensor = torch.ops.aten.scalar_tensor(sym_size_int)
  131. div_scalar_mode = torch.ops.aten.div.Scalar_mode(
  132. scalar_tensor, scale, rounding_mode="trunc"
  133. )
  134. int_tensor = torch.ops.aten.Int.Tensor(div_scalar_mode)
  135. return int_tensor
  136. def replacement(im, dim, scale):
  137. sym_size_int = torch.ops.aten.sym_size.int(im, dim)
  138. return sym_size_int // scale
  139. subgraph_rewriter.replace_pattern(gm, pattern, replacement)
  140. def is_valid_for_codegen(name):
  141. if len(name) == 0:
  142. raise RuntimeError("Empty argument name for codegen")
  143. if name[0].isdigit():
  144. return False
  145. return True
  146. def normalize_name(name: str, prefix: str = "rename") -> str:
  147. name = name.replace(".", "_")
  148. if is_valid_for_codegen(name):
  149. return name
  150. return f"{prefix}_{name}"
  151. def ir_name_to_func_name(name: str) -> str:
  152. """prim::If -> convert_prim_If"""
  153. name_list = name.split("::")
  154. return "convert_" + "_".join(name_list)
  155. def get_node_as_placeholder_or_get_attr(fx_graph, name, is_top_level_graph):
  156. if is_top_level_graph:
  157. return fx_graph.get_attr(name)
  158. return fx_graph.placeholder(name)
  159. _TORCH_DTYPE_TO_ENUM = {
  160. torch.uint8: 0,
  161. torch.int8: 1,
  162. torch.int16: 2,
  163. torch.int32: 3,
  164. torch.int64: 4,
  165. torch.float16: 5,
  166. torch.float32: 6,
  167. torch.float64: 7,
  168. torch.complex32: 8,
  169. torch.complex64: 9,
  170. torch.complex128: 10,
  171. torch.bool: 11,
  172. torch.qint8: 12,
  173. torch.quint8: 13,
  174. torch.bfloat16: 15,
  175. }
  176. _TORCH_ENUM_TO_DTYPE = {value: key for key, value in _TORCH_DTYPE_TO_ENUM.items()}
  177. def get_dtype_as_int(tensor):
  178. """
  179. prim::dtype has the signature "Tensor a) -> int", where it gets the dtype of
  180. the tensor and returns the integer corresponding to this dtype based on the
  181. enum in ScalarType.h
  182. """
  183. dtype = tensor.dtype
  184. if dtype not in _TORCH_DTYPE_TO_ENUM:
  185. raise RuntimeError(f"Unsupported dtype {dtype}")
  186. return _TORCH_DTYPE_TO_ENUM[dtype]
  187. # Those operators will be automatically populated to a instance method
  188. # of TS2FXGraphConverter with name convert_<namespace>_<opname>().
  189. # Please check __init__ for method population implementations.
  190. kind_to_standard_operators: dict[str, Callable[..., Any]] = {
  191. "prim::max": builtins.max,
  192. "prim::min": builtins.min,
  193. "prim::TupleIndex": operator.getitem,
  194. "aten::__is__": operator.is_,
  195. "aten::__isnot__": operator.is_not,
  196. "aten::__not__": operator.not_,
  197. "aten::__contains__": operator.contains,
  198. "prim::dtype": get_dtype_as_int,
  199. "aten::len": len,
  200. # Mapping from specialized op to its symbolic counterpart.
  201. # They currently do not have any other overrides.
  202. "aten::numel": torch.ops.aten.sym_numel,
  203. "aten::size": torch.ops.aten.sym_size,
  204. "aten::storage_offset": torch.ops.aten.sym_storage_offset,
  205. "aten::stride": torch.ops.aten.sym_stride,
  206. }
  207. def get_ir_value_parent_name_and_attr_name(node):
  208. irv_parent_name, irv_name = node.input().debugName(), node.output().debugName()
  209. attr_name = node.s("name")
  210. return irv_name, irv_parent_name, attr_name
  211. def construct_fqn(ir, ref_map, name_map):
  212. name_list = []
  213. while ir in ref_map:
  214. name_list.append(name_map[ir])
  215. ir = ref_map[ir]
  216. return ".".join(reversed(name_list))
  217. def get_block_to_lifted_attrs(
  218. graph: torch._C.Graph,
  219. ) -> tuple[dict[torch._C.Block, set[str]], dict[str, str]]:
  220. """
  221. Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes.
  222. When a graph has control flow, the graph will be divided into multiple blocks. We want to convert
  223. each block to a graph which will be passed into torch.cond. A restriction for torch.cond is that model
  224. parameters/buffers are expected to be lifted as inputs to the subgraphs. Before converting the model,
  225. we will run this pass which will:
  226. 1. Figure out which params/buffers are used within blocks through tracing the GetAttr calls.
  227. 2. Process the graph bottom up to find the lifted attributes of each block by taking the union
  228. of the attributes used in the current block, and the lifted attributes of all its child blocks.
  229. Returns:
  230. A mapping of blocks to a set of FQNs of its lifted attributes, and a
  231. mapping of node names to the FQNs of its lifted attributes.
  232. """
  233. # A map from a block to its expected to be lifted arguments.
  234. blocks_to_lifted_attrs: dict[torch._C.Block, set[str]] = {}
  235. # Reference map stores the input (i.e., src) and output (i.e., dest) IR of a
  236. # GetAttr node. By traversing this reference map, we can figure out the
  237. # full IR aliasing pass and figure out the FQN of an attribute.
  238. # E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1"
  239. node_to_parent_map: dict[str, str] = {}
  240. # Used for reconstructing the FQN of an attribute based on the reference map.
  241. # In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR
  242. # This name map stores which attribute name is called for a src IR --> dest IR action.
  243. # E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear"
  244. node_to_attr_name: dict[str, str] = {}
  245. def _dfs_get_attr_dependency(entry):
  246. """
  247. First DFS path to construct reference map and name map.
  248. """
  249. for node in entry.nodes():
  250. if node.kind() == "prim::GetAttr":
  251. (
  252. irv_name,
  253. irv_parent_name,
  254. attr_name,
  255. ) = get_ir_value_parent_name_and_attr_name(node)
  256. node_to_parent_map[irv_name] = irv_parent_name
  257. node_to_attr_name[irv_name] = attr_name
  258. for block in node.blocks():
  259. _dfs_get_attr_dependency(block)
  260. def _map_blocks_to_lifted_attrs(entry):
  261. """
  262. Walk the graph in a bottom-up fashion to build the expected to be
  263. lifted arguments for each block.
  264. """
  265. arguments: set[str] = set()
  266. for node in entry.nodes():
  267. for block in node.blocks():
  268. # Recursively build.
  269. arguments = arguments.union(_map_blocks_to_lifted_attrs(block))
  270. if node.kind() == "prim::GetAttr":
  271. irv_name = node.output().debugName()
  272. # Skip for intermediate GetAttr, which will anyway not result a FQN.
  273. # E.g., node_to_parent_name: {"%3": "%2", "%2": "%1"}
  274. # node_to_attr_name: {"%3": "weight", "%2": "linear", "%1": "self"}
  275. # There is only one FQN %3-->%2-->%1: self.linear.weight
  276. # %2-->%1 is not a FQN: self.linear
  277. if irv_name not in set(node_to_parent_map.values()):
  278. arguments.add(
  279. construct_fqn(irv_name, node_to_parent_map, node_to_attr_name)
  280. )
  281. if not isinstance(entry, torch._C.Graph): # Skip the top level.
  282. blocks_to_lifted_attrs[entry] = arguments
  283. return arguments
  284. _dfs_get_attr_dependency(graph)
  285. _map_blocks_to_lifted_attrs(graph)
  286. return blocks_to_lifted_attrs, node_to_attr_name
  287. def get_attribute_fqn_from_ts_node(
  288. name_to_attribute_fqn: dict[str, str], node: torch._C.Node
  289. ) -> str:
  290. def get_attr(name: str):
  291. if name in name_to_attribute_fqn:
  292. return name_to_attribute_fqn[name]
  293. else:
  294. raise ValueError(f"Attribute {name} not found")
  295. if node.kind() == "prim::SetAttr":
  296. input_name = next(node.inputs()).debugName()
  297. elif node.kind() == "prim::GetAttr":
  298. input_name = node.input().debugName()
  299. else:
  300. raise RuntimeError(
  301. f"Unexpected node kind when getting attribute fqn. node: {node} "
  302. )
  303. attr_name = node.s("name")
  304. root_attr_name = get_attr(input_name)
  305. attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name
  306. return attr_fqn
  307. def get_op_overload(node: torch._C.Node):
  308. schema_str = node.schema()
  309. if schema_str == "(no schema)":
  310. raise AssertionError(f"got empty schema for {node}")
  311. schema: torch._C.FunctionSchema = torch._C.parse_schema(schema_str)
  312. ns, op_name = str(schema.name).split("::")
  313. override = schema.overload_name
  314. try:
  315. op_overload_mod = getattr(torch.ops, ns)
  316. op_overload_packet = getattr(op_overload_mod, op_name)
  317. if override:
  318. op_overload = getattr(op_overload_packet, override)
  319. else:
  320. op_overload = op_overload_packet.default
  321. except Exception as e:
  322. raise RuntimeError(
  323. f"Unable to find operator {node.kind()} with schema {node.schema()}"
  324. ) from e
  325. return op_overload
  326. class TS2FXGraphConverter:
  327. def __init__(
  328. self,
  329. ts_graph: Union[torch._C.Graph, torch._C.Block],
  330. name_to_param: dict[str, torch.Tensor],
  331. name_to_buffer: dict[str, torch.Tensor],
  332. blocks_to_lifted_attrs: dict[torch._C.Block, set[str]],
  333. name_to_non_tensor_attribute: dict[str, Any],
  334. name_to_constant: dict[str, Any],
  335. name_to_attribute_fqn: dict[str, str],
  336. ):
  337. self.ts_graph = ts_graph
  338. # Mapping of parameter FQN to actual parameter value
  339. self.name_to_param = name_to_param
  340. # Mapping of buffer FQN to actual buffer value
  341. self.name_to_buffer = name_to_buffer
  342. self.fx_graph: torch.fx.Graph = torch.fx.Graph()
  343. self.input_specs: list[InputSpec] = []
  344. self.output_specs: list[OutputSpec] = []
  345. # Mapping of TS node name to converted FX node
  346. self.name_to_node: dict[
  347. str, Union[torch.fx.Node, list[torch.fx.Node], dict[Any, torch.fx.Node]]
  348. ] = {}
  349. # Mapping of TS node name to constant value (int, str, TorchBind obj,
  350. # tensor constants ...)
  351. self.name_to_constant: dict[str, Any] = name_to_constant
  352. # Mapping from torchscript node output name to attribute fully qualified name
  353. self.name_to_attribute_fqn: dict[str, str] = name_to_attribute_fqn
  354. # Mapping from fully qualified name to real values or a fx graph node
  355. # During convert, this represents the current value of a non-tensor attribute
  356. # One use case is:
  357. # def forward(self, x):
  358. # c1 = self.count
  359. # self.count += 1
  360. # c2 = self.count
  361. # return x + c1 + c2
  362. self.name_to_non_tensor_attribute_node: dict[str, Any] = {}
  363. # Mapping from fully qualified name to initial real values inputs
  364. # We separate it from self.name_to_non_tensor_attribute_node since
  365. # we need initial real value input when we construct fx.GraphModule
  366. self.name_to_non_tensor_attribute: dict[str, Any] = name_to_non_tensor_attribute
  367. self.subgraphs: dict[str, torch.fx.GraphModule] = {}
  368. # Mapping of block to list of attributes that need to be lifted for each
  369. # block
  370. self.blocks_to_lifted_attrs = blocks_to_lifted_attrs
  371. # Populate methods for the standard operators.
  372. for k in kind_to_standard_operators:
  373. handler_func_name = ir_name_to_func_name(k)
  374. # Create an indirect function call:
  375. # convert_<namespace>_<opname> --> lambda node: _convert_standard_operator(node)
  376. setattr(
  377. self,
  378. handler_func_name,
  379. lambda node: self._convert_standard_operators(node),
  380. )
  381. # This stores a list of return results that do not appear in the original TS
  382. # graph's outputs. The reason we maintain this is because some operations in the sub-block
  383. # might have inplace updates to the variable defined in the parent fx graph. After
  384. # the execution of that sub-block, the variable defined in the parent fx graph also
  385. # needs to be updated.
  386. self.name_update_from_subblock_to_parent: set[str] = set()
  387. def _is_get_attr_node(self, fqn):
  388. return (
  389. fqn in self.name_to_buffer
  390. or fqn in self.name_to_param
  391. or (
  392. fqn in self.name_to_constant
  393. and isinstance(self.name_to_constant[fqn], torch.ScriptObject)
  394. )
  395. )
  396. def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: list[str]):
  397. subgraph_nodes, subgraph_converters = [], []
  398. for block in node.blocks():
  399. subgraph_converter = TS2FXGraphConverter(
  400. block,
  401. self.name_to_param,
  402. self.name_to_buffer,
  403. self.blocks_to_lifted_attrs,
  404. {},
  405. self.name_to_constant,
  406. self.name_to_attribute_fqn,
  407. )
  408. for block_arg in arguments:
  409. normalized_block_arg_name = normalize_name(block_arg)
  410. placeholder_node = subgraph_converter.fx_graph.placeholder(
  411. normalized_block_arg_name
  412. )
  413. subgraph_converter.name_to_node[block_arg] = placeholder_node
  414. subgraph = subgraph_converter.convert()
  415. subgraph_name = self.add_subgraph(subgraph)
  416. subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name))
  417. subgraph_converters.append(subgraph_converter)
  418. return subgraph_nodes, subgraph_converters
  419. def _identify_inputs_as_arguments(self, entry):
  420. """
  421. Identify inputs from the innermost sub-block. This is needed
  422. for nested sub-blocks when the input is hidden in the nested sub-block.
  423. E.g., example IR of input is hidden in the nested sub-block.
  424. Graph[x.1]
  425. %1 = ...
  426. Block[]
  427. Block[x.1]
  428. %2 = x.1 ...
  429. """
  430. arguments: set[str] = set()
  431. for block in entry.blocks():
  432. for block_node in block.nodes():
  433. for block_node_in in block_node.inputs():
  434. if (
  435. block_node_in.debugName() in self.name_to_node
  436. and block_node_in.debugName() not in self.name_to_attribute_fqn
  437. ):
  438. arguments.add(block_node_in.debugName())
  439. arguments = arguments.union(
  440. self._identify_inputs_as_arguments(block_node)
  441. )
  442. return arguments
  443. def is_top_level_graph(self):
  444. return isinstance(self.ts_graph, torch._C.Graph)
  445. def add_subgraph(self, subgraph) -> str:
  446. name = f"subgraph_{len(self.subgraphs)}"
  447. self.subgraphs[name] = subgraph
  448. return name
  449. def get_args_kwargs(self, node: torch._C.Node, schema):
  450. args = []
  451. kwargs = {}
  452. for input, schema_arg in zip(node.inputs(), schema.arguments):
  453. if schema_arg.kwarg_only:
  454. kwargs[schema_arg.name] = self.get_fx_value_by_ir_value(input)
  455. else:
  456. args.append(self.get_fx_value_by_ir_value(input))
  457. return tuple(args), kwargs
  458. def get_fx_value_by_ir_value(self, value: torch._C.Value):
  459. value_name = value.debugName()
  460. if value_name in self.name_to_node:
  461. input_node = self.name_to_node[value_name]
  462. return input_node
  463. elif value_name in self.name_to_constant:
  464. if isinstance(self.name_to_constant[value_name], torch.ScriptObject):
  465. return self.fx_graph.get_attr(value_name)
  466. return self.name_to_constant[value_name]
  467. elif value_name in self.name_to_attribute_fqn:
  468. return self.get_fx_value_by_fqn(self.name_to_attribute_fqn[value_name])
  469. else:
  470. raise ValueError(f"Input {value_name} not found")
  471. def get_fx_value_by_fqn(self, name):
  472. if name in self.name_to_node:
  473. fx_node = self.name_to_node[name]
  474. elif name in self.name_to_constant:
  475. fx_node = self.name_to_constant[name]
  476. elif name in self.name_to_non_tensor_attribute_node:
  477. fx_node = self.name_to_non_tensor_attribute_node[name]
  478. elif name in self.name_to_non_tensor_attribute:
  479. fx_node = self.name_to_non_tensor_attribute[name]
  480. else:
  481. raise ValueError(f"Attribute {name} not found")
  482. return fx_node
  483. def convert(self) -> torch.fx.GraphModule:
  484. self.convert_graph_inputs()
  485. for node in self.ts_graph.nodes():
  486. self.convert_node(node)
  487. self.convert_graph_outputs()
  488. # Pass parameter and buffer to the root for lookup.
  489. gm = torch.fx.GraphModule(
  490. {
  491. **self.subgraphs,
  492. **self.name_to_param,
  493. **self.name_to_buffer,
  494. **self.name_to_non_tensor_attribute,
  495. **self.name_to_constant,
  496. },
  497. self.fx_graph,
  498. )
  499. inplace_optimize_sym_size_div(gm)
  500. gm.graph.lint()
  501. return gm
  502. def convert_graph_inputs(self):
  503. for graph_input in self.ts_graph.inputs():
  504. name = graph_input.debugName()
  505. if name in self.name_to_param:
  506. normalized_name = normalize_name(name)
  507. self.input_specs.append(
  508. InputSpec(
  509. InputKind.PARAMETER,
  510. arg=TensorArgument(name=normalized_name),
  511. target=name,
  512. )
  513. )
  514. fx_node = get_node_as_placeholder_or_get_attr(
  515. self.fx_graph, name, self.is_top_level_graph()
  516. )
  517. elif name in self.name_to_buffer:
  518. normalized_name = normalize_name(name)
  519. self.input_specs.append(
  520. InputSpec(
  521. InputKind.BUFFER,
  522. arg=TensorArgument(name=normalized_name),
  523. target=name,
  524. persistent=True,
  525. )
  526. )
  527. fx_node = get_node_as_placeholder_or_get_attr(
  528. self.fx_graph, name, self.is_top_level_graph()
  529. )
  530. elif name in self.name_to_constant:
  531. if not isinstance(self.name_to_constant[name], torch.ScriptObject):
  532. raise AssertionError(
  533. f"Input conversion only handles ScriptObject, got {type(self.name_to_constant[name])}"
  534. )
  535. normalized_name = normalize_name(name)
  536. self.input_specs.append(
  537. InputSpec(
  538. InputKind.CUSTOM_OBJ,
  539. arg=CustomObjArgument(
  540. name=normalized_name, class_fqn=normalized_name
  541. ),
  542. target=name,
  543. persistent=False,
  544. )
  545. )
  546. fx_node = get_node_as_placeholder_or_get_attr(
  547. self.fx_graph, name, self.is_top_level_graph()
  548. )
  549. elif isinstance(graph_input.type(), torch.ClassType):
  550. # Directly skip inputs that are ScriptObject but not used in the graph.
  551. continue
  552. else:
  553. normalized_name = normalize_name(name, prefix="input")
  554. self.input_specs.append(
  555. InputSpec(
  556. InputKind.USER_INPUT,
  557. arg=TensorArgument(name=normalized_name),
  558. target=name,
  559. )
  560. )
  561. fx_node = self.fx_graph.placeholder(normalized_name)
  562. self.name_to_node[name] = fx_node
  563. def convert_aten_Float(self, node: torch._C.Node):
  564. def to_float_tensor(t):
  565. return t.to(dtype=torch.float).item()
  566. inp_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()] # noqa: C416
  567. fx_node = self.fx_graph.call_function(
  568. to_float_tensor,
  569. tuple(inp_list),
  570. )
  571. self.name_to_node[node.output().debugName()] = fx_node
  572. def convert_aten_tensor(self, node: torch._C.Node):
  573. """aten::tensor creates a constant tensor ad-hoc --> GetAttr"""
  574. args, kwargs = self.get_args_kwargs(node, torch.ops.aten.tensor.default._schema)
  575. for k in kwargs:
  576. if k == "requires_grad":
  577. kwargs[k] = bool(kwargs[k]) # 0 -> False, 1 -> True
  578. to_tensor = (
  579. torch.tensor
  580. if all(isinstance(a, int) for a in args)
  581. else torch._refs.tensor
  582. )
  583. def target(*args, **kwargs):
  584. if "dtype" in kwargs and kwargs["dtype"] is not None:
  585. kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]]
  586. return to_tensor(*args, **kwargs)
  587. # def to_dynamic_tensor(*args, **kwargs):
  588. # if "dtype" in kwargs and kwargs["dtype"] is not None:
  589. # kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]]
  590. # return torch._refs.tensor(*args, **kwargs)
  591. output_name = node.output().debugName()
  592. fx_node = self.fx_graph.call_function(target, args, kwargs)
  593. self.name_to_node[output_name] = fx_node
  594. def convert_aten_append(self, node: torch._C.Node):
  595. # special handle python list append: "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)"
  596. # inplace append to the list!! This is kinda crazy, as we are inplace mutating the list
  597. # This makes the converter "non-functional", and the result depends on the order of the nodes being converter
  598. # In a sense, the converter now becomes an stateful interpreter
  599. warnings.warn(
  600. "Converting aten::append.t, which is a inplace mutation of the list. "
  601. "This makes the converter non-functional: the result depends on the order of the append nodes being converter!",
  602. stacklevel=2,
  603. )
  604. args = tuple(self.get_fx_value_by_ir_value(inp) for inp in node.inputs())
  605. fx_node = self.fx_graph.call_function(list_append, args)
  606. self.name_to_node[node.output().debugName()] = fx_node
  607. # inplace mutate arg[0], which is the python list
  608. self.name_to_node[node.inputsAt(0).debugName()] = fx_node
  609. # Variables that need to be updated to parent module.
  610. if not self.is_top_level_graph() and args[0].op == "placeholder":
  611. self.name_update_from_subblock_to_parent.add(node.inputsAt(0).debugName())
  612. def convert_prim_Constant(self, node: torch._C.Node):
  613. name = node.output().debugName()
  614. value: Any = None
  615. if node.hasAttribute("value"):
  616. constant_kind = node.kindOf("value")
  617. if constant_kind == "i":
  618. value = node.i("value")
  619. elif constant_kind == "f":
  620. value = node.f("value")
  621. elif constant_kind == "s":
  622. value = node.s("value")
  623. elif constant_kind == "t":
  624. alias_name = (
  625. f"lifted_tensor_{name}" # Follow naming convention from EP tracing.
  626. )
  627. fx_node = self.fx_graph.get_attr(alias_name)
  628. self.name_to_node[name] = fx_node
  629. name, value = alias_name, node.t("value")
  630. elif constant_kind == "ival":
  631. value = node.ival("value")
  632. else:
  633. raise ValueError(f"Unsupported constant type: {node.kindOf('value')}")
  634. else:
  635. value = None
  636. self.name_to_constant[name] = value
  637. def convert_prim_CallMethod(self, node: torch._C.Node):
  638. inp_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()] # noqa: C416
  639. fx_node = self.fx_graph.call_method(
  640. node.s("name"),
  641. tuple(inp_list),
  642. )
  643. self.name_to_node[node.output().debugName()] = fx_node
  644. def convert_prim_device(self, node: torch._C.Node):
  645. input_type = node.input().type()
  646. if input_type.isSubtypeOf(torch._C.TensorType.get()):
  647. device = input_type.device() # type: ignore[attr-defined]
  648. output_name = node.output().debugName()
  649. self.name_to_constant[output_name] = device
  650. else:
  651. raise ValueError(f"Unsupported JitType ({input_type}) when get device")
  652. def convert_prim_GetAttr(self, node: torch._C.Node):
  653. # Build fully qualified name
  654. attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node)
  655. output_name = node.output().debugName()
  656. self.name_to_attribute_fqn[output_name] = attr_fqn
  657. if self.is_top_level_graph():
  658. if self._is_get_attr_node(attr_fqn):
  659. # We insert a get_attr node due to two reasons.
  660. # First, ts graph does not lift tensor constants as input nodes. So
  661. # tensor constants may be ignored by in convert_graph_inputs().
  662. # Second, attr_fqn may have been written to via SetAttr. Two
  663. # GetAttr may give different values.
  664. self.name_to_node[output_name] = self.fx_graph.get_attr(attr_fqn)
  665. else:
  666. if attr_fqn not in self.name_to_non_tensor_attribute_node:
  667. self.name_to_non_tensor_attribute_node[attr_fqn] = (
  668. self.name_to_non_tensor_attribute[attr_fqn]
  669. )
  670. self.name_to_node[output_name] = self.name_to_non_tensor_attribute_node[
  671. attr_fqn
  672. ]
  673. else:
  674. # Special support for if blocks which do not allow SetAttr TorchScript
  675. # node and get_attr FX Graph Node.
  676. if self._is_get_attr_node(attr_fqn):
  677. self.name_to_node[output_name] = self.name_to_node[attr_fqn]
  678. def convert_prim_SetAttr(self, node: torch._C.Node):
  679. attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node)
  680. attr_value = tuple(node.inputs())[1]
  681. ts_graph_tensor_input = self.get_fx_value_by_ir_value(attr_value)
  682. if self._is_get_attr_node(attr_fqn):
  683. fx_attr_node = self.fx_graph.get_attr(attr_fqn)
  684. self.fx_graph.call_function(
  685. torch.Tensor.copy_, (fx_attr_node, ts_graph_tensor_input)
  686. )
  687. else:
  688. self.name_to_non_tensor_attribute_node[attr_fqn] = ts_graph_tensor_input
  689. def convert_call_function_op(self, node: torch._C.Node):
  690. target = get_op_overload(node)
  691. args, kwargs = self.get_args_kwargs(node, target._schema)
  692. fx_node = self.fx_graph.call_function(target, args, kwargs)
  693. # TODO: convert sourceRange() into stack_trace
  694. # fx_node.meta["stack_trace"] = node.sourceRange()
  695. if node.outputsSize() == 1:
  696. output_name = node.output().debugName()
  697. self.name_to_node[output_name] = fx_node
  698. else:
  699. for i, outp in enumerate(node.outputs()):
  700. output_name = outp.debugName()
  701. next_fx_node = self.fx_graph.call_function(
  702. operator.getitem, (fx_node, i)
  703. )
  704. self.name_to_node[output_name] = next_fx_node
  705. def convert_prim_TupleConstruct(self, node: torch._C.Node):
  706. self._convert_prim_iterator(node)
  707. def convert_prim_ListConstruct(self, node: torch._C.Node):
  708. self._convert_prim_iterator(node)
  709. def _convert_prim_iterator(self, node: torch._C.Node):
  710. output_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()]
  711. output_name = node.output().debugName()
  712. self.name_to_node[output_name] = output_list
  713. def convert_prim_DictConstruct(self, node: torch._C.Node):
  714. output_dict = {}
  715. k, v = None, None
  716. for i, inp in enumerate(node.inputs()):
  717. # We assume key value are stored in pair in the DictConstruct.
  718. # The first element is the key and the following is the value.
  719. if i % 2 == 0:
  720. k = self.get_fx_value_by_ir_value(inp)
  721. else:
  722. v = self.get_fx_value_by_ir_value(inp)
  723. if k is None or v is None:
  724. raise AssertionError("DictConstruct has an empty key value pair.")
  725. output_dict[k] = v
  726. k, v = None, None
  727. if k is not None or v is not None:
  728. raise AssertionError(
  729. "DictConstruct has an odd number of elements (violating our assumption)."
  730. )
  731. output_name = node.output().debugName()
  732. self.name_to_node[output_name] = output_dict
  733. def convert_prim_ListUnpack(self, node: torch._C.Node):
  734. self._convert_prim_unpack_iterator(node)
  735. def convert_prim_TupleUnpack(self, node: torch._C.Node):
  736. self._convert_prim_unpack_iterator(node)
  737. def _convert_prim_unpack_iterator(self, node: torch._C.Node):
  738. # Single input and multiple outputs for unpacking.
  739. for i, outp in enumerate(node.outputs()):
  740. outp_name = outp.debugName()
  741. inp = self.get_fx_value_by_ir_value(node.input())
  742. fx_node = self.fx_graph.call_function(operator.getitem, (inp, i))
  743. self.name_to_node[outp_name] = fx_node
  744. def convert_aten_Int(self, node: torch._C.Node):
  745. # converts aten::Int as aten._to_copy + aten::_local_scalar_dense
  746. target = torch.ops.aten._to_copy.default
  747. args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
  748. to_copy_node = self.fx_graph.call_function(target, args, {"dtype": torch.int32})
  749. fx_node = self.fx_graph.call_function(
  750. torch.ops.aten._local_scalar_dense.default, (to_copy_node,)
  751. )
  752. # TODO: convert sourceRange() into stack_trace
  753. # fx_node.meta["stack_trace"] = node.sourceRange()
  754. output_name = node.output().debugName()
  755. self.name_to_node[output_name] = fx_node
  756. def convert_prim_NumToTensor(self, node: torch._C.Node):
  757. # Converts prim::NumToTensor as aten.scalar_tensor.
  758. # prim::NumToTensor IRs are currently triggered by:
  759. # .size() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L950
  760. # .numel() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L971
  761. # For both of those APIs, torch.jit.trace implicitly sets the output tensor type
  762. # to be LongTensor.
  763. target = torch.ops.aten.scalar_tensor
  764. args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
  765. fx_node = self.fx_graph.call_function(target, args, {"dtype": torch.long})
  766. output_name = node.output().debugName()
  767. self.name_to_node[output_name] = fx_node
  768. def convert_prim_CreateObject(self, node: torch._C.Node):
  769. output_name = node.output().debugName()
  770. self.name_to_attribute_fqn[output_name] = ""
  771. def convert_aten__convolution(self, node: torch._C.Node):
  772. # converts aten::_convolution as aten.convolution, since aten::_convolution
  773. # doesn't have a meta function
  774. target = torch.ops.aten.convolution.default
  775. args, kwargs = self.get_args_kwargs(node, target._schema)
  776. fx_node = self.fx_graph.call_function(target, args, kwargs)
  777. output_name = node.output().debugName()
  778. self.name_to_node[output_name] = fx_node
  779. def convert_aten_div(self, node: torch._C.Node):
  780. target = get_op_overload(node)
  781. schema = target._schema
  782. args, kwargs = self.get_args_kwargs(node, schema)
  783. # converts aten::div.Tensor_mode(x, tensor_constant)
  784. # as aten.div.Scalar_mode(x, tensor_constant.item())
  785. if schema.overload_name == "Tensor_mode":
  786. arg1_name = args[1].name
  787. if arg1_name in self.name_to_constant and isinstance(
  788. self.name_to_constant[arg1_name], torch.Tensor
  789. ):
  790. tensor_constant = self.name_to_constant[arg1_name]
  791. if tensor_constant.numel() == 1:
  792. updated_args = list(args)
  793. updated_args[1] = self.name_to_constant[arg1_name].item()
  794. fx_node = self.fx_graph.call_function(
  795. torch.ops.aten.div.Scalar_mode,
  796. tuple(updated_args),
  797. kwargs,
  798. )
  799. # TODO: convert sourceRange() into stack_trace
  800. # fx_node.meta["stack_trace"] = node.sourceRange()
  801. output_name = node.output().debugName()
  802. self.name_to_node[output_name] = fx_node
  803. return
  804. self.convert_call_function_op(node)
  805. def convert_aten___getitem__(self, node: torch._C.Node):
  806. input_container, index = tuple(
  807. self.get_fx_value_by_ir_value(input) for input in node.inputs()
  808. )
  809. fx_node = self.fx_graph.call_function(
  810. operator.getitem, (input_container, index)
  811. )
  812. output_name = node.output().debugName()
  813. self.name_to_node[output_name] = fx_node
  814. def convert_aten_to(self, node: torch._C.Node):
  815. target = get_op_overload(node)
  816. args, _kwargs = self.get_args_kwargs(node, target._schema)
  817. # special handle aten.to.dtype and aten.to.prim_dtype followed by inplace_mutation_op
  818. # coz aten.to + inplace_mutation_op pattern would trigger
  819. # "cannot mutate tensors with frozen storage" functionalization error.
  820. # To work around the issue, we override the copy to be True, so that the output
  821. # is for sure not an alias of input
  822. if target is torch.ops.aten.to.dtype or target is torch.ops.aten.to.prim_dtype:
  823. user_nodes = [use.user for use in node.output().uses()]
  824. user_targets = [
  825. get_op_overload(user_node)
  826. for user_node in user_nodes
  827. if user_node.schema() != "(no schema)"
  828. ]
  829. has_mutable_target = any(
  830. target._schema.is_mutable for target in user_targets
  831. )
  832. if has_mutable_target:
  833. if len(args) < 4:
  834. raise AssertionError(f"expected at least 4 args, got {len(args)}")
  835. new_args = list(args)
  836. new_args[3] = True # copy, override to True
  837. fx_node = self.fx_graph.call_function(
  838. torch.ops.aten.to.dtype, tuple(new_args)
  839. )
  840. # temp hack to work around the issue https://github.com/pytorch/pytorch/issues/131679
  841. # When this issue is fixed, the clone node would be no longer needed
  842. clone_node = self.fx_graph.call_function(
  843. torch.ops.aten.clone.default, (fx_node,)
  844. )
  845. output_name = node.output().debugName()
  846. self.name_to_node[output_name] = clone_node
  847. return
  848. self.convert_call_function_op(node)
  849. def convert_aten_add(self, node: torch._C.Node):
  850. if node.schema() == "(no schema)":
  851. if isinstance(node.inputsAt(0).type(), torch.ListType) and isinstance(
  852. node.inputsAt(1).type(), torch.ListType
  853. ):
  854. target = torch.ops.aten.add.t
  855. else:
  856. raise RuntimeError(f"unable to determined the target for {node}")
  857. else:
  858. target = get_op_overload(node)
  859. if target is torch.ops.aten.add.t:
  860. # special handle python list/tuple add: "aten::add.t(t[] a, t[] b) -> t[]" for
  861. # RuntimeError: aten::add() Expected a value of type 'List[t]' for argument 'a' but instead found type 'immutable_list'.
  862. args, _kwargs = self.get_args_kwargs(node, target._schema)
  863. output_name = node.output().debugName()
  864. self.name_to_node[output_name] = self.fx_graph.call_function(list_add, args)
  865. else:
  866. self.convert_call_function_op(node)
  867. def _check_prim_loop_support(self, node):
  868. inputs = list(node.inputs())
  869. # TODO: (1/N) stage.
  870. if inputs[0].debugName() not in self.name_to_constant:
  871. raise RuntimeError(
  872. "prim::Loop currently cannot run with dynamic value of number of iterations."
  873. )
  874. # Make sure the condition is not updated in the subblock.
  875. subblock = next(node.blocks())
  876. condition_output_name = next(subblock.outputs()).debugName()
  877. for node in subblock.nodes():
  878. if (
  879. node.outputsSize() == 1
  880. and node.output().debugName() == condition_output_name
  881. ):
  882. raise RuntimeError(
  883. "prim::Loop currently cannot run with dynamic value of condition."
  884. )
  885. if node.outputsSize() >= 2:
  886. for outp in node.outputs():
  887. if outp.debugName() == condition_output_name:
  888. raise RuntimeError(
  889. "prim::Loop currently cannot run with dynamic value of condition."
  890. )
  891. def convert_prim_Loop(self, node: torch._C.Node):
  892. inputs = list(node.inputs())
  893. self._check_prim_loop_support(node)
  894. num_iterations = self.get_fx_value_by_ir_value(inputs[0])
  895. # Find inputs.
  896. loop_local_arguments = [inp.debugName() for inp in inputs[2:]]
  897. global_arguments = self._identify_inputs_as_arguments(node)
  898. # Lift parameters as inputs.
  899. for block in node.blocks():
  900. global_arguments = global_arguments.union(
  901. self.blocks_to_lifted_attrs[block]
  902. )
  903. global_arguments = list(global_arguments)
  904. subgraph_nodes, subgraph_converters = self._convert_block_to_subgraph(
  905. node, global_arguments
  906. )
  907. if len(subgraph_nodes) != 1:
  908. raise AssertionError(f"expected 1 subgraph node, got {len(subgraph_nodes)}")
  909. subgraph_converter = subgraph_converters[0]
  910. if not self.is_top_level_graph():
  911. self.name_update_from_subblock_to_parent = (
  912. self.name_update_from_subblock_to_parent.union(
  913. subgraph_converter.name_update_from_subblock_to_parent
  914. )
  915. )
  916. fx_block_args = [
  917. self.get_fx_value_by_fqn(name)
  918. for name in loop_local_arguments + global_arguments
  919. ]
  920. for iter_idx in range(num_iterations):
  921. loop_node = self.fx_graph.call_function(
  922. execute_subgraph_from_prim_loop,
  923. # Check execute_node function for the expected arguments order.
  924. (
  925. subgraph_nodes[0],
  926. iter_idx,
  927. len(loop_local_arguments),
  928. *fx_block_args,
  929. ),
  930. {},
  931. )
  932. # Update the value of loop local variables.
  933. if node.outputsSize() >= 1:
  934. for i, outp in enumerate(node.outputs()):
  935. output_name = outp.debugName()
  936. self.name_to_node[output_name] = self.fx_graph.call_function(
  937. operator.getitem,
  938. (
  939. loop_node,
  940. i + 1,
  941. ), # + 1 because the 0th element is the condition.
  942. )
  943. fx_block_args[i] = self.name_to_node[output_name]
  944. # Update the value of global variables, whose values are modified inplace.
  945. for i, name in enumerate(
  946. subgraph_converter.name_update_from_subblock_to_parent
  947. ):
  948. self.name_to_node[name] = self.fx_graph.call_function(
  949. operator.getitem,
  950. (
  951. loop_node,
  952. i + node.outputsSize() + 1,
  953. ), # + 1 because the 0th element is the condition.
  954. )
  955. global_argument_index = global_arguments.index(name)
  956. fx_block_args[i + node.outputsSize() + global_argument_index] = (
  957. self.name_to_node[name]
  958. )
  959. def _check_set_attr_in_if_block(self, if_node: torch._C.Node):
  960. for block in if_node.blocks():
  961. for node in block.nodes():
  962. if node.kind() == "prim::SetAttr":
  963. raise RuntimeError(
  964. "During converting prim::If to torch.cond, found prim::SetAttr op"
  965. " which is not supported yet. Please file an issue if you come "
  966. "across this error."
  967. )
  968. def convert_prim_If(self, node: torch._C.Node):
  969. self._check_set_attr_in_if_block(node)
  970. inputs = list(node.inputs())
  971. if len(inputs) != 1:
  972. raise AssertionError(f"expected 1 input for prim::If, got {len(inputs)}")
  973. predicate = self.get_fx_value_by_ir_value(inputs[0])
  974. # Find inputs.
  975. arguments = self._identify_inputs_as_arguments(node)
  976. # Lift parameters as inputs.
  977. for block in node.blocks():
  978. arguments = arguments.union(self.blocks_to_lifted_attrs[block])
  979. arguments = list(arguments)
  980. subgraph_nodes, _ = self._convert_block_to_subgraph(node, arguments)
  981. if len(subgraph_nodes) != 2:
  982. raise AssertionError(
  983. f"expected 2 subgraph nodes, got {len(subgraph_nodes)}"
  984. )
  985. fx_block_args = [self.get_fx_value_by_fqn(name) for name in arguments]
  986. args = (
  987. predicate,
  988. subgraph_nodes[0],
  989. subgraph_nodes[1],
  990. tuple(fx_block_args),
  991. )
  992. cond_node = self.fx_graph.call_function(torch.cond, args, {})
  993. # prim::If can also have zero output.
  994. if node.outputsSize() == 1:
  995. output_name = node.output().debugName()
  996. self.name_to_node[output_name] = cond_node
  997. elif node.outputsSize() > 1:
  998. for i, output in enumerate(node.outputs()):
  999. output_name = output.debugName()
  1000. getitem = self.fx_graph.call_function(operator.getitem, (cond_node, i))
  1001. self.name_to_node[output_name] = getitem
  1002. def convert_aten_Bool(self, node: torch._C.Node):
  1003. self._convert_as_noop(node)
  1004. def convert_prim_Enter(self, node: torch._C.Node):
  1005. # export generally treats prim::Enter as noop
  1006. # The only context manager export supports is aten::enable_grad.
  1007. # Unfortunately, TorchScript does not support aten::enable_grad yet.
  1008. # TODO: support aten::enable_grad in both TorchScript and Converter.
  1009. return
  1010. def convert_prim_Exit(self, node: torch._C.Node):
  1011. # export treats prim::Exit as noop
  1012. return
  1013. def _convert_as_noop(self, node: torch._C.Node):
  1014. # Converts the node as a no-op by mapping its output node as arg[0]
  1015. target = get_op_overload(node)
  1016. schema = target._schema
  1017. args, _kwargs = self.get_args_kwargs(node, schema)
  1018. output_name = node.output().debugName()
  1019. self.name_to_node[output_name] = args[0]
  1020. def convert_profiler__record_function_exit(self, node: torch._C.Node):
  1021. # _record_function_exit has side effect so we keep it in fx.graph
  1022. # currently, _record_function_enter_new and _record_function_exit are
  1023. # discarded during `retrace_as_exported_program`.
  1024. target = torch.ops.profiler._record_function_exit
  1025. args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
  1026. self.fx_graph.call_function(target, args)
  1027. def convert_prim_tolist(self, node: torch._C.Node):
  1028. # prim::tolist cannot be supported by `_convert_standard_operators`
  1029. # since it requires call_method instead of call_function.
  1030. target = "tolist"
  1031. args = (self.get_fx_value_by_ir_value(next(node.inputs())),)
  1032. fx_node = self.fx_graph.call_method(target, args)
  1033. output_name = node.output().debugName()
  1034. self.name_to_node[output_name] = fx_node
  1035. def convert_prim_Uninitialized(self, node: torch._C.Node):
  1036. # `prim::Uninitialized` is inserted by the compiler when it can prove
  1037. # the value will never be used. It can be introduced by exceptions,
  1038. # breaks, continues, and returns.
  1039. # So we add a dummy constant to the graph.
  1040. output_name = node.output().debugName()
  1041. self.name_to_constant[output_name] = torch.Tensor()
  1042. def _convert_standard_operators(self, node: torch._C.Node):
  1043. target = kind_to_standard_operators[node.kind()]
  1044. args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
  1045. fx_node = self.fx_graph.call_function(target, args)
  1046. output_name = node.output().debugName()
  1047. self.name_to_node[output_name] = fx_node
  1048. def convert_node(self, node: torch._C.Node):
  1049. node_kind = node.kind()
  1050. # Get handler based on namespace and operator name.
  1051. # Provide a default node handler as well in case we don't find
  1052. # matching converter for that.
  1053. handler_func_name = ir_name_to_func_name(node_kind)
  1054. handler_func = getattr(self, handler_func_name, self.convert_call_function_op)
  1055. # str calls print function implemented in CPP. To avoid repeating
  1056. # the entire logic here, we simply keep first line from node string (getting rid
  1057. # of sub-blocks IR prints).
  1058. node_str = "".join(str(node).split("\n")[:1])
  1059. log.debug("[%s] converts [%s]", handler_func.__name__, node_str)
  1060. try:
  1061. handler_func(node)
  1062. except Exception as e:
  1063. raise RuntimeError(f"TS2EPConverter failed for node {node_kind}") from e
  1064. def convert_graph_outputs(self):
  1065. args = []
  1066. outp_name_list = [outp.debugName() for outp in self.ts_graph.outputs()] + list(
  1067. self.name_update_from_subblock_to_parent
  1068. )
  1069. for output_name in outp_name_list:
  1070. if output_name in self.name_to_node:
  1071. fx_node = self.name_to_node[output_name]
  1072. # TODO: Revisit this later after HigherOrderOp design changes.
  1073. # Currently, we cannot directly return input as output.
  1074. if (
  1075. not self.is_top_level_graph()
  1076. and isinstance(fx_node, torch.fx.Node)
  1077. and fx_node.op == "placeholder"
  1078. ):
  1079. fx_node = self.fx_graph.call_function(torch.clone, (fx_node,))
  1080. args.append(fx_node)
  1081. self.output_specs.append(
  1082. OutputSpec(
  1083. OutputKind.USER_OUTPUT,
  1084. arg=TensorArgument(name=output_name),
  1085. target=output_name,
  1086. )
  1087. )
  1088. elif output_name in self.name_to_constant:
  1089. args.append(self.name_to_constant[output_name])
  1090. self.output_specs.append(
  1091. OutputSpec(
  1092. OutputKind.USER_OUTPUT,
  1093. arg=ConstantArgument(
  1094. name=output_name, value=self.name_to_constant[output_name]
  1095. ),
  1096. target=output_name,
  1097. )
  1098. )
  1099. else:
  1100. raise ValueError(f"Output {output_name} not found")
  1101. if len(args) == 0:
  1102. # Sub-block of prim::If can have zero output.
  1103. self.fx_graph.output([])
  1104. elif len(args) == 1:
  1105. self.fx_graph.output(
  1106. args[0]
  1107. ) # Get rid of an extra list wrapped around final output.
  1108. elif len(args) > 1:
  1109. self.fx_graph.output(
  1110. args
  1111. ) # For prim::Loop and prim::If with multiple outputs.
  1112. else:
  1113. # Sub-block of prim::Loop can have multiple outputs.
  1114. self.fx_graph.output(args)
  1115. class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
  1116. """
  1117. Run TS2FXGraphConverter in an explain mode. It collects all failed operators conversions
  1118. and provide that information to users. In order to collect all failed conversions, it
  1119. also mocks some internal attributes (e.g., name_to_node).
  1120. """
  1121. class _DictMock(dict):
  1122. def __init__(self, dict_data, mock_value):
  1123. super().__init__(dict_data)
  1124. self.mock_value = mock_value
  1125. def __getitem__(self, key):
  1126. # If the original dictionary has the key, return its value.
  1127. # Otherwise, return the mock value.
  1128. if not super().__contains__(key):
  1129. return self.mock_value
  1130. return super().__getitem__(key)
  1131. def __contains__(self, key):
  1132. return True
  1133. def __init__(
  1134. self,
  1135. ts_graph: Union[torch._C.Graph, torch._C.Block],
  1136. name_to_param: dict[str, torch.Tensor],
  1137. name_to_buffer: dict[str, torch.Tensor],
  1138. blocks_to_lifted_attrs: dict[torch._C.Block, set[str]],
  1139. name_to_non_tensor_attribute: dict[str, Any],
  1140. name_to_constant: dict[str, Any],
  1141. name_to_attribute_fqn: dict[str, str],
  1142. ):
  1143. super().__init__(
  1144. ts_graph,
  1145. name_to_param,
  1146. name_to_buffer,
  1147. blocks_to_lifted_attrs,
  1148. name_to_non_tensor_attribute,
  1149. name_to_constant,
  1150. name_to_attribute_fqn,
  1151. )
  1152. # Data to keep track of unsupported nodes.
  1153. self.unsupported_node_list: list[torch._C.Node] = []
  1154. # Add mock to needed attributes.
  1155. self.name_to_node = ExplainTS2FXGraphConverter._DictMock(
  1156. self.name_to_node,
  1157. # Dummy node.
  1158. torch.fx.Node(
  1159. None, # type: ignore[arg-type]
  1160. "mock",
  1161. "call_function",
  1162. lambda: None,
  1163. (),
  1164. {},
  1165. ),
  1166. )
  1167. def explain(self):
  1168. self.convert_graph_inputs()
  1169. for node in self.ts_graph.nodes():
  1170. self.convert_node(node)
  1171. self.convert_graph_outputs()
  1172. def convert_node(self, node):
  1173. try:
  1174. super().convert_node(node)
  1175. except Exception:
  1176. self.unsupported_node_list.append(node)
  1177. @contextmanager
  1178. def disable_logging(log):
  1179. disabled = log.disabled
  1180. log.disabled = True
  1181. try:
  1182. yield
  1183. finally:
  1184. log.disabled = disabled
  1185. class TS2EPConverter:
  1186. # TorchScript model to ExportedProgram converter
  1187. def __init__(
  1188. self,
  1189. ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction],
  1190. sample_args: tuple[Any, ...],
  1191. sample_kwargs: Optional[dict[str, Any]] = None,
  1192. ):
  1193. self.ts_model = ts_model
  1194. self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args)
  1195. self.sample_args = sample_args
  1196. self.sample_kwargs = sample_kwargs
  1197. self.name_to_param: dict[str, torch.Tensor] = {}
  1198. self.name_to_buffer: dict[str, torch.Tensor] = {}
  1199. param_list = (
  1200. list(self.ts_model.parameters())
  1201. if not isinstance(self.ts_model, torch._C.ScriptFunction)
  1202. else []
  1203. )
  1204. if not isinstance(self.ts_model, torch._C.ScriptFunction):
  1205. for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr]
  1206. # Check if tensor belongs to any parameter.
  1207. if any(
  1208. (tensor == param).all()
  1209. for param in param_list
  1210. if tensor.shape == param.shape
  1211. ):
  1212. self.name_to_param[k] = tensor
  1213. else:
  1214. self.name_to_buffer[k] = tensor
  1215. self.name_to_non_tensor_attributes: dict[str, Any] = {}
  1216. self.name_to_constant: dict[str, Any] = {}
  1217. self.lift_get_attr()
  1218. def convert(self) -> ExportedProgram:
  1219. log.info(
  1220. """
  1221. TS2EPConverter logging starts from here.
  1222. INFO: (TORCH_LOGS="export" <cmd>)
  1223. * Log TorchScript IR.
  1224. DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
  1225. * Log conversion IR by IR in a format of [<conversion handler name>] converts [<IR>].
  1226. """
  1227. )
  1228. log.info("TorchScript graph\n\n%s\n", self.ts_graph)
  1229. blocks_to_lifted_attrs, name_to_attribute_fqn = get_block_to_lifted_attrs(
  1230. self.ts_graph
  1231. )
  1232. graph_converter = TS2FXGraphConverter(
  1233. self.ts_graph,
  1234. self.name_to_param,
  1235. self.name_to_buffer,
  1236. blocks_to_lifted_attrs,
  1237. self.name_to_non_tensor_attributes,
  1238. self.name_to_constant,
  1239. name_to_attribute_fqn,
  1240. )
  1241. gm = graph_converter.convert()
  1242. # Post-processing step to deal with quantized operators.
  1243. replace_quantized_ops_with_standard_ops(gm)
  1244. log.info("GraphModule: %s", gm.print_readable(print_output=False))
  1245. ep = self.retrace_as_exported_program(
  1246. gm,
  1247. graph_converter.name_to_constant,
  1248. )
  1249. log.info("%s", ep)
  1250. # Post-processing step to ensure ExportedProgram has the same state_dict as
  1251. # the original TorchScript model. Throw warnings for additionally populated
  1252. # state_dict entries.
  1253. if not isinstance(self.ts_model, torch._C.ScriptFunction):
  1254. for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr]
  1255. if k not in ep.state_dict:
  1256. warnings.warn(
  1257. f"Manually populate {k} into state_dict ExportedProgram, but it is never used by the ExportedProgram.",
  1258. stacklevel=2,
  1259. )
  1260. ep.state_dict[k] = tensor
  1261. return ep
  1262. @disable_logging(log)
  1263. def explain(self, print_output=True):
  1264. blocks_to_lifted_attrs, name_to_attribute_fqn = get_block_to_lifted_attrs(
  1265. self.ts_graph
  1266. )
  1267. graph_converter = ExplainTS2FXGraphConverter(
  1268. self.ts_graph,
  1269. self.name_to_param,
  1270. self.name_to_buffer,
  1271. blocks_to_lifted_attrs,
  1272. self.name_to_non_tensor_attributes,
  1273. self.name_to_constant,
  1274. name_to_attribute_fqn,
  1275. )
  1276. graph_converter.explain()
  1277. if len(graph_converter.unsupported_node_list) > 0:
  1278. explain_str = "Unsupported nodes are found in the following list:"
  1279. for i, n in enumerate(graph_converter.unsupported_node_list):
  1280. node_str = "".join(str(n).split("\n")[:1])
  1281. explain_str += f"\n\n {i}. {n.kind()} [{node_str}]"
  1282. else:
  1283. explain_str = "Success!"
  1284. if print_output:
  1285. print(explain_str)
  1286. return explain_str
  1287. def retrace_as_exported_program(
  1288. self,
  1289. gm: torch.fx.GraphModule,
  1290. name_to_constant: dict[str, Any],
  1291. ):
  1292. dynamic_shapes = _tree_map_with_path(
  1293. lambda path, x: (
  1294. [Dim.AUTO] * x.dim() if isinstance(x, torch.Tensor) else None
  1295. ),
  1296. self.sample_args,
  1297. )
  1298. # TODO: adjust input orders to match GraphSignature convention
  1299. ep = torch.export._trace._export(
  1300. gm,
  1301. self.sample_args,
  1302. dynamic_shapes=dynamic_shapes,
  1303. strict=False,
  1304. pre_dispatch=True,
  1305. )
  1306. # Post-processing to make sure the ExportedProgram states are correct.
  1307. # Because during conversion, we set tensor constants as GetAttr,
  1308. # retracing cannot recognize them as tensor constants but instead
  1309. # treat them as buffers. We need to set them again here.
  1310. ep._constants.update(
  1311. {
  1312. k: v
  1313. for k, v in name_to_constant.items()
  1314. if isinstance(v, (torch.Tensor, torch.ScriptObject))
  1315. }
  1316. )
  1317. for k in name_to_constant:
  1318. ep.state_dict.pop(k, None)
  1319. for spec in ep.graph_signature.input_specs:
  1320. # Mark as constant tensors for erroneously traced buffers.
  1321. if spec.kind == InputKind.BUFFER and spec.target in name_to_constant:
  1322. if not isinstance(name_to_constant[spec.target], torch.Tensor):
  1323. raise AssertionError(
  1324. f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer"
  1325. )
  1326. spec.kind = InputKind.CONSTANT_TENSOR
  1327. spec.persistent = None
  1328. ep.verifier().check(ep)
  1329. return ep
  1330. def lift_get_attr(self):
  1331. # This function lifts multiple data types.
  1332. # 1. Tensor constants attributes (e.g., self.data = torch.tensor([2,3]))
  1333. # to buffers. Currently, when there are tensor constants, export
  1334. # would error and ask users to register tensor constants as buffers.
  1335. # Since it is hard to manually do so for TorchScript models
  1336. # (e.g., source code is missing), this function automatically
  1337. # lifts tensor constants to be buffers.
  1338. # 2. ScriptObbject to constant. It will then be converted to getattr in
  1339. # in the fx graph.
  1340. #
  1341. # This function should happen in TS2EPConverter instead of
  1342. # TS2FXGraphConverter since it gets attributes from self.ts_model
  1343. # which is not accessible in TS2FXGraphConverter. It is similar to where
  1344. # we collect self.name_to_param and self.name_to_buffer.
  1345. name_to_attribute_fqn: dict[str, str] = {}
  1346. def get_attr(fqn: str):
  1347. name = fqn.split(".")
  1348. v = self.ts_model
  1349. for n in name:
  1350. v = getattr(v, n)
  1351. return v
  1352. def get_fqn(node: torch._C.Node):
  1353. attr_name = node.s("name")
  1354. input_name = node.input().debugName()
  1355. root_attr_name = name_to_attribute_fqn[input_name]
  1356. attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name
  1357. return attr_fqn
  1358. def _dfs_get_attr(block):
  1359. for node in block.nodes():
  1360. if node.kind() == "prim::CreateObject":
  1361. output_name = node.output().debugName()
  1362. name_to_attribute_fqn[output_name] = ""
  1363. if node.kind() == "prim::GetAttr":
  1364. attr_fqn = get_fqn(node)
  1365. value = get_attr(attr_fqn)
  1366. output_name = node.output().debugName()
  1367. name_to_attribute_fqn[output_name] = attr_fqn
  1368. if isinstance(value, torch.Tensor):
  1369. if attr_fqn not in self.name_to_buffer:
  1370. # Lift tensor constants to be a buffer
  1371. self.name_to_buffer[attr_fqn] = value
  1372. elif isinstance(value, torch.ScriptObject):
  1373. if attr_fqn not in self.name_to_constant:
  1374. self.name_to_constant[attr_fqn] = value
  1375. else:
  1376. self.name_to_non_tensor_attributes[attr_fqn] = value
  1377. for subblock in node.blocks():
  1378. _dfs_get_attr(subblock)
  1379. _dfs_get_attr(self.ts_graph)