compression.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. """Helper functions for a standard streaming compression API"""
  2. import sys
  3. from zipfile import ZipFile
  4. import fsspec.utils
  5. from fsspec.spec import AbstractBufferedFile
  6. def noop_file(file, mode, **kwargs):
  7. return file
  8. # TODO: files should also be available as contexts
  9. # should be functions of the form func(infile, mode=, **kwargs) -> file-like
  10. compr = {None: noop_file}
  11. def register_compression(name, callback, extensions, force=False):
  12. """Register an "inferable" file compression type.
  13. Registers transparent file compression type for use with fsspec.open.
  14. Compression can be specified by name in open, or "infer"-ed for any files
  15. ending with the given extensions.
  16. Args:
  17. name: (str) The compression type name. Eg. "gzip".
  18. callback: A callable of form (infile, mode, **kwargs) -> file-like.
  19. Accepts an input file-like object, the target mode and kwargs.
  20. Returns a wrapped file-like object.
  21. extensions: (str, Iterable[str]) A file extension, or list of file
  22. extensions for which to infer this compression scheme. Eg. "gz".
  23. force: (bool) Force re-registration of compression type or extensions.
  24. Raises:
  25. ValueError: If name or extensions already registered, and not force.
  26. """
  27. if isinstance(extensions, str):
  28. extensions = [extensions]
  29. # Validate registration
  30. if name in compr and not force:
  31. raise ValueError(f"Duplicate compression registration: {name}")
  32. for ext in extensions:
  33. if ext in fsspec.utils.compressions and not force:
  34. raise ValueError(f"Duplicate compression file extension: {ext} ({name})")
  35. compr[name] = callback
  36. for ext in extensions:
  37. fsspec.utils.compressions[ext] = name
  38. def unzip(infile, mode="rb", filename=None, **kwargs):
  39. if "r" not in mode:
  40. filename = filename or "file"
  41. z = ZipFile(infile, mode="w", **kwargs)
  42. fo = z.open(filename, mode="w")
  43. fo.close = lambda closer=fo.close: closer() or z.close()
  44. return fo
  45. z = ZipFile(infile)
  46. if filename is None:
  47. filename = z.namelist()[0]
  48. return z.open(filename, mode="r", **kwargs)
  49. register_compression("zip", unzip, "zip")
  50. try:
  51. from bz2 import BZ2File
  52. except ImportError:
  53. pass
  54. else:
  55. register_compression("bz2", BZ2File, "bz2")
  56. try: # pragma: no cover
  57. from isal import igzip
  58. def isal(infile, mode="rb", **kwargs):
  59. return igzip.IGzipFile(fileobj=infile, mode=mode, **kwargs)
  60. register_compression("gzip", isal, "gz")
  61. except ImportError:
  62. from gzip import GzipFile
  63. register_compression(
  64. "gzip", lambda f, **kwargs: GzipFile(fileobj=f, **kwargs), "gz"
  65. )
  66. try:
  67. from lzma import LZMAFile
  68. register_compression("lzma", LZMAFile, "lzma")
  69. register_compression("xz", LZMAFile, "xz")
  70. except ImportError:
  71. pass
  72. try:
  73. import lzmaffi
  74. register_compression("lzma", lzmaffi.LZMAFile, "lzma", force=True)
  75. register_compression("xz", lzmaffi.LZMAFile, "xz", force=True)
  76. except ImportError:
  77. pass
  78. class SnappyFile(AbstractBufferedFile):
  79. def __init__(self, infile, mode, **kwargs):
  80. import snappy
  81. super().__init__(
  82. fs=None, path="snappy", mode=mode.strip("b") + "b", size=999999999, **kwargs
  83. )
  84. self.infile = infile
  85. if "r" in mode:
  86. self.codec = snappy.StreamDecompressor()
  87. else:
  88. self.codec = snappy.StreamCompressor()
  89. def _upload_chunk(self, final=False):
  90. self.buffer.seek(0)
  91. out = self.codec.add_chunk(self.buffer.read())
  92. self.infile.write(out)
  93. return True
  94. def seek(self, loc, whence=0):
  95. raise NotImplementedError("SnappyFile is not seekable")
  96. def seekable(self):
  97. return False
  98. def _fetch_range(self, start, end):
  99. """Get the specified set of bytes from remote"""
  100. data = self.infile.read(end - start)
  101. return self.codec.decompress(data)
  102. try:
  103. import snappy
  104. snappy.compress(b"")
  105. # Snappy may use the .sz file extension, but this is not part of the
  106. # standard implementation.
  107. register_compression("snappy", SnappyFile, [])
  108. except (ImportError, NameError, AttributeError):
  109. pass
  110. try:
  111. import lz4.frame
  112. register_compression("lz4", lz4.frame.open, "lz4")
  113. except ImportError:
  114. pass
  115. try:
  116. if sys.version_info >= (3, 14):
  117. from compression import zstd
  118. else:
  119. from backports import zstd
  120. register_compression("zstd", zstd.ZstdFile, "zst")
  121. except ImportError:
  122. try:
  123. import zstandard as zstd
  124. def zstandard_file(infile, mode="rb"):
  125. if "r" in mode:
  126. cctx = zstd.ZstdDecompressor()
  127. return cctx.stream_reader(infile)
  128. else:
  129. cctx = zstd.ZstdCompressor(level=10)
  130. return cctx.stream_writer(infile)
  131. register_compression("zstd", zstandard_file, "zst")
  132. except ImportError:
  133. pass
  134. pass
  135. def available_compressions():
  136. """Return a list of the implemented compressions."""
  137. return list(compr)