_exposed_in.py 720 B

123456789101112131415161718192021
  1. from collections.abc import Callable
  2. from typing import TypeVar
  3. F = TypeVar("F")
  4. # Allows one to expose an API in a private submodule publicly as per the definition
  5. # in PyTorch's public api policy.
  6. #
  7. # It is a temporary solution while we figure out if it should be the long-term solution
  8. # or if we should amend PyTorch's public api policy. The concern is that this approach
  9. # may not be very robust because it's not clear what __module__ is used for.
  10. # However, both numpy and jax overwrite the __module__ attribute of their APIs
  11. # without problem, so it seems fine.
  12. def exposed_in(module: str) -> Callable[[F], F]:
  13. def wrapper(fn: F) -> F:
  14. fn.__module__ = module
  15. return fn
  16. return wrapper