replicate.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. from collections import OrderedDict
  2. from collections.abc import Iterator, Sequence
  3. from typing import cast, TYPE_CHECKING, TypeVar
  4. from typing_extensions import TypeIs
  5. import torch
  6. from torch._utils import _get_device_index
  7. from torch.nn.modules import Module
  8. from torch.nn.parallel import comm
  9. if TYPE_CHECKING:
  10. from torch._C import ScriptMethod
  11. from torch.jit import ScriptModule
  12. from torch.jit._state import EnabledProxy
  13. __all__ = ["replicate"]
  14. def _is_script_module(module: Module) -> TypeIs["ScriptModule"]:
  15. import torch.jit
  16. return isinstance(module, torch.jit.ScriptModule)
  17. def _is_script_method(module: object) -> TypeIs["ScriptMethod"]:
  18. import torch.jit
  19. return isinstance(module, torch._C.ScriptMethod)
  20. def _init_script_module() -> "ScriptModule":
  21. import torch.jit
  22. return torch.jit.ScriptModule()
  23. def _is_jit_enabled() -> "EnabledProxy":
  24. import torch.jit._state
  25. return torch.jit._state._enabled
  26. # Check if we can safely replicate the module.
  27. # there are two types of module:
  28. # 1. python modules
  29. # 2. ScriptModule
  30. #
  31. # currently a module cannot be replicated properly if the descendants of
  32. # any ScriptModule contains python module (type 1 above)
  33. def _replicatable_module(module: Module, memo: set[Module] | None = None) -> bool:
  34. # module.modules() contains module itself as the first element
  35. def descendant_modules(module: Module) -> Iterator[Module]:
  36. gen = module.modules()
  37. next(gen)
  38. return gen
  39. if not _is_jit_enabled():
  40. return True
  41. if memo is None:
  42. memo = set()
  43. # memoize visited modules
  44. memo.add(module)
  45. if _is_script_module(module):
  46. memo.update(descendant_modules(module))
  47. return all(
  48. _is_script_module(descendant) for descendant in descendant_modules(module)
  49. )
  50. for child in module.children():
  51. # since any unreplicatable module will cause the check to return
  52. # False early, visited modules here can be safely ignored.
  53. if child in memo:
  54. continue
  55. if not _replicatable_module(child, memo):
  56. return False
  57. return True
  58. def _broadcast_coalesced_reshape(
  59. tensors: Sequence[torch.Tensor],
  60. devices: Sequence[int | torch.device],
  61. detach: bool = False,
  62. ) -> list[list[torch.Tensor]]:
  63. from torch.nn.parallel._functions import Broadcast
  64. if len(tensors) == 0:
  65. return []
  66. if detach:
  67. complex_mask = [
  68. not isinstance(t, torch.nn.UninitializedParameter) and t.is_complex()
  69. for t in tensors
  70. ]
  71. outputs = comm.broadcast_coalesced(tensors, devices)
  72. for device_outputs in outputs:
  73. for i, is_complex in enumerate(complex_mask):
  74. if is_complex:
  75. device_outputs[i] = torch.view_as_complex(device_outputs[i])
  76. return outputs
  77. else:
  78. tensor_copies = Broadcast.apply(devices, *tensors)
  79. return [
  80. list(tensor_copies[i : i + len(tensors)])
  81. for i in range(0, len(tensor_copies), len(tensors))
  82. ]
  83. T = TypeVar("T", bound=Module)
  84. def replicate(
  85. network: T,
  86. devices: Sequence[int | torch.device],
  87. detach: bool = False,
  88. ) -> list[T]:
  89. if not _replicatable_module(network):
  90. raise RuntimeError(
  91. "Cannot replicate network where python modules are children of ScriptModule"
  92. )
  93. if not devices:
  94. return []
  95. devices = [_get_device_index(x, True) for x in devices]
  96. num_replicas = len(devices)
  97. params = list(network.parameters())
  98. param_indices = {param: idx for idx, param in enumerate(params)}
  99. param_copies = _broadcast_coalesced_reshape(params, devices, detach)
  100. buffers = list(network.buffers())
  101. buffers_rg: list[torch.Tensor] = []
  102. buffers_not_rg: list[torch.Tensor] = []
  103. for buf in buffers:
  104. if buf.requires_grad and not detach:
  105. buffers_rg.append(buf)
  106. else:
  107. buffers_not_rg.append(buf)
  108. buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
  109. buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}
  110. buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach)
  111. buffer_copies_not_rg = _broadcast_coalesced_reshape(
  112. buffers_not_rg, devices, detach=True
  113. )
  114. modules = list(network.modules())
  115. module_copies: list[list[Module]] = [[] for _ in devices]
  116. module_indices: dict[Module, int] = {}
  117. for i, module in enumerate(modules):
  118. module_indices[module] = i
  119. for j in range(num_replicas):
  120. replica = module._replicate_for_data_parallel()
  121. # This is a temporary fix for DDP. DDP needs to access the
  122. # replicated model parameters. It used to do so through
  123. # `mode.parameters()`. The fix added in #33907 for DP stops the
  124. # `parameters()` API from exposing the replicated parameters.
  125. # Hence, we add a `_former_parameters` dict here to support DDP.
  126. replica._former_parameters = OrderedDict()
  127. module_copies[j].append(replica)
  128. for i, module in enumerate(modules):
  129. for key, child in module._modules.items():
  130. if child is None:
  131. for j in range(num_replicas):
  132. replica = module_copies[j][i]
  133. replica._modules[key] = None
  134. else:
  135. module_idx = module_indices[child]
  136. for j in range(num_replicas):
  137. replica = module_copies[j][i]
  138. setattr(replica, key, module_copies[j][module_idx])
  139. for key, param in module._parameters.items():
  140. if param is None:
  141. for j in range(num_replicas):
  142. replica = module_copies[j][i]
  143. replica._parameters[key] = None
  144. else:
  145. param_idx = param_indices[param]
  146. for j in range(num_replicas):
  147. replica = module_copies[j][i]
  148. param_copy = param_copies[j][param_idx]
  149. # parameters in replicas are no longer leaves,
  150. # so setattr them as non-parameter attributes
  151. setattr(replica, key, param_copy)
  152. # expose the parameter for DDP
  153. replica._former_parameters[key] = param_copy # type: ignore[operator, index]
  154. for key, buf in module._buffers.items(): # type: ignore[assignment]
  155. if buf is None:
  156. for j in range(num_replicas):
  157. replica = module_copies[j][i]
  158. replica._buffers[key] = None
  159. else:
  160. if buf.requires_grad and not detach:
  161. buffer_copies = buffer_copies_rg
  162. buffer_idx = buffer_indices_rg[buf]
  163. else:
  164. buffer_copies = buffer_copies_not_rg
  165. buffer_idx = buffer_indices_not_rg[buf]
  166. for j in range(num_replicas):
  167. replica = module_copies[j][i]
  168. setattr(replica, key, buffer_copies[j][buffer_idx])
  169. return [cast(T, module_copies[j][0]) for j in range(num_replicas)]