_composable_state.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import weakref
  2. from typing import cast
  3. import torch.nn as nn
  4. class _State:
  5. pass
  6. _module_state_mapping: weakref.WeakKeyDictionary[
  7. nn.Module, weakref.ReferenceType[_State]
  8. ] = weakref.WeakKeyDictionary()
  9. def _insert_module_state(module: nn.Module, state: _State) -> None:
  10. global _module_state_mapping
  11. if module in _module_state_mapping:
  12. raise AssertionError(f"Inserting {module} more than once.")
  13. _module_state_mapping[module] = weakref.ref(state)
  14. def _get_module_state(module: nn.Module) -> _State | None:
  15. """
  16. Return the ``_State`` in ``model``.
  17. Given a ``module``, this API finds out if the module is also a ``_State``
  18. instance or if the module is managed by a composable API. If the module
  19. is also a ``_State``, ``module`` will be casted to ``_State` and returned.
  20. If it is managed by a composable API, the corresponding ``_State`` will
  21. be returned.
  22. """
  23. global _module_state_mapping
  24. if isinstance(module, _State):
  25. return cast(_State, module)
  26. else:
  27. # https://github.com/pytorch/pytorch/issues/107054
  28. if module in _module_state_mapping:
  29. state_ref = _module_state_mapping[module]
  30. state = state_ref()
  31. if state is None:
  32. raise AssertionError("State has already been garbage collected")
  33. return state
  34. else:
  35. return None