brain_typing.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
  2. # For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
  4. """Astroid hooks for typing.py support."""
  5. from __future__ import annotations
  6. import textwrap
  7. import typing
  8. from collections.abc import Iterator
  9. from functools import partial
  10. from typing import Final
  11. from astroid import context, nodes
  12. from astroid.brain.helpers import register_module_extender
  13. from astroid.builder import AstroidBuilder, _extract_single_node, extract_node
  14. from astroid.const import PY312_PLUS, PY313_PLUS, PY314_PLUS
  15. from astroid.exceptions import (
  16. AstroidSyntaxError,
  17. AttributeInferenceError,
  18. InferenceError,
  19. UseInferenceDefault,
  20. )
  21. from astroid.inference_tip import inference_tip
  22. from astroid.manager import AstroidManager
  23. TYPING_TYPEVARS = {"TypeVar", "NewType"}
  24. TYPING_TYPEVARS_QUALIFIED: Final = {
  25. "typing.TypeVar",
  26. "typing.NewType",
  27. "typing_extensions.TypeVar",
  28. }
  29. TYPING_TYPEDDICT_QUALIFIED: Final = {"typing.TypedDict", "typing_extensions.TypedDict"}
  30. TYPING_TYPE_TEMPLATE = """
  31. class Meta(type):
  32. def __getitem__(self, item):
  33. return self
  34. @property
  35. def __args__(self):
  36. return ()
  37. class {0}(metaclass=Meta):
  38. pass
  39. """
  40. TYPING_MEMBERS = set(getattr(typing, "__all__", []))
  41. TYPING_ALIAS = frozenset(
  42. (
  43. "typing.Hashable",
  44. "typing.Awaitable",
  45. "typing.Coroutine",
  46. "typing.AsyncIterable",
  47. "typing.AsyncIterator",
  48. "typing.Iterable",
  49. "typing.Iterator",
  50. "typing.Reversible",
  51. "typing.Sized",
  52. "typing.Container",
  53. "typing.Collection",
  54. "typing.Callable",
  55. "typing.AbstractSet",
  56. "typing.MutableSet",
  57. "typing.Mapping",
  58. "typing.MutableMapping",
  59. "typing.Sequence",
  60. "typing.MutableSequence",
  61. "typing.ByteString", # scheduled for removal in 3.17
  62. "typing.Tuple",
  63. "typing.List",
  64. "typing.Deque",
  65. "typing.Set",
  66. "typing.FrozenSet",
  67. "typing.MappingView",
  68. "typing.KeysView",
  69. "typing.ItemsView",
  70. "typing.ValuesView",
  71. "typing.ContextManager",
  72. "typing.AsyncContextManager",
  73. "typing.Dict",
  74. "typing.DefaultDict",
  75. "typing.OrderedDict",
  76. "typing.Counter",
  77. "typing.ChainMap",
  78. "typing.Generator",
  79. "typing.AsyncGenerator",
  80. "typing.Type",
  81. "typing.Pattern",
  82. "typing.Match",
  83. )
  84. )
  85. CLASS_GETITEM_TEMPLATE = """
  86. @classmethod
  87. def __class_getitem__(cls, item):
  88. return cls
  89. """
  90. def looks_like_typing_typevar_or_newtype(node) -> bool:
  91. func = node.func
  92. if isinstance(func, nodes.Attribute):
  93. return func.attrname in TYPING_TYPEVARS
  94. if isinstance(func, nodes.Name):
  95. return func.name in TYPING_TYPEVARS
  96. return False
  97. def infer_typing_typevar_or_newtype(
  98. node: nodes.Call, context_itton: context.InferenceContext | None = None
  99. ) -> Iterator[nodes.ClassDef]:
  100. """Infer a typing.TypeVar(...) or typing.NewType(...) call."""
  101. try:
  102. func = next(node.func.infer(context=context_itton))
  103. except (InferenceError, StopIteration) as exc:
  104. raise UseInferenceDefault from exc
  105. if func.qname() not in TYPING_TYPEVARS_QUALIFIED:
  106. raise UseInferenceDefault
  107. if not node.args:
  108. raise UseInferenceDefault
  109. # Cannot infer from a dynamic class name (f-string)
  110. if isinstance(node.args[0], nodes.JoinedStr):
  111. raise UseInferenceDefault
  112. typename = node.args[0].as_string().strip("'")
  113. try:
  114. node = extract_node(TYPING_TYPE_TEMPLATE.format(typename))
  115. except AstroidSyntaxError as exc:
  116. raise InferenceError from exc
  117. return node.infer(context=context_itton)
  118. def _looks_like_typing_subscript(node) -> bool:
  119. """Try to figure out if a Subscript node *might* be a typing-related subscript."""
  120. if isinstance(node, nodes.Name):
  121. return node.name in TYPING_MEMBERS
  122. if isinstance(node, nodes.Attribute):
  123. return node.attrname in TYPING_MEMBERS
  124. if isinstance(node, nodes.Subscript):
  125. return _looks_like_typing_subscript(node.value)
  126. return False
  127. def infer_typing_attr(
  128. node: nodes.Subscript, ctx: context.InferenceContext | None = None
  129. ) -> Iterator[nodes.ClassDef]:
  130. """Infer a typing.X[...] subscript."""
  131. try:
  132. value = next(node.value.infer()) # type: ignore[union-attr] # value shouldn't be None for Subscript.
  133. except (InferenceError, StopIteration) as exc:
  134. raise UseInferenceDefault from exc
  135. if not value.qname().startswith("typing.") or value.qname() in TYPING_ALIAS:
  136. # If typing subscript belongs to an alias handle it separately.
  137. raise UseInferenceDefault
  138. if (
  139. PY313_PLUS
  140. and isinstance(value, nodes.FunctionDef)
  141. and value.qname() == "typing.Annotated"
  142. ):
  143. # typing.Annotated is a FunctionDef on 3.13+
  144. node._explicit_inference = lambda node, context: iter([value])
  145. return iter([value])
  146. if isinstance(value, nodes.ClassDef) and value.qname() in {
  147. "typing.Generic",
  148. "typing.Annotated",
  149. "typing_extensions.Annotated",
  150. }:
  151. # typing.Generic and typing.Annotated (PY39) are subscriptable
  152. # through __class_getitem__. Since astroid can't easily
  153. # infer the native methods, replace them for an easy inference tip
  154. func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
  155. value.locals["__class_getitem__"] = [func_to_add]
  156. if (
  157. isinstance(node.parent, nodes.ClassDef)
  158. and node in node.parent.bases
  159. and getattr(node.parent, "__cache", None)
  160. ):
  161. # node.parent.slots is evaluated and cached before the inference tip
  162. # is first applied. Remove the last result to allow a recalculation of slots
  163. cache = node.parent.__cache # type: ignore[attr-defined] # Unrecognized getattr
  164. if cache.get(node.parent.slots) is not None:
  165. del cache[node.parent.slots]
  166. # Avoid re-instantiating this class every time it's seen
  167. node._explicit_inference = lambda node, context: iter([value])
  168. return iter([value])
  169. node = extract_node(TYPING_TYPE_TEMPLATE.format(value.qname().split(".")[-1]))
  170. return node.infer(context=ctx)
  171. def _looks_like_generic_class_pep695(node: nodes.ClassDef) -> bool:
  172. """Check if class is using type parameter. Python 3.12+."""
  173. return len(node.type_params) > 0
  174. def infer_typing_generic_class_pep695(
  175. node: nodes.ClassDef, ctx: context.InferenceContext | None = None
  176. ) -> Iterator[nodes.ClassDef]:
  177. """Add __class_getitem__ for generic classes. Python 3.12+."""
  178. func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
  179. node.locals["__class_getitem__"] = [func_to_add]
  180. return iter([node])
  181. def _looks_like_typedDict( # pylint: disable=invalid-name
  182. node: nodes.FunctionDef | nodes.ClassDef,
  183. ) -> bool:
  184. """Check if node is TypedDict FunctionDef."""
  185. return node.qname() in TYPING_TYPEDDICT_QUALIFIED
  186. def infer_typedDict( # pylint: disable=invalid-name
  187. node: nodes.FunctionDef, ctx: context.InferenceContext | None = None
  188. ) -> Iterator[nodes.ClassDef]:
  189. """Replace TypedDict FunctionDef with ClassDef."""
  190. class_def = nodes.ClassDef(
  191. name="TypedDict",
  192. lineno=node.lineno,
  193. col_offset=node.col_offset,
  194. parent=node.parent,
  195. end_lineno=node.end_lineno,
  196. end_col_offset=node.end_col_offset,
  197. )
  198. class_def.postinit(bases=[extract_node("dict")], body=[], decorators=None)
  199. func_to_add = _extract_single_node("dict")
  200. class_def.locals["__call__"] = [func_to_add]
  201. return iter([class_def])
  202. def _looks_like_typing_alias(node: nodes.Call) -> bool:
  203. """
  204. Returns True if the node corresponds to a call to _alias function.
  205. For example :
  206. MutableSet = _alias(collections.abc.MutableSet, T)
  207. :param node: call node
  208. """
  209. return (
  210. isinstance(node.func, nodes.Name)
  211. # TODO: remove _DeprecatedGenericAlias when Py3.14 min
  212. and node.func.name in {"_alias", "_DeprecatedGenericAlias"}
  213. and len(node.args) == 2
  214. and (
  215. # _alias function works also for builtins object such as list and dict
  216. isinstance(node.args[0], (nodes.Attribute, nodes.Name))
  217. )
  218. )
  219. def _forbid_class_getitem_access(node: nodes.ClassDef) -> None:
  220. """Disable the access to __class_getitem__ method for the node in parameters."""
  221. def full_raiser(origin_func, attr, *args, **kwargs):
  222. """
  223. Raises an AttributeInferenceError in case of access to __class_getitem__ method.
  224. Otherwise, just call origin_func.
  225. """
  226. if attr == "__class_getitem__":
  227. raise AttributeInferenceError("__class_getitem__ access is not allowed")
  228. return origin_func(attr, *args, **kwargs)
  229. try:
  230. node.getattr("__class_getitem__")
  231. # If we are here, then we are sure to modify an object that does have
  232. # __class_getitem__ method (which origin is the protocol defined in
  233. # collections module) whereas the typing module considers it should not.
  234. # We do not want __class_getitem__ to be found in the classdef
  235. partial_raiser = partial(full_raiser, node.getattr)
  236. node.getattr = partial_raiser
  237. except AttributeInferenceError:
  238. pass
  239. def infer_typing_alias(
  240. node: nodes.Call, ctx: context.InferenceContext | None = None
  241. ) -> Iterator[nodes.ClassDef]:
  242. """
  243. Infers the call to _alias function
  244. Insert ClassDef, with same name as aliased class,
  245. in mro to simulate _GenericAlias.
  246. :param node: call node
  247. :param context: inference context
  248. # TODO: evaluate if still necessary when Py3.12 is minimum
  249. """
  250. if not (
  251. isinstance(node.parent, nodes.Assign)
  252. and len(node.parent.targets) == 1
  253. and isinstance(node.parent.targets[0], nodes.AssignName)
  254. ):
  255. raise UseInferenceDefault
  256. try:
  257. res = next(node.args[0].infer(context=ctx))
  258. except StopIteration as e:
  259. raise InferenceError(node=node.args[0], context=ctx) from e
  260. assign_name = node.parent.targets[0]
  261. class_def = nodes.ClassDef(
  262. name=assign_name.name,
  263. lineno=assign_name.lineno,
  264. col_offset=assign_name.col_offset,
  265. parent=node.parent,
  266. end_lineno=assign_name.end_lineno,
  267. end_col_offset=assign_name.end_col_offset,
  268. )
  269. if isinstance(res, nodes.ClassDef):
  270. # Only add `res` as base if it's a `ClassDef`
  271. # This isn't the case for `typing.Pattern` and `typing.Match`
  272. class_def.postinit(bases=[res], body=[], decorators=None)
  273. maybe_type_var = node.args[1]
  274. if isinstance(maybe_type_var, nodes.Const) and maybe_type_var.value > 0:
  275. # If typing alias is subscriptable, add `__class_getitem__` to ClassDef
  276. func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
  277. class_def.locals["__class_getitem__"] = [func_to_add]
  278. else:
  279. # If not, make sure that `__class_getitem__` access is forbidden.
  280. # This is an issue in cases where the aliased class implements it,
  281. # but the typing alias isn't subscriptable. E.g., `typing.ByteString` for PY39+
  282. _forbid_class_getitem_access(class_def)
  283. # Avoid re-instantiating this class every time it's seen
  284. node._explicit_inference = lambda node, context: iter([class_def])
  285. return iter([class_def])
  286. def _looks_like_special_alias(node: nodes.Call) -> bool:
  287. """Return True if call is for Tuple or Callable alias.
  288. In PY37 and PY38 the call is to '_VariadicGenericAlias' with 'tuple' as
  289. first argument. In PY39+ it is replaced by a call to '_TupleType'.
  290. PY37: Tuple = _VariadicGenericAlias(tuple, (), inst=False, special=True)
  291. PY39: Tuple = _TupleType(tuple, -1, inst=False, name='Tuple')
  292. PY37: Callable = _VariadicGenericAlias(collections.abc.Callable, (), special=True)
  293. PY39: Callable = _CallableType(collections.abc.Callable, 2)
  294. """
  295. return (
  296. isinstance(node.func, nodes.Name)
  297. and node.args
  298. and (
  299. (
  300. node.func.name == "_TupleType"
  301. and isinstance(node.args[0], nodes.Name)
  302. and node.args[0].name == "tuple"
  303. )
  304. or (
  305. node.func.name == "_CallableType"
  306. and isinstance(node.args[0], nodes.Attribute)
  307. and node.args[0].as_string() == "collections.abc.Callable"
  308. )
  309. )
  310. )
  311. def infer_special_alias(
  312. node: nodes.Call, ctx: context.InferenceContext | None = None
  313. ) -> Iterator[nodes.ClassDef]:
  314. """Infer call to tuple alias as new subscriptable class typing.Tuple."""
  315. if not (
  316. isinstance(node.parent, nodes.Assign)
  317. and len(node.parent.targets) == 1
  318. and isinstance(node.parent.targets[0], nodes.AssignName)
  319. ):
  320. raise UseInferenceDefault
  321. try:
  322. res = next(node.args[0].infer(context=ctx))
  323. except StopIteration as e:
  324. raise InferenceError(node=node.args[0], context=ctx) from e
  325. assign_name = node.parent.targets[0]
  326. class_def = nodes.ClassDef(
  327. name=assign_name.name,
  328. parent=node.parent,
  329. lineno=assign_name.lineno,
  330. col_offset=assign_name.col_offset,
  331. end_lineno=assign_name.end_lineno,
  332. end_col_offset=assign_name.end_col_offset,
  333. )
  334. class_def.postinit(bases=[res], body=[], decorators=None)
  335. func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
  336. class_def.locals["__class_getitem__"] = [func_to_add]
  337. # Avoid re-instantiating this class every time it's seen
  338. node._explicit_inference = lambda node, context: iter([class_def])
  339. return iter([class_def])
  340. def _looks_like_typing_cast(node: nodes.Call) -> bool:
  341. return (isinstance(node.func, nodes.Name) and node.func.name == "cast") or (
  342. isinstance(node.func, nodes.Attribute) and node.func.attrname == "cast"
  343. )
  344. def infer_typing_cast(
  345. node: nodes.Call, ctx: context.InferenceContext | None = None
  346. ) -> Iterator[nodes.NodeNG]:
  347. """Infer call to cast() returning same type as casted-from var."""
  348. if not isinstance(node.func, (nodes.Name, nodes.Attribute)):
  349. raise UseInferenceDefault
  350. try:
  351. func = next(node.func.infer(context=ctx))
  352. except (InferenceError, StopIteration) as exc:
  353. raise UseInferenceDefault from exc
  354. if not (
  355. isinstance(func, nodes.FunctionDef)
  356. and func.qname() == "typing.cast"
  357. and len(node.args) == 2
  358. ):
  359. raise UseInferenceDefault
  360. return node.args[1].infer(context=ctx)
  361. def _typing_transform():
  362. code = textwrap.dedent(
  363. """
  364. class Generic:
  365. @classmethod
  366. def __class_getitem__(cls, item): return cls
  367. class ParamSpec:
  368. @property
  369. def args(self):
  370. return ParamSpecArgs(self)
  371. @property
  372. def kwargs(self):
  373. return ParamSpecKwargs(self)
  374. class ParamSpecArgs: ...
  375. class ParamSpecKwargs: ...
  376. class TypeAlias: ...
  377. class Type:
  378. @classmethod
  379. def __class_getitem__(cls, item): return cls
  380. class TypeVar:
  381. @classmethod
  382. def __class_getitem__(cls, item): return cls
  383. class TypeVarTuple: ...
  384. class ContextManager:
  385. @classmethod
  386. def __class_getitem__(cls, item): return cls
  387. class AsyncContextManager:
  388. @classmethod
  389. def __class_getitem__(cls, item): return cls
  390. class Pattern:
  391. @classmethod
  392. def __class_getitem__(cls, item): return cls
  393. class Match:
  394. @classmethod
  395. def __class_getitem__(cls, item): return cls
  396. """
  397. )
  398. if PY314_PLUS:
  399. code += textwrap.dedent(
  400. """
  401. from annotationlib import ForwardRef
  402. class Union:
  403. @classmethod
  404. def __class_getitem__(cls, item): return cls
  405. """
  406. )
  407. return AstroidBuilder(AstroidManager()).string_build(code)
  408. def register(manager: AstroidManager) -> None:
  409. manager.register_transform(
  410. nodes.Call,
  411. inference_tip(infer_typing_typevar_or_newtype),
  412. looks_like_typing_typevar_or_newtype,
  413. )
  414. manager.register_transform(
  415. nodes.Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript
  416. )
  417. manager.register_transform(
  418. nodes.Call, inference_tip(infer_typing_cast), _looks_like_typing_cast
  419. )
  420. manager.register_transform(
  421. nodes.FunctionDef, inference_tip(infer_typedDict), _looks_like_typedDict
  422. )
  423. manager.register_transform(
  424. nodes.Call, inference_tip(infer_typing_alias), _looks_like_typing_alias
  425. )
  426. manager.register_transform(
  427. nodes.Call, inference_tip(infer_special_alias), _looks_like_special_alias
  428. )
  429. if PY312_PLUS:
  430. register_module_extender(manager, "typing", _typing_transform)
  431. manager.register_transform(
  432. nodes.ClassDef,
  433. inference_tip(infer_typing_generic_class_pep695),
  434. _looks_like_generic_class_pep695,
  435. )