gds.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import os
  2. import sys
  3. from collections.abc import Callable
  4. import torch
  5. from torch.types import Storage
  6. __all__: list[str] = [
  7. "gds_register_buffer",
  8. "gds_deregister_buffer",
  9. "GdsFile",
  10. ]
  11. def _dummy_fn(name: str) -> Callable:
  12. def fn(*args, **kwargs): # type: ignore[no-untyped-def]
  13. raise RuntimeError(f"torch._C.{name} is not supported on this platform")
  14. return fn
  15. if not hasattr(torch._C, "_gds_register_buffer"):
  16. if hasattr(torch._C, "_gds_deregister_buffer"):
  17. raise AssertionError(
  18. "_gds_deregister_buffer exists but _gds_register_buffer does not"
  19. )
  20. if hasattr(torch._C, "_gds_register_handle"):
  21. raise AssertionError(
  22. "_gds_register_handle exists but _gds_register_buffer does not"
  23. )
  24. if hasattr(torch._C, "_gds_deregister_handle"):
  25. raise AssertionError(
  26. "_gds_deregister_handle exists but _gds_register_buffer does not"
  27. )
  28. if hasattr(torch._C, "_gds_load_storage"):
  29. raise AssertionError(
  30. "_gds_load_storage exists but _gds_register_buffer does not"
  31. )
  32. if hasattr(torch._C, "_gds_save_storage"):
  33. raise AssertionError(
  34. "_gds_save_storage exists but _gds_register_buffer does not"
  35. )
  36. # Define functions
  37. torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer")
  38. torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer")
  39. torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle")
  40. torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle")
  41. torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage")
  42. torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage")
  43. def gds_register_buffer(s: Storage) -> None:
  44. """Registers a storage on a CUDA device as a cufile buffer.
  45. Example::
  46. >>> # xdoctest: +SKIP("gds filesystem requirements")
  47. >>> src = torch.randn(1024, device="cuda")
  48. >>> s = src.untyped_storage()
  49. >>> gds_register_buffer(s)
  50. Args:
  51. s (Storage): Buffer to register.
  52. """
  53. torch._C._gds_register_buffer(s)
  54. def gds_deregister_buffer(s: Storage) -> None:
  55. """Deregisters a previously registered storage on a CUDA device as a cufile buffer.
  56. Example::
  57. >>> # xdoctest: +SKIP("gds filesystem requirements")
  58. >>> src = torch.randn(1024, device="cuda")
  59. >>> s = src.untyped_storage()
  60. >>> gds_register_buffer(s)
  61. >>> gds_deregister_buffer(s)
  62. Args:
  63. s (Storage): Buffer to register.
  64. """
  65. torch._C._gds_deregister_buffer(s)
  66. class GdsFile:
  67. r"""Wrapper around cuFile.
  68. cuFile is a file-like interface to the GPUDirect Storage (GDS) API.
  69. See the `cufile docs <https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api>`_
  70. for more details.
  71. Args:
  72. filename (str): Name of the file to open.
  73. flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will
  74. be added automatically.
  75. Example::
  76. >>> # xdoctest: +SKIP("gds filesystem requirements")
  77. >>> src1 = torch.randn(1024, device="cuda")
  78. >>> src2 = torch.randn(2, 1024, device="cuda")
  79. >>> file = torch.cuda.gds.GdsFile(f, os.O_CREAT | os.O_RDWR)
  80. >>> file.save_storage(src1.untyped_storage(), offset=0)
  81. >>> file.save_storage(src2.untyped_storage(), offset=src1.nbytes)
  82. >>> dest1 = torch.empty(1024, device="cuda")
  83. >>> dest2 = torch.empty(2, 1024, device="cuda")
  84. >>> file.load_storage(dest1.untyped_storage(), offset=0)
  85. >>> file.load_storage(dest2.untyped_storage(), offset=src1.nbytes)
  86. >>> torch.equal(src1, dest1)
  87. True
  88. >>> torch.equal(src2, dest2)
  89. True
  90. """
  91. def __init__(self, filename: str, flags: int):
  92. if sys.platform == "win32":
  93. raise RuntimeError("GdsFile is not supported on this platform.")
  94. self.filename = filename
  95. self.flags = flags
  96. self.fd = os.open(filename, flags | os.O_DIRECT) # type: ignore[attr-defined]
  97. self.handle: int | None = None
  98. self.register_handle()
  99. def __del__(self) -> None:
  100. if self.handle is not None:
  101. self.deregister_handle()
  102. os.close(self.fd)
  103. def register_handle(self) -> None:
  104. """Registers file descriptor to cuFile Driver.
  105. This is a wrapper around ``cuFileHandleRegister``.
  106. """
  107. if self.handle is not None:
  108. raise AssertionError("Cannot register a handle that is already registered.")
  109. self.handle = torch._C._gds_register_handle(self.fd)
  110. def deregister_handle(self) -> None:
  111. """Deregisters file descriptor from cuFile Driver.
  112. This is a wrapper around ``cuFileHandleDeregister``.
  113. """
  114. if self.handle is None:
  115. raise AssertionError("Cannot deregister a handle that is not registered.")
  116. torch._C._gds_deregister_handle(self.handle)
  117. self.handle = None
  118. def load_storage(self, storage: Storage, offset: int = 0) -> None:
  119. """Loads data from the file into the storage.
  120. This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data
  121. will be loaded from the file at ``offset`` into the storage.
  122. Args:
  123. storage (Storage): Storage to load data into.
  124. offset (int, optional): Offset into the file to start loading from. (Default: 0)
  125. """
  126. if self.handle is None:
  127. raise AssertionError("Cannot load data from a file that is not registered.")
  128. torch._C._gds_load_storage(self.handle, storage, offset)
  129. def save_storage(self, storage: Storage, offset: int = 0) -> None:
  130. """Saves data from the storage into the file.
  131. This is a wrapper around ``cuFileWrite``. All bytes of the storage
  132. will be written to the file at ``offset``.
  133. Args:
  134. storage (Storage): Storage to save data from.
  135. offset (int, optional): Offset into the file to start saving to. (Default: 0)
  136. """
  137. if self.handle is None:
  138. raise AssertionError("Cannot save data to a file that is not registered.")
  139. torch._C._gds_save_storage(self.handle, storage, offset)