helpers.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. """Reusable functions and classes for different types of integration tests.
  2. For example ``Archive`` can be used to check the contents of distribution built
  3. with setuptools, and ``run`` will always try to be as verbose as possible to
  4. facilitate debugging.
  5. """
  6. from __future__ import annotations
  7. import os
  8. import subprocess
  9. import tarfile
  10. from collections.abc import Iterator
  11. from pathlib import Path
  12. from zipfile import ZipFile, ZipInfo
  13. def run(cmd, env=None):
  14. r = subprocess.run(
  15. cmd,
  16. capture_output=True,
  17. text=True,
  18. encoding="utf-8",
  19. env={**os.environ, **(env or {})},
  20. # ^-- allow overwriting instead of discarding the current env
  21. )
  22. out = r.stdout + "\n" + r.stderr
  23. # pytest omits stdout/err by default, if the test fails they help debugging
  24. print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
  25. print(f"Command: {cmd}\nreturn code: {r.returncode}\n\n{out}")
  26. if r.returncode == 0:
  27. return out
  28. raise subprocess.CalledProcessError(r.returncode, cmd, r.stdout, r.stderr)
  29. class Archive:
  30. """Compatibility layer for ZipFile/Info and TarFile/Info"""
  31. def __init__(self, filename) -> None:
  32. self._filename = filename
  33. if filename.endswith("tar.gz"):
  34. self._obj: tarfile.TarFile | ZipFile = tarfile.open(filename, "r:gz")
  35. elif filename.endswith("zip"):
  36. self._obj = ZipFile(filename)
  37. else:
  38. raise ValueError(f"{filename} doesn't seem to be a zip or tar.gz")
  39. def __iter__(self) -> Iterator[ZipInfo] | Iterator[tarfile.TarInfo]:
  40. if hasattr(self._obj, "infolist"):
  41. return iter(self._obj.infolist())
  42. return iter(self._obj)
  43. def get_name(self, zip_or_tar_info):
  44. if hasattr(zip_or_tar_info, "filename"):
  45. return zip_or_tar_info.filename
  46. return zip_or_tar_info.name
  47. def get_content(self, zip_or_tar_info):
  48. if hasattr(self._obj, "extractfile"):
  49. content = self._obj.extractfile(zip_or_tar_info)
  50. if content is None:
  51. msg = f"Invalid {zip_or_tar_info.name} in {self._filename}"
  52. raise ValueError(msg)
  53. return str(content.read(), "utf-8")
  54. return str(self._obj.read(zip_or_tar_info), "utf-8")
  55. def get_sdist_members(sdist_path):
  56. with tarfile.open(sdist_path, "r:gz") as tar:
  57. files = [Path(f) for f in tar.getnames()]
  58. # remove root folder
  59. relative_files = ("/".join(f.parts[1:]) for f in files)
  60. return {f for f in relative_files if f}
  61. def get_wheel_members(wheel_path):
  62. with ZipFile(wheel_path) as zipfile:
  63. return set(zipfile.namelist())