dependencies.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # Copyright The Lightning AI team.
  2. # Licensed under the Apache License, Version 2.0 (the "License");
  3. # http://www.apache.org/licenses/LICENSE-2.0
  4. #
  5. import glob
  6. import os.path
  7. import re
  8. import warnings
  9. from collections.abc import Sequence
  10. from pprint import pprint
  11. from typing import Union
  12. REQUIREMENT_ROOT = "requirements.txt"
  13. REQUIREMENT_FILES_ALL: list = glob.glob(os.path.join("requirements", "*.txt"))
  14. REQUIREMENT_FILES_ALL += glob.glob(os.path.join("requirements", "**", "*.txt"), recursive=True)
  15. REQUIREMENT_FILES_ALL += glob.glob(os.path.join("**", "pyproject.toml"))
  16. if os.path.isfile(REQUIREMENT_ROOT):
  17. REQUIREMENT_FILES_ALL += [REQUIREMENT_ROOT]
  18. def prune_packages_in_requirements(
  19. packages: Union[str, Sequence[str]], req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL
  20. ) -> None:
  21. """Remove one or more packages from the specified requirement files.
  22. Args:
  23. packages: A package name or list of package names to remove.
  24. req_files: A path or list of paths to requirement files to process.
  25. """
  26. if isinstance(packages, str):
  27. packages = [packages]
  28. if isinstance(req_files, str):
  29. req_files = [req_files]
  30. for req in req_files:
  31. _prune_packages(req, packages)
  32. def _prune_packages(req_file: str, packages: Sequence[str]) -> None:
  33. """Remove all occurrences of the given packages (by line prefix) from a requirements file.
  34. Args:
  35. req_file: Path to a requirements file.
  36. packages: Package names to remove. Lines starting with any of these names will be dropped.
  37. """
  38. with open(req_file) as fp:
  39. lines = fp.readlines()
  40. if isinstance(packages, str):
  41. packages = [packages]
  42. for pkg in packages:
  43. lines = [ln for ln in lines if not ln.startswith(pkg)]
  44. pprint(lines)
  45. with open(req_file, "w") as fp:
  46. fp.writelines(lines)
  47. def _replace_min_req_in_txt(req_file: str) -> None:
  48. """Replace all occurrences of '>=' with '==' in a plain text requirements file.
  49. Args:
  50. req_file: Path to the requirements.txt-like file to update.
  51. """
  52. with open(req_file) as fopen:
  53. req = fopen.read().replace(">=", "==")
  54. with open(req_file, "w") as fw:
  55. fw.write(req)
  56. def _replace_min_req_in_pyproject_toml(proj_file: str = "pyproject.toml") -> None:
  57. """Replace all '>=' with '==' in the [project.dependencies] section of a standard pyproject.toml.
  58. Preserves formatting and comments using tomlkit.
  59. Args:
  60. proj_file: Path to the pyproject.toml file.
  61. """
  62. import tomlkit
  63. # Load and parse the existing pyproject.toml
  64. with open(proj_file, encoding="utf-8") as f:
  65. content = f.read()
  66. doc = tomlkit.parse(content)
  67. # todo: consider also replace extras in [dependency-groups] -> extras = [...]
  68. deps = doc.get("project", {}).get("dependencies")
  69. if not deps:
  70. return
  71. # Replace '>=version' with '==version' in each dependency
  72. for i, req in enumerate(deps):
  73. # Simple string value
  74. deps[i] = req.replace(">=", "==")
  75. # Dump back out, preserving layout
  76. with open(proj_file, "w", encoding="utf-8") as f:
  77. f.write(tomlkit.dumps(doc))
  78. def replace_oldest_version(req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL) -> None:
  79. """Convert minimal version specifiers (>=) to pinned ones (==) in the given requirement files.
  80. Supports plain *.txt requirements and pyproject.toml files. Unsupported file types trigger a warning.
  81. Args:
  82. req_files: A path or list of paths to requirement files to process.
  83. """
  84. if isinstance(req_files, str):
  85. req_files = [req_files]
  86. for fname in req_files:
  87. if fname.endswith(".txt"):
  88. _replace_min_req_in_txt(fname)
  89. elif os.path.basename(fname) == "pyproject.toml":
  90. _replace_min_req_in_pyproject_toml(fname)
  91. else:
  92. warnings.warn(
  93. "Only *.txt with plain list of requirements or standard pyproject.toml are supported."
  94. f"Provided '{fname}' is not supported.",
  95. UserWarning,
  96. stacklevel=2,
  97. )
  98. def _replace_package_name_in_txt(req_file: str, old_package: str, new_package: str) -> None:
  99. """Rename a package in a plain text requirements file, preserving version specifiers and markers.
  100. Args:
  101. req_file: Path to the requirements.txt-like file to update.
  102. old_package: The original package name to replace.
  103. new_package: The new package name to use.
  104. """
  105. # load file
  106. with open(req_file) as fopen:
  107. requirements = fopen.readlines()
  108. # replace all occurrences
  109. for i, req in enumerate(requirements):
  110. requirements[i] = re.sub(r"^" + re.escape(old_package) + r"(?=[ <=>#]|$)", new_package, req)
  111. # save file
  112. with open(req_file, "w") as fw:
  113. fw.writelines(requirements)
  114. def _replace_package_name_in_pyproject_toml(proj_file: str, old_package: str, new_package: str) -> None:
  115. """Rename a package in the [project.dependencies] section of a standard pyproject.toml, preserving constraints.
  116. Args:
  117. proj_file: Path to the pyproject.toml file.
  118. old_package: The original package name to replace.
  119. new_package: The new package name to use.
  120. """
  121. import tomlkit
  122. # Load and parse the existing pyproject.toml
  123. with open(proj_file, encoding="utf-8") as f:
  124. content = f.read()
  125. doc = tomlkit.parse(content)
  126. # todo: consider also replace extras in [dependency-groups] -> extras = [...]
  127. deps = doc.get("project", {}).get("dependencies")
  128. if not deps:
  129. return
  130. # Replace '>=version' with '==version' in each dependency
  131. for i, req in enumerate(deps):
  132. # Simple string value
  133. deps[i] = re.sub(r"^" + re.escape(old_package) + r"(?=[ <=>]|$)", new_package, req)
  134. # Dump back out, preserving layout
  135. with open(proj_file, "w", encoding="utf-8") as f:
  136. f.write(tomlkit.dumps(doc))
  137. def replace_package_in_requirements(
  138. old_package: str, new_package: str, req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL
  139. ) -> None:
  140. """Rename a package across multiple requirement files while keeping version constraints intact.
  141. Supports plain *.txt requirements and pyproject.toml files. Unsupported file types trigger a warning.
  142. Args:
  143. old_package: The original package name to replace.
  144. new_package: The new package name to use.
  145. req_files: A path or list of paths to requirement files to process.
  146. """
  147. if isinstance(req_files, str):
  148. req_files = [req_files]
  149. for fname in req_files:
  150. if fname.endswith(".txt"):
  151. _replace_package_name_in_txt(fname, old_package, new_package)
  152. elif os.path.basename(fname) == "pyproject.toml":
  153. _replace_package_name_in_pyproject_toml(fname, old_package, new_package)
  154. else:
  155. warnings.warn(
  156. "Only *.txt with plain list of requirements or standard pyproject.toml are supported."
  157. f"Provided '{fname}' is not supported.",
  158. UserWarning,
  159. stacklevel=2,
  160. )