formatting.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Licensed under the Apache License, Version 2.0 (the "License");
  2. # http://www.apache.org/licenses/LICENSE-2.0
  3. #
  4. import glob
  5. import importlib
  6. import inspect
  7. import logging
  8. import os
  9. import re
  10. import sys
  11. from collections.abc import Iterable
  12. from typing import Optional, Union
  13. def _transform_changelog(path_in: str, path_out: str) -> None:
  14. """Adjust changelog headers to avoid duplication of short subtitles.
  15. Args:
  16. path_in: Input Markdown file path.
  17. path_out: Output Markdown file path.
  18. """
  19. with open(path_in) as fp:
  20. chlog_lines = fp.readlines()
  21. # enrich short subsub-titles to be unique
  22. chlog_ver = ""
  23. for i, ln in enumerate(chlog_lines):
  24. if ln.startswith("## "):
  25. chlog_ver = ln[2:].split("-")[0].strip()
  26. elif ln.startswith("### "):
  27. ln = ln.replace("###", f"### {chlog_ver} -")
  28. chlog_lines[i] = ln
  29. with open(path_out, "w") as fp:
  30. fp.writelines(chlog_lines)
  31. def _linkcode_resolve(
  32. domain: str,
  33. info: dict,
  34. github_user: str,
  35. github_repo: str,
  36. main_branch: str = "master",
  37. stable_branch: str = "release/stable",
  38. ) -> str:
  39. def find_source() -> tuple[str, int, int]:
  40. # try to find the file and line number, based on code from numpy:
  41. # https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286
  42. obj = sys.modules[info["module"]]
  43. for part in info["fullname"].split("."):
  44. obj = getattr(obj, part)
  45. fname = str(inspect.getsourcefile(obj))
  46. # https://github.com/rtfd/readthedocs.org/issues/5735
  47. if any(s in fname for s in ("readthedocs", "rtfd", "checkouts")):
  48. # /home/docs/checkouts/readthedocs.org/user_builds/pytorch_lightning/checkouts/
  49. # devel/pytorch_lightning/utilities/cls_experiment.py#L26-L176
  50. path_top = os.path.abspath(os.path.join("..", "..", ".."))
  51. fname = str(os.path.relpath(fname, start=path_top))
  52. else:
  53. # Local build, imitate master
  54. fname = f"{main_branch}/{os.path.relpath(fname, start=os.path.abspath('..'))}"
  55. source, line_start = inspect.getsourcelines(obj)
  56. return fname, line_start, line_start + len(source) - 1
  57. if domain != "py" or not info["module"]:
  58. return ""
  59. try:
  60. filename = "%s#L%d-L%d" % find_source() # noqa: UP031
  61. except Exception:
  62. filename = info["module"].replace(".", "/") + ".py"
  63. # import subprocess
  64. # tag = subprocess.Popen(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE,
  65. # universal_newlines=True).communicate()[0][:-1]
  66. branch = filename.split("/")[0]
  67. # do mapping from latest tags to master
  68. branch = {"latest": main_branch, "stable": stable_branch}.get(branch, branch)
  69. filename = "/".join([branch, *filename.split("/")[1:]])
  70. return f"https://github.com/{github_user}/{github_repo}/blob/{filename}"
  71. def _load_pypi_versions(package_name: str) -> list[str]:
  72. """Load the versions of the package from PyPI.
  73. >>> _load_pypi_versions("numpy") # doctest: +ELLIPSIS
  74. ['0.9.6', '0.9.8', '1.0', ...]
  75. >>> _load_pypi_versions("scikit-learn") # doctest: +ELLIPSIS
  76. ['0.9', '0.10', '0.11', '0.12', ...]
  77. """
  78. import requests
  79. from packaging.version import Version
  80. url = f"https://pypi.org/pypi/{package_name}/json"
  81. data = requests.get(url, timeout=10).json()
  82. versions = data["releases"].keys()
  83. # filter all version which include only numbers and dots
  84. versions = {k for k in versions if re.match(r"^\d+(\.\d+)*$", k)}
  85. return sorted(versions, key=Version)
  86. def _update_link_based_imported_package(link: str, pkg_ver: str, version_digits: Optional[int]) -> str:
  87. """Resolve a ``{package.version}`` placeholder in a link using the latest available version.
  88. Args:
  89. link: The link template containing a ``{...}`` placeholder to replace.
  90. pkg_ver: A dotted path to resolve the version (e.g., ``"numpy.__version__"``).
  91. version_digits: Number of version components to keep (e.g., ``2`` -> ``"1.26"``). If ``None``, keep all.
  92. Returns:
  93. The link with the ``{...}`` placeholder replaced by a version string.
  94. """
  95. pkg_att = pkg_ver.split(".")
  96. try:
  97. ver = _load_pypi_versions(pkg_att[0])[-1]
  98. except Exception:
  99. # load the package with all additional sub-modules
  100. module = importlib.import_module(".".join(pkg_att[:-1]))
  101. # load the attribute
  102. ver = getattr(module, pkg_att[0])
  103. # drop any additional context after `+`
  104. ver = ver.split("+")[0]
  105. # crop the version to the number of digits
  106. ver = ".".join(ver.split(".")[:version_digits])
  107. # replace the version
  108. return link.replace(f"{{{pkg_ver}}}", ver)
  109. def adjust_linked_external_docs(
  110. source_link: str,
  111. target_link: str,
  112. browse_folder: Union[str, Iterable[str]],
  113. file_extensions: Iterable[str] = (".rst", ".py"),
  114. version_digits: int = 2,
  115. ) -> None:
  116. r"""Adjust the linked external docs to be local.
  117. Args:
  118. source_link: the link to be replaced
  119. target_link: the link to be replaced, if ``{package.version}`` is included it will be replaced accordingly
  120. browse_folder: the location of the browsable folder
  121. file_extensions: what kind of files shall be scanned
  122. version_digits: for semantic versioning, how many digits to be considered
  123. Examples:
  124. >>> adjust_linked_external_docs(
  125. ... "https://numpy.org/doc/stable/",
  126. ... "https://numpy.org/doc/{numpy.__version__}/",
  127. ... "docs/source",
  128. ... )
  129. """
  130. list_files = []
  131. if isinstance(browse_folder, str):
  132. browse_folder = [browse_folder]
  133. for folder in browse_folder:
  134. for ext in file_extensions:
  135. list_files += glob.glob(os.path.join(folder, "**", f"*{ext}"), recursive=True)
  136. if not list_files:
  137. logging.warning(f'No files were listed in folder "{browse_folder}" and pattern "{file_extensions}"')
  138. return
  139. # find the expression for package version in {} brackets if any, use re to find it
  140. pkg_ver_all = re.findall(r"{(.+)}", target_link)
  141. for pkg_ver in pkg_ver_all:
  142. target_link = _update_link_based_imported_package(target_link, pkg_ver, version_digits)
  143. # replace the source link with target link
  144. for fpath in set(list_files):
  145. with open(fpath, encoding="UTF-8") as fopen:
  146. lines = fopen.readlines()
  147. found, skip = False, False
  148. for i, ln in enumerate(lines):
  149. # prevent the replacement its own function calls
  150. if f"{adjust_linked_external_docs.__name__}(" in ln:
  151. skip = True
  152. if not skip and source_link in ln:
  153. # replace the link if any found
  154. lines[i] = ln.replace(source_link, target_link)
  155. # record the found link for later write file
  156. found = True
  157. if skip and ")" in ln:
  158. skip = False
  159. if not found:
  160. continue
  161. logging.debug(f'links adjusting in {fpath}: "{source_link}" -> "{target_link}"')
  162. with open(fpath, "w", encoding="UTF-8") as fw:
  163. fw.writelines(lines)