mypy_plugin.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. """A mypy_ plugin for managing a number of platform-specific annotations.
  2. Its functionality can be split into three distinct parts:
  3. * Assigning the (platform-dependent) precisions of certain `~numpy.number`
  4. subclasses, including the likes of `~numpy.int_`, `~numpy.intp` and
  5. `~numpy.longlong`. See the documentation on
  6. :ref:`scalar types <arrays.scalars.built-in>` for a comprehensive overview
  7. of the affected classes. Without the plugin the precision of all relevant
  8. classes will be inferred as `~typing.Any`.
  9. * Removing all extended-precision `~numpy.number` subclasses that are
  10. unavailable for the platform in question. Most notably this includes the
  11. likes of `~numpy.float128` and `~numpy.complex256`. Without the plugin *all*
  12. extended-precision types will, as far as mypy is concerned, be available
  13. to all platforms.
  14. * Assigning the (platform-dependent) precision of `~numpy.ctypeslib.c_intp`.
  15. Without the plugin the type will default to `ctypes.c_int64`.
  16. .. versionadded:: 1.22
  17. Examples
  18. --------
  19. To enable the plugin, one must add it to their mypy `configuration file`_:
  20. .. code-block:: ini
  21. [mypy]
  22. plugins = numpy.typing.mypy_plugin
  23. .. _mypy: https://mypy-lang.org/
  24. .. _configuration file: https://mypy.readthedocs.io/en/stable/config_file.html
  25. """
  26. from __future__ import annotations
  27. from typing import Final, TYPE_CHECKING, Callable
  28. import numpy as np
  29. if TYPE_CHECKING:
  30. from collections.abc import Iterable
  31. try:
  32. import mypy.types
  33. from mypy.types import Type
  34. from mypy.plugin import Plugin, AnalyzeTypeContext
  35. from mypy.nodes import MypyFile, ImportFrom, Statement
  36. from mypy.build import PRI_MED
  37. _HookFunc = Callable[[AnalyzeTypeContext], Type]
  38. MYPY_EX: None | ModuleNotFoundError = None
  39. except ModuleNotFoundError as ex:
  40. MYPY_EX = ex
  41. __all__: list[str] = []
  42. def _get_precision_dict() -> dict[str, str]:
  43. names = [
  44. ("_NBitByte", np.byte),
  45. ("_NBitShort", np.short),
  46. ("_NBitIntC", np.intc),
  47. ("_NBitIntP", np.intp),
  48. ("_NBitInt", np.int_),
  49. ("_NBitLong", np.long),
  50. ("_NBitLongLong", np.longlong),
  51. ("_NBitHalf", np.half),
  52. ("_NBitSingle", np.single),
  53. ("_NBitDouble", np.double),
  54. ("_NBitLongDouble", np.longdouble),
  55. ]
  56. ret = {}
  57. module = "numpy._typing"
  58. for name, typ in names:
  59. n: int = 8 * typ().dtype.itemsize
  60. ret[f'{module}._nbit.{name}'] = f"{module}._nbit_base._{n}Bit"
  61. return ret
  62. def _get_extended_precision_list() -> list[str]:
  63. extended_names = [
  64. "uint128",
  65. "uint256",
  66. "int128",
  67. "int256",
  68. "float80",
  69. "float96",
  70. "float128",
  71. "float256",
  72. "complex160",
  73. "complex192",
  74. "complex256",
  75. "complex512",
  76. ]
  77. return [i for i in extended_names if hasattr(np, i)]
  78. def _get_c_intp_name() -> str:
  79. # Adapted from `np.core._internal._getintp_ctype`
  80. char = np.dtype('n').char
  81. if char == 'i':
  82. return "c_int"
  83. elif char == 'l':
  84. return "c_long"
  85. elif char == 'q':
  86. return "c_longlong"
  87. else:
  88. return "c_long"
  89. #: A dictionary mapping type-aliases in `numpy._typing._nbit` to
  90. #: concrete `numpy.typing.NBitBase` subclasses.
  91. _PRECISION_DICT: Final = _get_precision_dict()
  92. #: A list with the names of all extended precision `np.number` subclasses.
  93. _EXTENDED_PRECISION_LIST: Final = _get_extended_precision_list()
  94. #: The name of the ctypes equivalent of `np.intp`
  95. _C_INTP: Final = _get_c_intp_name()
  96. def _hook(ctx: AnalyzeTypeContext) -> Type:
  97. """Replace a type-alias with a concrete ``NBitBase`` subclass."""
  98. typ, _, api = ctx
  99. name = typ.name.split(".")[-1]
  100. name_new = _PRECISION_DICT[f"numpy._typing._nbit.{name}"]
  101. return api.named_type(name_new)
  102. if TYPE_CHECKING or MYPY_EX is None:
  103. def _index(iterable: Iterable[Statement], id: str) -> int:
  104. """Identify the first ``ImportFrom`` instance the specified `id`."""
  105. for i, value in enumerate(iterable):
  106. if getattr(value, "id", None) == id:
  107. return i
  108. raise ValueError("Failed to identify a `ImportFrom` instance "
  109. f"with the following id: {id!r}")
  110. def _override_imports(
  111. file: MypyFile,
  112. module: str,
  113. imports: list[tuple[str, None | str]],
  114. ) -> None:
  115. """Override the first `module`-based import with new `imports`."""
  116. # Construct a new `from module import y` statement
  117. import_obj = ImportFrom(module, 0, names=imports)
  118. import_obj.is_top_level = True
  119. # Replace the first `module`-based import statement with `import_obj`
  120. for lst in [file.defs, file.imports]: # type: list[Statement]
  121. i = _index(lst, module)
  122. lst[i] = import_obj
  123. class _NumpyPlugin(Plugin):
  124. """A mypy plugin for handling versus numpy-specific typing tasks."""
  125. def get_type_analyze_hook(self, fullname: str) -> None | _HookFunc:
  126. """Set the precision of platform-specific `numpy.number`
  127. subclasses.
  128. For example: `numpy.int_`, `numpy.longlong` and `numpy.longdouble`.
  129. """
  130. if fullname in _PRECISION_DICT:
  131. return _hook
  132. return None
  133. def get_additional_deps(
  134. self, file: MypyFile
  135. ) -> list[tuple[int, str, int]]:
  136. """Handle all import-based overrides.
  137. * Import platform-specific extended-precision `numpy.number`
  138. subclasses (*e.g.* `numpy.float96`, `numpy.float128` and
  139. `numpy.complex256`).
  140. * Import the appropriate `ctypes` equivalent to `numpy.intp`.
  141. """
  142. ret = [(PRI_MED, file.fullname, -1)]
  143. if file.fullname == "numpy":
  144. _override_imports(
  145. file, "numpy._typing._extended_precision",
  146. imports=[(v, v) for v in _EXTENDED_PRECISION_LIST],
  147. )
  148. elif file.fullname == "numpy.ctypeslib":
  149. _override_imports(
  150. file, "ctypes",
  151. imports=[(_C_INTP, "_c_intp")],
  152. )
  153. return ret
  154. def plugin(version: str) -> type[_NumpyPlugin]:
  155. """An entry-point for mypy."""
  156. return _NumpyPlugin
  157. else:
  158. def plugin(version: str) -> type[_NumpyPlugin]:
  159. """An entry-point for mypy."""
  160. raise MYPY_EX