dynamic_module_utils.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810
  1. # Copyright 2021 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Utilities to dynamically load objects from the Hub."""
  15. import ast
  16. import filecmp
  17. import hashlib
  18. import importlib
  19. import importlib.metadata
  20. import importlib.util
  21. import keyword
  22. import os
  23. import re
  24. import shutil
  25. import signal
  26. import sys
  27. import threading
  28. from pathlib import Path
  29. from types import ModuleType
  30. from typing import Any
  31. from huggingface_hub import is_offline_mode, try_to_load_from_cache
  32. from packaging import version
  33. from .utils import (
  34. HF_MODULES_CACHE,
  35. TRANSFORMERS_DYNAMIC_MODULE_NAME,
  36. cached_file,
  37. extract_commit_hash,
  38. logging,
  39. )
  40. from .utils.import_utils import VersionComparison, split_package_version
  41. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  42. def _sanitize_module_name(name: str) -> str:
  43. r"""
  44. Tries to sanitize a module name so that it can be used as a Python module.
  45. The following transformations are applied:
  46. 1. Replace `.` in module names with `_dot_`.
  47. 2. Replace `-` in module names with `_hyphen_`.
  48. 3. If the module name starts with a digit, prepend it with `_`.
  49. 4. Warn if the sanitized name is a Python reserved keyword or not a valid identifier.
  50. If the input name is already a valid identifier, it is returned unchanged.
  51. """
  52. # We not replacing `\W` characters with `_` to avoid collisions. Because `_` is a very common
  53. # separator used in module names, replacing `\W` with `_` would create too many collisions.
  54. # Once a module is imported, it is cached in `sys.modules` and the second import would return
  55. # the first module, which might not be the expected behavior if name collisions happen.
  56. new_name = name.replace(".", "_dot_").replace("-", "_hyphen_")
  57. if new_name and new_name[0].isdigit():
  58. new_name = f"_{new_name}"
  59. if keyword.iskeyword(new_name):
  60. logger.warning(
  61. f"The module name {new_name} (originally {name}) is a reserved keyword in Python. "
  62. "Please rename the original module to avoid import issues."
  63. )
  64. elif not new_name.isidentifier():
  65. logger.warning(
  66. f"The module name {new_name} (originally {name}) is not a valid Python identifier. "
  67. "Please rename the original module to avoid import issues."
  68. )
  69. return new_name
  70. _HF_REMOTE_CODE_LOCK = threading.Lock()
  71. def init_hf_modules():
  72. """
  73. Creates the cache directory for modules with an init, and adds it to the Python path.
  74. """
  75. # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
  76. if HF_MODULES_CACHE in sys.path:
  77. return
  78. sys.path.append(HF_MODULES_CACHE)
  79. os.makedirs(HF_MODULES_CACHE, exist_ok=True)
  80. init_path = Path(HF_MODULES_CACHE) / "__init__.py"
  81. if not init_path.exists():
  82. init_path.touch()
  83. importlib.invalidate_caches()
  84. def create_dynamic_module(name: str | os.PathLike) -> None:
  85. """
  86. Creates a dynamic module in the cache directory for modules.
  87. Args:
  88. name (`str` or `os.PathLike`):
  89. The name of the dynamic module to create.
  90. """
  91. init_hf_modules()
  92. dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
  93. # If the parent module does not exist yet, recursively create it.
  94. if not dynamic_module_path.parent.exists():
  95. create_dynamic_module(dynamic_module_path.parent)
  96. os.makedirs(dynamic_module_path, exist_ok=True)
  97. init_path = dynamic_module_path / "__init__.py"
  98. if not init_path.exists():
  99. init_path.touch()
  100. # It is extremely important to invalidate the cache when we change stuff in those modules, or users end up
  101. # with errors about module that do not exist. Same for all other `invalidate_caches` in this file.
  102. importlib.invalidate_caches()
  103. def get_relative_imports(module_file: str | os.PathLike) -> list[str]:
  104. """
  105. Get the list of modules that are relatively imported in a module file.
  106. Args:
  107. module_file (`str` or `os.PathLike`): The module file to inspect.
  108. Returns:
  109. `list[str]`: The list of relative imports in the module.
  110. """
  111. with open(module_file, encoding="utf-8") as f:
  112. content = f.read()
  113. # Imports of the form `import .xxx`
  114. relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
  115. # Imports of the form `from .xxx import yyy`
  116. relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
  117. # Unique-ify
  118. return list(set(relative_imports))
  119. def get_relative_import_files(module_file: str | os.PathLike) -> list[str]:
  120. """
  121. Get the list of all files that are needed for a given module. Note that this function recurses through the relative
  122. imports (if a imports b and b imports c, it will return module files for b and c).
  123. Args:
  124. module_file (`str` or `os.PathLike`): The module file to inspect.
  125. Returns:
  126. `list[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
  127. of module files a given module needs.
  128. """
  129. no_change = False
  130. files_to_check = [module_file]
  131. all_relative_imports = []
  132. # Let's recurse through all relative imports
  133. while not no_change:
  134. new_imports = []
  135. for f in files_to_check:
  136. new_imports.extend(get_relative_imports(f))
  137. module_path = Path(module_file).parent
  138. new_import_files = [f"{str(module_path / m)}.py" for m in new_imports]
  139. files_to_check = [f for f in new_import_files if f not in all_relative_imports]
  140. no_change = len(files_to_check) == 0
  141. all_relative_imports.extend(files_to_check)
  142. return all_relative_imports
  143. def get_imports(filename: str | os.PathLike) -> list[str]:
  144. """
  145. Extracts all the libraries (not relative imports this time) that are imported in a file.
  146. Args:
  147. filename (`str` or `os.PathLike`): The module file to inspect.
  148. Returns:
  149. `list[str]`: The list of all packages required to use the input module.
  150. """
  151. with open(filename, encoding="utf-8") as f:
  152. content = f.read()
  153. imported_modules = set()
  154. import transformers.utils
  155. def recursive_look_for_imports(node):
  156. if isinstance(node, ast.Try):
  157. return # Don't recurse into Try blocks and ignore imports in them
  158. elif isinstance(node, ast.If):
  159. test = node.test
  160. for condition_node in ast.walk(test):
  161. if isinstance(condition_node, ast.Call):
  162. check_function = getattr(condition_node.func, "id", "")
  163. if (
  164. check_function.endswith("available")
  165. and check_function.startswith("is_flash_attn")
  166. or hasattr(transformers.utils.import_utils, check_function)
  167. ):
  168. # Don't recurse into "if flash_attn_available()" or any "if library_available" blocks
  169. # that appears in `transformers.utils.import_utils` and ignore imports in them
  170. return
  171. elif isinstance(node, ast.Import):
  172. # Handle 'import x' statements
  173. for alias in node.names:
  174. top_module = alias.name.split(".")[0]
  175. if top_module:
  176. imported_modules.add(top_module)
  177. elif isinstance(node, ast.ImportFrom):
  178. # Handle 'from x import y' statements, ignoring relative imports
  179. if node.level == 0 and node.module:
  180. top_module = node.module.split(".")[0]
  181. if top_module:
  182. imported_modules.add(top_module)
  183. # Recursively visit all children
  184. for child in ast.iter_child_nodes(node):
  185. recursive_look_for_imports(child)
  186. tree = ast.parse(content)
  187. recursive_look_for_imports(tree)
  188. return sorted(imported_modules)
  189. def check_imports(filename: str | os.PathLike) -> list[str]:
  190. """
  191. Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
  192. library is missing.
  193. Args:
  194. filename (`str` or `os.PathLike`): The module file to check.
  195. Returns:
  196. `list[str]`: The list of relative imports in the file.
  197. """
  198. imports = get_imports(filename)
  199. missing_packages = []
  200. for imp in imports:
  201. try:
  202. importlib.import_module(imp)
  203. except ImportError as exception:
  204. logger.warning(f"Encountered exception while importing {imp}: {exception}")
  205. # Some packages can fail with an ImportError because of a dependency issue.
  206. # This check avoids hiding such errors.
  207. # See https://github.com/huggingface/transformers/issues/33604
  208. if "No module named" in str(exception):
  209. missing_packages.append(imp)
  210. else:
  211. raise
  212. if len(missing_packages) > 0:
  213. raise ImportError(
  214. "This modeling file requires the following packages that were not found in your environment: "
  215. f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
  216. )
  217. return get_relative_imports(filename)
  218. def get_class_in_module(
  219. class_name: str,
  220. module_path: str | os.PathLike,
  221. *,
  222. force_reload: bool = False,
  223. ) -> type:
  224. """
  225. Import a module on the cache directory for modules and extract a class from it.
  226. Args:
  227. class_name (`str`): The name of the class to import.
  228. module_path (`str` or `os.PathLike`): The path to the module to import.
  229. force_reload (`bool`, *optional*, defaults to `False`):
  230. Whether to reload the dynamic module from file if it already exists in `sys.modules`.
  231. Otherwise, the module is only reloaded if the file has changed.
  232. Returns:
  233. `typing.Type`: The class looked for.
  234. """
  235. name = os.path.normpath(module_path)
  236. name = name.removesuffix(".py")
  237. name = name.replace(os.path.sep, ".")
  238. module_file: Path = Path(HF_MODULES_CACHE) / module_path
  239. with _HF_REMOTE_CODE_LOCK:
  240. if force_reload:
  241. sys.modules.pop(name, None)
  242. importlib.invalidate_caches()
  243. cached_module: ModuleType | None = sys.modules.get(name)
  244. module_spec = importlib.util.spec_from_file_location(name, location=module_file)
  245. # Hash the module file and all its relative imports to check if we need to reload it
  246. module_files: list[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
  247. module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
  248. module: ModuleType
  249. if cached_module is None:
  250. module = importlib.util.module_from_spec(module_spec)
  251. # insert it into sys.modules before any loading begins
  252. sys.modules[name] = module
  253. else:
  254. module = cached_module
  255. # reload in both cases, unless the module is already imported and the hash hits
  256. if getattr(module, "__transformers_module_hash__", "") != module_hash:
  257. module_spec.loader.exec_module(module)
  258. module.__transformers_module_hash__ = module_hash
  259. return getattr(module, class_name)
  260. def get_cached_module_file(
  261. pretrained_model_name_or_path: str | os.PathLike,
  262. module_file: str,
  263. cache_dir: str | os.PathLike | None = None,
  264. force_download: bool = False,
  265. proxies: dict[str, str] | None = None,
  266. token: bool | str | None = None,
  267. revision: str | None = None,
  268. local_files_only: bool = False,
  269. repo_type: str | None = None,
  270. _commit_hash: str | None = None,
  271. **deprecated_kwargs,
  272. ) -> str:
  273. """
  274. Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
  275. Transformers module.
  276. Args:
  277. pretrained_model_name_or_path (`str` or `os.PathLike`):
  278. This can be either:
  279. - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
  280. huggingface.co.
  281. - a path to a *directory* containing a configuration file saved using the
  282. [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
  283. module_file (`str`):
  284. The name of the module file containing the class to look for.
  285. cache_dir (`str` or `os.PathLike`, *optional*):
  286. Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
  287. cache should not be used.
  288. force_download (`bool`, *optional*, defaults to `False`):
  289. Whether or not to force to (re-)download the configuration files and override the cached versions if they
  290. exist.
  291. proxies (`dict[str, str]`, *optional*):
  292. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  293. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
  294. token (`str` or *bool*, *optional*):
  295. The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
  296. when running `hf auth login` (stored in `~/.huggingface`).
  297. revision (`str`, *optional*, defaults to `"main"`):
  298. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  299. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  300. identifier allowed by git.
  301. local_files_only (`bool`, *optional*, defaults to `False`):
  302. If `True`, will only try to load the tokenizer configuration from local files.
  303. repo_type (`str`, *optional*):
  304. Specify the repo type (useful when downloading from a space for instance).
  305. <Tip>
  306. Passing `token=True` is required when you want to use a private model.
  307. </Tip>
  308. Returns:
  309. `str`: The path to the module inside the cache.
  310. """
  311. if is_offline_mode() and not local_files_only:
  312. logger.info("Offline mode: forcing local_files_only=True")
  313. local_files_only = True
  314. # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
  315. pretrained_model_name_or_path = str(pretrained_model_name_or_path)
  316. is_local = os.path.isdir(pretrained_model_name_or_path)
  317. if is_local:
  318. submodule = _sanitize_module_name(os.path.basename(pretrained_model_name_or_path))
  319. else:
  320. submodule = os.path.sep.join(map(_sanitize_module_name, pretrained_model_name_or_path.split("/")))
  321. cached_module = try_to_load_from_cache(
  322. pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
  323. )
  324. new_files = []
  325. try:
  326. # Load from URL or cache if already cached
  327. resolved_module_file = cached_file(
  328. pretrained_model_name_or_path,
  329. module_file,
  330. cache_dir=cache_dir,
  331. force_download=force_download,
  332. proxies=proxies,
  333. local_files_only=local_files_only,
  334. token=token,
  335. revision=revision,
  336. repo_type=repo_type,
  337. _commit_hash=_commit_hash,
  338. )
  339. if not is_local and cached_module != resolved_module_file:
  340. new_files.append(module_file)
  341. except OSError:
  342. logger.info(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
  343. raise
  344. # Check we have all the requirements in our environment
  345. modules_needed = check_imports(resolved_module_file)
  346. # Now we move the module inside our cached dynamic modules.
  347. full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
  348. create_dynamic_module(full_submodule)
  349. submodule_path = Path(HF_MODULES_CACHE) / full_submodule
  350. if submodule == _sanitize_module_name(os.path.basename(pretrained_model_name_or_path)):
  351. # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
  352. # has changed since last copy.
  353. if not (submodule_path / module_file).exists() or not filecmp.cmp(
  354. resolved_module_file, str(submodule_path / module_file)
  355. ):
  356. (submodule_path / module_file).parent.mkdir(parents=True, exist_ok=True)
  357. shutil.copy(resolved_module_file, submodule_path / module_file)
  358. importlib.invalidate_caches()
  359. for module_needed in modules_needed:
  360. module_needed = Path(module_file).parent / f"{module_needed}.py"
  361. module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
  362. if not (submodule_path / module_needed).exists() or not filecmp.cmp(
  363. module_needed_file, str(submodule_path / module_needed)
  364. ):
  365. shutil.copy(module_needed_file, submodule_path / module_needed)
  366. importlib.invalidate_caches()
  367. else:
  368. # Get the commit hash
  369. commit_hash = extract_commit_hash(resolved_module_file, _commit_hash)
  370. # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
  371. # benefit of versioning.
  372. submodule_path = submodule_path / commit_hash
  373. full_submodule = full_submodule + os.path.sep + commit_hash
  374. full_submodule_module_file_path = os.path.join(full_submodule, module_file)
  375. create_dynamic_module(Path(full_submodule_module_file_path).parent)
  376. if not (submodule_path / module_file).exists():
  377. shutil.copy(resolved_module_file, submodule_path / module_file)
  378. importlib.invalidate_caches()
  379. # Make sure we also have every file with relative
  380. for module_needed in modules_needed:
  381. if not ((submodule_path / module_file).parent / f"{module_needed}.py").exists():
  382. get_cached_module_file(
  383. pretrained_model_name_or_path,
  384. f"{Path(module_file).parent / module_needed}.py",
  385. cache_dir=cache_dir,
  386. force_download=force_download,
  387. proxies=proxies,
  388. token=token,
  389. revision=revision,
  390. local_files_only=local_files_only,
  391. _commit_hash=commit_hash,
  392. )
  393. new_files.append(f"{module_needed}.py")
  394. if len(new_files) > 0 and revision is None:
  395. new_files = "\n".join([f"- {f}" for f in new_files])
  396. repo_type_str = "" if repo_type is None else f"{repo_type}s/"
  397. url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
  398. logger.warning(
  399. f"A new version of the following files was downloaded from {url}:\n{new_files}"
  400. "\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
  401. "versions of the code file, you can pin a revision."
  402. )
  403. return os.path.join(full_submodule, module_file)
  404. def get_class_from_dynamic_module(
  405. class_reference: str,
  406. pretrained_model_name_or_path: str | os.PathLike,
  407. cache_dir: str | os.PathLike | None = None,
  408. force_download: bool = False,
  409. proxies: dict[str, str] | None = None,
  410. token: bool | str | None = None,
  411. revision: str | None = None,
  412. local_files_only: bool = False,
  413. repo_type: str | None = None,
  414. code_revision: str | None = None,
  415. **kwargs,
  416. ) -> type:
  417. """
  418. Extracts a class from a module file, present in the local folder or repository of a model.
  419. <Tip warning={true}>
  420. Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
  421. therefore only be called on trusted repos.
  422. </Tip>
  423. Args:
  424. class_reference (`str`):
  425. The full name of the class to load, including its module and optionally its repo.
  426. pretrained_model_name_or_path (`str` or `os.PathLike`):
  427. This can be either:
  428. - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
  429. huggingface.co.
  430. - a path to a *directory* containing a configuration file saved using the
  431. [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
  432. This is used when `class_reference` does not specify another repo.
  433. module_file (`str`):
  434. The name of the module file containing the class to look for.
  435. class_name (`str`):
  436. The name of the class to import in the module.
  437. cache_dir (`str` or `os.PathLike`, *optional*):
  438. Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
  439. cache should not be used.
  440. force_download (`bool`, *optional*, defaults to `False`):
  441. Whether or not to force to (re-)download the configuration files and override the cached versions if they
  442. exist.
  443. proxies (`dict[str, str]`, *optional*):
  444. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  445. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
  446. token (`str` or `bool`, *optional*):
  447. The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
  448. when running `hf auth login` (stored in `~/.huggingface`).
  449. revision (`str`, *optional*, defaults to `"main"`):
  450. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  451. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  452. identifier allowed by git.
  453. local_files_only (`bool`, *optional*, defaults to `False`):
  454. If `True`, will only try to load the tokenizer configuration from local files.
  455. repo_type (`str`, *optional*):
  456. Specify the repo type (useful when downloading from a space for instance).
  457. code_revision (`str`, *optional*, defaults to `"main"`):
  458. The specific revision to use for the code on the Hub, if the code leaves in a different repository than the
  459. rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for
  460. storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
  461. <Tip>
  462. Passing `token=True` is required when you want to use a private model.
  463. </Tip>
  464. Returns:
  465. `typing.Type`: The class, dynamically imported from the module.
  466. Examples:
  467. ```python
  468. # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
  469. # module.
  470. cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")
  471. # Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
  472. # module.
  473. cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
  474. ```"""
  475. # Catch the name of the repo if it's specified in `class_reference`
  476. if "--" in class_reference:
  477. repo_id, class_reference = class_reference.split("--")
  478. else:
  479. repo_id = pretrained_model_name_or_path
  480. module_file, class_name = class_reference.split(".")
  481. if code_revision is None and pretrained_model_name_or_path == repo_id:
  482. code_revision = revision
  483. # And lastly we get the class inside our newly created module
  484. final_module = get_cached_module_file(
  485. repo_id,
  486. module_file + ".py",
  487. cache_dir=cache_dir,
  488. force_download=force_download,
  489. proxies=proxies,
  490. token=token,
  491. revision=code_revision,
  492. local_files_only=local_files_only,
  493. repo_type=repo_type,
  494. )
  495. return get_class_in_module(class_name, final_module, force_reload=force_download)
  496. def custom_object_save(obj: Any, folder: str | os.PathLike, config: dict | None = None) -> list[str]:
  497. """
  498. Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
  499. adds the proper fields in a config.
  500. Args:
  501. obj (`Any`): The object for which to save the module files.
  502. folder (`str` or `os.PathLike`): The folder where to save.
  503. config (`PreTrainedConfig` or dictionary, `optional`):
  504. A config in which to register the auto_map corresponding to this custom object.
  505. Returns:
  506. `list[str]`: The list of files saved.
  507. """
  508. if obj.__module__ == "__main__":
  509. logger.warning(
  510. f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put "
  511. "this code in a separate module so we can include it in the saved folder and make it easier to share via "
  512. "the Hub."
  513. )
  514. return
  515. def _set_auto_map_in_config(_config):
  516. module_name = obj.__class__.__module__
  517. last_module = module_name.split(".")[-1]
  518. full_name = f"{last_module}.{obj.__class__.__name__}"
  519. # Special handling for tokenizers
  520. if "Tokenizer" in full_name:
  521. slow_tokenizer_class = None
  522. fast_tokenizer_class = None
  523. if obj.__class__.__name__.endswith("Fast"):
  524. # Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute.
  525. fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
  526. if getattr(obj, "slow_tokenizer_class", None) is not None:
  527. slow_tokenizer = getattr(obj, "slow_tokenizer_class")
  528. slow_tok_module_name = slow_tokenizer.__module__
  529. last_slow_tok_module = slow_tok_module_name.split(".")[-1]
  530. slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}"
  531. else:
  532. # Slow tokenizer: no way to have the fast class
  533. slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
  534. full_name = (slow_tokenizer_class, fast_tokenizer_class)
  535. if isinstance(_config, dict):
  536. auto_map = _config.get("auto_map", {})
  537. auto_map[obj._auto_class] = full_name
  538. _config["auto_map"] = auto_map
  539. elif getattr(_config, "auto_map", None) is not None:
  540. _config.auto_map[obj._auto_class] = full_name
  541. else:
  542. _config.auto_map = {obj._auto_class: full_name}
  543. # Add object class to the config auto_map
  544. if isinstance(config, (list, tuple)):
  545. for cfg in config:
  546. _set_auto_map_in_config(cfg)
  547. elif config is not None:
  548. _set_auto_map_in_config(config)
  549. result = []
  550. # Copy module file to the output folder.
  551. object_file = sys.modules[obj.__module__].__file__
  552. dest_file = Path(folder) / (Path(object_file).name)
  553. shutil.copy(object_file, dest_file)
  554. result.append(dest_file)
  555. # Gather all relative imports recursively and make sure they are copied as well.
  556. for needed_file in get_relative_import_files(object_file):
  557. dest_file = Path(folder) / (Path(needed_file).name)
  558. shutil.copy(needed_file, dest_file)
  559. result.append(dest_file)
  560. return result
  561. def _raise_timeout_error(signum, frame):
  562. raise ValueError(
  563. "Loading this model requires you to execute custom code contained in the model repository on your local "
  564. "machine. Please set the option `trust_remote_code=True` to permit loading of this model."
  565. )
  566. TIME_OUT_REMOTE_CODE = 15
  567. def resolve_trust_remote_code(
  568. trust_remote_code, model_name, has_local_code, has_remote_code, error_message=None, upstream_repo=None
  569. ):
  570. """
  571. Resolves the `trust_remote_code` argument. If there is remote code to be loaded, the user must opt-in to loading
  572. it.
  573. Args:
  574. trust_remote_code (`bool` or `None`):
  575. User-defined `trust_remote_code` value.
  576. model_name (`str`):
  577. The name of the model repository in huggingface.co.
  578. has_local_code (`bool`):
  579. Whether the model has local code.
  580. has_remote_code (`bool`):
  581. Whether the model has remote code.
  582. error_message (`str`, *optional*):
  583. Custom error message to display if there is remote code to load and the user didn't opt-in. If unset, the error
  584. message will be regarding loading a model with custom code.
  585. Returns:
  586. The resolved `trust_remote_code` value.
  587. """
  588. if error_message is None:
  589. if upstream_repo is not None:
  590. error_message = (
  591. f"The repository {model_name} references custom code contained in {upstream_repo} which "
  592. f"must be executed to correctly load the model. You can inspect the repository "
  593. f"content at https://hf.co/{upstream_repo} .\n"
  594. )
  595. elif os.path.isdir(model_name):
  596. error_message = (
  597. f"The repository {model_name} contains custom code which must be executed "
  598. f"to correctly load the model. You can inspect the repository "
  599. f"content at {os.path.abspath(model_name)} .\n"
  600. )
  601. else:
  602. error_message = (
  603. f"The repository {model_name} contains custom code which must be executed "
  604. f"to correctly load the model. You can inspect the repository "
  605. f"content at https://hf.co/{model_name} .\n"
  606. )
  607. if trust_remote_code is None:
  608. if has_local_code:
  609. trust_remote_code = False
  610. elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
  611. prev_sig_handler = None
  612. try:
  613. prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
  614. signal.alarm(TIME_OUT_REMOTE_CODE)
  615. while trust_remote_code is None:
  616. answer = input(
  617. f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
  618. f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
  619. f"Do you wish to run the custom code? [y/N] "
  620. )
  621. if answer.lower() in ["yes", "y", "1"]:
  622. trust_remote_code = True
  623. elif answer.lower() in ["no", "n", "0", ""]:
  624. trust_remote_code = False
  625. signal.alarm(0)
  626. except Exception:
  627. # OS which does not support signal.SIGALRM
  628. raise ValueError(
  629. f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
  630. f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
  631. )
  632. finally:
  633. if prev_sig_handler is not None:
  634. signal.signal(signal.SIGALRM, prev_sig_handler)
  635. signal.alarm(0)
  636. elif has_remote_code:
  637. # For the CI which puts the timeout at 0
  638. _raise_timeout_error(None, None)
  639. if has_remote_code and not has_local_code and not trust_remote_code:
  640. raise ValueError(
  641. f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
  642. f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
  643. )
  644. return trust_remote_code
  645. def check_python_requirements(path_or_repo_id, requirements_file="requirements.txt", **kwargs):
  646. """
  647. Tries to locate `requirements_file` in a local folder or repo, and confirms that the environment has all the
  648. python dependencies installed.
  649. Args:
  650. path_or_repo_id (`str` or `os.PathLike`):
  651. This can be either:
  652. - a string, the *model id* of a model repo on huggingface.co.
  653. - a path to a *directory* potentially containing the file.
  654. kwargs (`dict[str, Any]`, *optional*):
  655. Additional arguments to pass to `cached_file`.
  656. """
  657. failed = [] # error messages regarding requirements
  658. try:
  659. requirements = cached_file(path_or_repo_id=path_or_repo_id, filename=requirements_file, **kwargs)
  660. with open(requirements, "r") as f:
  661. requirements = f.readlines()
  662. for requirement in requirements:
  663. requirement = requirement.strip()
  664. if not requirement or requirement.startswith("#"): # skip empty lines and comments
  665. continue
  666. try:
  667. # e.g. "torch>2.6.0" -> "torch", ">", "2.6.0"
  668. package_name, delimiter, version_number = split_package_version(requirement)
  669. except ValueError: # e.g. "torch", as opposed to "torch>2.6.0"
  670. package_name = requirement
  671. delimiter, version_number = None, None
  672. try:
  673. local_package_version = importlib.metadata.version(package_name)
  674. except importlib.metadata.PackageNotFoundError:
  675. failed.append(f"{requirement} (installed: None)")
  676. continue
  677. if delimiter is not None and version_number is not None:
  678. is_satisfied = VersionComparison.from_string(delimiter).value(
  679. version.parse(local_package_version), version.parse(version_number)
  680. )
  681. else:
  682. is_satisfied = True
  683. if not is_satisfied:
  684. failed.append(f"{requirement} (installed: {local_package_version})")
  685. except OSError: # no requirements.txt
  686. pass
  687. if failed:
  688. raise ImportError(
  689. f"Missing requirements in your local environment for `{path_or_repo_id}`:\n" + "\n".join(failed)
  690. )