verifier.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. import math
  4. import operator
  5. from collections.abc import Iterable
  6. from typing import Any, final, TYPE_CHECKING
  7. import torch
  8. from torch._library.opaque_object import is_opaque_type
  9. from torch._ops import HigherOrderOperator, OpOverload
  10. from torch._subclasses.fake_tensor import FakeTensor
  11. from torch.export.graph_signature import (
  12. CustomObjArgument,
  13. InputKind,
  14. SymBoolArgument,
  15. SymFloatArgument,
  16. SymIntArgument,
  17. TensorArgument,
  18. TokenArgument,
  19. )
  20. from torch.fx import GraphModule
  21. if TYPE_CHECKING:
  22. from torch.export.exported_program import ExportedProgram
  23. class SpecViolationError(Exception):
  24. pass
  25. def is_functional(op: OpOverload) -> bool:
  26. return not op._schema.is_mutable
  27. def _check_has_fake_tensor(node: torch.fx.Node) -> None:
  28. # TODO(angelayi): remove this in favor of _check_val
  29. return _check_val(node)
  30. def _check_val(node: torch.fx.Node) -> None:
  31. from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt
  32. def _check_correct_val(val):
  33. if val is None:
  34. return True
  35. elif isinstance(val, (int, bool, str, float)):
  36. return True
  37. elif isinstance(
  38. val, (torch.memory_format, torch.dtype, torch.device, torch.layout)
  39. ):
  40. return True
  41. elif isinstance(
  42. val, (FakeTensor, torch.Tensor)
  43. ): # TODO(zhxchen17) Remove Tensor.
  44. return True
  45. elif isinstance(val, (SymInt, SymFloat, SymBool)):
  46. return True
  47. elif isinstance(val, CustomObjArgument):
  48. return True
  49. elif isinstance(val, Iterable):
  50. return all(_check_correct_val(x) for x in val)
  51. elif is_opaque_type(type(val)):
  52. return True
  53. return False
  54. def _no_returns(op):
  55. if not isinstance(op, OpOverload):
  56. return False
  57. return len(op._schema.returns) == 0
  58. if "val" not in node.meta:
  59. if node.op == "call_function" and _no_returns(node.target):
  60. return
  61. raise SpecViolationError(f"Node.meta {node.name} is missing val field.")
  62. val = node.meta["val"]
  63. if not _check_correct_val(val):
  64. raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}")
  65. def _check_torch_fn(node: torch.fx.Node) -> None:
  66. torch_fn = node.meta.get("torch_fn")
  67. if torch_fn is None:
  68. raise SpecViolationError(
  69. f"Unable to find torch_fn metadata for node {node.name}"
  70. )
  71. if (
  72. not isinstance(torch_fn, tuple)
  73. and isinstance(torch_fn[0], str)
  74. and isinstance(torch_fn[1], str)
  75. ):
  76. raise SpecViolationError(
  77. f"Node.meta {node.name} has invalid torch_fn field {torch_fn}"
  78. )
  79. class _VerifierMeta(type):
  80. _registry: dict[str, type["Verifier"]] = {}
  81. def __new__(metacls, name, bases, attrs):
  82. if bases:
  83. if "check" in attrs or "_check_graph_module" in attrs:
  84. raise SyntaxError("Overriding method check is not allowed.")
  85. if "dialect" not in attrs or attrs["dialect"] == "ATEN":
  86. raise AssertionError(
  87. f"subclass must define dialect != 'ATEN', got {attrs.get('dialect')}"
  88. )
  89. else:
  90. if "check" not in attrs:
  91. raise AssertionError("base class must define 'check' method")
  92. if "_check_graph_module" not in attrs:
  93. raise AssertionError(
  94. "base class must define '_check_graph_module' method"
  95. )
  96. if attrs["dialect"] != "ATEN":
  97. raise AssertionError(
  98. f"base class dialect must be 'ATEN', got {attrs['dialect']}"
  99. )
  100. if not isinstance(attrs["dialect"], str):
  101. raise AssertionError(f"dialect must be str, got {type(attrs['dialect'])}")
  102. ret = type.__new__(metacls, name, bases, attrs)
  103. metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment]
  104. return ret
  105. def getattr_recursive(obj: Any, target: str) -> Any:
  106. target_atoms = target.split(".")
  107. attr_itr = obj
  108. for i, atom in enumerate(target_atoms):
  109. if not hasattr(attr_itr, atom):
  110. raise RuntimeError(
  111. f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
  112. )
  113. attr_itr = getattr(attr_itr, atom)
  114. return attr_itr
  115. class Verifier(metaclass=_VerifierMeta):
  116. dialect = "ATEN"
  117. def allowed_builtin_ops(self) -> list:
  118. return [
  119. operator.getitem,
  120. operator.add,
  121. operator.mul,
  122. operator.sub,
  123. operator.truediv,
  124. operator.ge,
  125. operator.le,
  126. operator.gt,
  127. operator.lt,
  128. operator.eq,
  129. operator.ne,
  130. operator.floordiv,
  131. operator.mod,
  132. operator.and_,
  133. operator.or_,
  134. operator.not_,
  135. operator.pow,
  136. operator.neg,
  137. operator.abs,
  138. operator.lshift,
  139. operator.rshift,
  140. math.ceil,
  141. math.floor,
  142. math.trunc,
  143. round,
  144. ]
  145. def allowed_op_types(self) -> tuple[type[Any], ...]:
  146. return (OpOverload, HigherOrderOperator)
  147. def allowed_getattr_types(self) -> tuple[type[Any], ...]:
  148. return (torch.fx.GraphModule, torch.utils._pytree.TreeSpec)
  149. def allowed_getattr_types_for_subgm(self) -> tuple[type[Any], ...]:
  150. # subgm in HOP's argument could has have getattr(weight) nodes, thus stateful
  151. return (
  152. torch.fx.GraphModule,
  153. torch.nn.parameter.Parameter,
  154. torch.Tensor, # for buffer and constant tensor
  155. torch.utils._pytree.TreeSpec,
  156. )
  157. def check_valid_op(self, op):
  158. pass
  159. def check_additional(self, gm: GraphModule) -> None:
  160. """
  161. Additional checks that are specific to some dialects.
  162. """
  163. @final
  164. def check(self, ep: "ExportedProgram") -> None:
  165. self._check_graph_module(ep.graph_module)
  166. _verify_exported_program_module_call_graph(ep)
  167. _verify_exported_program_signature(ep)
  168. @final
  169. def _check_graph_module(self, gm: torch.fx.GraphModule) -> None:
  170. def _allowed_getattr_types(is_toplevel_gm) -> tuple[type[Any], ...]:
  171. if is_toplevel_gm:
  172. ret = self.allowed_getattr_types()
  173. else:
  174. ret = self.allowed_getattr_types_for_subgm()
  175. if any(t is object for t in ret):
  176. raise AssertionError("allowed_getattr_types must not contain 'object'")
  177. return ret
  178. def _check_valid_op(op) -> None:
  179. def _allowed_builtin_ops() -> list:
  180. ret = self.allowed_builtin_ops()
  181. if not all(inspect.isbuiltin(op) for op in ret):
  182. raise AssertionError("allowed_builtin_ops must all be builtins")
  183. return ret
  184. def _allowed_op_types() -> tuple[type[Any], ...]:
  185. ret = self.allowed_op_types()
  186. if any(t is object for t in ret):
  187. raise AssertionError("allowed_op_types must not contain 'object'")
  188. return ret
  189. # TODO Remove this allowlist.
  190. _allowed_torch_functions = (
  191. torch.autograd.grad_mode.set_grad_enabled,
  192. torch.sym_int,
  193. torch.sym_float,
  194. torch.sym_ite,
  195. torch.sym_max,
  196. torch.sym_min,
  197. torch.sym_not,
  198. torch.sym_sqrt,
  199. torch.sym_sum,
  200. torch.export.custom_ops._call_custom_autograd_function_in_pre_dispatch,
  201. # TODO (tmanlaibaatar)
  202. # Predispatch export is able to contain autograd ops.
  203. # These will be modeled as HOO later
  204. torch._C._set_grad_enabled,
  205. torch.amp.autocast_mode._enter_autocast,
  206. torch.amp.autocast_mode._exit_autocast,
  207. torch.fx.experimental.symbolic_shapes.cast_symbool_to_symint_guardless,
  208. torch._functorch.predispatch._add_batch_dim,
  209. torch._functorch.predispatch._remove_batch_dim,
  210. torch._functorch.predispatch._vmap_increment_nesting,
  211. torch._functorch.predispatch._vmap_decrement_nesting,
  212. torch._functorch.predispatch.lazy_load_decompositions,
  213. )
  214. if not isinstance(op, _allowed_op_types()):
  215. if (
  216. op not in _allowed_builtin_ops()
  217. and op not in _allowed_torch_functions
  218. ):
  219. raise SpecViolationError(
  220. f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n"
  221. f"Valid builtin ops: {_allowed_builtin_ops()}"
  222. f"Valid torch functions: {_allowed_torch_functions}"
  223. )
  224. if isinstance(op, OpOverload):
  225. # All ops functional
  226. # TODO (tmanlaibaatar) more proper way is needed here
  227. if self.dialect != "TRAINING" and not is_functional(op):
  228. raise SpecViolationError(f"operator '{op}' is not functional")
  229. self.check_valid_op(op)
  230. for mod in gm.modules():
  231. is_toplevel_gm = mod is gm
  232. if not isinstance(mod, torch.fx.GraphModule):
  233. continue
  234. mod.graph.lint()
  235. for node in mod.graph.nodes:
  236. # TODO(T140410192): should have fake tensor for all dialects
  237. if node.op in {"call_module", "call_method"}:
  238. raise SpecViolationError(
  239. f"call_module is not valid: got a class '{node.target}' ",
  240. )
  241. elif node.op == "call_function":
  242. _check_val(node)
  243. _check_valid_op(node.target)
  244. elif node.op == "get_attr":
  245. if not isinstance(node.target, str):
  246. raise SpecViolationError(
  247. f"Expected get_attr target to be string, but got {type(node.target)}"
  248. )
  249. attr = getattr_recursive(mod, node.target)
  250. if isinstance(attr, torch.nn.Module):
  251. def _is_type(name, ty):
  252. return isinstance(getattr(attr, name, None), ty)
  253. if type(attr).__name__ == "LoweredBackendModule":
  254. if (
  255. _is_type("backend_id", str)
  256. and hasattr(attr, "original_module")
  257. and hasattr(attr, "module_name")
  258. and getattr(attr, "backend_id", None) == "aoti"
  259. ):
  260. continue
  261. if (
  262. _is_type("backend_id", str)
  263. and _is_type("processed_bytes", bytes)
  264. and _is_type("compile_specs", list)
  265. and hasattr(attr, "original_module")
  266. ):
  267. continue
  268. else:
  269. backend_id = getattr(attr, "backend_id", None)
  270. processed_bytes = getattr(attr, "processed_bytes", None)
  271. compile_specs = getattr(attr, "compile_specs", None)
  272. raise SpecViolationError(
  273. f"Invalid get_attr type {type(attr)}. \n"
  274. f"LoweredBackendModule fields: "
  275. f"backend_id(str) : {type(backend_id)}, "
  276. f"processed_bytes(bytes) : {type(processed_bytes)}, "
  277. f"compile_specs(list) : {type(compile_specs)}"
  278. )
  279. elif type(attr).__name__ == "AOTInductorEPModule":
  280. continue
  281. elif type(attr).__name__ == "AOTInductorRunnerWrapper":
  282. continue
  283. if not isinstance(attr, _allowed_getattr_types(is_toplevel_gm)):
  284. raise SpecViolationError(
  285. f"Invalid get_attr type {type(attr)} on target {node.target}. \n"
  286. f"Valid get_attr types: {_allowed_getattr_types(is_toplevel_gm)}"
  287. )
  288. elif node.op == "placeholder":
  289. _check_val(node)
  290. # TODO(zhxchen17)
  291. # elif node.op == "output":
  292. # _check_flattened_outputs()
  293. self.check_additional(gm)
  294. class TrainingIRVerifier(Verifier):
  295. dialect = "TRAINING"
  296. def _verify_exported_program_module_call_graph(exported_program) -> None:
  297. module_call_graph = exported_program.module_call_graph
  298. nodes = {node.name for node in exported_program.graph.nodes}
  299. for entry in module_call_graph:
  300. if entry.signature is not None:
  301. for arg in entry.signature.inputs:
  302. if arg.name and arg.name not in nodes:
  303. raise SpecViolationError(
  304. f"Input {arg.name} does not exist in the graph."
  305. )
  306. for arg in entry.signature.outputs:
  307. if arg.name and arg.name not in nodes:
  308. raise SpecViolationError(
  309. f"Output {arg.name} does not exist in the graph."
  310. )
  311. def _verify_exported_program_signature(exported_program) -> None:
  312. # Check ExportedProgram signature matches
  313. gs = exported_program.graph_signature
  314. # Check every node in the signature exists in the graph
  315. input_node_names = [
  316. node.name for node in exported_program.graph.nodes if node.op == "placeholder"
  317. ]
  318. if len(input_node_names) != len(gs.input_specs):
  319. input_spec_names = [
  320. spec.arg.name for spec in gs.input_specs if hasattr(spec.arg, "name")
  321. ]
  322. missing_in_specs = set(input_node_names) - set(input_spec_names)
  323. missing_in_graph = set(input_spec_names) - set(input_node_names)
  324. raise SpecViolationError(
  325. f"Number of graph inputs ({len(input_node_names)}) "
  326. f"does not match number of inputs in the graph signature ({len(gs.input_specs)})\n"
  327. f"Placeholders missing input_specs: {missing_in_specs}\n"
  328. f"Input_specs missing placeholders: {missing_in_graph}"
  329. )
  330. for input_spec, node in zip(gs.input_specs, input_node_names):
  331. if isinstance(
  332. input_spec.arg,
  333. (TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument),
  334. ):
  335. if input_spec.arg.name != node:
  336. raise SpecViolationError(
  337. f"Input spec name {input_spec.arg.name} does not match node name {node}"
  338. )
  339. if input_spec.kind == InputKind.USER_INPUT:
  340. continue
  341. elif input_spec.kind == InputKind.PARAMETER:
  342. if not isinstance(input_spec.arg, TensorArgument):
  343. raise SpecViolationError(
  344. f"Parameter {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
  345. )
  346. if input_spec.target is None:
  347. raise SpecViolationError(
  348. f"InputSpec for {input_spec.name} has no target."
  349. )
  350. param = input_spec.target
  351. if param not in exported_program.state_dict:
  352. raise SpecViolationError(f"Parameter {param} is not in the state dict.")
  353. if not isinstance(exported_program.state_dict[param], torch.nn.Parameter):
  354. raise SpecViolationError(
  355. f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter."
  356. )
  357. elif input_spec.kind == InputKind.BUFFER:
  358. if not isinstance(input_spec.arg, TensorArgument):
  359. raise SpecViolationError(
  360. f"Buffer {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
  361. )
  362. if input_spec.target is None:
  363. raise SpecViolationError(
  364. f"InputSpec for {input_spec.name} has no target."
  365. )
  366. buffer = input_spec.target
  367. if input_spec.persistent is None:
  368. raise SpecViolationError(
  369. f"Buffer {buffer} is missing a persistence flag"
  370. )
  371. if (
  372. input_spec.persistent is True
  373. and buffer not in exported_program.state_dict
  374. ):
  375. raise SpecViolationError(f"Buffer {buffer} is not in the state dict.")
  376. if input_spec.persistent is False and buffer in exported_program.state_dict:
  377. raise SpecViolationError(
  378. f"Non-persistent buffer {buffer} is in the state dict, it should not be."
  379. )
  380. elif input_spec.kind == InputKind.CONSTANT_TENSOR:
  381. if not isinstance(input_spec.arg, TensorArgument):
  382. raise SpecViolationError(
  383. f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
  384. )
  385. if input_spec.target is None:
  386. raise SpecViolationError(
  387. f"InputSpec for {input_spec.name} has no target."
  388. )
  389. tensor_const = input_spec.target
  390. if tensor_const not in exported_program.constants:
  391. raise SpecViolationError(
  392. f"Constant tensor {tensor_const} is not in the constants dictionary."
  393. )
  394. elif input_spec.kind == InputKind.CUSTOM_OBJ:
  395. if not isinstance(input_spec.arg, CustomObjArgument):
  396. raise SpecViolationError(
  397. f"Custom object {input_spec.name} is not a custom object argument. Found {input_spec.arg} instead."
  398. )
  399. if input_spec.target is None:
  400. raise SpecViolationError(
  401. f"InputSpec for {input_spec.name} has no target."
  402. )
  403. custom_obj = input_spec.target
  404. if custom_obj not in exported_program.constants:
  405. raise SpecViolationError(
  406. f"Custom object {custom_obj} is not in the constants dictionary."
  407. )
  408. elif input_spec.kind == InputKind.TOKEN:
  409. if not isinstance(input_spec.arg, TokenArgument):
  410. raise SpecViolationError(
  411. f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
  412. )
  413. else:
  414. raise SpecViolationError(f"Unknown InputKind {input_spec.kind}.")
  415. # Check outputs
  416. output_node = list(exported_program.graph.nodes)[-1]
  417. if output_node.op != "output":
  418. raise AssertionError(f"last node must be output, got {output_node.op}")
  419. output_nodes = [
  420. arg.name if isinstance(arg, torch.fx.Node) else arg
  421. for arg in output_node.args[0]
  422. ]
  423. if len(output_nodes) != len(gs.output_specs):
  424. output_spec_names = [
  425. spec.arg.name if hasattr(spec.arg, "name") else str(spec.arg)
  426. for spec in gs.output_specs
  427. ]
  428. missing_out_specs = set(output_nodes) - set(output_spec_names)
  429. missing_out_graph = set(output_spec_names) - set(output_nodes)
  430. raise SpecViolationError(
  431. f"Number of output nodes {len(output_nodes)} is different "
  432. f"Than the number of outputs specified by the graph signature: {len(gs.output_specs)}\n"
  433. f"Nodes missing output_specs: {missing_out_specs}\n"
  434. f"Output_specs missing nodes: {missing_out_graph}"
  435. )
  436. num_tokens = len(gs.output_tokens)
  437. end = (
  438. len(gs.buffers_to_mutate)
  439. + len(gs.parameters_to_mutate)
  440. + len(gs.user_inputs_to_mutate)
  441. + num_tokens
  442. )
  443. mutate_nodes: list[str] = output_nodes[num_tokens:end]
  444. user_output_nodes = output_nodes[end : end + len(gs.user_outputs)]
  445. for mutation_node in mutate_nodes:
  446. if mutation_node in gs.buffers_to_mutate:
  447. if gs.buffers_to_mutate[mutation_node] not in gs.buffers:
  448. raise SpecViolationError(
  449. f"Buffer output {mutation_node} does not point to a buffer that exists. \n"
  450. f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
  451. f"Buffer nodes available: {gs.buffers} \n"
  452. )
  453. elif mutation_node in gs.parameters_to_mutate:
  454. if gs.parameters_to_mutate[mutation_node] not in gs.parameters:
  455. raise SpecViolationError(
  456. f"Parameter output {mutation_node} does not point to a parameter that exists. \n"
  457. f"Dict of parameters that are mutated, in order: {gs.parameters_to_mutate} \n"
  458. f"Parameter nodes available: {gs.parameters} \n"
  459. )
  460. elif mutation_node in gs.user_inputs_to_mutate:
  461. if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs:
  462. raise SpecViolationError(
  463. f"User input output {mutation_node} does not point to a user input that exists. \n"
  464. f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n"
  465. f"User input nodes available: {gs.user_inputs} \n"
  466. )
  467. else:
  468. raise SpecViolationError(
  469. f"Mutation node {mutation_node} is neither a buffer nor a user input. "
  470. f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}"
  471. )
  472. for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs):
  473. if user_output_node != user_output_name:
  474. raise SpecViolationError(
  475. f"User output {user_output_node} is not in the correct "
  476. "order or is not found in the "
  477. f"exported program's user_output list: {gs.user_outputs}. "
  478. )
  479. def load_verifier(dialect: str) -> type[Verifier]:
  480. if dialect == "ATEN" or dialect == "":
  481. return _VerifierMeta._registry.get(dialect, Verifier)
  482. return _VerifierMeta._registry[dialect]