graph_signature.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759
  1. # mypy: allow-untyped-defs
  2. import dataclasses
  3. from collections.abc import Collection, Mapping
  4. from enum import auto, Enum
  5. from typing import TYPE_CHECKING, Union
  6. from torch._library.fake_class_registry import FakeScriptObject
  7. from torch._library.opaque_object import get_opaque_type_name, is_opaque_type
  8. from torch._subclasses.fake_tensor import is_fake
  9. if TYPE_CHECKING:
  10. import torch
  11. from torch._functorch._aot_autograd.schemas import GraphSignature
  12. __all__ = [
  13. "ConstantArgument",
  14. "CustomObjArgument",
  15. "ExportBackwardSignature",
  16. "ExportGraphSignature",
  17. "InputKind",
  18. "InputSpec",
  19. "OutputKind",
  20. "OutputSpec",
  21. "SymIntArgument",
  22. "SymFloatArgument",
  23. "SymBoolArgument",
  24. "TensorArgument",
  25. ]
  26. @dataclasses.dataclass
  27. class TensorArgument:
  28. name: str
  29. @dataclasses.dataclass
  30. class TokenArgument:
  31. name: str
  32. @dataclasses.dataclass
  33. class SymIntArgument:
  34. name: str
  35. @dataclasses.dataclass
  36. class SymFloatArgument:
  37. name: str
  38. @dataclasses.dataclass
  39. class SymBoolArgument:
  40. name: str
  41. @dataclasses.dataclass
  42. class CustomObjArgument:
  43. name: str
  44. class_fqn: str
  45. fake_val: FakeScriptObject | None = None
  46. @dataclasses.dataclass
  47. class ConstantArgument:
  48. name: str
  49. value: int | float | bool | str | None
  50. ArgumentSpec = Union[
  51. TensorArgument,
  52. SymIntArgument,
  53. SymFloatArgument,
  54. SymBoolArgument,
  55. ConstantArgument,
  56. CustomObjArgument,
  57. TokenArgument,
  58. ]
  59. class InputKind(Enum):
  60. USER_INPUT = auto()
  61. PARAMETER = auto()
  62. BUFFER = auto()
  63. CONSTANT_TENSOR = auto()
  64. CUSTOM_OBJ = auto()
  65. TOKEN = auto()
  66. @dataclasses.dataclass
  67. class InputSpec:
  68. kind: InputKind
  69. arg: ArgumentSpec
  70. target: str | None
  71. persistent: bool | None = None
  72. def __post_init__(self):
  73. if self.kind == InputKind.BUFFER:
  74. if self.persistent is None:
  75. raise AssertionError("Failed to specify persistent flag on BUFFER.")
  76. if not isinstance(
  77. self.arg,
  78. (
  79. TensorArgument,
  80. SymIntArgument,
  81. SymFloatArgument,
  82. SymBoolArgument,
  83. ConstantArgument,
  84. CustomObjArgument,
  85. TokenArgument,
  86. ),
  87. ):
  88. raise AssertionError(f"expected valid arg type, got {type(self.arg)}")
  89. def __str__(self):
  90. target = "" if self.target is None else f" target='{self.target}'"
  91. persistent = "" if self.persistent is None else f" persistent={self.persistent}"
  92. return f"{str(self.arg.name)}: {str(self.kind.name)}{target}{persistent}"
  93. class OutputKind(Enum):
  94. USER_OUTPUT = auto()
  95. LOSS_OUTPUT = auto()
  96. BUFFER_MUTATION = auto()
  97. PARAMETER_MUTATION = auto()
  98. GRADIENT_TO_PARAMETER = auto()
  99. GRADIENT_TO_USER_INPUT = auto()
  100. USER_INPUT_MUTATION = auto()
  101. TOKEN = auto()
  102. @dataclasses.dataclass
  103. class OutputSpec:
  104. kind: OutputKind
  105. arg: ArgumentSpec
  106. target: str | None
  107. def __post_init__(self):
  108. if not isinstance(
  109. self.arg,
  110. (
  111. TensorArgument,
  112. SymIntArgument,
  113. SymFloatArgument,
  114. SymBoolArgument,
  115. ConstantArgument,
  116. TokenArgument,
  117. CustomObjArgument,
  118. ),
  119. ):
  120. raise AssertionError(f"expected valid arg type, got {self.arg}")
  121. def __str__(self):
  122. target = "" if self.target is None else f" target='{self.target}'"
  123. return f"{str(self.arg.name)}: {str(self.kind.name)}{target}"
  124. @dataclasses.dataclass
  125. class ExportBackwardSignature:
  126. gradients_to_parameters: dict[str, str]
  127. gradients_to_user_inputs: dict[str, str]
  128. loss_output: str
  129. @dataclasses.dataclass
  130. class ExportGraphSignature:
  131. """
  132. :class:`ExportGraphSignature` models the input/output signature of Export Graph,
  133. which is a fx.Graph with stronger invariants guarantees.
  134. Export Graph is functional and does not access "states" like parameters
  135. or buffers within the graph via ``getattr`` nodes. Instead, :func:`export`
  136. guarantees that parameters, buffers, and constant tensors are lifted out of
  137. the graph as inputs. Similarly, any mutations to buffers are not included
  138. in the graph either, instead the updated values of mutated buffers are
  139. modeled as additional outputs of Export Graph.
  140. The ordering of all inputs and outputs are::
  141. Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
  142. Outputs = [*mutated_inputs, *flattened_user_outputs]
  143. e.g. If following module is exported::
  144. class CustomModule(nn.Module):
  145. def __init__(self) -> None:
  146. super(CustomModule, self).__init__()
  147. # Define a parameter
  148. self.my_parameter = nn.Parameter(torch.tensor(2.0))
  149. # Define two buffers
  150. self.register_buffer("my_buffer1", torch.tensor(3.0))
  151. self.register_buffer("my_buffer2", torch.tensor(4.0))
  152. def forward(self, x1, x2):
  153. # Use the parameter, buffers, and both inputs in the forward method
  154. output = (
  155. x1 + self.my_parameter
  156. ) * self.my_buffer1 + x2 * self.my_buffer2
  157. # Mutate one of the buffers (e.g., increment it by 1)
  158. self.my_buffer2.add_(1.0) # In-place addition
  159. return output
  160. mod = CustomModule()
  161. ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))
  162. Resulting Graph is non-functional::
  163. graph():
  164. %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter]
  165. %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1]
  166. %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2]
  167. %x1 : [num_users=1] = placeholder[target=x1]
  168. %x2 : [num_users=1] = placeholder[target=x2]
  169. %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {})
  170. %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {})
  171. %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {})
  172. %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {})
  173. %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {})
  174. return (add_1,)
  175. Resulting ExportGraphSignature of the non-functional Graph would be::
  176. # inputs
  177. p_my_parameter: PARAMETER target='my_parameter'
  178. b_my_buffer1: BUFFER target='my_buffer1' persistent=True
  179. b_my_buffer2: BUFFER target='my_buffer2' persistent=True
  180. x1: USER_INPUT
  181. x2: USER_INPUT
  182. # outputs
  183. add_1: USER_OUTPUT
  184. To get a functional Graph, you can use :func:`run_decompositions`::
  185. mod = CustomModule()
  186. ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))
  187. ep = ep.run_decompositions()
  188. Resulting Graph is functional::
  189. graph():
  190. %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter]
  191. %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1]
  192. %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2]
  193. %x1 : [num_users=1] = placeholder[target=x1]
  194. %x2 : [num_users=1] = placeholder[target=x2]
  195. %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {})
  196. %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {})
  197. %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {})
  198. %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {})
  199. %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {})
  200. return (add_2, add_1)
  201. Resulting ExportGraphSignature of the functional Graph would be::
  202. # inputs
  203. p_my_parameter: PARAMETER target='my_parameter'
  204. b_my_buffer1: BUFFER target='my_buffer1' persistent=True
  205. b_my_buffer2: BUFFER target='my_buffer2' persistent=True
  206. x1: USER_INPUT
  207. x2: USER_INPUT
  208. # outputs
  209. add_2: BUFFER_MUTATION target='my_buffer2'
  210. add_1: USER_OUTPUT
  211. """
  212. input_specs: list[InputSpec]
  213. output_specs: list[OutputSpec]
  214. # A list of parameters uniquely identified by mangled fully qualified name
  215. @property
  216. def parameters(self) -> Collection[str]:
  217. return tuple(
  218. s.target
  219. for s in self.input_specs
  220. if s.kind == InputKind.PARAMETER
  221. if isinstance(s.target, str)
  222. )
  223. # A list of buffers uniquely identified by mangled fully qualified name
  224. @property
  225. def buffers(self) -> Collection[str]:
  226. return tuple(
  227. s.target
  228. for s in self.input_specs
  229. if s.kind == InputKind.BUFFER
  230. if isinstance(s.target, str)
  231. )
  232. @property
  233. def non_persistent_buffers(self) -> Collection[str]:
  234. return tuple(
  235. s.target
  236. for s in self.input_specs
  237. if s.kind == InputKind.BUFFER
  238. if s.persistent is False
  239. if isinstance(s.target, str)
  240. )
  241. # A list of lifted constant tensors
  242. @property
  243. def lifted_tensor_constants(self) -> Collection[str]:
  244. return tuple(
  245. s.target
  246. for s in self.input_specs
  247. if s.kind == InputKind.CONSTANT_TENSOR
  248. if isinstance(s.target, str)
  249. )
  250. @property
  251. def lifted_custom_objs(self) -> Collection[str]:
  252. return tuple(
  253. s.target
  254. for s in self.input_specs
  255. if s.kind == InputKind.CUSTOM_OBJ
  256. if isinstance(s.target, str)
  257. )
  258. # Graph node names of pytree-flattened inputs of original program
  259. @property
  260. def user_inputs(self) -> Collection[int | float | bool | str | None]:
  261. user_inputs: list[int | float | bool | str | None] = []
  262. for s in self.input_specs:
  263. if s.kind != InputKind.USER_INPUT:
  264. continue
  265. if isinstance(
  266. s.arg,
  267. (
  268. TensorArgument,
  269. SymIntArgument,
  270. SymFloatArgument,
  271. SymBoolArgument,
  272. CustomObjArgument,
  273. ),
  274. ):
  275. user_inputs.append(s.arg.name)
  276. elif isinstance(s.arg, ConstantArgument):
  277. user_inputs.append(s.arg.value)
  278. else:
  279. raise RuntimeError(f"{s.arg} is not a valid user inputs")
  280. return tuple(user_inputs)
  281. # Graph node names of pytree-flattened outputs of original program
  282. # For joint-graph purposes, will include the loss output.
  283. @property
  284. def user_outputs(self) -> Collection[int | float | bool | str | None]:
  285. user_outputs: list[int | float | bool | str | None] = []
  286. for s in self.output_specs:
  287. if s.kind not in [
  288. OutputKind.USER_OUTPUT,
  289. OutputKind.LOSS_OUTPUT,
  290. ]:
  291. continue
  292. if isinstance(
  293. s.arg,
  294. (TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument),
  295. ):
  296. user_outputs.append(s.arg.name)
  297. elif isinstance(s.arg, ConstantArgument):
  298. user_outputs.append(s.arg.value)
  299. elif isinstance(s.arg, CustomObjArgument):
  300. user_outputs.append(s.arg.name)
  301. else:
  302. raise RuntimeError(f"{s.arg} is not a valid user output")
  303. return tuple(user_outputs)
  304. # A dictionary mapping graph input node names to parameters. If a graph input
  305. # name is found in this dictionary, it is guaranteed to be a lifted parameter.
  306. @property
  307. def inputs_to_parameters(self) -> Mapping[str, str]:
  308. return _immutable_dict(
  309. (s.arg.name, s.target)
  310. for s in self.input_specs
  311. if s.kind == InputKind.PARAMETER
  312. and isinstance(s.arg, TensorArgument)
  313. and isinstance(s.target, str)
  314. )
  315. # A dictionary mapping graph input node names to buffers. If a graph input
  316. # name is found in this dictionary, it is guaranteed to be a lifted buffer.
  317. @property
  318. def inputs_to_buffers(self) -> Mapping[str, str]:
  319. return _immutable_dict(
  320. (s.arg.name, s.target) # type: ignore[union-attr, misc]
  321. for s in self.input_specs
  322. if s.kind == InputKind.BUFFER
  323. and isinstance(s.arg, TensorArgument)
  324. and isinstance(s.target, str)
  325. )
  326. # A dictionary mapping graph output node names to buffers that are mutated in the
  327. # original program. Buffers that are not mutated will not be found in this dictionary.
  328. @property
  329. def buffers_to_mutate(self) -> Mapping[str, str]:
  330. return _immutable_dict(
  331. (s.arg.name, s.target)
  332. for s in self.output_specs
  333. if s.kind == OutputKind.BUFFER_MUTATION
  334. and isinstance(s.arg, TensorArgument)
  335. and isinstance(s.target, str)
  336. )
  337. @property
  338. def parameters_to_mutate(self) -> Mapping[str, str]:
  339. return _immutable_dict(
  340. (s.arg.name, s.target)
  341. for s in self.output_specs
  342. if s.kind == OutputKind.PARAMETER_MUTATION
  343. and isinstance(s.arg, TensorArgument)
  344. and isinstance(s.target, str)
  345. )
  346. @property
  347. def user_inputs_to_mutate(self) -> Mapping[str, str]:
  348. return _immutable_dict(
  349. (s.arg.name, s.target)
  350. for s in self.output_specs
  351. if s.kind == OutputKind.USER_INPUT_MUTATION
  352. and isinstance(s.arg, TensorArgument)
  353. and isinstance(s.target, str)
  354. )
  355. # A dictionary mapping graph input node names to lifted tensor constants.
  356. @property
  357. def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]:
  358. return _immutable_dict(
  359. (s.arg.name, s.target)
  360. for s in self.input_specs
  361. if s.kind == InputKind.CONSTANT_TENSOR
  362. and isinstance(s.arg, TensorArgument)
  363. and isinstance(s.target, str)
  364. )
  365. @property
  366. def inputs_to_lifted_custom_objs(self) -> Mapping[str, str]:
  367. return _immutable_dict(
  368. (s.arg.name, s.target)
  369. for s in self.input_specs
  370. if s.kind == InputKind.CUSTOM_OBJ
  371. and isinstance(s.arg, CustomObjArgument)
  372. and isinstance(s.target, str)
  373. )
  374. @property
  375. def backward_signature(self) -> ExportBackwardSignature | None:
  376. loss_output = None
  377. gradients_to_parameters: dict[str, str] = {}
  378. gradients_to_user_inputs: dict[str, str] = {}
  379. for spec in self.output_specs:
  380. if spec.kind == OutputKind.LOSS_OUTPUT:
  381. if loss_output is not None:
  382. raise AssertionError("multiple LOSS_OUTPUT specs found")
  383. if not isinstance(spec.arg, TensorArgument):
  384. raise AssertionError(
  385. f"expected TensorArgument for LOSS_OUTPUT, got {type(spec.arg)}"
  386. )
  387. loss_output = spec.arg.name
  388. elif spec.kind == OutputKind.GRADIENT_TO_PARAMETER:
  389. if not isinstance(spec.target, str):
  390. raise AssertionError(
  391. f"expected str target for GRADIENT_TO_PARAMETER, got {type(spec.target)}"
  392. )
  393. if not isinstance(spec.arg, TensorArgument):
  394. raise AssertionError(
  395. f"expected TensorArgument for GRADIENT_TO_PARAMETER, got {type(spec.arg)}"
  396. )
  397. gradients_to_parameters[spec.arg.name] = spec.target
  398. elif spec.kind == OutputKind.GRADIENT_TO_USER_INPUT:
  399. if not isinstance(spec.target, str):
  400. raise AssertionError(
  401. f"expected str target for GRADIENT_TO_USER_INPUT, got {type(spec.target)}"
  402. )
  403. if not isinstance(spec.arg, TensorArgument):
  404. raise AssertionError(
  405. f"expected TensorArgument for GRADIENT_TO_USER_INPUT, got {type(spec.arg)}"
  406. )
  407. gradients_to_user_inputs[spec.arg.name] = spec.target
  408. if loss_output is None:
  409. return None
  410. return ExportBackwardSignature(
  411. loss_output=loss_output,
  412. gradients_to_parameters=gradients_to_parameters,
  413. gradients_to_user_inputs=gradients_to_user_inputs,
  414. )
  415. # Map from assertion dependency token index to assertion dep token output
  416. # name in output. The shape of output after aot_autograd will be like:
  417. # (updated_inputs, user_outputs, dep_token).
  418. @property
  419. def assertion_dep_token(self) -> Mapping[int, str] | None:
  420. return None
  421. @property
  422. def input_tokens(self) -> Collection[str]:
  423. input_tokens = []
  424. for s in self.input_specs:
  425. if s.kind == InputKind.TOKEN:
  426. if not isinstance(s.arg, TokenArgument):
  427. raise AssertionError(
  428. f"expected TokenArgument for TOKEN kind, got {type(s.arg)}"
  429. )
  430. input_tokens.append(s.arg.name)
  431. return tuple(input_tokens)
  432. @property
  433. def output_tokens(self) -> Collection[str]:
  434. output_tokens = []
  435. for s in self.output_specs:
  436. if s.kind == OutputKind.TOKEN:
  437. if not isinstance(s.arg, TokenArgument):
  438. raise AssertionError(
  439. f"expected TokenArgument for TOKEN kind, got {type(s.arg)}"
  440. )
  441. output_tokens.append(s.arg.name)
  442. return tuple(output_tokens)
  443. def __post_init__(self) -> None:
  444. assertion_dep_token = self.assertion_dep_token
  445. if assertion_dep_token is None:
  446. return
  447. if len(assertion_dep_token) != 1:
  448. raise AssertionError(
  449. f"expected exactly 1 assertion_dep_token, got {len(assertion_dep_token)}"
  450. )
  451. assertion_dep_token_index = next(iter(assertion_dep_token.keys()))
  452. expected_index = len(self.user_outputs) + len(self.buffers_to_mutate)
  453. if expected_index != assertion_dep_token_index:
  454. raise AssertionError(
  455. f"expected assertion_dep_token_index to be {expected_index}, got {assertion_dep_token_index}"
  456. )
  457. def replace_all_uses(self, old: str, new: str):
  458. """
  459. Replace all uses of the old name with new name in the signature.
  460. """
  461. if not isinstance(old, str):
  462. raise AssertionError(f"expected old to be str, got {type(old)}")
  463. if not isinstance(new, str):
  464. raise AssertionError(f"expected new to be str, got {type(new)}")
  465. arg_types = (
  466. TensorArgument,
  467. SymIntArgument,
  468. SymFloatArgument,
  469. SymBoolArgument,
  470. CustomObjArgument,
  471. TokenArgument,
  472. )
  473. for o in self.output_specs:
  474. if isinstance(o.arg, arg_types):
  475. if o.arg.name == old:
  476. o.arg.name = new
  477. for i in self.input_specs:
  478. if isinstance(i.arg, arg_types):
  479. if i.arg.name == old:
  480. i.arg.name = new
  481. def get_replace_hook(self, replace_inputs=False):
  482. def _(old, new, user):
  483. if user.op == "output":
  484. self.replace_all_uses(old.name, new)
  485. if replace_inputs and old.op == "placeholder":
  486. self.replace_all_uses(old.name, new)
  487. return _
  488. def __str__(self):
  489. input_specs = "\n".join(str(s) for s in self.input_specs)
  490. output_specs = "\n".join(str(s) for s in self.output_specs)
  491. return f"\n# inputs\n{input_specs}\n\n# outputs\n{output_specs}\n"
  492. def _immutable_dict(items):
  493. """
  494. Creates a mapping where items cannot be added, deleted, or updated.
  495. NOTE: The immutability is shallow (like tuple is an immutable collection).
  496. """
  497. from types import MappingProxyType
  498. return MappingProxyType(dict(items))
  499. def _make_argument_spec(node, token_names) -> ArgumentSpec:
  500. from torch import ScriptObject, SymBool, SymFloat, SymInt
  501. from torch._library.fake_class_registry import FakeScriptObject
  502. if isinstance(node, (int, bool, float, type(None), str)):
  503. # For const outputs we just directly return this
  504. return ConstantArgument(name="", value=node)
  505. if "val" not in node.meta:
  506. raise AssertionError(
  507. f"{node} is not a constant or a node with a 'val' metadata field"
  508. )
  509. val = node.meta["val"]
  510. if node.name in token_names:
  511. return TokenArgument(name=node.name)
  512. elif is_fake(val):
  513. return TensorArgument(name=node.name)
  514. elif isinstance(val, SymInt):
  515. return SymIntArgument(name=node.name)
  516. elif isinstance(val, SymFloat):
  517. return SymFloatArgument(name=node.name)
  518. elif isinstance(val, SymBool):
  519. return SymBoolArgument(name=node.name)
  520. elif isinstance(val, ScriptObject):
  521. return CustomObjArgument(name=node.name, class_fqn=val._type().qualified_name()) # type: ignore[attr-defined]
  522. elif isinstance(val, FakeScriptObject):
  523. return CustomObjArgument(
  524. name=node.name, class_fqn=val.script_class_name, fake_val=val
  525. )
  526. elif is_opaque_type(type(val)):
  527. return CustomObjArgument(
  528. name=node.name, class_fqn=get_opaque_type_name(type(val)), fake_val=val
  529. )
  530. elif isinstance(val, (int, bool, str, float, type(None))):
  531. return ConstantArgument(name=node.name, value=val)
  532. else:
  533. raise AssertionError(
  534. f"Encountered an unsupported object of type {type(val)} "
  535. f"while writing the metadata for exported program"
  536. )
  537. def _convert_to_export_graph_signature(
  538. graph_signature: "GraphSignature",
  539. gm: "torch.fx.GraphModule",
  540. non_persistent_buffers: set[str],
  541. ) -> "ExportGraphSignature":
  542. from torch.utils import _pytree as pytree
  543. is_joint = graph_signature.backward_signature is not None
  544. # unpack objects
  545. user_inputs = set(graph_signature.user_inputs)
  546. inputs_to_parameters = graph_signature.inputs_to_parameters
  547. inputs_to_buffers = graph_signature.inputs_to_buffers
  548. user_outputs = set(graph_signature.user_outputs)
  549. buffer_mutations = graph_signature.buffers_to_mutate
  550. parameter_mutations = graph_signature.parameters_to_mutate
  551. user_input_mutations = graph_signature.user_inputs_to_mutate
  552. grad_params = (
  553. graph_signature.backward_signature.gradients_to_parameter # type: ignore[union-attr]
  554. if is_joint
  555. else {}
  556. )
  557. grad_user_inputs = (
  558. graph_signature.backward_signature.gradients_to_user_inputs # type: ignore[union-attr]
  559. if is_joint
  560. else {}
  561. )
  562. loss_output = (
  563. graph_signature.backward_signature.loss_output # type: ignore[union-attr]
  564. if is_joint
  565. else None
  566. )
  567. input_tokens = graph_signature.input_tokens
  568. output_tokens = graph_signature.output_tokens
  569. inputs = [
  570. _make_argument_spec(node, input_tokens)
  571. for node in gm.graph.nodes
  572. if node.op == "placeholder"
  573. ]
  574. outputs = [
  575. _make_argument_spec(node, output_tokens)
  576. for node in pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args)
  577. ]
  578. def to_input_spec(inp: ArgumentSpec) -> InputSpec:
  579. if isinstance(inp, TokenArgument):
  580. return InputSpec(kind=InputKind.TOKEN, arg=inp, target=None)
  581. if not isinstance(inp, TensorArgument):
  582. return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None)
  583. name = inp.name
  584. if name in user_inputs:
  585. return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None)
  586. elif name in inputs_to_parameters:
  587. return InputSpec(
  588. kind=InputKind.PARAMETER,
  589. arg=inp,
  590. target=inputs_to_parameters[name], # type: ignore[index]
  591. )
  592. elif name in inputs_to_buffers:
  593. return InputSpec(
  594. kind=InputKind.BUFFER,
  595. arg=inp,
  596. target=inputs_to_buffers[name], # type: ignore[index]
  597. persistent=(inputs_to_buffers[name] not in non_persistent_buffers), # type: ignore[index]
  598. )
  599. else:
  600. raise AssertionError(f"Unknown tensor input kind: {name}")
  601. def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec:
  602. if isinstance(o, TokenArgument):
  603. return OutputSpec(kind=OutputKind.TOKEN, arg=o, target=None)
  604. if not isinstance(o, TensorArgument):
  605. return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
  606. name = o.name
  607. if idx < len(buffer_mutations) + len(parameter_mutations) + len(
  608. user_input_mutations
  609. ) + len(output_tokens):
  610. if name in buffer_mutations:
  611. return OutputSpec(
  612. kind=OutputKind.BUFFER_MUTATION,
  613. arg=o,
  614. target=buffer_mutations[name], # type: ignore[index]
  615. )
  616. elif name in parameter_mutations:
  617. return OutputSpec(
  618. kind=OutputKind.PARAMETER_MUTATION,
  619. arg=o,
  620. target=parameter_mutations[name], # type: ignore[index]
  621. )
  622. elif name in user_input_mutations:
  623. return OutputSpec(
  624. kind=OutputKind.USER_INPUT_MUTATION,
  625. arg=o,
  626. target=user_input_mutations[name], # type: ignore[index]
  627. )
  628. else:
  629. raise AssertionError(f"Unknown tensor mutation kind: {name}")
  630. else:
  631. if name in user_outputs:
  632. return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
  633. elif name in grad_params:
  634. return OutputSpec(
  635. kind=OutputKind.GRADIENT_TO_PARAMETER,
  636. arg=o,
  637. target=grad_params[name],
  638. )
  639. elif name in grad_user_inputs:
  640. return OutputSpec(
  641. kind=OutputKind.GRADIENT_TO_USER_INPUT,
  642. arg=o,
  643. target=grad_user_inputs[name],
  644. )
  645. elif name == loss_output:
  646. return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None)
  647. else:
  648. raise AssertionError(f"Unknown tensor output kind: {name}")
  649. input_specs = [to_input_spec(inp) for inp in inputs]
  650. output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)]
  651. return ExportGraphSignature(input_specs=input_specs, output_specs=output_specs)