_check.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. # mypy: allow-untyped-defs
  2. import ast
  3. import inspect
  4. import sys
  5. import textwrap
  6. import warnings
  7. import torch
  8. class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
  9. """Check the ``__init__`` method of a given ``nn.Module``.
  10. It ensures that all instance-level attributes can be properly initialized.
  11. Specifically, we do type inference based on attribute values...even
  12. if the attribute in question has already been typed using
  13. Python3-style annotations or ``torch.jit.annotate``. This means that
  14. setting an instance-level attribute to ``[]`` (for ``List``),
  15. ``{}`` for ``Dict``), or ``None`` (for ``Optional``) isn't enough
  16. information for us to properly initialize that attribute.
  17. An object of this class can walk a given ``nn.Module``'s AST and
  18. determine if it meets our requirements or not.
  19. Known limitations
  20. 1. We can only check the AST nodes for certain constructs; we can't
  21. ``eval`` arbitrary expressions. This means that function calls,
  22. class instantiations, and complex expressions that resolve to one of
  23. the "empty" values specified above will NOT be flagged as
  24. problematic.
  25. 2. We match on string literals, so if the user decides to use a
  26. non-standard import (e.g. `from typing import List as foo`), we
  27. won't catch it.
  28. Example:
  29. .. code-block:: python
  30. class M(torch.nn.Module):
  31. def fn(self):
  32. return []
  33. def __init__(self) -> None:
  34. super().__init__()
  35. self.x: List[int] = []
  36. def forward(self, x: List[int]):
  37. self.x = x
  38. return 1
  39. The above code will pass the ``AttributeTypeIsSupportedChecker``
  40. check since we have a function call in ``__init__``. However,
  41. it will still fail later with the ``RuntimeError`` "Tried to set
  42. nonexistent attribute: x. Did you forget to initialize it in
  43. __init__()?".
  44. Args:
  45. nn_module - The instance of ``torch.nn.Module`` whose
  46. ``__init__`` method we wish to check
  47. """
  48. def check(self, nn_module: torch.nn.Module) -> None:
  49. source_lines = inspect.getsource(nn_module.__class__.__init__)
  50. # Ignore comments no matter the indentation
  51. def is_useless_comment(line):
  52. line = line.strip()
  53. return line.startswith("#") and not line.startswith("# type:")
  54. source_lines = "\n".join(
  55. [l for l in source_lines.split("\n") if not is_useless_comment(l)]
  56. )
  57. # This AST only contains the `__init__` method of the nn.Module
  58. init_ast = ast.parse(textwrap.dedent(source_lines))
  59. # Get items annotated in the class body
  60. if sys.version_info >= (3, 14):
  61. import annotationlib
  62. self.class_level_annotations = list(
  63. annotationlib.get_annotations(
  64. nn_module, format=annotationlib.Format.FORWARDREF
  65. ).keys()
  66. )
  67. else:
  68. self.class_level_annotations = list(nn_module.__annotations__.keys())
  69. # Flag for later
  70. self.visiting_class_level_ann = False
  71. self.visit(init_ast)
  72. def _is_empty_container(self, node: ast.AST, ann_type: str) -> bool:
  73. if ann_type == "List":
  74. # Assigning `[]` to a `List` type gives you a Node where
  75. # value=List(elts=[], ctx=Load())
  76. if not isinstance(node, ast.List):
  77. return False
  78. if node.elts:
  79. return False
  80. elif ann_type == "Dict":
  81. # Assigning `{}` to a `Dict` type gives you a Node where
  82. # value=Dict(keys=[], values=[])
  83. if not isinstance(node, ast.Dict):
  84. return False
  85. if node.keys:
  86. return False
  87. elif ann_type == "Optional":
  88. # Assigning `None` to an `Optional` type gives you a
  89. # Node where value=Constant(value=None, kind=None)
  90. if not isinstance(node, ast.Constant):
  91. return False
  92. if node.value: # type: ignore[attr-defined]
  93. return False
  94. return True
  95. def visit_Assign(self, node) -> None:
  96. """Store assignment state when assigning to a Call Node.
  97. If we're visiting a Call Node (the right-hand side of an
  98. assignment statement), we won't be able to check the variable
  99. that we're assigning to (the left-hand side of an assignment).
  100. Because of this, we need to store this state in visitAssign.
  101. (Luckily, we only have to do this if we're assigning to a Call
  102. Node, i.e. ``torch.jit.annotate``. If we're using normal Python
  103. annotations, we'll be visiting an AnnAssign Node, which has its
  104. target built in.)
  105. """
  106. try:
  107. if (
  108. isinstance(node.value, ast.Call)
  109. and node.targets[0].attr in self.class_level_annotations
  110. ):
  111. self.visiting_class_level_ann = True
  112. except AttributeError:
  113. return
  114. self.generic_visit(node)
  115. self.visiting_class_level_ann = False
  116. def visit_AnnAssign(self, node) -> None:
  117. """Visit an AnnAssign node in an ``nn.Module``'s ``__init__`` method.
  118. It checks if it conforms to our attribute annotation rules."""
  119. # If we have a local variable
  120. try:
  121. if node.target.value.id != "self":
  122. return
  123. except AttributeError:
  124. return
  125. # If we have an attribute that's already been annotated at the
  126. # class level
  127. if node.target.attr in self.class_level_annotations:
  128. return
  129. # TODO @ansley: add `Union` once landed
  130. # NB: Even though `Tuple` is a "container", we don't want to
  131. # check for it here. `Tuple` functions as an type with an
  132. # "infinite" number of subtypes, in the sense that you can have
  133. # `Tuple[())]`, `Tuple[T1]`, `Tuple[T2]`, `Tuple[T1, T2]`,
  134. # `Tuple[T2, T1]` and so on, and none of these subtypes can be
  135. # used in place of the other. Therefore, assigning an empty
  136. # tuple in `__init__` CORRECTLY means that that variable
  137. # cannot be reassigned later to a non-empty tuple. Same
  138. # deal with `NamedTuple`
  139. containers = {"List", "list", "Dict", "dict", "Optional"}
  140. # If we're not evaluating one of the specified problem types
  141. try:
  142. if node.annotation.value.id not in containers:
  143. return
  144. except AttributeError:
  145. # To evaluate a base type (`str`, `int`, etc.), we would
  146. # have needed to get the name through `node.annotation.id`
  147. # instead of `node.annotation.value.id`. Seems that we're
  148. # not evaluating one of our "containers"
  149. return
  150. # Check if the assigned variable is empty
  151. ann_type = node.annotation.value.id
  152. if not self._is_empty_container(node.value, ann_type):
  153. return
  154. warnings.warn(
  155. "The TorchScript type system doesn't support "
  156. "instance-level annotations on empty non-base "
  157. "types in `__init__`. Instead, either 1) use a "
  158. "type annotation in the class body, or 2) wrap "
  159. "the type in `torch.jit.Attribute`.",
  160. stacklevel=2,
  161. )
  162. def visit_Call(self, node) -> None:
  163. """Determine if a Call node is 'torch.jit.annotate' in __init__.
  164. Visit a Call node in an ``nn.Module``'s ``__init__``
  165. method and determine if it's ``torch.jit.annotate``. If so,
  166. see if it conforms to our attribute annotation rules.
  167. """
  168. # If we have an attribute that's already been annotated at the
  169. # class level
  170. if self.visiting_class_level_ann:
  171. return
  172. # If this isn't a call to `torch.jit.annotate`
  173. try:
  174. if (
  175. node.func.value.value.id != "torch"
  176. or node.func.value.attr != "jit"
  177. or node.func.attr != "annotate"
  178. ):
  179. self.generic_visit(node)
  180. elif (
  181. node.func.value.value.id != "jit" or node.func.value.attr != "annotate"
  182. ):
  183. self.generic_visit(node)
  184. except AttributeError:
  185. # Looks like we didn't even have the right node structure
  186. # to check for `torch.jit.annotate` in the first place
  187. self.generic_visit(node)
  188. # Invariant: we have a `torch.jit.annotate` or a
  189. # `torch.annotate` call
  190. # A Call Node for `torch.jit.annotate` should have an `args`
  191. # list of length 2 where args[0] represents the annotation and
  192. # args[1] represents the actual value
  193. if len(node.args) != 2:
  194. return
  195. if not isinstance(node.args[0], ast.Subscript):
  196. return
  197. # See notes in `visit_AnnAssign` r.e. containers
  198. containers = {"List", "Dict", "Optional"}
  199. try:
  200. ann_type = node.args[0].value.id # type: ignore[attr-defined]
  201. except AttributeError:
  202. return
  203. if ann_type not in containers:
  204. return
  205. # Check if the assigned variable is empty
  206. if not self._is_empty_container(node.args[1], ann_type):
  207. return
  208. warnings.warn(
  209. "The TorchScript type system doesn't support "
  210. "instance-level annotations on empty non-base "
  211. "types in `__init__`. Instead, either 1) use a "
  212. "type annotation in the class body, or 2) wrap "
  213. "the type in `torch.jit.Attribute`.",
  214. stacklevel=2,
  215. )