retriever.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Licensed under the Apache License, Version 2.0 (the "License");
  2. # http://www.apache.org/licenses/LICENSE-2.0
  3. #
  4. import glob
  5. import logging
  6. import os
  7. import re
  8. import requests
  9. def _download_file(file_url: str, folder: str) -> str:
  10. """Download a file from a URL into the given folder.
  11. If a file with the same name already exists, it will be overwritten.
  12. Returns the basename of the downloaded file. Network-related exceptions from
  13. ``requests.get`` (e.g., timeouts or connection errors) may propagate to the caller.
  14. """
  15. fname = os.path.basename(file_url)
  16. file_path = os.path.join(folder, fname)
  17. if os.path.isfile(file_path):
  18. logging.warning(f'given file "{file_path}" already exists and will be overwritten with {file_url}')
  19. # see: https://stackoverflow.com/a/34957875
  20. rq = requests.get(file_url, timeout=10)
  21. with open(file_path, "wb") as outfile:
  22. outfile.write(rq.content)
  23. return fname
  24. def _search_all_occurrences(list_files: list[str], pattern: str) -> list[str]:
  25. """Search for all occurrences of a regular-expression pattern across files.
  26. Args:
  27. list_files: The list of file paths to scan.
  28. pattern: A regular-expression pattern to search for in each file.
  29. Returns:
  30. A list with all matches found across the provided files (order preserved per file).
  31. """
  32. collected = []
  33. for file_path in list_files:
  34. with open(file_path, encoding="UTF-8") as fopem:
  35. body = fopem.read()
  36. found = re.findall(pattern, body)
  37. collected += found
  38. return collected
  39. def _replace_remote_with_local(file_path: str, docs_folder: str, pairs_url_path: list[tuple[str, str]]) -> None:
  40. """Replace all matching remote URLs with local file paths in a given file.
  41. Args:
  42. file_path: The file in which replacements should be performed.
  43. docs_folder: The documentation root folder (used to compute relative paths).
  44. pairs_url_path: Pairs of (remote_url, local_relative_path) to replace.
  45. """
  46. # drop the default/global path to the docs
  47. relt_path = os.path.dirname(file_path).replace(docs_folder, "")
  48. # filter the path starting with / as not empty folder names
  49. depth = len([p for p in relt_path.split(os.path.sep) if p])
  50. with open(file_path, encoding="UTF-8") as fopen:
  51. body = fopen.read()
  52. for url, fpath in pairs_url_path:
  53. if depth:
  54. path_up = [".."] * depth
  55. fpath = os.path.join(*path_up, fpath)
  56. body = body.replace(url, fpath)
  57. with open(file_path, "w", encoding="UTF-8") as fw:
  58. fw.write(body)
  59. def fetch_external_assets(
  60. docs_folder: str = "docs/source",
  61. assets_folder: str = "fetched-s3-assets",
  62. file_pattern: str = "*.rst",
  63. retrieve_pattern: str = r"https?://[-a-zA-Z0-9_]+\.s3\.[-a-zA-Z0-9()_\\+.\\/=]+",
  64. ) -> None:
  65. """Find S3 (or HTTP) asset URLs in docs, download them locally, and rewrite references to local paths.
  66. Args:
  67. docs_folder: The documentation root relative to the project.
  68. assets_folder: Subfolder inside ``docs_folder`` used to store downloaded assets (created if missing).
  69. file_pattern: Glob pattern of files to scan.
  70. retrieve_pattern: Regular-expression pattern used to find remote asset URLs.
  71. """
  72. list_files = glob.glob(os.path.join(docs_folder, "**", file_pattern), recursive=True)
  73. if not list_files:
  74. logging.warning(f'no files were listed in folder "{docs_folder}" and pattern "{file_pattern}"')
  75. return
  76. urls = _search_all_occurrences(list_files, pattern=retrieve_pattern)
  77. if not urls:
  78. logging.info(f"no resources/assets were match in {docs_folder} for {retrieve_pattern}")
  79. return
  80. target_folder = os.path.join(docs_folder, assets_folder)
  81. os.makedirs(target_folder, exist_ok=True)
  82. pairs_url_file = []
  83. for i, url in enumerate(set(urls)):
  84. logging.info(f" >> downloading ({i}/{len(urls)}): {url}")
  85. fname = _download_file(url, target_folder)
  86. pairs_url_file.append((url, os.path.join(assets_folder, fname)))
  87. for fpath in list_files:
  88. _replace_remote_with_local(fpath, docs_folder, pairs_url_file)