mypy_plugin.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. """A mypy_ plugin for managing a number of platform-specific annotations.
  2. Its functionality can be split into three distinct parts:
  3. * Assigning the (platform-dependent) precisions of certain `~numpy.number`
  4. subclasses, including the likes of `~numpy.int_`, `~numpy.intp` and
  5. `~numpy.longlong`. See the documentation on
  6. :ref:`scalar types <arrays.scalars.built-in>` for a comprehensive overview
  7. of the affected classes. Without the plugin the precision of all relevant
  8. classes will be inferred as `~typing.Any`.
  9. * Removing all extended-precision `~numpy.number` subclasses that are
  10. unavailable for the platform in question. Most notably this includes the
  11. likes of `~numpy.float128` and `~numpy.complex256`. Without the plugin *all*
  12. extended-precision types will, as far as mypy is concerned, be available
  13. to all platforms.
  14. * Assigning the (platform-dependent) precision of `~numpy.ctypeslib.c_intp`.
  15. Without the plugin the type will default to `ctypes.c_int64`.
  16. .. versionadded:: 1.22
  17. .. deprecated:: 2.3
  18. The :mod:`numpy.typing.mypy_plugin` entry-point is deprecated in favor of
  19. platform-agnostic static type inference. Remove
  20. ``numpy.typing.mypy_plugin`` from the ``plugins`` section of your mypy
  21. configuration; if that surfaces new errors, please open an issue with a
  22. minimal reproducer.
  23. Examples
  24. --------
  25. To enable the plugin, one must add it to their mypy `configuration file`_:
  26. .. code-block:: ini
  27. [mypy]
  28. plugins = numpy.typing.mypy_plugin
  29. .. _mypy: https://mypy-lang.org/
  30. .. _configuration file: https://mypy.readthedocs.io/en/stable/config_file.html
  31. """
  32. from collections.abc import Callable, Iterable
  33. from typing import TYPE_CHECKING, Final, TypeAlias, cast
  34. import numpy as np
  35. __all__: list[str] = []
  36. def _get_precision_dict() -> dict[str, str]:
  37. names = [
  38. ("_NBitByte", np.byte),
  39. ("_NBitShort", np.short),
  40. ("_NBitIntC", np.intc),
  41. ("_NBitIntP", np.intp),
  42. ("_NBitInt", np.int_),
  43. ("_NBitLong", np.long),
  44. ("_NBitLongLong", np.longlong),
  45. ("_NBitHalf", np.half),
  46. ("_NBitSingle", np.single),
  47. ("_NBitDouble", np.double),
  48. ("_NBitLongDouble", np.longdouble),
  49. ]
  50. ret: dict[str, str] = {}
  51. for name, typ in names:
  52. n = 8 * np.dtype(typ).itemsize
  53. ret[f"{_MODULE}._nbit.{name}"] = f"{_MODULE}._nbit_base._{n}Bit"
  54. return ret
  55. def _get_extended_precision_list() -> list[str]:
  56. extended_names = [
  57. "float96",
  58. "float128",
  59. "complex192",
  60. "complex256",
  61. ]
  62. return [i for i in extended_names if hasattr(np, i)]
  63. def _get_c_intp_name() -> str:
  64. # Adapted from `np.core._internal._getintp_ctype`
  65. return {
  66. "i": "c_int",
  67. "l": "c_long",
  68. "q": "c_longlong",
  69. }.get(np.dtype("n").char, "c_long")
  70. _MODULE: Final = "numpy._typing"
  71. #: A dictionary mapping type-aliases in `numpy._typing._nbit` to
  72. #: concrete `numpy.typing.NBitBase` subclasses.
  73. _PRECISION_DICT: Final = _get_precision_dict()
  74. #: A list with the names of all extended precision `np.number` subclasses.
  75. _EXTENDED_PRECISION_LIST: Final = _get_extended_precision_list()
  76. #: The name of the ctypes equivalent of `np.intp`
  77. _C_INTP: Final = _get_c_intp_name()
  78. try:
  79. if TYPE_CHECKING:
  80. from mypy.typeanal import TypeAnalyser
  81. import mypy.types
  82. from mypy.build import PRI_MED
  83. from mypy.nodes import ImportFrom, MypyFile, Statement
  84. from mypy.plugin import AnalyzeTypeContext, Plugin
  85. except ModuleNotFoundError as e:
  86. def plugin(version: str) -> type:
  87. raise e
  88. else:
  89. _HookFunc: TypeAlias = Callable[[AnalyzeTypeContext], mypy.types.Type]
  90. def _hook(ctx: AnalyzeTypeContext) -> mypy.types.Type:
  91. """Replace a type-alias with a concrete ``NBitBase`` subclass."""
  92. typ, _, api = ctx
  93. name = typ.name.split(".")[-1]
  94. name_new = _PRECISION_DICT[f"{_MODULE}._nbit.{name}"]
  95. return cast("TypeAnalyser", api).named_type(name_new)
  96. def _index(iterable: Iterable[Statement], id: str) -> int:
  97. """Identify the first ``ImportFrom`` instance the specified `id`."""
  98. for i, value in enumerate(iterable):
  99. if getattr(value, "id", None) == id:
  100. return i
  101. raise ValueError("Failed to identify a `ImportFrom` instance "
  102. f"with the following id: {id!r}")
  103. def _override_imports(
  104. file: MypyFile,
  105. module: str,
  106. imports: list[tuple[str, str | None]],
  107. ) -> None:
  108. """Override the first `module`-based import with new `imports`."""
  109. # Construct a new `from module import y` statement
  110. import_obj = ImportFrom(module, 0, names=imports)
  111. import_obj.is_top_level = True
  112. # Replace the first `module`-based import statement with `import_obj`
  113. for lst in [file.defs, cast("list[Statement]", file.imports)]:
  114. i = _index(lst, module)
  115. lst[i] = import_obj
  116. class _NumpyPlugin(Plugin):
  117. """A mypy plugin for handling versus numpy-specific typing tasks."""
  118. def get_type_analyze_hook(self, fullname: str) -> _HookFunc | None:
  119. """Set the precision of platform-specific `numpy.number`
  120. subclasses.
  121. For example: `numpy.int_`, `numpy.longlong` and `numpy.longdouble`.
  122. """
  123. if fullname in _PRECISION_DICT:
  124. return _hook
  125. return None
  126. def get_additional_deps(
  127. self, file: MypyFile
  128. ) -> list[tuple[int, str, int]]:
  129. """Handle all import-based overrides.
  130. * Import platform-specific extended-precision `numpy.number`
  131. subclasses (*e.g.* `numpy.float96` and `numpy.float128`).
  132. * Import the appropriate `ctypes` equivalent to `numpy.intp`.
  133. """
  134. fullname = file.fullname
  135. if fullname == "numpy":
  136. _override_imports(
  137. file,
  138. f"{_MODULE}._extended_precision",
  139. imports=[(v, v) for v in _EXTENDED_PRECISION_LIST],
  140. )
  141. elif fullname == "numpy.ctypeslib":
  142. _override_imports(
  143. file,
  144. "ctypes",
  145. imports=[(_C_INTP, "_c_intp")],
  146. )
  147. return [(PRI_MED, fullname, -1)]
  148. def plugin(version: str) -> type:
  149. import warnings
  150. plugin = "numpy.typing.mypy_plugin"
  151. # Deprecated 2025-01-10, NumPy 2.3
  152. warn_msg = (
  153. f"`{plugin}` is deprecated, and will be removed in a future "
  154. f"release. Please remove `plugins = {plugin}` in your mypy config."
  155. f"(deprecated in NumPy 2.3)"
  156. )
  157. warnings.warn(warn_msg, DeprecationWarning, stacklevel=3)
  158. return _NumpyPlugin