utils.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # mypy: allow-untyped-defs
  2. from itertools import chain
  3. from typing import Any
  4. from torch import nn
  5. from torch.nn.utils.parametrize import is_parametrized, type_before_parametrizations
  6. __all__ = [
  7. "module_contains_param",
  8. "swap_module",
  9. "module_to_fqn",
  10. "fqn_to_module",
  11. "get_arg_info_from_tensor_fqn",
  12. "FakeSparsity",
  13. ]
  14. def module_contains_param(module: nn.Module, parametrization: type[nn.Module]) -> bool:
  15. if is_parametrized(module):
  16. # see if any of the module tensors have a parametriztion attached that matches the one passed in
  17. return any(
  18. any(isinstance(param, parametrization) for param in param_list)
  19. for key, param_list in module.parametrizations.items() # type: ignore[union-attr,operator]
  20. )
  21. return False
  22. def swap_module(
  23. mod: nn.Module, mapping: dict[type[nn.Module], type[nn.Module]]
  24. ) -> nn.Module:
  25. r"""Swaps the module using from_dense according to the mapping passed in.
  26. Args:
  27. mod: input module
  28. mapping: a dictionary that maps from nn module to sparse nn module
  29. Return:
  30. The corresponding sparse module of `mod` according to mapping, created using from_dense
  31. """
  32. if type_before_parametrizations(mod) in mapping:
  33. sparse_mod = mapping[type_before_parametrizations(mod)]
  34. # TODO Fix this typing, as Type[Module] has no attribute "from_dense"
  35. new_mod = sparse_mod.from_dense(mod) # type: ignore[attr-defined]
  36. # Preserve module's pre forward hooks. They'll be called on quantized input
  37. for pre_hook_fn in mod._forward_pre_hooks.values():
  38. new_mod.register_forward_pre_hook(pre_hook_fn)
  39. # Preserve module's post forward hooks except _observer_forward_hook
  40. # After convert they'll work with quantized output
  41. for hook_fn in mod._forward_hooks.values():
  42. new_mod.register_forward_hook(hook_fn)
  43. # respect device affinity when swapping modules
  44. # pyrefly: ignore [bad-argument-type]
  45. devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
  46. if len(devices) > 1:
  47. raise AssertionError(
  48. f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
  49. )
  50. device = next(iter(devices)) if len(devices) > 0 else None
  51. if device:
  52. new_mod.to(device)
  53. return new_mod
  54. else:
  55. return mod
  56. def module_to_fqn(model: nn.Module, module: nn.Module, prefix: str = "") -> str | None:
  57. """
  58. Returns the fqn for a module or None if module not a descendent of model.
  59. """
  60. if module is model:
  61. return ""
  62. for name, child in model.named_children():
  63. fqn = module_to_fqn(child, module, ".")
  64. if isinstance(fqn, str):
  65. return prefix + name + fqn
  66. return None
  67. def fqn_to_module(model: nn.Module | None, path: str) -> nn.Module | None:
  68. """
  69. Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path`
  70. doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors.
  71. """
  72. if path != "":
  73. for name in path.split("."):
  74. model = getattr(model, name, None)
  75. return model
  76. def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> dict[str, Any]:
  77. """
  78. Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name
  79. """
  80. # string manip to split tensor_fqn into module_fqn and tensor_name
  81. # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight'
  82. # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight'
  83. tensor_name = tensor_fqn.rsplit(".", maxsplit=1)[-1]
  84. module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)]
  85. module = fqn_to_module(model, module_fqn)
  86. return {
  87. "module_fqn": module_fqn,
  88. "module": module,
  89. "tensor_name": tensor_name,
  90. "tensor_fqn": tensor_fqn,
  91. }
  92. # Parametrizations
  93. class FakeSparsity(nn.Module):
  94. r"""Parametrization for the weights. Should be attached to the 'weight' or
  95. any other parameter that requires a mask applied to it.
  96. Note::
  97. Once the mask is passed, the variable should not change the id. The
  98. contents of the mask can change, but the mask reference itself should
  99. not.
  100. """
  101. def __init__(self, mask):
  102. super().__init__()
  103. self.register_buffer("mask", mask)
  104. def forward(self, x):
  105. if self.mask.shape != x.shape:
  106. raise AssertionError(
  107. f"mask shape ({self.mask.shape}) must match x shape ({x.shape})"
  108. )
  109. return self.mask * x
  110. def state_dict(self, *args, **kwargs):
  111. # We don't want to let the parametrizations to save the mask.
  112. # That way we make sure that the linear module doesn't store the masks
  113. # alongside their parametrizations.
  114. return {}