source.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295
  1. """
  2. This module provides Source classes that track the origins of values in PyTorch Dynamo.
  3. Sources represent where values come from (e.g. local variables, globals, attributes) and
  4. are used for guard generation and code reconstruction during compilation.
  5. The module includes specialized sources for:
  6. - Local variables and synthetic locals
  7. - Global variables and constants
  8. - Object attributes and method calls
  9. - NN module specialization (specialized vs unspecialized)
  10. - Random values and tensor properties
  11. - Default argument handling
  12. - FSDP (Fully Sharded Data Parallel) modules
  13. Sources play a key role in Dynamo's guard system by tracking value origins for
  14. guard generation, and in code reconstruction by providing methods to rebuild
  15. the code needed to recreate values.
  16. """
  17. import dataclasses
  18. import enum
  19. import functools
  20. from collections.abc import Callable
  21. from typing import Any, Optional, TYPE_CHECKING, Union
  22. from torch import device as device_type
  23. from torch._guards import (
  24. ChainedSource,
  25. dataclass_with_cached_hash,
  26. Guard,
  27. GuardSource,
  28. Source,
  29. )
  30. from . import utils
  31. from .bytecode_transformation import (
  32. create_binary_subscr,
  33. create_build_tuple,
  34. create_call_function,
  35. )
  36. if TYPE_CHECKING:
  37. from .codegen import PyCodegen
  38. # It shouldn't be supported to construct an NNModuleVariable inside an FSDP module,
  39. # so those cases are omitted intentionally
  40. # represents nn.Modules tracked with NNModuleVariable (specialized is implicit in the variable name)
  41. _GUARD_SOURCE_SPECIALIZED_NN_MODULE = {
  42. GuardSource.LOCAL: GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
  43. GuardSource.GLOBAL: GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
  44. GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
  45. GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
  46. # Just to ensure that guard_source() works
  47. GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
  48. GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
  49. GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  50. GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  51. GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE,
  52. GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
  53. }
  54. # represents nn.Modules tracked with UnspecializedNNModuleVariable
  55. _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE = {
  56. GuardSource.LOCAL: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
  57. GuardSource.GLOBAL: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
  58. GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
  59. GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
  60. # this happens for an UnspecializedNNModule submodule on a NNModuleVariable
  61. GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
  62. GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
  63. # Just to ensure that guard_source() works
  64. GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  65. GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  66. GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE,
  67. GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
  68. }
  69. # represents nn.Modules tracked with UnspecializedBuiltinNNModuleVariable
  70. _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE = {
  71. GuardSource.LOCAL: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  72. GuardSource.GLOBAL: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  73. GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  74. GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  75. GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  76. GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  77. # Just to ensure that guard_source() works
  78. GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  79. GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
  80. GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE,
  81. GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
  82. }
  83. _GUARD_SOURCE_FSDP_MODULE = {
  84. GuardSource.LOCAL: GuardSource.LOCAL_FSDP_MODULE,
  85. GuardSource.GLOBAL: GuardSource.GLOBAL_FSDP_MODULE,
  86. GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE,
  87. GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
  88. GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE,
  89. GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
  90. GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE,
  91. GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
  92. GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE,
  93. GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
  94. }
  95. def is_constant_source(source: Source) -> bool:
  96. if isinstance(source, ConstantSource):
  97. return True
  98. try:
  99. if source.guard_source == GuardSource.CONSTANT:
  100. return True
  101. except NotImplementedError:
  102. pass
  103. return False
  104. def _get_source_debug_name(source: Optional[Source]) -> str:
  105. if source is None:
  106. return "<unknown source>"
  107. else:
  108. try:
  109. return source.name
  110. except NotImplementedError:
  111. return "<unknown source>"
  112. def _esc_str(s: Any, apply_repr: bool = False) -> str:
  113. """
  114. Escapes curly brackets for format strings.
  115. e.g. "frozenset({0})" becomes "frozenset({{0}})".
  116. This is used by _name_template for example, because it's
  117. expected to return a format string, but we may wish to include
  118. strings that should not be accidentally formatted.
  119. """
  120. if apply_repr:
  121. s = repr(s)
  122. else:
  123. s = str(s)
  124. return s.replace("{", "{{").replace("}", "}}")
  125. @dataclass_with_cached_hash(frozen=True)
  126. class LocalSource(Source):
  127. local_name: str
  128. # Whether this local is an input to the root frame.
  129. is_input: bool = False
  130. # Whether we know this input is dynamic (based on example_inputs)
  131. # For non tensors, we simply look at the first index of the tuple
  132. dynamism: Optional[frozenset[str]] = None
  133. # Whether the item at this source is the _content_ of a cell that is
  134. # dereferenced from the root frame, i.e., it's a part of the `co_cellvars`
  135. # or `co_freevars`.
  136. is_derefed_cell_contents: bool = False
  137. def reconstruct(self, codegen: "PyCodegen") -> None:
  138. if self.is_derefed_cell_contents:
  139. codegen.load_deref(self.local_name)
  140. else:
  141. codegen.append_output(codegen.create_load(self.local_name))
  142. @property
  143. def guard_source(self) -> GuardSource:
  144. return GuardSource.LOCAL
  145. @functools.cached_property
  146. def _name_template(self) -> str:
  147. return f"L[{_esc_str(self.local_name, apply_repr=True)}]"
  148. @dataclass_with_cached_hash(frozen=True)
  149. class TempLocalSource(Source):
  150. # like LocalSource, but cannot be guarded on
  151. local_name: str
  152. def reconstruct(self, codegen: "PyCodegen") -> None:
  153. codegen.append_output(codegen.create_load(self.local_name))
  154. @property
  155. def guard_source(self) -> GuardSource:
  156. return GuardSource.TEMP_LOCAL
  157. @property
  158. def _name_template(self) -> str:
  159. raise NotImplementedError(
  160. "Cannot create guard on TempLocalSource - this is an internal Dynamo bug. Please file an issue on GitHub."
  161. )
  162. @dataclass_with_cached_hash(frozen=True)
  163. class SyntheticLocalSource(Source):
  164. local_name: str
  165. def reconstruct(self, codegen: "PyCodegen") -> None:
  166. codegen.append_output(codegen.create_load(self.local_name))
  167. @property
  168. def guard_source(self) -> GuardSource:
  169. return GuardSource.SYNTHETIC_LOCAL
  170. @functools.cached_property
  171. def _name_template(self) -> str:
  172. return f"SYNTHETIC_LOCAL[{_esc_str(self.local_name, apply_repr=True)}]"
  173. @dataclass_with_cached_hash(frozen=True)
  174. class RandomValueSource(Source):
  175. random_call_index: int
  176. @property
  177. def guard_source(self) -> GuardSource:
  178. return GuardSource.RANDOM_VALUE
  179. def reconstruct(self, codegen: "PyCodegen") -> None:
  180. codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var))
  181. codegen.append_output(codegen.create_load_const(self.random_call_index))
  182. codegen.append_output(create_binary_subscr())
  183. @functools.cached_property
  184. def _name_template(self) -> str:
  185. return f"random_value_{_esc_str(self.random_call_index)}"
  186. @dataclass_with_cached_hash(frozen=True)
  187. class GlobalSource(Source):
  188. global_name: str
  189. def reconstruct(self, codegen: "PyCodegen") -> None:
  190. codegen.append_output(codegen.create_load_global(self.global_name, add=True))
  191. @property
  192. def guard_source(self) -> GuardSource:
  193. return GuardSource.GLOBAL
  194. @functools.cached_property
  195. def _name_template(self) -> str:
  196. return f"G[{_esc_str(self.global_name, apply_repr=True)}]"
  197. @dataclass_with_cached_hash(frozen=True)
  198. class GlobalWeakRefSource(Source):
  199. global_name: str
  200. def reconstruct(self, codegen: "PyCodegen") -> None:
  201. codegen.add_push_null(
  202. lambda: codegen.append_output(
  203. codegen.create_load_global(self.global_name, add=True)
  204. )
  205. )
  206. codegen.extend_output(create_call_function(0, False))
  207. @property
  208. def guard_source(self) -> GuardSource:
  209. return GuardSource.GLOBAL
  210. @functools.cached_property
  211. def _name_template(self) -> str:
  212. return f"G[{_esc_str(self.global_name, apply_repr=True)}]()"
  213. @dataclass_with_cached_hash(frozen=True)
  214. class WeakRefCallSource(ChainedSource):
  215. def reconstruct(self, codegen: "PyCodegen") -> None:
  216. codegen.add_push_null(lambda: codegen(self.base))
  217. codegen.extend_output(create_call_function(0, False))
  218. @property
  219. def _name_template(self) -> str:
  220. return "{0}()"
  221. @dataclass_with_cached_hash(frozen=True)
  222. class CallFunctionNoArgsSource(WeakRefCallSource):
  223. pass
  224. @dataclass_with_cached_hash(frozen=True)
  225. class AttrSource(ChainedSource):
  226. member: str
  227. def __post_init__(self) -> None:
  228. assert self.base, "Can't construct an AttrSource without a valid base source"
  229. assert "." not in self.member, (
  230. f"AttrSource member must not contain '.', got {self.member!r}. "
  231. "Use OutputGraph.get_chained_attr_source() for dotted paths."
  232. )
  233. def reconstruct(self, codegen: "PyCodegen") -> None:
  234. codegen(self.base)
  235. codegen.extend_output(codegen.create_load_attrs(self.member))
  236. @functools.cached_property
  237. def _name_template(self) -> str:
  238. if not self.member.isidentifier():
  239. return f"getattr({{0}}, {_esc_str(self.member, apply_repr=True)})"
  240. return f"{{0}}.{_esc_str(self.member)}"
  241. @dataclass_with_cached_hash(frozen=True)
  242. class CellContentsSource(AttrSource):
  243. """
  244. Source for closure cell contents that also stores the freevar name.
  245. This allows guard failure messages to show which variable the closure cell refers to.
  246. """
  247. freevar_name: str = dataclasses.field(default="")
  248. def __post_init__(self) -> None:
  249. assert self.base, (
  250. "Can't construct a CellContentsSource without a valid base source"
  251. )
  252. assert self.member == "cell_contents", (
  253. "CellContentsSource should only be used for cell_contents"
  254. )
  255. @dataclass_with_cached_hash(frozen=True)
  256. class GenericAttrSource(ChainedSource):
  257. member: str
  258. def __post_init__(self) -> None:
  259. assert self.base, (
  260. "Can't construct a GenericAttrSource without a valid base source"
  261. )
  262. assert "." not in self.member, (
  263. f"GenericAttrSource member must not contain '.', got {self.member!r}. "
  264. "Use OutputGraph.get_chained_attr_source() for dotted paths."
  265. )
  266. def reconstruct(self, codegen: "PyCodegen") -> None:
  267. codegen(self.base)
  268. codegen.extend_output(codegen.create_load_attrs(self.member))
  269. @functools.cached_property
  270. def _name_template(self) -> str:
  271. return (
  272. f"object.__getattribute__({{0}}, {_esc_str(self.member, apply_repr=True)})"
  273. )
  274. # Represents obj.__dict__ where obj is a type object
  275. @dataclass_with_cached_hash(frozen=True)
  276. class TypeDictSource(ChainedSource):
  277. def reconstruct(self, codegen: "PyCodegen") -> None:
  278. codegen(self.base)
  279. codegen.extend_output(codegen.create_load_attrs("__dict__"))
  280. @property
  281. def _name_template(self) -> str:
  282. # type(ob).__dict__ can return a proxy of the dict. But in the C++
  283. # guard accessor, we are use type->tp_dict which is a dict. So,
  284. # forcefully pass a dict object to ensure that the GuardManager
  285. # registers that its working on a dict object.
  286. return "dict({0}.__dict__)"
  287. # Represents obj.__mro__ where object is type object
  288. @dataclass_with_cached_hash(frozen=True)
  289. class TypeMROSource(ChainedSource):
  290. def reconstruct(self, codegen: "PyCodegen") -> None:
  291. codegen(self.base)
  292. codegen.extend_output(codegen.create_load_attrs("__mro__"))
  293. @property
  294. def _name_template(self) -> str:
  295. return "{0}.__mro__"
  296. @dataclass_with_cached_hash(frozen=True)
  297. class LocalCellSource(Source):
  298. """
  299. Conceptually, this class is `LocalSource` for cell objects implicitly
  300. generated by Python (e.g., captured variables).
  301. """
  302. local_name: str
  303. def reconstruct(self, codegen: "PyCodegen") -> None:
  304. # Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics,
  305. # Dynamo's bytecode transformation differentiates them slightly, so we
  306. # always emit `LOAD_CLOSURE` here.
  307. codegen.append_output(codegen.create_load_closure(self.local_name))
  308. # All the other methods are intentionally unimplemented because e.g., a
  309. # local cell object should never be used for guards.
  310. # Represents obj.__code__ where object is type object
  311. @dataclass_with_cached_hash(frozen=True)
  312. class CodeSource(ChainedSource):
  313. def reconstruct(self, codegen: "PyCodegen") -> None:
  314. codegen(self.base)
  315. codegen.extend_output(codegen.create_load_attrs("__code__"))
  316. @property
  317. def _name_template(self) -> str:
  318. return "{0}.__code__"
  319. # Represents obj.__closure__ where object is type object
  320. @dataclass_with_cached_hash(frozen=True)
  321. class ClosureSource(ChainedSource):
  322. def reconstruct(self, codegen: "PyCodegen") -> None:
  323. codegen(self.base)
  324. codegen.extend_output(codegen.create_load_attrs("__closure__"))
  325. @property
  326. def _name_template(self) -> str:
  327. return "{0}.__closure__"
  328. # Represents tensor.grad source. It could be represented by AttrSource as well.
  329. # But, we could access grad field on tensor directly in C++ without going
  330. # through the Python bytecodes. Therefore, we use a separate source for grad
  331. # field.
  332. @dataclass_with_cached_hash(frozen=True)
  333. class GradSource(ChainedSource):
  334. member: str = "grad"
  335. def reconstruct(self, codegen: "PyCodegen") -> None:
  336. codegen(self.base)
  337. codegen.extend_output(codegen.create_load_attrs(self.member))
  338. @functools.cached_property
  339. def _name_template(self) -> str:
  340. return f"{{0}}.{_esc_str(self.member)}"
  341. @dataclass_with_cached_hash(frozen=True)
  342. class ParamBufferSource(AttrSource):
  343. @functools.cached_property
  344. def guard_source(self) -> GuardSource:
  345. return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source]
  346. # Special AttrSource to differentiate module._buffers or module._parameters
  347. @dataclass_with_cached_hash(frozen=True)
  348. class UnspecializedParamBufferSource(AttrSource):
  349. pass
  350. # This source is intended to be used in places where a source is needed but it is expected
  351. # that the symbol will be simplified out later on. Symbols with ephemeral sources are
  352. # prioritized to be simplified out when e.g. compared against a symbol without an ephemeral
  353. # source. Guarding on this source is an error.
  354. #
  355. # Example: During subclass view fake-ification, any close-over ViewFunc state should be
  356. # symbolicized / fake-ified to avoid invalid specialization during view replay. This source
  357. # is useful for symbols utilized in the middle of the view chain that are not expected to be
  358. # present within the final view shape metadata.
  359. @dataclass_with_cached_hash(frozen=True)
  360. class EphemeralSource(Source):
  361. desc: Optional[str] = None
  362. @property
  363. def guard_source(self) -> GuardSource:
  364. return GuardSource.EPHEMERAL
  365. @functools.cached_property
  366. def _name_template(self) -> str:
  367. desc = ": " + self.desc if self.desc is not None else ""
  368. return f"<ephemeral{_esc_str(desc)}>"
  369. def make_guard(self, fn: Callable[..., Any]) -> Guard:
  370. raise NotImplementedError
  371. def is_ephemeral(self) -> bool:
  372. return True
  373. @dataclass_with_cached_hash(frozen=True)
  374. class SkipGuardSource(ChainedSource):
  375. def reconstruct(self, codegen: "PyCodegen") -> None:
  376. self.base.reconstruct(codegen)
  377. @property
  378. def _name_template(self) -> str:
  379. return "{0}"
  380. class TensorProperty(enum.Enum):
  381. SIZE = 0
  382. STRIDE = 1
  383. STORAGE_OFFSET = 2
  384. def method_name(self) -> str:
  385. if self is TensorProperty.SIZE:
  386. return "size"
  387. elif self is TensorProperty.STRIDE:
  388. return "stride"
  389. elif self is TensorProperty.STORAGE_OFFSET:
  390. return "storage_offset"
  391. else:
  392. raise AssertionError(f"unhandled {_esc_str(self)}")
  393. @dataclass_with_cached_hash(frozen=True)
  394. class TensorPropertySource(ChainedSource):
  395. prop: TensorProperty
  396. idx: Optional[int] = None # None for STORAGE_OFFSET
  397. def __post_init__(self) -> None:
  398. assert self.base is not None
  399. if self.prop is TensorProperty.STORAGE_OFFSET:
  400. assert self.idx is None
  401. else:
  402. assert self.idx is not None
  403. def reconstruct(self, codegen: "PyCodegen") -> None:
  404. codegen.add_push_null(
  405. lambda: codegen.load_import_from(
  406. utils.__name__, f"call_{_esc_str(self.prop.method_name())}"
  407. )
  408. )
  409. codegen(self.base)
  410. if self.idx is not None:
  411. codegen.append_output(codegen.create_load_const(self.idx))
  412. codegen.extend_output(
  413. create_call_function(2 if self.idx is not None else 1, False)
  414. )
  415. @functools.cached_property
  416. def _name_template(self) -> str:
  417. if self.prop is TensorProperty.SIZE:
  418. return f"{{0}}.size()[{_esc_str(self.idx)}]"
  419. elif self.prop is TensorProperty.STRIDE:
  420. return f"{{0}}.stride()[{_esc_str(self.idx)}]"
  421. elif self.prop is TensorProperty.STORAGE_OFFSET:
  422. assert self.idx is None
  423. return "{0}.storage_offset()"
  424. else:
  425. raise AssertionError(f"unhandled {_esc_str(self.prop)}")
  426. @dataclass_with_cached_hash(frozen=True)
  427. class IndexedSource(ChainedSource):
  428. idx: int
  429. def __post_init__(self) -> None:
  430. assert self.base is not None
  431. def reconstruct(self, codegen: "PyCodegen") -> None:
  432. raise NotImplementedError
  433. @functools.cached_property
  434. def _name_template(self) -> str:
  435. return f"({_esc_str(self.idx)}, {{0}})"
  436. @dataclass_with_cached_hash(frozen=True)
  437. class NegateSource(ChainedSource):
  438. def __post_init__(self) -> None:
  439. assert self.base is not None
  440. def reconstruct(self, codegen: "PyCodegen") -> None:
  441. raise NotImplementedError
  442. @property
  443. def _name_template(self) -> str:
  444. # NB: use method call so that function stripping regexes work
  445. return "{0}.__neg__()"
  446. @dataclass_with_cached_hash(frozen=True)
  447. class ConvertIntSource(ChainedSource):
  448. def __post_init__(self) -> None:
  449. assert self.base is not None
  450. def reconstruct(self, codegen: "PyCodegen") -> None:
  451. codegen(self.base)
  452. @property
  453. def _name_template(self) -> str:
  454. return "cast_symbool_to_symint_guardless({0})"
  455. @dataclass_with_cached_hash(frozen=True)
  456. class DynamicScalarSource(ChainedSource):
  457. is_int: bool
  458. def __post_init__(self) -> None:
  459. assert self.base is not None
  460. def reconstruct(self, codegen: "PyCodegen") -> None:
  461. # Integer casting at reconstruction helps reduce the amount of DynamicInts returned
  462. # to the user, in favor of plain ints.
  463. # For example, a compiled region that only does int arithmetic could return a
  464. # DynamicInt without the casting here.
  465. codegen.add_push_null(lambda: codegen.load_import_from("builtins", "int"))
  466. codegen(self.base)
  467. codegen.extend_output(create_call_function(1, False))
  468. @property
  469. def _name_template(self) -> str:
  470. return "int({0})"
  471. @dataclass_with_cached_hash(frozen=True)
  472. class FlattenScriptObjectSource(ChainedSource):
  473. def __post_init__(self) -> None:
  474. assert self.base is not None
  475. def reconstruct(self, codegen: "PyCodegen") -> None:
  476. codegen(self.base)
  477. @property
  478. def _name_template(self) -> str:
  479. return "{0}.__obj_flatten__()"
  480. @dataclass_with_cached_hash(frozen=True)
  481. class ScriptObjectQualifiedNameSource(ChainedSource):
  482. def __post_init__(self) -> None:
  483. assert self.base is not None
  484. def reconstruct(self, codegen: "PyCodegen") -> None:
  485. codegen(self.base)
  486. @property
  487. def _name_template(self) -> str:
  488. return "{0}._type().qualified_name()"
  489. class AttrProxySource(ChainedSource):
  490. def reconstruct(self, codegen: "PyCodegen") -> None:
  491. codegen(self.base)
  492. @property
  493. def _name_template(self) -> str:
  494. return "{0}.get_base()"
  495. @dataclass_with_cached_hash(frozen=True)
  496. class DefaultsSource(ChainedSource):
  497. idx_key: Union[int, str]
  498. is_kw: bool = False
  499. field: str = dataclasses.field(init=False, repr=False, compare=False)
  500. _name: str = dataclasses.field(init=False, repr=False, compare=False)
  501. def __post_init__(self) -> None:
  502. assert self.base, (
  503. "Base must be a valid source in order to properly track and guard this Defaults to its origin."
  504. )
  505. if self.is_kw:
  506. assert isinstance(self.idx_key, str)
  507. object.__setattr__(self, "field", "__kwdefaults__")
  508. object.__setattr__(
  509. self,
  510. "_name",
  511. f"{{0}}.{_esc_str(self.field)}['{_esc_str(self.idx_key)}']",
  512. )
  513. else:
  514. assert isinstance(self.idx_key, int)
  515. object.__setattr__(self, "field", "__defaults__")
  516. object.__setattr__(
  517. self, "_name", f"{{0}}.{_esc_str(self.field)}[{_esc_str(self.idx_key)}]"
  518. )
  519. def reconstruct(self, codegen: "PyCodegen") -> None:
  520. codegen(self.base)
  521. codegen.extend_output(codegen.create_load_attrs(self.field))
  522. codegen.append_output(codegen.create_load_const(self.idx_key))
  523. codegen.append_output(create_binary_subscr())
  524. @functools.cached_property
  525. def _name_template(self) -> str:
  526. return self._name
  527. @dataclass_with_cached_hash(frozen=True)
  528. class GetItemSource(ChainedSource):
  529. index: Any
  530. index_is_slice: bool = False
  531. def __post_init__(self) -> None:
  532. assert self.base is not None
  533. if isinstance(self.index, slice):
  534. # store the hashable version of the slice so the whole GetItemSource is hashable
  535. super().__setattr__("index", self.index.__reduce__())
  536. super().__setattr__("index_is_slice", True)
  537. def reconstruct(self, codegen: "PyCodegen") -> None:
  538. codegen(self.base)
  539. if self.index_is_slice:
  540. codegen.append_output(codegen.create_load_const(self.unpack_slice()))
  541. else:
  542. codegen.append_output(codegen.create_load_const(self.index))
  543. codegen.append_output(create_binary_subscr())
  544. def unpack_slice(self) -> slice:
  545. assert self.index_is_slice
  546. slice_class, slice_args = self.index
  547. return slice_class(*slice_args)
  548. @functools.cached_property
  549. def _name_template(self) -> str:
  550. # Index can be of following types
  551. # 1) index is a slice - example 1:4
  552. # 2) index is a constant - example string, integer
  553. assert not isinstance(self.index, Source)
  554. if self.index_is_slice:
  555. return f"{{0}}[{_esc_str(self.unpack_slice(), apply_repr=True)}]"
  556. else:
  557. return f"{{0}}[{_esc_str(self.index, apply_repr=True)}]"
  558. @dataclass_with_cached_hash(frozen=True)
  559. class ConstDictKeySource(ChainedSource):
  560. index: Any
  561. def reconstruct(self, codegen: "PyCodegen") -> None:
  562. codegen.add_push_null(
  563. lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem")
  564. )
  565. codegen(self.base)
  566. codegen.append_output(codegen.create_load_const(self.index))
  567. codegen.extend_output(create_call_function(2, False))
  568. @functools.cached_property
  569. def _name_template(self) -> str:
  570. # The list creation will be CSE'd by PyExprCSEPass
  571. return f"list(dict.keys({{0}}))[{_esc_str(self.index, apply_repr=True)}]"
  572. def is_dict_key(self) -> bool:
  573. return True
  574. @dataclass_with_cached_hash(frozen=True)
  575. class NonSerializableSetGetItemSource(ChainedSource):
  576. index: int
  577. def __post_init__(self) -> None:
  578. from .variables import ConstantVariable
  579. assert ConstantVariable.is_literal(self.index)
  580. def reconstruct(self, codegen: "PyCodegen") -> None:
  581. codegen.add_push_null(
  582. lambda: codegen.load_import_from(utils.__name__, "set_getitem")
  583. )
  584. codegen(self.base)
  585. codegen.append_output(codegen.create_load_const(self.index))
  586. codegen.extend_output(create_call_function(2, False))
  587. @functools.cached_property
  588. def _name_template(self) -> str:
  589. # set ordering might not be stable
  590. return f"list({{0}})[{_esc_str(self.index, apply_repr=True)}]"
  591. def is_dict_key(self) -> bool:
  592. return False
  593. # Used to access an item from the dictionary
  594. @dataclass_with_cached_hash(frozen=True)
  595. class DictGetItemSource(ChainedSource):
  596. # Key to access in the dictionary. It can be one of the following types
  597. # 1) ConstDictKeySource
  598. # 2) constant - like string, integer
  599. index: Any
  600. def __post_init__(self) -> None:
  601. from .variables import ConstantVariable
  602. assert isinstance(
  603. self.index, ConstDictKeySource
  604. ) or ConstantVariable.is_literal(self.index)
  605. def reconstruct(self, codegen: "PyCodegen") -> None:
  606. # Load dict
  607. codegen(self.base)
  608. # Load key
  609. if isinstance(self.index, Source):
  610. codegen(self.index)
  611. else:
  612. codegen.append_output(codegen.create_load_const(self.index))
  613. codegen.append_output(create_binary_subscr())
  614. @functools.cached_property
  615. def _name_template(self) -> str:
  616. if isinstance(self.index, ConstDictKeySource):
  617. return f"{{0}}[{_esc_str(self.index.name)}]"
  618. else:
  619. return f"{{0}}[{_esc_str(self.index, apply_repr=True)}]"
  620. # Same as DictGetItemSource but used for dict.__getitem__ calls to ensure that
  621. # torch.compile does not run the overridden __getitem__ method
  622. @dataclass_with_cached_hash(frozen=True)
  623. class DictSubclassGetItemSource(ChainedSource):
  624. # Key to access in the dictionary. It can be one of the following types
  625. # 1) ConstDictKeySource
  626. # 2) constant - like string, integer
  627. index: Any
  628. def __post_init__(self) -> None:
  629. from .variables import ConstantVariable
  630. assert isinstance(
  631. self.index, ConstDictKeySource
  632. ) or ConstantVariable.is_literal(self.index)
  633. def reconstruct(self, codegen: "PyCodegen") -> None:
  634. # reconstruct dict.__getitem__(dct, key)
  635. # Load dict.__getitem__
  636. codegen.add_push_null(
  637. lambda: codegen.load_import_from(utils.__name__, "dict_getitem")
  638. )
  639. # Load dict
  640. codegen(self.base)
  641. # Load key
  642. if isinstance(self.index, Source):
  643. codegen(self.index)
  644. else:
  645. codegen.append_output(codegen.create_load_const(self.index))
  646. codegen.extend_output(create_call_function(2, False))
  647. @functools.cached_property
  648. def _name_template(self) -> str:
  649. if isinstance(self.index, ConstDictKeySource):
  650. return f"dict.__getitem__({{0}}, {_esc_str(self.index.name)})"
  651. else:
  652. return f"{{0}}[{_esc_str(self.index, apply_repr=True)}]"
  653. @dataclass_with_cached_hash(frozen=True)
  654. class ListGetItemSource(GetItemSource):
  655. """
  656. Same as GetItemSource with reconstruct and name overridden to be list specific.
  657. """
  658. def reconstruct(self, codegen: "PyCodegen") -> None:
  659. # Reconstruct list.__getitem__(lst, index) to avoid any side effects
  660. # from possibly overridden __getitem__.
  661. # Load list.__getitem__
  662. codegen.add_push_null(
  663. lambda: codegen.load_import_from(utils.__name__, "list_getitem")
  664. )
  665. # Load the list
  666. codegen(self.base)
  667. # Load the index
  668. if self.index_is_slice:
  669. raise RuntimeError(
  670. "List[slice] is a temporary object and should not have a source"
  671. )
  672. else:
  673. codegen.append_output(codegen.create_load_const(self.index))
  674. codegen.extend_output(create_call_function(2, False))
  675. @functools.cached_property
  676. def _name_template(self) -> str:
  677. # Index can be of following types
  678. # 1) index is a slice - example 1:4
  679. # 2) index is a constant - example string, integer
  680. assert not isinstance(self.index, Source)
  681. if self.index_is_slice:
  682. raise RuntimeError(
  683. "List[slice] is a temporary object and should not have a source"
  684. )
  685. else:
  686. return f"list.__getitem__({{0}}, {_esc_str(self.index, apply_repr=True)})"
  687. @dataclass_with_cached_hash(frozen=True)
  688. class TupleIteratorGetItemSource(GetItemSource):
  689. def reconstruct(self, codegen: "PyCodegen") -> None:
  690. codegen.add_push_null(
  691. lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
  692. )
  693. codegen(self.base)
  694. codegen.append_output(codegen.create_load_const(self.index))
  695. codegen.extend_output(create_call_function(2, False))
  696. @functools.cached_property
  697. def _name_template(self) -> str:
  698. return (
  699. f"___tuple_iterator_getitem({{0}}, {_esc_str(self.index, apply_repr=True)})"
  700. )
  701. @dataclass_with_cached_hash(frozen=True)
  702. class NamedTupleFieldsSource(ChainedSource):
  703. def reconstruct(self, codegen: "PyCodegen") -> None:
  704. codegen(self.base)
  705. codegen.extend_output(codegen.create_load_attrs("_fields"))
  706. @property
  707. def _name_template(self) -> str:
  708. return "___namedtuple_fields({0})"
  709. @dataclass_with_cached_hash(frozen=True)
  710. class DataclassFieldsSource(ChainedSource):
  711. def reconstruct(self, codegen: "PyCodegen") -> None:
  712. codegen.add_push_null(
  713. lambda: codegen.load_import_from(utils.__name__, "dataclass_fields")
  714. )
  715. codegen(self.base)
  716. codegen.extend_output(create_call_function(1, False))
  717. @property
  718. def _name_template(self) -> str:
  719. return "___dataclass_fields({0})"
  720. @dataclass_with_cached_hash(frozen=True)
  721. class TypeSource(ChainedSource):
  722. def __post_init__(self) -> None:
  723. assert self.base is not None
  724. def reconstruct(self, codegen: "PyCodegen") -> None:
  725. codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type"))
  726. codegen(self.base)
  727. codegen.extend_output(create_call_function(1, False))
  728. @property
  729. def _name_template(self) -> str:
  730. return "type({0})"
  731. @dataclass_with_cached_hash(frozen=True)
  732. class OptimizerSource(ChainedSource):
  733. def reconstruct(self, codegen: "PyCodegen") -> None:
  734. codegen(self.base)
  735. @property
  736. def _name_template(self) -> str:
  737. return "{0}"
  738. @dataclass_with_cached_hash(frozen=True)
  739. class NNModuleSource(ChainedSource):
  740. def reconstruct(self, codegen: "PyCodegen") -> None:
  741. codegen(self.base)
  742. @functools.cached_property
  743. def guard_source(self) -> GuardSource:
  744. return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source]
  745. @property
  746. def _name_template(self) -> str:
  747. return "{0}"
  748. @dataclass_with_cached_hash(frozen=True)
  749. class UnspecializedNNModuleSource(NNModuleSource):
  750. @functools.cached_property
  751. def guard_source(self) -> GuardSource:
  752. return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source]
  753. @dataclass_with_cached_hash(frozen=True)
  754. class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource):
  755. @functools.cached_property
  756. def guard_source(self) -> GuardSource:
  757. return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source]
  758. @dataclass_with_cached_hash(frozen=True)
  759. class FSDPNNModuleSource(NNModuleSource):
  760. @functools.cached_property
  761. def guard_source(self) -> GuardSource:
  762. return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source]
  763. @dataclass_with_cached_hash(frozen=True)
  764. class GlobalStateSource(Source):
  765. @property
  766. def _name_template(self) -> str:
  767. return ""
  768. @property
  769. def guard_source(self) -> GuardSource:
  770. return GuardSource.GLOBAL
  771. @dataclass_with_cached_hash(frozen=True)
  772. class ImportSource(Source):
  773. """Points to an imported module - used instead of GlobalSource
  774. in case the user has overridden the module name in their local namespace"""
  775. module_name: str
  776. def __post_init__(self) -> None:
  777. from .guards import GuardBuilder, install_guard
  778. install_guard(self.make_guard(GuardBuilder.ID_MATCH))
  779. @functools.cached_property
  780. def _name_template(self) -> str:
  781. return f"__import__('{self.module_name}')"
  782. def reconstruct(self, codegen: "PyCodegen") -> None:
  783. codegen.extend_output(
  784. [
  785. codegen.create_load_const(0), # level
  786. create_build_tuple(0), # fromlist
  787. codegen.create_import_name(self.module_name),
  788. ]
  789. )
  790. @property
  791. def guard_source(self) -> GuardSource:
  792. return GuardSource.GLOBAL
  793. @dataclass_with_cached_hash(frozen=True)
  794. class TorchFunctionModeStackSource(Source):
  795. ind: int
  796. @functools.cached_property
  797. def _name_template(self) -> str:
  798. return f"___get_torch_function_mode_stack_at({_esc_str(self._get_index())})"
  799. def _get_index(self) -> int:
  800. from .variables.torch_function import TorchFunctionModeStackVariable
  801. return TorchFunctionModeStackVariable.get_mode_index(self.ind)
  802. def reconstruct(self, codegen: "PyCodegen") -> None:
  803. codegen.add_push_null(
  804. lambda: codegen.load_import_from(
  805. utils.__name__, "get_torch_function_mode_stack_at"
  806. )
  807. )
  808. codegen.extend_output([codegen.create_load_const(self._get_index())])
  809. codegen.extend_output(create_call_function(1, False))
  810. @property
  811. def guard_source(self) -> GuardSource:
  812. return GuardSource.GLOBAL
  813. @dataclass_with_cached_hash(frozen=True)
  814. class ConstantSource(Source):
  815. source_name: str
  816. def reconstruct(self, codegen: "PyCodegen") -> None:
  817. codegen.append_output(codegen.create_load_global(self.source_name, add=False))
  818. @property
  819. def guard_source(self) -> GuardSource:
  820. return GuardSource.CONSTANT
  821. @functools.cached_property
  822. def _name_template(self) -> str:
  823. return self.source_name
  824. def make_guard(self, fn: Any) -> Any:
  825. raise NotImplementedError
  826. @dataclass_with_cached_hash(frozen=True)
  827. class NumpyTensorSource(ChainedSource):
  828. @property
  829. def _name_template(self) -> str:
  830. return "___from_numpy({0})"
  831. def reconstruct(self, codegen: "PyCodegen") -> None:
  832. codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor"))
  833. codegen(self.base)
  834. codegen.extend_output(create_call_function(1, False))
  835. @dataclass_with_cached_hash(frozen=True)
  836. class SubclassAttrListSource(ChainedSource):
  837. @property
  838. def _name_template(self) -> str:
  839. return "{0}.__tensor_flatten__()[0]"
  840. # NB: We don't expect you to actually ever generate guards against this
  841. # source, it is ephemeral
  842. @dataclass_with_cached_hash(frozen=True)
  843. class FloatTensorSource(ChainedSource):
  844. @property
  845. def _name_template(self) -> str:
  846. return "___as_tensor({0})"
  847. @dataclass_with_cached_hash(frozen=True)
  848. class CallMethodItemSource(ChainedSource):
  849. @property
  850. def _name_template(self) -> str:
  851. return "{0}.item()"
  852. # This is a synthetic source that is associated with the singleton
  853. # shape env guard we always register for all frames. We get the actual
  854. # guard contents from the ambient ShapeEnv
  855. @dataclass_with_cached_hash(frozen=True)
  856. class ShapeEnvSource(Source):
  857. @property
  858. def _name_template(self) -> str:
  859. return ""
  860. @property
  861. def guard_source(self) -> GuardSource:
  862. return GuardSource.SHAPE_ENV
  863. @dataclass_with_cached_hash(frozen=True)
  864. class CurrentStreamSource(Source):
  865. device: device_type
  866. @functools.cached_property
  867. def _name_template(self) -> str:
  868. return f"___get_current_stream(torch.device('{_esc_str(self.device.type)}', {_esc_str(self.device.index)}))"
  869. def reconstruct(self, codegen: "PyCodegen") -> None:
  870. num_args = 1
  871. codegen.add_push_null(
  872. lambda: codegen.load_import_from(utils.__name__, "get_current_stream")
  873. )
  874. codegen.add_push_null(lambda: codegen.load_import_from("torch", "device"))
  875. codegen.extend_output([codegen.create_load_const(self.device.type)])
  876. if self.device.index is not None:
  877. num_args += 1
  878. codegen.extend_output([codegen.create_load_const(self.device.index)])
  879. codegen.extend_output(create_call_function(num_args, False))
  880. codegen.extend_output(create_call_function(1, False))
  881. @property
  882. def guard_source(self) -> GuardSource:
  883. return GuardSource.GLOBAL
  884. @dataclass_with_cached_hash(frozen=True)
  885. class BackwardStateSource(Source):
  886. @property
  887. def _name_template(self) -> str:
  888. return ""
  889. @property
  890. def guard_source(self) -> GuardSource:
  891. return GuardSource.BACKWARD_STATE
  892. @functools.lru_cache
  893. def get_local_source_name(
  894. source: Source, *, only_allow_input: bool = False
  895. ) -> Optional[str]:
  896. if isinstance(source, ChainedSource):
  897. return get_local_source_name(source.base, only_allow_input=only_allow_input)
  898. if not isinstance(source, LocalSource):
  899. return None
  900. if only_allow_input and not source.is_input:
  901. return None
  902. return source.local_name
  903. @functools.lru_cache
  904. def is_from_local_source(source: Source, *, only_allow_input: bool = False) -> bool:
  905. return get_local_source_name(source, only_allow_input=only_allow_input) is not None
  906. @functools.lru_cache
  907. def is_from_global_source(source: Source) -> bool:
  908. return get_global_source_name(source) is not None
  909. @functools.lru_cache
  910. def get_global_source_name(source: Source | None) -> str | None:
  911. if isinstance(source, ChainedSource):
  912. return get_global_source_name(source.base)
  913. if not isinstance(source, GlobalSource):
  914. return None
  915. return source.global_name
  916. @functools.lru_cache
  917. def is_from_nonlocal_source(source: Source) -> bool:
  918. if isinstance(source, ChainedSource):
  919. return is_from_nonlocal_source(source.base)
  920. return (
  921. isinstance(source, LocalSource)
  922. and source.is_derefed_cell_contents
  923. and not source.is_input
  924. )
  925. @functools.lru_cache
  926. def is_from_closure_source(source: Source) -> bool:
  927. if isinstance(source, ClosureSource):
  928. return True
  929. if isinstance(source, ChainedSource):
  930. return is_from_closure_source(source.base)
  931. return False
  932. @functools.lru_cache
  933. def is_from_source(source: Source, target: Source) -> bool:
  934. if isinstance(source, ChainedSource):
  935. return is_from_source(source.base, target)
  936. return source == target
  937. @functools.lru_cache
  938. def is_from_unspecialized_nn_module_source(source: Source) -> bool:
  939. if isinstance(source, UnspecializedNNModuleSource):
  940. return True
  941. if isinstance(source, ChainedSource):
  942. return is_from_unspecialized_nn_module_source(source.base)
  943. return False
  944. @functools.lru_cache
  945. def is_from_unspecialized_builtin_nn_module_source(source: Source) -> bool:
  946. if isinstance(source, UnspecializedBuiltinNNModuleSource):
  947. return True
  948. if isinstance(source, ChainedSource):
  949. return is_from_unspecialized_builtin_nn_module_source(source.base)
  950. return False
  951. @functools.lru_cache
  952. def is_from_unspecialized_param_buffer_source(source: Source) -> bool:
  953. if isinstance(source, UnspecializedParamBufferSource):
  954. return True
  955. if isinstance(source, ChainedSource):
  956. return is_from_unspecialized_param_buffer_source(source.base)
  957. return False
  958. @functools.lru_cache
  959. def is_from_flatten_script_object_source(source: Source) -> bool:
  960. if isinstance(source, FlattenScriptObjectSource):
  961. return True
  962. elif isinstance(source, ChainedSource):
  963. return is_from_flatten_script_object_source(source.base)
  964. return False
  965. @functools.lru_cache
  966. def is_from_optimizer_source(source: Source) -> bool:
  967. if isinstance(source, OptimizerSource):
  968. return True
  969. if isinstance(source, ChainedSource):
  970. return is_from_optimizer_source(source.base)
  971. return False
  972. # TODO: can probably write a generic "test this on everything in the chain"
  973. # helper
  974. @functools.lru_cache
  975. def is_from_defaults(source: Source) -> bool:
  976. if isinstance(source, DefaultsSource):
  977. return True
  978. # Accessed with func.__kwdefaults__["foo"]
  979. if (
  980. isinstance(source, DictGetItemSource)
  981. and isinstance(source.base, AttrSource)
  982. and source.base.member == "__kwdefaults__"
  983. ):
  984. return True
  985. # Accessed with func.__defaults__[0]
  986. if (
  987. isinstance(source, GetItemSource)
  988. and isinstance(source.base, AttrSource)
  989. and source.base.member == "__defaults__"
  990. ):
  991. return True
  992. if isinstance(source, ChainedSource):
  993. return is_from_defaults(source.base)
  994. return False
  995. @functools.lru_cache
  996. def is_from_skip_guard_source(source: Source) -> bool:
  997. if isinstance(source, SkipGuardSource):
  998. return True
  999. if isinstance(source, ChainedSource):
  1000. return is_from_skip_guard_source(source.base)
  1001. return False