_backend.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import scipy._lib.uarray as ua
  2. from scipy._lib._array_api import xp_capabilities
  3. from . import _basic_backend
  4. from . import _realtransforms_backend
  5. from . import _fftlog_backend
  6. class _ScipyBackend:
  7. """The default backend for fft calculations
  8. Notes
  9. -----
  10. We use the domain ``numpy.scipy`` rather than ``scipy`` because ``uarray``
  11. treats the domain as a hierarchy. This means the user can install a single
  12. backend for ``numpy`` and have it implement ``numpy.scipy.fft`` as well.
  13. """
  14. __ua_domain__ = "numpy.scipy.fft"
  15. @staticmethod
  16. def __ua_function__(method, args, kwargs):
  17. fn = getattr(_basic_backend, method.__name__, None)
  18. if fn is None:
  19. fn = getattr(_realtransforms_backend, method.__name__, None)
  20. if fn is None:
  21. fn = getattr(_fftlog_backend, method.__name__, None)
  22. if fn is None:
  23. return NotImplemented
  24. return fn(*args, **kwargs)
  25. _named_backends = {
  26. 'scipy': _ScipyBackend,
  27. }
  28. def _backend_from_arg(backend):
  29. """Maps strings to known backends and validates the backend"""
  30. if isinstance(backend, str):
  31. try:
  32. backend = _named_backends[backend]
  33. except KeyError as e:
  34. raise ValueError(f'Unknown backend {backend}') from e
  35. if backend.__ua_domain__ != 'numpy.scipy.fft':
  36. raise ValueError('Backend does not implement "numpy.scipy.fft"')
  37. return backend
  38. @xp_capabilities(out_of_scope=True)
  39. def set_global_backend(backend, coerce=False, only=False, try_last=False):
  40. """Sets the global fft backend
  41. This utility method replaces the default backend for permanent use. It
  42. will be tried in the list of backends automatically, unless the
  43. ``only`` flag is set on a backend. This will be the first tried
  44. backend outside the :obj:`set_backend` context manager.
  45. Parameters
  46. ----------
  47. backend : {object, 'scipy'}
  48. The backend to use.
  49. Can either be a ``str`` containing the name of a known backend
  50. {'scipy'} or an object that implements the uarray protocol.
  51. coerce : bool
  52. Whether to coerce input types when trying this backend.
  53. only : bool
  54. If ``True``, no more backends will be tried if this fails.
  55. Implied by ``coerce=True``.
  56. try_last : bool
  57. If ``True``, the global backend is tried after registered backends.
  58. Raises
  59. ------
  60. ValueError: If the backend does not implement ``numpy.scipy.fft``.
  61. Notes
  62. -----
  63. This will overwrite the previously set global backend, which, by default, is
  64. the SciPy implementation.
  65. Examples
  66. --------
  67. We can set the global fft backend:
  68. >>> from scipy.fft import fft, set_global_backend
  69. >>> set_global_backend("scipy") # Sets global backend (default is "scipy").
  70. >>> fft([1]) # Calls the global backend
  71. array([1.+0.j])
  72. """
  73. backend = _backend_from_arg(backend)
  74. ua.set_global_backend(backend, coerce=coerce, only=only, try_last=try_last)
  75. @xp_capabilities(out_of_scope=True)
  76. def register_backend(backend):
  77. """
  78. Register a backend for permanent use.
  79. Registered backends have the lowest priority and will be tried after the
  80. global backend.
  81. Parameters
  82. ----------
  83. backend : {object, 'scipy'}
  84. The backend to use.
  85. Can either be a ``str`` containing the name of a known backend
  86. {'scipy'} or an object that implements the uarray protocol.
  87. Raises
  88. ------
  89. ValueError: If the backend does not implement ``numpy.scipy.fft``.
  90. Examples
  91. --------
  92. We can register a new fft backend:
  93. >>> from scipy.fft import fft, register_backend, set_global_backend
  94. >>> class NoopBackend: # Define an invalid Backend
  95. ... __ua_domain__ = "numpy.scipy.fft"
  96. ... def __ua_function__(self, func, args, kwargs):
  97. ... return NotImplemented
  98. >>> set_global_backend(NoopBackend()) # Set the invalid backend as global
  99. >>> register_backend("scipy") # Register a new backend
  100. # The registered backend is called because
  101. # the global backend returns `NotImplemented`
  102. >>> fft([1])
  103. array([1.+0.j])
  104. >>> set_global_backend("scipy") # Restore global backend to default
  105. """
  106. backend = _backend_from_arg(backend)
  107. ua.register_backend(backend)
  108. @xp_capabilities(out_of_scope=True)
  109. def set_backend(backend, coerce=False, only=False):
  110. """Context manager to set the backend within a fixed scope.
  111. Upon entering the ``with`` statement, the given backend will be added to
  112. the list of available backends with the highest priority. Upon exit, the
  113. backend is reset to the state before entering the scope.
  114. Parameters
  115. ----------
  116. backend : {object, 'scipy'}
  117. The backend to use.
  118. Can either be a ``str`` containing the name of a known backend
  119. {'scipy'} or an object that implements the uarray protocol.
  120. coerce : bool, optional
  121. Whether to allow expensive conversions for the ``x`` parameter. e.g.,
  122. copying a NumPy array to the GPU for a CuPy backend. Implies ``only``.
  123. only : bool, optional
  124. If only is ``True`` and this backend returns ``NotImplemented``, then a
  125. BackendNotImplemented error will be raised immediately. Ignoring any
  126. lower priority backends.
  127. Examples
  128. --------
  129. >>> import scipy.fft as fft
  130. >>> with fft.set_backend('scipy', only=True):
  131. ... fft.fft([1]) # Always calls the scipy implementation
  132. array([1.+0.j])
  133. """
  134. backend = _backend_from_arg(backend)
  135. return ua.set_backend(backend, coerce=coerce, only=only)
  136. @xp_capabilities(out_of_scope=True)
  137. def skip_backend(backend):
  138. """Context manager to skip a backend within a fixed scope.
  139. Within the context of a ``with`` statement, the given backend will not be
  140. called. This covers backends registered both locally and globally. Upon
  141. exit, the backend will again be considered.
  142. Parameters
  143. ----------
  144. backend : {object, 'scipy'}
  145. The backend to skip.
  146. Can either be a ``str`` containing the name of a known backend
  147. {'scipy'} or an object that implements the uarray protocol.
  148. Examples
  149. --------
  150. >>> import scipy.fft as fft
  151. >>> fft.fft([1]) # Calls default SciPy backend
  152. array([1.+0.j])
  153. >>> with fft.skip_backend('scipy'): # We explicitly skip the SciPy backend
  154. ... fft.fft([1]) # leaving no implementation available
  155. Traceback (most recent call last):
  156. ...
  157. BackendNotImplementedError: No selected backends had an implementation ...
  158. """
  159. backend = _backend_from_arg(backend)
  160. return ua.skip_backend(backend)
  161. set_global_backend('scipy', try_last=True)