parallel_apply.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import threading
  2. from collections.abc import Sequence
  3. from typing import Any, cast
  4. import torch
  5. from torch._utils import ExceptionWrapper
  6. from torch.cuda._utils import _get_device_index
  7. from torch.nn.modules import Module
  8. __all__ = ["get_a_var", "parallel_apply"]
  9. def get_a_var(
  10. obj: torch.Tensor | list[Any] | tuple[Any, ...] | dict[Any, Any],
  11. ) -> torch.Tensor | None:
  12. if isinstance(obj, torch.Tensor):
  13. return obj
  14. if isinstance(obj, (list, tuple)):
  15. for result in map(get_a_var, obj):
  16. if isinstance(result, torch.Tensor):
  17. return result
  18. if isinstance(obj, dict):
  19. for result in map(get_a_var, obj.items()):
  20. if isinstance(result, torch.Tensor):
  21. return result
  22. return None
  23. def parallel_apply(
  24. modules: Sequence[Module],
  25. inputs: Sequence[Any],
  26. kwargs_tup: Sequence[dict[str, Any]] | None = None,
  27. devices: Sequence[int | torch.device | None] | None = None,
  28. ) -> list[Any]:
  29. r"""Apply each `module` in :attr:`modules` in parallel on each of :attr:`devices`.
  30. Args:
  31. modules (Module): modules to be parallelized
  32. inputs (tensor): inputs to the modules
  33. devices (list of int or torch.device): CUDA devices
  34. :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
  35. :attr:`devices` (if given) should all have same length. Moreover, each
  36. element of :attr:`inputs` can either be a single object as the only argument
  37. to a module, or a collection of positional arguments.
  38. """
  39. if len(modules) != len(inputs):
  40. raise AssertionError(
  41. f"The number of modules {len(modules)} is not equal to "
  42. f"the number of inputs {len(inputs)}"
  43. )
  44. if kwargs_tup is not None:
  45. if len(modules) != len(kwargs_tup):
  46. raise AssertionError(
  47. f"The number of modules {len(modules)} is not equal to "
  48. f"the number of kwargs_tup {len(kwargs_tup)}"
  49. )
  50. else:
  51. kwargs_tup = (cast(dict[str, Any], {}),) * len(modules)
  52. if devices is not None:
  53. if len(modules) != len(devices):
  54. raise AssertionError(
  55. f"The number of modules {len(modules)} is not equal to "
  56. f"the number of devices {len(devices)}"
  57. )
  58. else:
  59. devices = [None] * len(modules)
  60. devices = [_get_device_index(x, True) for x in devices]
  61. streams = [torch.accelerator.current_stream(x) for x in devices]
  62. if not torch.accelerator.is_available():
  63. raise AssertionError("No available accelerator found.")
  64. device_type = torch.accelerator.current_accelerator().type # type: ignore[union-attr]
  65. lock = threading.Lock()
  66. results = {}
  67. grad_enabled, autocast_enabled = (
  68. torch.is_grad_enabled(),
  69. torch.is_autocast_enabled(),
  70. )
  71. def _worker(
  72. i: int,
  73. module: Module,
  74. input: Any,
  75. kwargs: dict[str, Any],
  76. device: int | torch.device | None = None,
  77. stream: torch.Stream | None = None,
  78. ) -> None:
  79. torch.set_grad_enabled(grad_enabled)
  80. if device is None:
  81. t = get_a_var(input)
  82. if t is None:
  83. with lock:
  84. results[i] = ExceptionWrapper(
  85. where=f"in replica {i}, no device was provided and no tensor input was found; "
  86. "device cannot be resolved"
  87. )
  88. return
  89. device = t.get_device()
  90. if isinstance(device, torch.device):
  91. device = device.index
  92. if stream is None:
  93. stream = torch.accelerator.current_stream(device)
  94. try:
  95. with (
  96. torch.accelerator.device_index(device),
  97. stream,
  98. torch.amp.autocast(device_type, enabled=autocast_enabled),
  99. ):
  100. # this also avoids accidental slicing of `input` if it is a Tensor
  101. if not isinstance(input, (list, tuple)):
  102. input = (input,)
  103. output = module(*input, **kwargs)
  104. with lock:
  105. results[i] = output
  106. except Exception:
  107. with lock:
  108. results[i] = ExceptionWrapper(
  109. where=f"in replica {i} on device {device}"
  110. )
  111. if len(modules) > 1:
  112. threads = [
  113. threading.Thread(
  114. target=_worker, args=(i, module, input, kwargs, device, stream)
  115. )
  116. for i, (module, input, kwargs, device, stream) in enumerate(
  117. zip(modules, inputs, kwargs_tup, devices, streams, strict=True)
  118. )
  119. ]
  120. for thread in threads:
  121. thread.start()
  122. for thread in threads:
  123. thread.join()
  124. else:
  125. _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])
  126. outputs = []
  127. for i in range(len(inputs)):
  128. output = results[i]
  129. if isinstance(output, ExceptionWrapper):
  130. output.reraise()
  131. outputs.append(output)
  132. return outputs