_recursive.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import functools
  4. import inspect
  5. import textwrap
  6. import types
  7. import warnings
  8. import torch
  9. import torch._jit_internal as _jit_internal
  10. from torch._sources import fake_range
  11. from torch.jit._builtins import _find_builtin
  12. from torch.jit._check import AttributeTypeIsSupportedChecker
  13. from torch.jit._state import _add_script_class, _get_script_class, _python_cu
  14. from torch.jit.frontend import (
  15. get_class_properties,
  16. get_default_args,
  17. get_jit_class_def,
  18. get_jit_def,
  19. )
  20. from torch.nn import Module
  21. ScriptMethodStub = collections.namedtuple(
  22. "ScriptMethodStub", ("resolution_callback", "def_", "original_method")
  23. )
  24. PropertyStub = collections.namedtuple("PropertyStub", ("resolution_callback", "def_"))
  25. # TODO: there should be a more principled way of doing this.
  26. ignored_attributes = [
  27. "_version",
  28. "_parameters",
  29. "_buffers",
  30. "_non_persistent_buffers_set",
  31. "_backward_hooks",
  32. "_backward_pre_hooks",
  33. "_forward_hooks",
  34. "_forward_hooks_with_kwargs",
  35. "_forward_pre_hooks",
  36. "_forward_pre_hooks_with_kwargs",
  37. "_forward_hooks_always_called",
  38. "_state_dict_hooks",
  39. "_state_dict_pre_hooks",
  40. "_load_state_dict_pre_hooks",
  41. "_load_state_dict_post_hooks",
  42. "_modules",
  43. "_initializing",
  44. "dump_patches",
  45. ]
  46. def _compile_and_register_class(obj, rcb, qualified_name):
  47. script_class = _get_script_class(obj)
  48. if not script_class:
  49. ast = get_jit_class_def(obj, obj.__name__)
  50. defaults = torch.jit.frontend.get_default_args_for_class(obj)
  51. script_class = torch._C._jit_script_class_compile(
  52. qualified_name, ast, defaults, rcb
  53. )
  54. _add_script_class(obj, script_class)
  55. return script_class
  56. def make_stub(func, name):
  57. rcb = _jit_internal.createResolutionCallbackFromClosure(func)
  58. ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
  59. return ScriptMethodStub(rcb, ast, func)
  60. def make_stub_from_method(nn_module, method_name):
  61. func = getattr(nn_module, method_name)
  62. if isinstance(func, ScriptMethodStub):
  63. return func
  64. # Make sure the name present in the resulting AST will match the name
  65. # requested here. The only time they don't match is if you do something
  66. # like:
  67. # def _forward(self):
  68. # pass
  69. # forward = _forward
  70. # In this case, the actual function object will have the name `_forward`,
  71. # even though we requested a stub for `forward`.
  72. return make_stub(func, method_name)
  73. def make_stubs_from_exported_methods(mod):
  74. stubs = []
  75. for name in dir(mod):
  76. item = getattr(mod, name, None)
  77. if (
  78. _jit_internal.get_torchscript_modifier(item)
  79. is _jit_internal.FunctionModifiers.EXPORT
  80. ):
  81. stubs.append(make_stub_from_method(mod, name))
  82. return stubs
  83. def jit_ignored_properties(module):
  84. user_annotated_ignored_attributes = getattr(
  85. module, "__jit_ignored_attributes__", []
  86. )
  87. def get_properties_names(module):
  88. return {k for k, v in vars(module).items() if isinstance(v, property)}
  89. properties = get_properties_names(type(module))
  90. user_annoted_ignored_properties = set()
  91. for ignored_attr in user_annotated_ignored_attributes:
  92. if ignored_attr in properties:
  93. user_annoted_ignored_properties.add(ignored_attr)
  94. return user_annoted_ignored_properties
  95. # base types that can be constants
  96. # in addition, tuples and lists of these base types are also considered constants
  97. # If you edit this list, then you also need to edit the handlers in
  98. # ConstantValue in jit/script/init.cpp
  99. _constant_types = (
  100. bool,
  101. float,
  102. int,
  103. str,
  104. type(None),
  105. torch.device,
  106. torch.layout,
  107. torch.dtype,
  108. torch.qscheme,
  109. )
  110. def _get_valid_constant(attr, v, owner_type):
  111. if isinstance(v, _constant_types):
  112. return v
  113. elif isinstance(v, (tuple, list)):
  114. return tuple(_get_valid_constant(attr, x, owner_type) for x in v)
  115. constants = ", ".join(torch.typename(typ) for typ in _constant_types)
  116. raise TypeError(
  117. textwrap.dedent(
  118. f"""
  119. '{torch.typename(type(v))}' object in attribute '{owner_type}.{attr}' is not a valid constant.
  120. Valid constants are:
  121. 1. a nn.ModuleList
  122. 2. a value of type {{{constants}}}
  123. 3. a list or tuple of (2)
  124. """
  125. )
  126. )
  127. class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
  128. pass
  129. def get_annotations(obj):
  130. # In Python-3.10+ it is recommended to use inspect.get_annotations
  131. # See https://docs.python.org/3.10/howto/annotations.html
  132. # But also, in 3.10 annotations from base class are not inherited
  133. # by unannotated derived one, so they must be manually extracted
  134. annotations = inspect.get_annotations(obj)
  135. if annotations:
  136. return annotations
  137. def get_cls_annotations(cls):
  138. cls_annotations = inspect.get_annotations(cls)
  139. if cls_annotations:
  140. return cls_annotations
  141. for base in cls.__bases__:
  142. cls_annotations = get_cls_annotations(base)
  143. if cls_annotations:
  144. return cls_annotations
  145. return {}
  146. cls = obj if isinstance(obj, type) else type(obj)
  147. return get_cls_annotations(cls)
  148. def infer_concrete_type_builder(nn_module, share_types=True):
  149. """
  150. Build a ConcreteModuleTypeBuilder from an nn.Module.
  151. This ConcreteModuleType doesn't have a JIT type associated with it yet, it
  152. must be filled in by the caller.
  153. """
  154. concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module))
  155. if isinstance(nn_module, (torch.nn.ModuleDict)):
  156. concrete_type_builder.set_module_dict()
  157. if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)):
  158. concrete_type_builder.set_module_list()
  159. if isinstance(nn_module, (torch.nn.ParameterList)):
  160. concrete_type_builder.set_parameter_list()
  161. if isinstance(nn_module, (torch.nn.ParameterDict)):
  162. concrete_type_builder.set_parameter_dict()
  163. class_annotations = get_annotations(nn_module)
  164. if isinstance(nn_module, (torch.ao.quantization.QuantWrapper)):
  165. class_annotations = {}
  166. # Get user-annotated ignored attributes.
  167. user_annotated_ignored_attributes = getattr(
  168. nn_module, "__jit_ignored_attributes__", []
  169. )
  170. concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes)
  171. ignored_properties = jit_ignored_properties(nn_module)
  172. # try to infer the type from type annotation or from the object itself
  173. def infer_type(name, item):
  174. # The forward function from Module is special; never use this annotations; we
  175. # need to infer type directly using JIT. I originally wanted to write
  176. # this test as isinstance(class_annotations[name], Callable) but
  177. # isinstance on typing things doesn't seem to work: isinstance(list, Callable)
  178. # is also true!
  179. inferred = False
  180. try:
  181. if (
  182. name in class_annotations
  183. and class_annotations[name]
  184. != torch.nn.Module.__annotations__["forward"]
  185. ):
  186. ann_to_type = torch.jit.annotations.ann_to_type(
  187. class_annotations[name], fake_range()
  188. )
  189. attr_type = torch._C.InferredType(ann_to_type)
  190. elif isinstance(item, torch.jit.Attribute):
  191. ann_to_type = torch.jit.annotations.ann_to_type(item.type, fake_range())
  192. attr_type = torch._C.InferredType(ann_to_type)
  193. else:
  194. attr_type = torch._C._jit_try_infer_type(item)
  195. inferred = True
  196. except RuntimeError as re:
  197. raise RuntimeError(f"Error inferring type for {name}: {item}: {re}") from re
  198. return attr_type, inferred
  199. added_names = set()
  200. for name, item in nn_module._parameters.items():
  201. if name in user_annotated_ignored_attributes:
  202. continue
  203. if not (item is None or isinstance(item, torch.Tensor)):
  204. raise AssertionError(
  205. f"Expected parameter '{name}' to be None or Tensor, got {type(item)}"
  206. )
  207. attr_type, _ = infer_type(name, item)
  208. # We currently have the invariant in various places in our code
  209. # that parameters must be Tensors. However, the nn.Module API also
  210. # allows NoneType parameters. These parameters are not returned as
  211. # part of `parameters()` and its variants, but are available
  212. # through direct attribute access.
  213. concrete_type_builder.add_attribute(name, attr_type.type(), True, False)
  214. added_names.add(name)
  215. for name, item in nn_module._buffers.items():
  216. if name in user_annotated_ignored_attributes:
  217. continue
  218. if not (item is None or isinstance(item, torch.Tensor)):
  219. raise AssertionError(
  220. f"Expected buffer '{name}' to be None or Tensor, got {type(item)}"
  221. )
  222. attr_type, _ = infer_type(name, item)
  223. concrete_type_builder.add_attribute(name, attr_type.type(), False, True)
  224. added_names.add(name)
  225. for name, item in nn_module._modules.items():
  226. if name in user_annotated_ignored_attributes:
  227. continue
  228. attr_type, _ = infer_type(name, item)
  229. if item is None:
  230. # Modules can be None. We don't have direct support for optional
  231. # Modules, so the register it as an NoneType attribute instead.
  232. concrete_type_builder.add_attribute(name, attr_type.type(), False, False)
  233. continue
  234. if attr_type.success():
  235. if not attr_type.type().is_interface_type():
  236. raise AssertionError(
  237. f"Expected inferred type to be interface type for '{name}'"
  238. )
  239. # if the type can be inferred, it should be a module interface type
  240. sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(
  241. attr_type.type()
  242. )
  243. else:
  244. # otherwise we get the concrete module type for item and add it to concrete_type
  245. sub_concrete_type = get_module_concrete_type(item, share_types)
  246. concrete_type_builder.add_module(name, sub_concrete_type)
  247. added_names.add(name)
  248. # populate constants_set
  249. constants_set = set(getattr(nn_module, "__constants__", ()))
  250. # Constants annotated via `Final[T]` rather than being added to `__constants__`
  251. for name, ann in class_annotations.items():
  252. if torch._jit_internal.is_final(ann):
  253. constants_set.add(name)
  254. for name in constants_set:
  255. if name in added_names:
  256. # TODO: We should really error in this case, but its bc-breaking so
  257. # we need to warn for at least one release
  258. if name in nn_module._modules:
  259. hint = "submodule"
  260. elif name in nn_module._buffers:
  261. hint = "buffer"
  262. elif name in nn_module._parameters:
  263. hint = "parameter"
  264. else:
  265. raise AssertionError(
  266. "added_names must be submodule, parameter, or buffer"
  267. )
  268. warnings.warn(
  269. f"'{name}' was found in ScriptModule constants, "
  270. f" but it is a non-constant {hint}. Consider removing it.",
  271. stacklevel=2,
  272. )
  273. continue
  274. if not hasattr(nn_module, name):
  275. # TODO: We should really error in this case, but its bc-breaking so
  276. # we need to warn for at least one release
  277. warnings.warn(
  278. f"'{name}' was found in ScriptModule constants, "
  279. "but was not actually set in __init__. "
  280. "Consider removing it.",
  281. stacklevel=2,
  282. )
  283. continue
  284. value = getattr(nn_module, name)
  285. concrete_type_builder.add_constant(
  286. name, _get_valid_constant(name, value, type(nn_module).__name__)
  287. )
  288. added_names.add(name)
  289. # populate overloads
  290. overloads = getattr(nn_module, "__overloads__", {})
  291. # update with any annotated overloads
  292. overloads.update(
  293. get_overload_name_mapping(
  294. get_overload_annotations(nn_module, ignored_properties)
  295. )
  296. )
  297. for name, overloaded_names in overloads.items():
  298. concrete_type_builder.add_overload(name, overloaded_names)
  299. for name, value in nn_module.__dict__.items():
  300. if name in ignored_attributes or name.startswith("__"):
  301. # Python objects have lots of random attributes attached to them;
  302. # PyTorch adds a few more. Prevent these from getting compiled.
  303. continue
  304. if name in user_annotated_ignored_attributes:
  305. continue
  306. if name in added_names:
  307. # Don't re-add anything we already added
  308. continue
  309. isoverloadpacket = isinstance(value, torch._ops.OpOverloadPacket)
  310. if isoverloadpacket:
  311. value = value.op
  312. # Handle Python function attributes
  313. if inspect.isfunction(value):
  314. try:
  315. scripted_fn = torch.jit.script(value)
  316. concrete_type_builder.add_function_attribute(
  317. name, torch._C._jit_try_infer_type(scripted_fn).type(), value
  318. )
  319. except Exception as e:
  320. # If we fail to script the function, it isn't a hard error.
  321. # Instead, we will add it to the list of attributes we failed
  322. # to convert, with the compilation error.
  323. hint = (
  324. "(This function exists as an attribute on the Python module, "
  325. "but we failed to compile it to a TorchScript function. "
  326. f"\nThe error stack is reproduced here:\n{e})"
  327. )
  328. concrete_type_builder.add_failed_attribute(name, hint)
  329. continue
  330. # Handle calls to builtin functions (either bespoke builtins from torch.jit._builtins or
  331. # a call to an aten function like torch.add)
  332. builtin_symbol_name = _find_builtin(value)
  333. if builtin_symbol_name:
  334. concrete_type_builder.add_builtin_function(name, builtin_symbol_name)
  335. continue
  336. # Handle Script function attributes
  337. if isinstance(value, torch.jit.ScriptFunction):
  338. concrete_type_builder.add_function_attribute(
  339. name, torch._C._jit_try_infer_type(value).type(), value
  340. )
  341. continue
  342. # If we got here, this is a regular "data" attribute, add it to the concrete type
  343. attr_type, inferred = infer_type(name, value)
  344. if attr_type.success():
  345. concrete_type_builder.add_attribute(name, attr_type.type(), False, False)
  346. else:
  347. # TODO: could add more detail here. For example, what the user should do
  348. # when the pytype is `list` or `NoneType`
  349. inferred_msg = (
  350. "Its type was inferred; try adding a type annotation for the attribute."
  351. if inferred
  352. else ""
  353. )
  354. additional_info = f"{attr_type.reason()}. {inferred_msg}"
  355. hint = (
  356. "(This attribute exists on the Python module, "
  357. f"but we failed to convert Python type: '{torch.typename(type(value))}' "
  358. f"to a TorchScript type. {additional_info})"
  359. )
  360. concrete_type_builder.add_failed_attribute(name, hint)
  361. # add hooks to concrete type
  362. for hook in nn_module._forward_hooks.values():
  363. concrete_type_builder.add_forward_hook(hook)
  364. for pre_hook in nn_module._forward_pre_hooks.values():
  365. concrete_type_builder.add_forward_pre_hook(pre_hook)
  366. return concrete_type_builder
  367. class ConcreteTypeStore:
  368. type_store: dict[type[Module], list[torch._C.ConcreteModuleType]]
  369. methods_compiled: set[torch._C.ConcreteModuleType]
  370. def __init__(self) -> None:
  371. # Python module type => List[ConcreteModuleType)]
  372. self.type_store = {}
  373. # ConcreteTypes that have had their methods already compiled
  374. self.methods_compiled = set()
  375. def get_or_create_concrete_type(self, nn_module):
  376. """Infer a ConcreteType from this `nn.Module` instance. Underlying JIT types are reused if possible."""
  377. concrete_type_builder = infer_concrete_type_builder(nn_module)
  378. nn_module_type = type(nn_module)
  379. if nn_module_type not in self.type_store:
  380. self.type_store[nn_module_type] = []
  381. # Search the type store for an already-available JIT type
  382. known_types = self.type_store[nn_module_type]
  383. for known_type in known_types:
  384. if known_type.equals(concrete_type_builder):
  385. return known_type
  386. # We didn't find anything; generate a new JIT type from this concrete type
  387. concrete_type = concrete_type_builder.build()
  388. self.type_store[nn_module_type].append(concrete_type)
  389. return concrete_type
  390. concrete_type_store = ConcreteTypeStore()
  391. def create_methods_and_properties_from_stubs(
  392. concrete_type, method_stubs, property_stubs
  393. ) -> None:
  394. method_defs = [m.def_ for m in method_stubs]
  395. method_rcbs = [m.resolution_callback for m in method_stubs]
  396. method_defaults = [get_default_args(m.original_method) for m in method_stubs]
  397. property_defs = [p.def_ for p in property_stubs]
  398. property_rcbs = [p.resolution_callback for p in property_stubs]
  399. concrete_type._create_methods_and_properties(
  400. property_defs, property_rcbs, method_defs, method_rcbs, method_defaults
  401. )
  402. def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs) -> None:
  403. hook_defs = [h.def_ for h in hook_stubs]
  404. hook_rcbs = [h.resolution_callback for h in hook_stubs]
  405. pre_hook_defs = [h.def_ for h in pre_hook_stubs]
  406. pre_hook_rcbs = [h.resolution_callback for h in pre_hook_stubs]
  407. concrete_type._create_hooks(hook_defs, hook_rcbs, pre_hook_defs, pre_hook_rcbs)
  408. def get_module_concrete_type(nn_module, share_types=True):
  409. """
  410. Get a concrete type for nn_modules.
  411. If share_types is True, the concrete type is fetched from concrete_type_store.
  412. If it is False, a new concrete type is created without first searching concrete_type_store.
  413. Args:
  414. nn_module: The original Python nn.Module that we are creating a ScriptModule for.
  415. share_types = Whether to share underlying JIT types between modules (if possible).
  416. Returns:
  417. A concrete type for nn_module.
  418. """
  419. if not isinstance(nn_module, Module):
  420. raise AssertionError(f"Expected Module, got {type(nn_module)}")
  421. if isinstance(nn_module, torch.jit.ScriptModule) and hasattr(
  422. nn_module, "_concrete_type"
  423. ):
  424. return nn_module._concrete_type
  425. if share_types:
  426. # Look into the store of cached JIT types
  427. concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
  428. else:
  429. # Get a concrete type directly, without trying to reuse an existing JIT
  430. # type from the type store.
  431. concrete_type_builder = infer_concrete_type_builder(nn_module, share_types)
  432. concrete_type_builder.set_poisoned()
  433. concrete_type = concrete_type_builder.build()
  434. return concrete_type
  435. def create_script_class(obj):
  436. """
  437. Create and return a RecursiveScriptClass instance from a Python object.
  438. Arguments:
  439. obj: A Python object.
  440. """
  441. qualified_class_name = _jit_internal._qualified_name(type(obj))
  442. rcb = _jit_internal.createResolutionCallbackForClassMethods(type(obj))
  443. # Script the type of obj if it hasn't already been scripted.
  444. _compile_and_register_class(type(obj), rcb, qualified_class_name)
  445. class_ty = _python_cu.get_class(qualified_class_name)
  446. # Create an empty torch._C.ScriptObject with the scripted type.
  447. cpp_object = torch._C._create_object_with_type(class_ty)
  448. # Copy all of the attributes over to the torch._C.ScriptObject.
  449. for name, value in obj.__dict__.items():
  450. cpp_object.setattr(name, value)
  451. # Wrap the torch._C.ScriptObject in a RecursiveScriptClass instance.
  452. return wrap_cpp_class(cpp_object)
  453. def create_script_module(nn_module, stubs_fn, share_types=True, is_tracing=False):
  454. """
  455. Create a new ScriptModule from an nn.Module.
  456. Args:
  457. nn_module: The original Python nn.Module that we are creating a ScriptModule for.
  458. stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
  459. share_types: Whether to share underlying JIT types between modules (if possible).
  460. NOTE: Only set to False this when we cannot guarantee type sharing will work
  461. correctly. This only happens today for traced modules, where the same
  462. module can produce different traced methods depending on the inputs.
  463. is_tracing: Whether this function is called during tracing or scripting. If tracing,
  464. we don't need to do AttributeTypeIsSupportedChecker because all the unsupported
  465. attributes will be baked as constant in the tracing graph. In addition,
  466. this check significantly slows down the traced modules when the module size is big.
  467. """
  468. if isinstance(nn_module, torch.jit.RecursiveScriptModule):
  469. raise AssertionError("Cannot script a RecursiveScriptModule (already compiled)")
  470. check_module_initialized(nn_module)
  471. concrete_type = get_module_concrete_type(nn_module, share_types)
  472. if not is_tracing:
  473. AttributeTypeIsSupportedChecker().check(nn_module)
  474. return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  475. def create_script_module_impl(nn_module, concrete_type, stubs_fn):
  476. """
  477. Convert an nn.Module to a RecursiveScriptModule.
  478. Args:
  479. nn_module: The original Python nn.Module that we are creating a ScriptModule for.
  480. concrete_type: The fully initialized ConcreteType of the module.
  481. stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
  482. """
  483. cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
  484. method_stubs = stubs_fn(nn_module)
  485. property_stubs = get_property_stubs(nn_module)
  486. hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module)
  487. ignored_properties = jit_ignored_properties(nn_module)
  488. def init_fn(script_module) -> None:
  489. # Initialize the ScriptModule:
  490. # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule.
  491. for name in concrete_type.get_attributes():
  492. orig_value = getattr(nn_module, name)
  493. orig_value = (
  494. orig_value.value
  495. if isinstance(orig_value, torch.jit.Attribute)
  496. else orig_value
  497. )
  498. cpp_module.setattr(name, orig_value)
  499. # 2. Copy the submodules from the original `nn_module` to the new ScriptModule,
  500. # recursively scripting them.
  501. for name, sub_concrete_type in concrete_type.get_modules():
  502. orig_value = getattr(nn_module, name)
  503. if not isinstance(orig_value, Module):
  504. raise AssertionError(f"Expected Module but got {type(orig_value)}")
  505. module_type = sub_concrete_type.jit_type
  506. if isinstance(module_type, torch._C.InterfaceType):
  507. # use the interface inference rule to compile the module
  508. scripted = interface_script(module_type, orig_value)
  509. elif isinstance(orig_value, torch.jit.ScriptModule):
  510. scripted = orig_value
  511. else:
  512. # always reuse the provided stubs_fn to infer the methods to compile
  513. scripted = create_script_module_impl(
  514. orig_value, sub_concrete_type, stubs_fn
  515. )
  516. cpp_module.setattr(name, scripted)
  517. script_module._modules[name] = scripted
  518. # 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule.
  519. # This ensures we can access these Python methods on the ScriptModule.
  520. for name in dir(nn_module):
  521. if name in ignored_properties:
  522. continue
  523. item = getattr(nn_module, name, None)
  524. if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
  525. unbound_function = getattr(nn_module, name).__func__
  526. bound_method = unbound_function.__get__(script_module)
  527. setattr(script_module, name, bound_method)
  528. elif concrete_type.is_ignored_attribute(name):
  529. setattr(script_module, name, item)
  530. # For convenience, attach the concrete type to the new ScriptModule
  531. script_module._concrete_type = concrete_type
  532. # Actually create the ScriptModule, initializing it with the function we just defined
  533. script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  534. # Compile methods if necessary
  535. if concrete_type not in concrete_type_store.methods_compiled:
  536. create_methods_and_properties_from_stubs(
  537. concrete_type, method_stubs, property_stubs
  538. )
  539. # Create hooks after methods to ensure no name collisions between hooks and methods.
  540. # If done before, hooks can overshadow methods that aren't exported.
  541. create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)
  542. torch._C._run_emit_module_hook(cpp_module)
  543. concrete_type_store.methods_compiled.add(concrete_type)
  544. # Copy the forward hooks and pre-hooks to the new ScriptModule
  545. # to allow the hooks to be run from eager as ScriptFunctions
  546. for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()):
  547. script_module._forward_pre_hooks[idx] = fn
  548. for idx, fn in enumerate(script_module._c._get_forward_hooks()):
  549. script_module._forward_hooks[idx] = fn
  550. # Special handling so methods like __len__ work in script methods on classes derived from containers
  551. if (
  552. isinstance(
  553. nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)
  554. )
  555. and "__len__" not in cpp_module._method_names()
  556. ):
  557. script_module.define(f"def __len__(self):\n return {len(nn_module)}\n")
  558. if (
  559. isinstance(nn_module, torch.nn.ModuleDict)
  560. and "__contains__" not in cpp_module._method_names()
  561. ):
  562. if len(nn_module.keys()):
  563. keys = repr(list(nn_module.keys()))
  564. script_module.define(
  565. f"def __contains__(self, key: str):\n return key in {keys}\n"
  566. )
  567. else:
  568. script_module.define("def __contains__(self, key: str):\n return False\n")
  569. # Make the compiled methods available to the Python ScriptModule class.
  570. for method_stub in method_stubs:
  571. if method_stub.original_method is None:
  572. # define()'d methods don't have an Python original_method, so we
  573. # don't need to do any Python re-wrapping stuff
  574. continue
  575. name = method_stub.original_method.__name__
  576. if name != method_stub.def_.name().name:
  577. # TODO: Why skip this? Because @torch.jit._overload_method will
  578. # mangle the name of the function.
  579. continue
  580. script_method = cpp_module._get_method(name)
  581. # Wrap the original to propagate docstrings and such.
  582. # TODO: we don't currently do this functions that are recursively
  583. # compiled, we should.
  584. wrapped_script_method = functools.wraps(method_stub.original_method)(
  585. script_method
  586. )
  587. # Add the methods to the script_module directly. This ensures they will
  588. # be found first when `name` is looked up (as opposed to the stubs or
  589. # nn.Module.forward)
  590. script_module.__dict__[name] = wrapped_script_method
  591. # Make module properties available on the Python ScriptModule class.
  592. for property_stub in property_stubs:
  593. property_name = property_stub.def_.name().name
  594. fget = cpp_module._get_method(property_stub.def_.getter_name().name)
  595. # Setter is optional, so it may not exist.
  596. setter_name = property_stub.def_.setter_name()
  597. fset = cpp_module._get_method(setter_name.name) if setter_name else None
  598. script_module.__dict__[property_name] = property(property_name, fget, fset) # type: ignore[arg-type]
  599. # copy over python methods to script module if they aren't defined on the script module
  600. # this is currently an internal api used only on module containers
  601. for name in dir(nn_module):
  602. if name in ignored_properties:
  603. continue
  604. item = getattr(nn_module, name, None)
  605. if (
  606. _jit_internal.get_torchscript_modifier(item)
  607. is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
  608. ):
  609. add_python_attr_to_scripted_model(script_module, nn_module, name)
  610. return script_module
  611. # We define shims of certain attributes on the RecursiveScriptModule to support
  612. # magic methods. To check if a script model defines an attribute we need
  613. # to also check that the attribute is not the shim
  614. def script_model_defines_attr(script_model, attr):
  615. script_attr = getattr(script_model, attr, None)
  616. if script_attr is None:
  617. return False
  618. default_attr = getattr(torch.jit.RecursiveScriptModule, attr, None)
  619. if default_attr is None:
  620. return False
  621. return script_attr != default_attr
  622. def add_python_attr_to_scripted_model(script_model, orig, attr) -> None:
  623. if hasattr(orig, attr) and script_model_defines_attr(script_model, attr):
  624. setattr(script_model, attr, getattr(orig, attr))
  625. def get_overload_annotations(mod, jit_ignored_properties):
  626. # original function => [(mangled overload name, overload function)]
  627. overloads = {}
  628. for name in dir(type(mod)):
  629. if name in jit_ignored_properties:
  630. continue
  631. item = getattr(mod, name, None)
  632. if not callable(item):
  633. continue
  634. # builtin functions like repr() in python 2 do not have __module__ defined
  635. if hasattr(item, "__module__") and item.__module__ is not None:
  636. method_overloads = _jit_internal._get_overloaded_methods(
  637. item, mod.__class__
  638. )
  639. if method_overloads is None:
  640. continue
  641. # pyrefly: ignore [missing-attribute]
  642. if item.__func__ in method_overloads:
  643. raise RuntimeError(
  644. _jit_internal.get_overload_no_implementation_error_message(
  645. "method", item.__func__
  646. )
  647. )
  648. names = [name + "__" + str(i) for i in range(len(method_overloads))]
  649. overloads[item] = list(zip(names, method_overloads))
  650. return overloads
  651. def get_overload_name_mapping(overload_info):
  652. # Same format as __overloads__
  653. # original function => [overload names]
  654. overload_name_mappings: dict[str, list[str]] = {}
  655. for orig_fn, overloads in overload_info.items():
  656. original_name = orig_fn.__name__
  657. if original_name not in overload_name_mappings:
  658. overload_name_mappings[original_name] = []
  659. for overload_name, _ in overloads:
  660. overload_name_mappings[original_name].append(overload_name)
  661. return overload_name_mappings
  662. def _check_no_signature(func) -> None:
  663. signature = torch.jit.annotations.get_signature(
  664. func, None, fake_range(), inspect.ismethod(func)
  665. )
  666. if signature is None:
  667. qual_name = _jit_internal._qualified_name(func)
  668. raise RuntimeError(
  669. f"Must explicitly add type annotations to overloaded functions: {qual_name}"
  670. )
  671. def make_stubs_for_overloads(overload_info):
  672. overload_stubs = []
  673. for orig_fn, overloads in overload_info.items():
  674. orig_ast = get_jit_def(
  675. orig_fn, orig_fn.__name__, self_name="RecursiveScriptModule"
  676. )
  677. for overload_name, overload_fn in overloads:
  678. _check_no_signature(overload_fn)
  679. over_ast = get_jit_def(
  680. overload_fn, overload_fn.__name__, self_name="RecursiveScriptModule"
  681. )
  682. new_ast = torch._C._replace_overloaded_method_decl(
  683. over_ast.decl(), orig_ast, overload_name
  684. )
  685. _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn)
  686. overload_stubs.append(ScriptMethodStub(_rcb, new_ast, overload_fn))
  687. return overload_stubs
  688. def check_module_initialized(mod) -> None:
  689. if not isinstance(mod, torch.nn.Module):
  690. raise AssertionError(f"Expected torch.nn.Module, got {type(mod)}")
  691. if not hasattr(mod, "_parameters"):
  692. raise RuntimeError(
  693. f"'{torch.typename(type(mod))}' has not been initialized, did you forget to call 'super()'?"
  694. )
  695. # This is to avoid importing torch.distributed.nn
  696. if not hasattr(mod, "remote_parameters"):
  697. for name, param in mod._parameters.items():
  698. if param is not None and torch.nn.parameter.is_lazy(param):
  699. raise RuntimeError(
  700. f"'{torch.typename(type(mod))}' has uninitialized parameters {name}. Did you forget to run a forward pass?"
  701. )
  702. for name, buf in mod._buffers.items():
  703. if buf is not None and torch.nn.parameter.is_lazy(buf):
  704. raise RuntimeError(
  705. f"'{torch.typename(type(mod))}' has uninitialized buffers {name}. Did you forget to run a forward pass?"
  706. )
  707. def infer_methods_to_compile(nn_module):
  708. """Implement the default rules for which methods should act as starting points for compilation.
  709. (TODO add a link when the rules are published).
  710. """
  711. check_module_initialized(nn_module)
  712. ignored_properties = jit_ignored_properties(nn_module)
  713. methods: list[str] = []
  714. if hasattr(nn_module, "forward") and not _jit_internal.is_ignored_fn(
  715. nn_module.forward
  716. ):
  717. forward_func = getattr(nn_module.forward, "__func__", None)
  718. module_forward = getattr(torch.nn.Module, "forward", None)
  719. if forward_func != module_forward:
  720. methods = ["forward"]
  721. exported = []
  722. for name in dir(nn_module):
  723. if name in ignored_properties:
  724. continue
  725. item = getattr(nn_module, name, None)
  726. if (
  727. _jit_internal.get_torchscript_modifier(item)
  728. is _jit_internal.FunctionModifiers.EXPORT
  729. ):
  730. exported.append(name)
  731. methods = methods + exported
  732. overload_name_mappings = dict(getattr(nn_module, "__overloads__", {}))
  733. overload_info = get_overload_annotations(nn_module, ignored_properties)
  734. overload_name_mappings.update(get_overload_name_mapping(overload_info))
  735. overload_stubs = make_stubs_for_overloads(overload_info)
  736. nn_module.__overloads__ = overload_name_mappings
  737. # we shouldn't directly compile overloaded methods, just its overloads
  738. def ignore_overloaded(method_name):
  739. return method_name not in overload_name_mappings
  740. filtered_methods = filter(ignore_overloaded, methods)
  741. # Unique the methods. We don't want to use a set to store the methods because it
  742. # introduces non-determinism to compile order.
  743. uniquer: set[str] = set()
  744. uniqued_methods = []
  745. for name in filtered_methods:
  746. if name in uniquer:
  747. continue
  748. uniqued_methods.append(name)
  749. uniquer.add(name)
  750. stubs = [make_stub_from_method(nn_module, method) for method in uniqued_methods]
  751. return overload_stubs + stubs
  752. def get_hook_stubs(nn_module):
  753. """Return forward hook and pre_hook ScriptModuleStubs."""
  754. check_module_initialized(nn_module)
  755. hook_map: dict = {}
  756. hook_stubs = []
  757. for hook in nn_module._forward_hooks.values():
  758. if hook.__name__ in hook_map:
  759. if id(hook) != id(hook_map[hook.__name__]):
  760. raise RuntimeError(
  761. f"Hook '{hook.__name__}' on {type(nn_module).__name__} "
  762. "has at least two different python definitions."
  763. " Please use unique names for all hooks."
  764. )
  765. else:
  766. hook_map[hook.__name__] = hook
  767. hook_stubs.append(make_stub(hook, hook.__name__))
  768. pre_hook_stubs = []
  769. for pre_hook in nn_module._forward_pre_hooks.values():
  770. if pre_hook.__name__ in hook_map:
  771. if id(pre_hook) != id(hook_map[pre_hook.__name__]):
  772. raise RuntimeError(
  773. f"Pre-hook '{pre_hook.__name__}' on {type(nn_module).__name__} "
  774. "has at least two different python definitions."
  775. " Please use unique names for all hooks."
  776. )
  777. else:
  778. hook_map[pre_hook.__name__] = pre_hook
  779. pre_hook_stubs.append(make_stub(pre_hook, pre_hook.__name__))
  780. return hook_stubs, pre_hook_stubs
  781. def get_property_stubs(nn_module):
  782. """Create property stubs for the properties of the module by creating method stubs for the getter and setter."""
  783. module_ty = type(nn_module)
  784. properties_asts = get_class_properties(module_ty, self_name="RecursiveScriptModule")
  785. rcbs = {}
  786. for name in dir(module_ty):
  787. item = getattr(module_ty, name, None)
  788. if isinstance(item, property):
  789. if not item.fget:
  790. raise RuntimeError(
  791. f"Property {name} of {nn_module.__name__} must have a getter"
  792. )
  793. rcbs[name] = _jit_internal.createResolutionCallbackFromClosure(item.fget)
  794. stubs = [PropertyStub(rcbs[ast.name().name], ast) for ast in properties_asts]
  795. return stubs
  796. def interface_script(mod_interface, nn_module):
  797. """
  798. Make a ScriptModule from an nn.Module, using the interface methods rule for determining which methods to compile.
  799. Args:
  800. mod_interface: the interface type that the module have
  801. nn_module: The original Python nn.Module that we are creating a ScriptModule for.
  802. """
  803. if isinstance(nn_module, torch.jit.ScriptModule):
  804. return nn_module
  805. check_module_initialized(nn_module)
  806. def infer_interface_methods_to_compile(nn_module):
  807. """Rule to infer the methods from the interface type.
  808. It is used to know which methods need to act as starting points for compilation.
  809. """
  810. stubs = [
  811. make_stub_from_method(nn_module, method)
  812. for method in mod_interface.getMethodNames()
  813. ]
  814. return stubs
  815. return create_script_module(nn_module, infer_interface_methods_to_compile)
  816. def try_compile_fn(fn, loc):
  817. if _jit_internal.is_ignored_fn(fn):
  818. # Don't do anything for @ignore'd functions
  819. return None
  820. if isinstance(fn, torch.nn.Module):
  821. # Since modules are callable pybind recognizes them as functions, but
  822. # don't do anything for them
  823. return None
  824. if not inspect.isfunction(fn) and not inspect.ismethod(fn):
  825. raise RuntimeError(
  826. f"`{fn}` is not a function. Recursive scripting only supports "
  827. "Python functions or methods currently.\n"
  828. f"Consider manually annotating `{fn}` with @torch.jit.script."
  829. )
  830. # The object returned by __prepare_scriptable__ might have a different closure.
  831. # Resolve it here to get the right resolution callback.
  832. fn = fn.__prepare_scriptable__() if hasattr(fn, "__prepare_scriptable__") else fn # type: ignore[operator]
  833. # We don't have the actual scope where the function was defined, but we can
  834. # extract the necessary info from the closed over variables on the function
  835. # object
  836. rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
  837. return torch.jit.script(fn, _rcb=rcb)
  838. def wrap_cpp_class(cpp_class):
  839. """Wrap this torch._C.Object in a Python RecursiveScriptClass."""
  840. return torch.jit.RecursiveScriptClass(cpp_class)
  841. def wrap_cpp_module(cpp_module):
  842. """Wrap this torch._C.ScriptModule in a Python ScriptModule, recursively for all submodules."""
  843. def init_fn(script_module) -> None:
  844. for name, cpp_module in torch._C.ModuleDict(script_module._c).items():
  845. setattr(script_module, name, wrap_cpp_module(cpp_module))
  846. script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type(
  847. script_module._c._type()
  848. )
  849. for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()):
  850. script_module._forward_pre_hooks[idx] = fn
  851. for idx, fn in enumerate(script_module._c._get_forward_hooks()):
  852. script_module._forward_hooks[idx] = fn
  853. return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  854. def compile_unbound_method(concrete_type, fn):
  855. if _jit_internal.is_ignored_fn(fn):
  856. return None
  857. stub = make_stub(fn, fn.__name__)
  858. with torch._jit_internal._disable_emit_hooks():
  859. # We don't want to call the hooks here since the graph that is calling
  860. # this function is not yet complete
  861. create_methods_and_properties_from_stubs(concrete_type, (stub,), ())
  862. return stub
  863. def lazy_bind(concrete_type, unbound_method):
  864. """
  865. Return a function that lazily binds `unbound_method` to a provided Module IValue, then invokes the method.
  866. We do this so that any Python shenanigans that
  867. will poison type sharing are impossible at compile time.
  868. """
  869. def lazy_binding_method(cpp_module, *args):
  870. def init_fn(script_module) -> None:
  871. orig_class = concrete_type.py_class
  872. # Copy @ignored/@unused methods from the original module to the new one.
  873. # This ensures they are available during execution.
  874. for name in dir(orig_class):
  875. item = getattr(orig_class, name, None)
  876. if _jit_internal.is_ignored_fn(item):
  877. setattr(script_module, name, item)
  878. # Copy constants over so they are available during execution.
  879. for name, value in concrete_type.get_constants().items():
  880. setattr(script_module, name, value)
  881. script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  882. method = types.MethodType(unbound_method, script_module)
  883. return method(*args)
  884. # make the lazy binding method "look like" the original method
  885. lazy_binding_method.original_fn = unbound_method # type: ignore[attr-defined]
  886. lazy_binding_method.__name__ = unbound_method.__name__
  887. torch._jit_internal.copy_torchscript_modifier(unbound_method, lazy_binding_method)
  888. return lazy_binding_method