graph_settings.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. import warnings
  4. from typing import Any
  5. from typing_extensions import deprecated
  6. import torch
  7. from torch.utils.data.datapipes.iter.sharding import (
  8. _ShardingIterDataPipe,
  9. SHARDING_PRIORITIES,
  10. )
  11. from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps
  12. __all__ = [
  13. "apply_random_seed",
  14. "apply_sharding",
  15. "apply_shuffle_seed",
  16. "apply_shuffle_settings",
  17. "get_all_graph_pipes",
  18. ]
  19. def get_all_graph_pipes(graph: DataPipeGraph) -> list[DataPipe]:
  20. return _get_all_graph_pipes_helper(graph, set())
  21. def _get_all_graph_pipes_helper(
  22. graph: DataPipeGraph, id_cache: set[int]
  23. ) -> list[DataPipe]:
  24. results: list[DataPipe] = []
  25. for dp_id, (datapipe, sub_graph) in graph.items():
  26. if dp_id in id_cache:
  27. continue
  28. id_cache.add(dp_id)
  29. results.append(datapipe)
  30. results.extend(_get_all_graph_pipes_helper(sub_graph, id_cache))
  31. return results
  32. def _is_sharding_datapipe(datapipe: DataPipe) -> bool:
  33. return isinstance(datapipe, _ShardingIterDataPipe) or (
  34. hasattr(datapipe, "apply_sharding")
  35. and inspect.ismethod(datapipe.apply_sharding)
  36. )
  37. def apply_sharding(
  38. datapipe: DataPipe,
  39. num_of_instances: int,
  40. instance_id: int,
  41. sharding_group=SHARDING_PRIORITIES.DEFAULT,
  42. ) -> DataPipe:
  43. r"""
  44. Apply dynamic sharding over the ``sharding_filter`` DataPipe that has a method ``apply_sharding``.
  45. RuntimeError will be raised when multiple ``sharding_filter`` are presented in the same branch.
  46. """
  47. graph = traverse_dps(datapipe)
  48. def _helper(graph, prev_applied=None) -> None:
  49. for dp, sub_graph in graph.values():
  50. applied = None
  51. if _is_sharding_datapipe(dp):
  52. if prev_applied is not None:
  53. raise RuntimeError(
  54. "Sharding twice on a single pipeline is likely unintended and will cause data loss. "
  55. f"Sharding already applied to {prev_applied} while trying to apply to {dp}"
  56. )
  57. # For BC, only provide sharding_group if accepted
  58. sig = inspect.signature(dp.apply_sharding)
  59. if len(sig.parameters) < 3:
  60. dp.apply_sharding(num_of_instances, instance_id)
  61. else:
  62. dp.apply_sharding(
  63. num_of_instances, instance_id, sharding_group=sharding_group
  64. )
  65. applied = dp
  66. if applied is None:
  67. applied = prev_applied
  68. _helper(sub_graph, applied)
  69. _helper(graph)
  70. return datapipe
  71. def _is_shuffle_datapipe(datapipe: DataPipe) -> bool:
  72. return (
  73. hasattr(datapipe, "set_shuffle")
  74. and hasattr(datapipe, "set_seed")
  75. and inspect.ismethod(datapipe.set_shuffle)
  76. and inspect.ismethod(datapipe.set_seed)
  77. )
  78. def apply_shuffle_settings(datapipe: DataPipe, shuffle: bool | None = None) -> DataPipe:
  79. r"""
  80. Traverse the graph of ``DataPipes`` to find and set shuffle attribute.
  81. Apply the method to each `DataPipe` that has APIs of ``set_shuffle``
  82. and ``set_seed``.
  83. Args:
  84. datapipe: DataPipe that needs to set shuffle attribute
  85. shuffle: Shuffle option (default: ``None`` and no-op to the graph)
  86. """
  87. if shuffle is None:
  88. return datapipe
  89. graph = traverse_dps(datapipe)
  90. all_pipes = get_all_graph_pipes(graph)
  91. shufflers = [pipe for pipe in all_pipes if _is_shuffle_datapipe(pipe)]
  92. if not shufflers and shuffle:
  93. warnings.warn(
  94. "`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. "
  95. "Be aware that the default buffer size might not be sufficient for your task.",
  96. stacklevel=2,
  97. )
  98. datapipe = datapipe.shuffle()
  99. shufflers = [
  100. datapipe,
  101. ]
  102. for shuffler in shufflers:
  103. shuffler.set_shuffle(shuffle)
  104. return datapipe
  105. @deprecated(
  106. "`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases. "
  107. "Please use `apply_random_seed` instead.",
  108. category=FutureWarning,
  109. )
  110. def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe:
  111. return apply_random_seed(datapipe, rng)
  112. def _is_random_datapipe(datapipe: DataPipe) -> bool:
  113. return hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed)
  114. def apply_random_seed(datapipe: DataPipe, rng: torch.Generator) -> DataPipe:
  115. r"""
  116. Traverse the graph of ``DataPipes`` to find random ``DataPipe`` with an API of ``set_seed``.
  117. Then set the random seed based on the provided RNG to those ``DataPipe``.
  118. Args:
  119. datapipe: DataPipe that needs to set randomness
  120. rng: Random number generator to generate random seeds
  121. """
  122. graph = traverse_dps(datapipe)
  123. all_pipes = get_all_graph_pipes(graph)
  124. # Using a set to track id of DataPipe to prevent setting randomness per DataPipe more than once.
  125. # And, `id` is used in case of unhashable DataPipe
  126. cache = set()
  127. random_datapipes = []
  128. for pipe in all_pipes:
  129. if id(pipe) in cache:
  130. continue
  131. if _is_random_datapipe(pipe):
  132. random_datapipes.append(pipe)
  133. cache.add(id(pipe))
  134. for pipe in random_datapipes:
  135. random_seed = int(
  136. torch.empty((), dtype=torch.int64).random_(generator=rng).item()
  137. )
  138. pipe.set_seed(random_seed)
  139. return datapipe