# Copyright The Lightning AI team. # Licensed under the Apache License, Version 2.0 (the "License"); # http://www.apache.org/licenses/LICENSE-2.0 # import glob import os.path import re import warnings from collections.abc import Sequence from pprint import pprint from typing import Union REQUIREMENT_ROOT = "requirements.txt" REQUIREMENT_FILES_ALL: list = glob.glob(os.path.join("requirements", "*.txt")) REQUIREMENT_FILES_ALL += glob.glob(os.path.join("requirements", "**", "*.txt"), recursive=True) REQUIREMENT_FILES_ALL += glob.glob(os.path.join("**", "pyproject.toml")) if os.path.isfile(REQUIREMENT_ROOT): REQUIREMENT_FILES_ALL += [REQUIREMENT_ROOT] def prune_packages_in_requirements( packages: Union[str, Sequence[str]], req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL ) -> None: """Remove one or more packages from the specified requirement files. Args: packages: A package name or list of package names to remove. req_files: A path or list of paths to requirement files to process. """ if isinstance(packages, str): packages = [packages] if isinstance(req_files, str): req_files = [req_files] for req in req_files: _prune_packages(req, packages) def _prune_packages(req_file: str, packages: Sequence[str]) -> None: """Remove all occurrences of the given packages (by line prefix) from a requirements file. Args: req_file: Path to a requirements file. packages: Package names to remove. Lines starting with any of these names will be dropped. """ with open(req_file) as fp: lines = fp.readlines() if isinstance(packages, str): packages = [packages] for pkg in packages: lines = [ln for ln in lines if not ln.startswith(pkg)] pprint(lines) with open(req_file, "w") as fp: fp.writelines(lines) def _replace_min_req_in_txt(req_file: str) -> None: """Replace all occurrences of '>=' with '==' in a plain text requirements file. Args: req_file: Path to the requirements.txt-like file to update. """ with open(req_file) as fopen: req = fopen.read().replace(">=", "==") with open(req_file, "w") as fw: fw.write(req) def _replace_min_req_in_pyproject_toml(proj_file: str = "pyproject.toml") -> None: """Replace all '>=' with '==' in the [project.dependencies] section of a standard pyproject.toml. Preserves formatting and comments using tomlkit. Args: proj_file: Path to the pyproject.toml file. """ import tomlkit # Load and parse the existing pyproject.toml with open(proj_file, encoding="utf-8") as f: content = f.read() doc = tomlkit.parse(content) # todo: consider also replace extras in [dependency-groups] -> extras = [...] deps = doc.get("project", {}).get("dependencies") if not deps: return # Replace '>=version' with '==version' in each dependency for i, req in enumerate(deps): # Simple string value deps[i] = req.replace(">=", "==") # Dump back out, preserving layout with open(proj_file, "w", encoding="utf-8") as f: f.write(tomlkit.dumps(doc)) def replace_oldest_version(req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL) -> None: """Convert minimal version specifiers (>=) to pinned ones (==) in the given requirement files. Supports plain *.txt requirements and pyproject.toml files. Unsupported file types trigger a warning. Args: req_files: A path or list of paths to requirement files to process. """ if isinstance(req_files, str): req_files = [req_files] for fname in req_files: if fname.endswith(".txt"): _replace_min_req_in_txt(fname) elif os.path.basename(fname) == "pyproject.toml": _replace_min_req_in_pyproject_toml(fname) else: warnings.warn( "Only *.txt with plain list of requirements or standard pyproject.toml are supported." f"Provided '{fname}' is not supported.", UserWarning, stacklevel=2, ) def _replace_package_name_in_txt(req_file: str, old_package: str, new_package: str) -> None: """Rename a package in a plain text requirements file, preserving version specifiers and markers. Args: req_file: Path to the requirements.txt-like file to update. old_package: The original package name to replace. new_package: The new package name to use. """ # load file with open(req_file) as fopen: requirements = fopen.readlines() # replace all occurrences for i, req in enumerate(requirements): requirements[i] = re.sub(r"^" + re.escape(old_package) + r"(?=[ <=>#]|$)", new_package, req) # save file with open(req_file, "w") as fw: fw.writelines(requirements) def _replace_package_name_in_pyproject_toml(proj_file: str, old_package: str, new_package: str) -> None: """Rename a package in the [project.dependencies] section of a standard pyproject.toml, preserving constraints. Args: proj_file: Path to the pyproject.toml file. old_package: The original package name to replace. new_package: The new package name to use. """ import tomlkit # Load and parse the existing pyproject.toml with open(proj_file, encoding="utf-8") as f: content = f.read() doc = tomlkit.parse(content) # todo: consider also replace extras in [dependency-groups] -> extras = [...] deps = doc.get("project", {}).get("dependencies") if not deps: return # Replace '>=version' with '==version' in each dependency for i, req in enumerate(deps): # Simple string value deps[i] = re.sub(r"^" + re.escape(old_package) + r"(?=[ <=>]|$)", new_package, req) # Dump back out, preserving layout with open(proj_file, "w", encoding="utf-8") as f: f.write(tomlkit.dumps(doc)) def replace_package_in_requirements( old_package: str, new_package: str, req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL ) -> None: """Rename a package across multiple requirement files while keeping version constraints intact. Supports plain *.txt requirements and pyproject.toml files. Unsupported file types trigger a warning. Args: old_package: The original package name to replace. new_package: The new package name to use. req_files: A path or list of paths to requirement files to process. """ if isinstance(req_files, str): req_files = [req_files] for fname in req_files: if fname.endswith(".txt"): _replace_package_name_in_txt(fname, old_package, new_package) elif os.path.basename(fname) == "pyproject.toml": _replace_package_name_in_pyproject_toml(fname, old_package, new_package) else: warnings.warn( "Only *.txt with plain list of requirements or standard pyproject.toml are supported." f"Provided '{fname}' is not supported.", UserWarning, stacklevel=2, )