spawn.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. # mypy: allow-untyped-defs
  2. import logging
  3. import multiprocessing
  4. import multiprocessing.connection
  5. import os
  6. import pickle
  7. import signal
  8. import sys
  9. import tempfile
  10. import time
  11. import warnings
  12. from concurrent.futures import as_completed, ThreadPoolExecutor
  13. from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
  14. ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START"
  15. log = logging.getLogger(__name__)
  16. __all__ = [
  17. "ProcessContext",
  18. "ProcessException",
  19. "ProcessExitedException",
  20. "ProcessRaisedException",
  21. "spawn",
  22. "SpawnContext",
  23. "start_processes",
  24. ]
  25. class ProcessException(Exception):
  26. __slots__ = ["error_index", "error_pid"]
  27. def __init__(self, msg: str, error_index: int, error_pid: int):
  28. super().__init__(msg)
  29. self.msg = msg
  30. self.error_index = error_index
  31. self.error_pid = error_pid
  32. def __reduce__(self):
  33. return type(self), (self.msg, self.error_index, self.error_pid)
  34. class ProcessRaisedException(ProcessException):
  35. """Exception raised when a process failed due to an exception raised by the code."""
  36. class ProcessExitedException(ProcessException):
  37. """Exception raised when a process failed due to signal or exited with a specific code."""
  38. __slots__ = ["exit_code"]
  39. def __init__(
  40. self,
  41. msg: str,
  42. error_index: int,
  43. error_pid: int,
  44. exit_code: int,
  45. signal_name: str | None = None,
  46. ):
  47. super().__init__(msg, error_index, error_pid)
  48. self.exit_code = exit_code
  49. self.signal_name = signal_name
  50. def __reduce__(self):
  51. return (
  52. type(self),
  53. (
  54. self.msg,
  55. self.error_index,
  56. self.error_pid,
  57. self.exit_code,
  58. self.signal_name,
  59. ),
  60. )
  61. def _wrap(fn, i, args, error_file):
  62. # prctl(2) is a Linux specific system call.
  63. # On other systems the following function call has no effect.
  64. # This is set to ensure that non-daemonic child processes can
  65. # terminate if their parent terminates before they do.
  66. _prctl_pr_set_pdeathsig(signal.SIGINT)
  67. try:
  68. fn(i, *args)
  69. except KeyboardInterrupt:
  70. pass # SIGINT; Killed by parent, do nothing
  71. except Exception:
  72. # Propagate exception to parent process, keeping original traceback
  73. import traceback
  74. with open(error_file, "wb") as fh:
  75. pickle.dump(traceback.format_exc(), fh)
  76. sys.exit(1)
  77. class ProcessContext:
  78. def __init__(self, processes, error_files):
  79. self.error_files = error_files
  80. self.processes = processes
  81. self.sentinels = {
  82. process.sentinel: index for index, process in enumerate(processes)
  83. }
  84. def pids(self):
  85. return [int(process.pid) for process in self.processes]
  86. def _join_procs_with_timeout(self, timeout: float):
  87. """Attempt to join all processes with a shared timeout."""
  88. end = time.monotonic() + timeout
  89. for process in self.processes:
  90. # pyrefly: ignore [no-matching-overload]
  91. time_to_wait = max(0, end - time.monotonic())
  92. process.join(time_to_wait)
  93. def join(self, timeout: float | None = None, grace_period: float | None = None):
  94. r"""Join one or more processes within spawn context.
  95. Attempt to join one or more processes in this spawn context.
  96. If one of them exited with a non-zero exit status, this function
  97. kills the remaining processes (optionally with a grace period)
  98. and raises an exception with the cause of the first process exiting.
  99. Returns ``True`` if all processes have been joined successfully,
  100. ``False`` if there are more processes that need to be joined.
  101. Args:
  102. timeout (float): Wait this long (in seconds) before giving up on waiting.
  103. grace_period (float): When any processes fail, wait this long (in seconds)
  104. for others to shutdown gracefully before terminating them. If they
  105. still don't exit, wait another grace period before killing them.
  106. """
  107. # Ensure this function can be called even when we're done.
  108. if len(self.sentinels) == 0:
  109. return True
  110. # Wait for any process to fail or all of them to succeed.
  111. ready = multiprocessing.connection.wait(
  112. self.sentinels.keys(),
  113. timeout=timeout,
  114. )
  115. error_index = None
  116. for sentinel in ready:
  117. index = self.sentinels.pop(sentinel)
  118. process = self.processes[index]
  119. process.join()
  120. if process.exitcode != 0:
  121. error_index = index
  122. break
  123. # Return if there was no error.
  124. if error_index is None:
  125. # Return whether or not all processes have been joined.
  126. return len(self.sentinels) == 0
  127. # An error occurred. Clean-up all processes before returning.
  128. # First, allow a grace period for processes to shutdown themselves.
  129. if grace_period is not None:
  130. self._join_procs_with_timeout(grace_period)
  131. # Then, terminate processes that are still alive. Try SIGTERM first.
  132. for process in self.processes:
  133. if process.is_alive():
  134. log.warning("Terminating process %s via signal SIGTERM", process.pid)
  135. process.terminate()
  136. # Try SIGKILL if the process isn't going down after another grace_period.
  137. # The reason is related to python signal handling is limited
  138. # to main thread and if that is in c/c++ land and stuck it won't
  139. # to handle it. We have seen processes getting stuck not handling
  140. # SIGTERM for the above reason.
  141. self._join_procs_with_timeout(30 if grace_period is None else grace_period)
  142. for process in self.processes:
  143. if process.is_alive():
  144. log.warning(
  145. "Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL",
  146. process.pid,
  147. )
  148. process.kill()
  149. process.join()
  150. # The file will only be created if the process crashed.
  151. failed_process = self.processes[error_index]
  152. if not os.access(self.error_files[error_index], os.R_OK):
  153. exitcode = self.processes[error_index].exitcode
  154. if exitcode < 0:
  155. try:
  156. name = signal.Signals(-exitcode).name
  157. except ValueError:
  158. name = f"<Unknown signal {-exitcode}>"
  159. raise ProcessExitedException(
  160. f"process {error_index:d} terminated with signal {name}",
  161. error_index=error_index,
  162. error_pid=failed_process.pid,
  163. exit_code=exitcode,
  164. signal_name=name,
  165. )
  166. else:
  167. raise ProcessExitedException(
  168. f"process {error_index:d} terminated with exit code {exitcode:d}",
  169. error_index=error_index,
  170. error_pid=failed_process.pid,
  171. exit_code=exitcode,
  172. )
  173. with open(self.error_files[error_index], "rb") as fh:
  174. original_trace = pickle.load(fh)
  175. msg = f"\n\n-- Process {error_index:d} terminated with the following error:\n"
  176. msg += original_trace
  177. raise ProcessRaisedException(msg, error_index, failed_process.pid)
  178. class SpawnContext(ProcessContext):
  179. def __init__(self, processes, error_files):
  180. warnings.warn(
  181. "SpawnContext is renamed to ProcessContext since 1.4 release.", stacklevel=2
  182. )
  183. super().__init__(processes, error_files)
  184. # Note: [start_processes]
  185. # mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a
  186. # more generalized API than mp.spawn. Currently we only document mp.spawn as it's the
  187. # CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork'
  188. # works better than 'spawn'. Every helper function we created for mp.spawn is indeed
  189. # general enough, and backends like XLA can reuse them in Colab notebooks as well.
  190. # Currently we only add this API first, we can consider adding it to documentation as
  191. # needed in the future.
  192. def start_processes(
  193. fn,
  194. args=(),
  195. nprocs=1,
  196. join=True,
  197. daemon=False,
  198. start_method="spawn",
  199. ):
  200. # To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
  201. # this func will start processes in parallel if start_method is 'forkserver'.
  202. # Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1.
  203. # todo: investigate why spawn does not work with threadpool and raises SIGINT
  204. if (
  205. start_method == "forkserver"
  206. and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1"
  207. ):
  208. log.info("Starting processes in parallel.")
  209. start_parallel = True
  210. else:
  211. # Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
  212. start_parallel = False
  213. mp = multiprocessing.get_context(start_method)
  214. error_files = [None] * nprocs
  215. processes = [None] * nprocs
  216. def start_process(i):
  217. # Each process is assigned a file to write tracebacks to. We
  218. # use the file being non-empty to indicate an exception
  219. # occurred (vs an expected shutdown). Note: this previously
  220. # used a multiprocessing.Queue but that can be prone to
  221. # deadlocks, so we went with a simpler solution for a one-shot
  222. # message between processes.
  223. tf = tempfile.NamedTemporaryFile( # noqa: SIM115
  224. prefix="pytorch-errorfile-", suffix=".pickle", delete=False
  225. )
  226. tf.close()
  227. os.unlink(tf.name)
  228. process = mp.Process( # pyrefly: ignore # missing-attribute
  229. target=_wrap,
  230. args=(fn, i, args, tf.name),
  231. daemon=daemon,
  232. )
  233. process.start()
  234. return i, process, tf.name
  235. if not start_parallel:
  236. for i in range(nprocs):
  237. idx, process, tf_name = start_process(i)
  238. error_files[idx] = tf_name
  239. processes[idx] = process
  240. else:
  241. with ThreadPoolExecutor(max_workers=nprocs) as executor:
  242. futures = [executor.submit(start_process, i) for i in range(nprocs)]
  243. for fut in as_completed(futures):
  244. idx, process, tf_name = fut.result()
  245. # idx and process rank needs to be the same.
  246. error_files[idx] = tf_name
  247. processes[idx] = process
  248. context = ProcessContext(processes, error_files)
  249. if not join:
  250. return context
  251. # Loop on join until it returns True or raises an exception.
  252. while not context.join():
  253. pass
  254. def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"):
  255. r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
  256. If one of the processes exits with a non-zero exit status, the
  257. remaining processes are killed and an exception is raised with the
  258. cause of termination. In the case an exception was caught in the
  259. child process, it is forwarded and its traceback is included in
  260. the exception raised in the parent process.
  261. Args:
  262. fn (function): Function is called as the entrypoint of the
  263. spawned process. This function must be defined at the top
  264. level of a module so it can be pickled and spawned. This
  265. is a requirement imposed by multiprocessing.
  266. The function is called as ``fn(i, *args)``, where ``i`` is
  267. the process index and ``args`` is the passed through tuple
  268. of arguments.
  269. args (tuple): Arguments passed to ``fn``.
  270. nprocs (int): Number of processes to spawn.
  271. join (bool): Perform a blocking join on all processes.
  272. daemon (bool): The spawned processes' daemon flag. If set to True,
  273. daemonic processes will be created.
  274. start_method (str): (deprecated) this method will always use ``spawn``
  275. as the start method. To use a different start method
  276. use ``start_processes()``.
  277. Returns:
  278. None if ``join`` is ``True``,
  279. :class:`~ProcessContext` if ``join`` is ``False``
  280. """
  281. if start_method != "spawn":
  282. msg = (
  283. f"This method only supports start_method=spawn (got: {start_method}).\n"
  284. "To use a different start_method use:\n\t\t"
  285. " torch.multiprocessing.start_processes(...)"
  286. )
  287. warnings.warn(msg, FutureWarning, stacklevel=2)
  288. return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")