| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- # Copyright The Lightning AI team.
- # Licensed under the Apache License, Version 2.0 (the "License");
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- import glob
- import os.path
- import re
- import warnings
- from collections.abc import Sequence
- from pprint import pprint
- from typing import Union
- REQUIREMENT_ROOT = "requirements.txt"
- REQUIREMENT_FILES_ALL: list = glob.glob(os.path.join("requirements", "*.txt"))
- REQUIREMENT_FILES_ALL += glob.glob(os.path.join("requirements", "**", "*.txt"), recursive=True)
- REQUIREMENT_FILES_ALL += glob.glob(os.path.join("**", "pyproject.toml"))
- if os.path.isfile(REQUIREMENT_ROOT):
- REQUIREMENT_FILES_ALL += [REQUIREMENT_ROOT]
- def prune_packages_in_requirements(
- packages: Union[str, Sequence[str]], req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL
- ) -> None:
- """Remove one or more packages from the specified requirement files.
- Args:
- packages: A package name or list of package names to remove.
- req_files: A path or list of paths to requirement files to process.
- """
- if isinstance(packages, str):
- packages = [packages]
- if isinstance(req_files, str):
- req_files = [req_files]
- for req in req_files:
- _prune_packages(req, packages)
- def _prune_packages(req_file: str, packages: Sequence[str]) -> None:
- """Remove all occurrences of the given packages (by line prefix) from a requirements file.
- Args:
- req_file: Path to a requirements file.
- packages: Package names to remove. Lines starting with any of these names will be dropped.
- """
- with open(req_file) as fp:
- lines = fp.readlines()
- if isinstance(packages, str):
- packages = [packages]
- for pkg in packages:
- lines = [ln for ln in lines if not ln.startswith(pkg)]
- pprint(lines)
- with open(req_file, "w") as fp:
- fp.writelines(lines)
- def _replace_min_req_in_txt(req_file: str) -> None:
- """Replace all occurrences of '>=' with '==' in a plain text requirements file.
- Args:
- req_file: Path to the requirements.txt-like file to update.
- """
- with open(req_file) as fopen:
- req = fopen.read().replace(">=", "==")
- with open(req_file, "w") as fw:
- fw.write(req)
- def _replace_min_req_in_pyproject_toml(proj_file: str = "pyproject.toml") -> None:
- """Replace all '>=' with '==' in the [project.dependencies] section of a standard pyproject.toml.
- Preserves formatting and comments using tomlkit.
- Args:
- proj_file: Path to the pyproject.toml file.
- """
- import tomlkit
- # Load and parse the existing pyproject.toml
- with open(proj_file, encoding="utf-8") as f:
- content = f.read()
- doc = tomlkit.parse(content)
- # todo: consider also replace extras in [dependency-groups] -> extras = [...]
- deps = doc.get("project", {}).get("dependencies")
- if not deps:
- return
- # Replace '>=version' with '==version' in each dependency
- for i, req in enumerate(deps):
- # Simple string value
- deps[i] = req.replace(">=", "==")
- # Dump back out, preserving layout
- with open(proj_file, "w", encoding="utf-8") as f:
- f.write(tomlkit.dumps(doc))
- def replace_oldest_version(req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL) -> None:
- """Convert minimal version specifiers (>=) to pinned ones (==) in the given requirement files.
- Supports plain *.txt requirements and pyproject.toml files. Unsupported file types trigger a warning.
- Args:
- req_files: A path or list of paths to requirement files to process.
- """
- if isinstance(req_files, str):
- req_files = [req_files]
- for fname in req_files:
- if fname.endswith(".txt"):
- _replace_min_req_in_txt(fname)
- elif os.path.basename(fname) == "pyproject.toml":
- _replace_min_req_in_pyproject_toml(fname)
- else:
- warnings.warn(
- "Only *.txt with plain list of requirements or standard pyproject.toml are supported."
- f"Provided '{fname}' is not supported.",
- UserWarning,
- stacklevel=2,
- )
- def _replace_package_name_in_txt(req_file: str, old_package: str, new_package: str) -> None:
- """Rename a package in a plain text requirements file, preserving version specifiers and markers.
- Args:
- req_file: Path to the requirements.txt-like file to update.
- old_package: The original package name to replace.
- new_package: The new package name to use.
- """
- # load file
- with open(req_file) as fopen:
- requirements = fopen.readlines()
- # replace all occurrences
- for i, req in enumerate(requirements):
- requirements[i] = re.sub(r"^" + re.escape(old_package) + r"(?=[ <=>#]|$)", new_package, req)
- # save file
- with open(req_file, "w") as fw:
- fw.writelines(requirements)
- def _replace_package_name_in_pyproject_toml(proj_file: str, old_package: str, new_package: str) -> None:
- """Rename a package in the [project.dependencies] section of a standard pyproject.toml, preserving constraints.
- Args:
- proj_file: Path to the pyproject.toml file.
- old_package: The original package name to replace.
- new_package: The new package name to use.
- """
- import tomlkit
- # Load and parse the existing pyproject.toml
- with open(proj_file, encoding="utf-8") as f:
- content = f.read()
- doc = tomlkit.parse(content)
- # todo: consider also replace extras in [dependency-groups] -> extras = [...]
- deps = doc.get("project", {}).get("dependencies")
- if not deps:
- return
- # Replace '>=version' with '==version' in each dependency
- for i, req in enumerate(deps):
- # Simple string value
- deps[i] = re.sub(r"^" + re.escape(old_package) + r"(?=[ <=>]|$)", new_package, req)
- # Dump back out, preserving layout
- with open(proj_file, "w", encoding="utf-8") as f:
- f.write(tomlkit.dumps(doc))
- def replace_package_in_requirements(
- old_package: str, new_package: str, req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL
- ) -> None:
- """Rename a package across multiple requirement files while keeping version constraints intact.
- Supports plain *.txt requirements and pyproject.toml files. Unsupported file types trigger a warning.
- Args:
- old_package: The original package name to replace.
- new_package: The new package name to use.
- req_files: A path or list of paths to requirement files to process.
- """
- if isinstance(req_files, str):
- req_files = [req_files]
- for fname in req_files:
- if fname.endswith(".txt"):
- _replace_package_name_in_txt(fname, old_package, new_package)
- elif os.path.basename(fname) == "pyproject.toml":
- _replace_package_name_in_pyproject_toml(fname, old_package, new_package)
- else:
- warnings.warn(
- "Only *.txt with plain list of requirements or standard pyproject.toml are supported."
- f"Provided '{fname}' is not supported.",
- UserWarning,
- stacklevel=2,
- )
|