_nbit_base.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. """A module with the precisions of generic `~numpy.number` types."""
  2. from .._utils import set_module
  3. from typing import final
  4. @final # Disallow the creation of arbitrary `NBitBase` subclasses
  5. @set_module("numpy.typing")
  6. class NBitBase:
  7. """
  8. A type representing `numpy.number` precision during static type checking.
  9. Used exclusively for the purpose static type checking, `NBitBase`
  10. represents the base of a hierarchical set of subclasses.
  11. Each subsequent subclass is herein used for representing a lower level
  12. of precision, *e.g.* ``64Bit > 32Bit > 16Bit``.
  13. .. versionadded:: 1.20
  14. Examples
  15. --------
  16. Below is a typical usage example: `NBitBase` is herein used for annotating
  17. a function that takes a float and integer of arbitrary precision
  18. as arguments and returns a new float of whichever precision is largest
  19. (*e.g.* ``np.float16 + np.int64 -> np.float64``).
  20. .. code-block:: python
  21. >>> from __future__ import annotations
  22. >>> from typing import TypeVar, TYPE_CHECKING
  23. >>> import numpy as np
  24. >>> import numpy.typing as npt
  25. >>> S = TypeVar("S", bound=npt.NBitBase)
  26. >>> T = TypeVar("T", bound=npt.NBitBase)
  27. >>> def add(a: np.floating[S], b: np.integer[T]) -> np.floating[S | T]:
  28. ... return a + b
  29. >>> a = np.float16()
  30. >>> b = np.int64()
  31. >>> out = add(a, b)
  32. >>> if TYPE_CHECKING:
  33. ... reveal_locals()
  34. ... # note: Revealed local types are:
  35. ... # note: a: numpy.floating[numpy.typing._16Bit*]
  36. ... # note: b: numpy.signedinteger[numpy.typing._64Bit*]
  37. ... # note: out: numpy.floating[numpy.typing._64Bit*]
  38. """
  39. def __init_subclass__(cls) -> None:
  40. allowed_names = {
  41. "NBitBase", "_256Bit", "_128Bit", "_96Bit", "_80Bit",
  42. "_64Bit", "_32Bit", "_16Bit", "_8Bit",
  43. }
  44. if cls.__name__ not in allowed_names:
  45. raise TypeError('cannot inherit from final class "NBitBase"')
  46. super().__init_subclass__()
  47. @final
  48. @set_module("numpy._typing")
  49. # Silence errors about subclassing a `@final`-decorated class
  50. class _256Bit(NBitBase): # type: ignore[misc]
  51. pass
  52. @final
  53. @set_module("numpy._typing")
  54. class _128Bit(_256Bit): # type: ignore[misc]
  55. pass
  56. @final
  57. @set_module("numpy._typing")
  58. class _96Bit(_128Bit): # type: ignore[misc]
  59. pass
  60. @final
  61. @set_module("numpy._typing")
  62. class _80Bit(_96Bit): # type: ignore[misc]
  63. pass
  64. @final
  65. @set_module("numpy._typing")
  66. class _64Bit(_80Bit): # type: ignore[misc]
  67. pass
  68. @final
  69. @set_module("numpy._typing")
  70. class _32Bit(_64Bit): # type: ignore[misc]
  71. pass
  72. @final
  73. @set_module("numpy._typing")
  74. class _16Bit(_32Bit): # type: ignore[misc]
  75. pass
  76. @final
  77. @set_module("numpy._typing")
  78. class _8Bit(_16Bit): # type: ignore[misc]
  79. pass