graph.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # mypy: allow-untyped-defs
  2. import io
  3. import pickle
  4. import warnings
  5. from collections.abc import Collection
  6. from torch.utils._import_utils import dill_available
  7. from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe
  8. __all__ = ["traverse", "traverse_dps"]
  9. DataPipe = IterDataPipe | MapDataPipe
  10. DataPipeGraph = dict[int, tuple[DataPipe, "DataPipeGraph"]]
  11. def _stub_unpickler() -> str:
  12. return "STUB"
  13. # TODO(VitalyFedyunin): Make sure it works without dill module installed
  14. def _list_connected_datapipes(
  15. scan_obj: DataPipe, only_datapipe: bool, cache: set[int]
  16. ) -> list[DataPipe]:
  17. f = io.BytesIO()
  18. p = pickle.Pickler(
  19. f
  20. ) # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is
  21. if dill_available():
  22. from dill import Pickler as dill_Pickler
  23. d = dill_Pickler(f)
  24. else:
  25. d = None
  26. captured_connections = []
  27. def getstate_hook(ori_state):
  28. state = None
  29. if isinstance(ori_state, dict):
  30. state = {}
  31. for k, v in ori_state.items():
  32. if isinstance(v, (IterDataPipe, MapDataPipe, Collection)):
  33. state[k] = v
  34. elif isinstance(ori_state, (tuple, list)):
  35. state = [] # type: ignore[assignment]
  36. for v in ori_state:
  37. if isinstance(v, (IterDataPipe, MapDataPipe, Collection)):
  38. state.append(v) # type: ignore[attr-defined]
  39. elif isinstance(ori_state, (IterDataPipe, MapDataPipe, Collection)):
  40. state = ori_state # type: ignore[assignment]
  41. return state
  42. def reduce_hook(obj):
  43. if obj == scan_obj or id(obj) in cache:
  44. raise NotImplementedError
  45. else:
  46. captured_connections.append(obj)
  47. # Adding id to remove duplicate DataPipe serialized at the same level
  48. cache.add(id(obj))
  49. return _stub_unpickler, ()
  50. datapipe_classes: tuple[type[DataPipe]] = (IterDataPipe, MapDataPipe) # type: ignore[assignment]
  51. try:
  52. for cls in datapipe_classes:
  53. cls.set_reduce_ex_hook(reduce_hook)
  54. if only_datapipe:
  55. cls.set_getstate_hook(getstate_hook)
  56. try:
  57. p.dump(scan_obj)
  58. except (pickle.PickleError, AttributeError, TypeError):
  59. if dill_available():
  60. # pyrefly: ignore [missing-attribute]
  61. d.dump(scan_obj)
  62. else:
  63. raise
  64. finally:
  65. for cls in datapipe_classes:
  66. cls.set_reduce_ex_hook(None)
  67. if only_datapipe:
  68. cls.set_getstate_hook(None)
  69. if dill_available():
  70. from dill import extend as dill_extend
  71. dill_extend(False) # Undo change to dispatch table
  72. return captured_connections
  73. def traverse_dps(datapipe: DataPipe) -> DataPipeGraph:
  74. r"""
  75. Traverse the DataPipes and their attributes to extract the DataPipe graph.
  76. This only looks into the attribute from each DataPipe that is either a
  77. DataPipe and a Python collection object such as ``list``, ``tuple``,
  78. ``set`` and ``dict``.
  79. Args:
  80. datapipe: the end DataPipe of the graph
  81. Returns:
  82. A graph represented as a nested dictionary, where keys are ids of DataPipe instances
  83. and values are tuples of DataPipe instance and the sub-graph
  84. """
  85. cache: set[int] = set()
  86. return _traverse_helper(datapipe, only_datapipe=True, cache=cache)
  87. def traverse(datapipe: DataPipe, only_datapipe: bool | None = None) -> DataPipeGraph:
  88. r"""
  89. Traverse the DataPipes and their attributes to extract the DataPipe graph.
  90. [Deprecated]
  91. When ``only_dataPipe`` is specified as ``True``, it would only look into the
  92. attribute from each DataPipe that is either a DataPipe and a Python collection object
  93. such as ``list``, ``tuple``, ``set`` and ``dict``.
  94. Note:
  95. This function is deprecated. Please use `traverse_dps` instead.
  96. Args:
  97. datapipe: the end DataPipe of the graph
  98. only_datapipe: If ``False`` (default), all attributes of each DataPipe are traversed.
  99. This argument is deprecating and will be removed after the next release.
  100. Returns:
  101. A graph represented as a nested dictionary, where keys are ids of DataPipe instances
  102. and values are tuples of DataPipe instance and the sub-graph
  103. """
  104. msg = (
  105. "`traverse` function and will be removed after 1.13. "
  106. "Please use `traverse_dps` instead."
  107. )
  108. if not only_datapipe:
  109. msg += " And, the behavior will be changed to the equivalent of `only_datapipe=True`."
  110. warnings.warn(msg, FutureWarning, stacklevel=2)
  111. if only_datapipe is None:
  112. only_datapipe = False
  113. cache: set[int] = set()
  114. return _traverse_helper(datapipe, only_datapipe, cache)
  115. # Add cache here to prevent infinite recursion on DataPipe
  116. def _traverse_helper(
  117. datapipe: DataPipe, only_datapipe: bool, cache: set[int]
  118. ) -> DataPipeGraph:
  119. if not isinstance(datapipe, (IterDataPipe, MapDataPipe)):
  120. raise RuntimeError(
  121. f"Expected `IterDataPipe` or `MapDataPipe`, but {type(datapipe)} is found"
  122. )
  123. dp_id = id(datapipe)
  124. if dp_id in cache:
  125. return {}
  126. cache.add(dp_id)
  127. # Using cache.copy() here is to prevent the same DataPipe pollutes the cache on different paths
  128. items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy())
  129. d: DataPipeGraph = {dp_id: (datapipe, {})}
  130. for item in items:
  131. # Using cache.copy() here is to prevent recursion on a single path rather than global graph
  132. # Single DataPipe can present multiple times in different paths in graph
  133. # pyrefly: ignore [no-matching-overload]
  134. d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
  135. return d