feature_extraction_utils.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672
  1. # Copyright 2021 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Feature extraction saving/loading class for common feature extractors.
  16. """
  17. import copy
  18. import json
  19. import os
  20. from collections import UserDict
  21. from typing import TYPE_CHECKING, Any, TypeVar, Union
  22. import numpy as np
  23. from huggingface_hub import create_repo, is_offline_mode
  24. from .dynamic_module_utils import custom_object_save
  25. from .utils import (
  26. FEATURE_EXTRACTOR_NAME,
  27. PROCESSOR_NAME,
  28. PushToHubMixin,
  29. TensorType,
  30. _is_tensor_or_array_like,
  31. copy_func,
  32. is_numpy_array,
  33. is_torch_available,
  34. is_torch_device,
  35. is_torch_dtype,
  36. logging,
  37. requires_backends,
  38. safe_load_json_file,
  39. )
  40. from .utils.hub import cached_file
  41. if TYPE_CHECKING:
  42. from .feature_extraction_sequence_utils import SequenceFeatureExtractor
  43. logger = logging.get_logger(__name__)
  44. PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"]
  45. # type hinting: specifying the type of feature extractor class that inherits from FeatureExtractionMixin
  46. SpecificFeatureExtractorType = TypeVar("SpecificFeatureExtractorType", bound="FeatureExtractionMixin")
  47. class BatchFeature(UserDict):
  48. r"""
  49. Holds the output of the [`~SequenceFeatureExtractor.pad`] and feature extractor specific `__call__` methods.
  50. This class is derived from a python dictionary and can be used as a dictionary.
  51. Args:
  52. data (`dict`, *optional*):
  53. Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask',
  54. etc.).
  55. tensor_type (`Union[None, str, TensorType]`, *optional*):
  56. You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at
  57. initialization.
  58. skip_tensor_conversion (`list[str]` or `set[str]`, *optional*):
  59. List or set of keys that should NOT be converted to tensors, even when `tensor_type` is specified.
  60. """
  61. def __init__(
  62. self,
  63. data: dict[str, Any] | None = None,
  64. tensor_type: None | str | TensorType = None,
  65. skip_tensor_conversion: list[str] | set[str] | None = None,
  66. ):
  67. super().__init__(data)
  68. self.skip_tensor_conversion = skip_tensor_conversion
  69. self.convert_to_tensors(tensor_type=tensor_type)
  70. def __getitem__(self, item: str) -> Any:
  71. """
  72. If the key is a string, returns the value of the dict associated to `key` ('input_values', 'attention_mask',
  73. etc.).
  74. """
  75. if isinstance(item, str):
  76. return self.data[item]
  77. else:
  78. raise KeyError("Indexing with integers is not available when using Python based feature extractors")
  79. def __getattr__(self, item: str):
  80. try:
  81. return self.data[item]
  82. except KeyError:
  83. raise AttributeError
  84. def __getstate__(self):
  85. return {"data": self.data}
  86. def __setstate__(self, state):
  87. if "data" in state:
  88. self.data = state["data"]
  89. def _get_is_as_tensor_fns(self, tensor_type: str | TensorType | None = None):
  90. if tensor_type is None:
  91. return None, None
  92. # Convert to TensorType
  93. if not isinstance(tensor_type, TensorType):
  94. tensor_type = TensorType(tensor_type)
  95. if tensor_type == TensorType.PYTORCH:
  96. if not is_torch_available():
  97. raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
  98. import torch
  99. def as_tensor(value):
  100. if torch.is_tensor(value):
  101. return value
  102. # stack list of tensors if tensor_type is PyTorch (# torch.tensor() does not support list of tensors)
  103. if isinstance(value, (list, tuple)) and len(value) > 0 and torch.is_tensor(value[0]):
  104. return torch.stack(value)
  105. # convert list of numpy arrays to numpy array (stack) if tensor_type is Numpy
  106. if isinstance(value, (list, tuple)) and len(value) > 0:
  107. if isinstance(value[0], np.ndarray):
  108. value = np.array(value)
  109. elif (
  110. isinstance(value[0], (list, tuple))
  111. and len(value[0]) > 0
  112. and isinstance(value[0][0], np.ndarray)
  113. ):
  114. value = np.array(value)
  115. if isinstance(value, np.ndarray):
  116. return torch.from_numpy(value)
  117. else:
  118. return torch.tensor(value)
  119. is_tensor = torch.is_tensor
  120. else:
  121. def as_tensor(value, dtype=None):
  122. if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
  123. value_lens = [len(val) for val in value]
  124. if len(set(value_lens)) > 1 and dtype is None:
  125. # we have a ragged list so handle explicitly
  126. value = as_tensor([np.asarray(val) for val in value], dtype=object)
  127. return np.asarray(value, dtype=dtype)
  128. is_tensor = is_numpy_array
  129. return is_tensor, as_tensor
  130. def convert_to_tensors(
  131. self,
  132. tensor_type: str | TensorType | None = None,
  133. skip_tensor_conversion: list[str] | set[str] | None = None,
  134. ):
  135. """
  136. Convert the inner content to tensors.
  137. Args:
  138. tensor_type (`str` or [`~utils.TensorType`], *optional*):
  139. The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
  140. `None`, no modification is done.
  141. skip_tensor_conversion (`list[str]` or `set[str]`, *optional*):
  142. List or set of keys that should NOT be converted to tensors, even when `tensor_type` is specified.
  143. Note:
  144. Values that don't have an array-like structure (e.g., strings, dicts, lists of strings) are
  145. automatically skipped and won't be converted to tensors. Ragged arrays (lists of arrays with
  146. different lengths) are still attempted, though they may raise errors during conversion.
  147. """
  148. if tensor_type is None:
  149. return self
  150. is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
  151. skip_tensor_conversion = (
  152. skip_tensor_conversion if skip_tensor_conversion is not None else self.skip_tensor_conversion
  153. )
  154. # Do the tensor conversion in batch
  155. for key, value in self.items():
  156. # Skip keys explicitly marked for no conversion
  157. if skip_tensor_conversion and key in skip_tensor_conversion:
  158. continue
  159. # Skip values that are not array-like
  160. if not _is_tensor_or_array_like(value):
  161. continue
  162. try:
  163. if not is_tensor(value):
  164. tensor = as_tensor(value)
  165. self[key] = tensor
  166. except Exception as e:
  167. if key == "overflowing_values":
  168. raise ValueError(
  169. f"Unable to create tensor for '{key}' with overflowing values of different lengths. "
  170. f"Original error: {str(e)}"
  171. ) from e
  172. raise ValueError(
  173. f"Unable to convert output '{key}' (type: {type(value).__name__}) to tensor: {str(e)}\n"
  174. f"You can try:\n"
  175. f" 1. Use padding=True to ensure all outputs have the same shape\n"
  176. f" 2. Set return_tensors=None to return Python objects instead of tensors"
  177. ) from e
  178. return self
  179. def to(self, *args, **kwargs) -> "BatchFeature":
  180. """
  181. Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
  182. different `dtypes` and sending the `BatchFeature` to a different `device`.
  183. Args:
  184. args (`Tuple`):
  185. Will be passed to the `to(...)` function of the tensors.
  186. kwargs (`Dict`, *optional*):
  187. Will be passed to the `to(...)` function of the tensors.
  188. To enable asynchronous data transfer, set the `non_blocking` flag in `kwargs` (defaults to `False`).
  189. Returns:
  190. [`BatchFeature`]: The same instance after modification.
  191. """
  192. requires_backends(self, ["torch"])
  193. import torch
  194. device = kwargs.get("device")
  195. non_blocking = kwargs.get("non_blocking", False)
  196. # Check if the args are a device or a dtype
  197. if device is None and len(args) > 0:
  198. # device should be always the first argument
  199. arg = args[0]
  200. if is_torch_dtype(arg):
  201. # The first argument is a dtype
  202. pass
  203. elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
  204. device = arg
  205. else:
  206. # it's something else
  207. raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
  208. # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
  209. def maybe_to(v):
  210. # check if v is a floating point tensor
  211. if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
  212. # cast and send to device
  213. return v.to(*args, **kwargs)
  214. elif isinstance(v, torch.Tensor) and device is not None:
  215. return v.to(device=device, non_blocking=non_blocking)
  216. # recursively handle lists and tuples
  217. elif isinstance(v, (list, tuple)):
  218. return type(v)(maybe_to(item) for item in v)
  219. else:
  220. return v
  221. self.data = {k: maybe_to(v) for k, v in self.items()}
  222. return self
  223. class FeatureExtractionMixin(PushToHubMixin):
  224. """
  225. This is a feature extraction mixin used to provide saving/loading functionality for sequential and audio feature
  226. extractors.
  227. """
  228. _auto_class = None
  229. def __init__(self, **kwargs):
  230. """Set elements of `kwargs` as attributes."""
  231. # Pop "processor_class", it should not be saved in feature extractor config
  232. kwargs.pop("processor_class", None)
  233. # Additional attributes without default values
  234. for key, value in kwargs.items():
  235. try:
  236. setattr(self, key, value)
  237. except AttributeError as err:
  238. logger.error(f"Can't set {key} with value {value} for {self}")
  239. raise err
  240. @classmethod
  241. def from_pretrained(
  242. cls: type[SpecificFeatureExtractorType],
  243. pretrained_model_name_or_path: str | os.PathLike,
  244. cache_dir: str | os.PathLike | None = None,
  245. force_download: bool = False,
  246. local_files_only: bool = False,
  247. token: str | bool | None = None,
  248. revision: str = "main",
  249. **kwargs,
  250. ) -> SpecificFeatureExtractorType:
  251. r"""
  252. Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a
  253. derived class of [`SequenceFeatureExtractor`].
  254. Args:
  255. pretrained_model_name_or_path (`str` or `os.PathLike`):
  256. This can be either:
  257. - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
  258. huggingface.co.
  259. - a path to a *directory* containing a feature extractor file saved using the
  260. [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g.,
  261. `./my_model_directory/`.
  262. - a path to a saved feature extractor JSON *file*, e.g.,
  263. `./my_model_directory/preprocessor_config.json`.
  264. cache_dir (`str` or `os.PathLike`, *optional*):
  265. Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
  266. standard cache should not be used.
  267. force_download (`bool`, *optional*, defaults to `False`):
  268. Whether or not to force to (re-)download the feature extractor files and override the cached versions
  269. if they exist.
  270. proxies (`dict[str, str]`, *optional*):
  271. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  272. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
  273. token (`str` or `bool`, *optional*):
  274. The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
  275. the token generated when running `hf auth login` (stored in `~/.huggingface`).
  276. revision (`str`, *optional*, defaults to `"main"`):
  277. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  278. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  279. identifier allowed by git.
  280. <Tip>
  281. To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
  282. </Tip>
  283. return_unused_kwargs (`bool`, *optional*, defaults to `False`):
  284. If `False`, then this function returns just the final feature extractor object. If `True`, then this
  285. functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
  286. consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
  287. `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
  288. kwargs (`dict[str, Any]`, *optional*):
  289. The values in kwargs of any keys which are feature extractor attributes will be used to override the
  290. loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
  291. controlled by the `return_unused_kwargs` keyword parameter.
  292. Returns:
  293. A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`].
  294. Examples:
  295. ```python
  296. # We can't instantiate directly the base class *FeatureExtractionMixin* nor *SequenceFeatureExtractor* so let's show the examples on a
  297. # derived class: *Wav2Vec2FeatureExtractor*
  298. feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
  299. "facebook/wav2vec2-base-960h"
  300. ) # Download feature_extraction_config from huggingface.co and cache.
  301. feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
  302. "./test/saved_model/"
  303. ) # E.g. feature_extractor (or model) was saved using *save_pretrained('./test/saved_model/')*
  304. feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("./test/saved_model/preprocessor_config.json")
  305. feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
  306. "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False
  307. )
  308. assert feature_extractor.return_attention_mask is False
  309. feature_extractor, unused_kwargs = Wav2Vec2FeatureExtractor.from_pretrained(
  310. "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False, return_unused_kwargs=True
  311. )
  312. assert feature_extractor.return_attention_mask is False
  313. assert unused_kwargs == {"foo": False}
  314. ```"""
  315. kwargs["cache_dir"] = cache_dir
  316. kwargs["force_download"] = force_download
  317. kwargs["local_files_only"] = local_files_only
  318. kwargs["revision"] = revision
  319. if token is not None:
  320. kwargs["token"] = token
  321. feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
  322. return cls.from_dict(feature_extractor_dict, **kwargs)
  323. def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
  324. """
  325. Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the
  326. [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.
  327. Args:
  328. save_directory (`str` or `os.PathLike`):
  329. Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
  330. push_to_hub (`bool`, *optional*, defaults to `False`):
  331. Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
  332. repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
  333. namespace).
  334. kwargs (`dict[str, Any]`, *optional*):
  335. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
  336. """
  337. if os.path.isfile(save_directory):
  338. raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
  339. os.makedirs(save_directory, exist_ok=True)
  340. if push_to_hub:
  341. commit_message = kwargs.pop("commit_message", None)
  342. repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
  343. repo_id = create_repo(repo_id, exist_ok=True, **kwargs).repo_id
  344. files_timestamps = self._get_files_timestamps(save_directory)
  345. # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
  346. # loaded from the Hub.
  347. if self._auto_class is not None:
  348. custom_object_save(self, save_directory, config=self)
  349. # If we save using the predefined names, we can load using `from_pretrained`
  350. output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
  351. self.to_json_file(output_feature_extractor_file)
  352. logger.info(f"Feature extractor saved in {output_feature_extractor_file}")
  353. if push_to_hub:
  354. self._upload_modified_files(
  355. save_directory,
  356. repo_id,
  357. files_timestamps,
  358. commit_message=commit_message,
  359. token=kwargs.get("token"),
  360. )
  361. return [output_feature_extractor_file]
  362. @classmethod
  363. def get_feature_extractor_dict(
  364. cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
  365. ) -> tuple[dict[str, Any], dict[str, Any]]:
  366. """
  367. From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
  368. feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] using `from_dict`.
  369. Parameters:
  370. pretrained_model_name_or_path (`str` or `os.PathLike`):
  371. The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
  372. Returns:
  373. `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the feature extractor object.
  374. """
  375. cache_dir = kwargs.pop("cache_dir", None)
  376. force_download = kwargs.pop("force_download", False)
  377. proxies = kwargs.pop("proxies", None)
  378. subfolder = kwargs.pop("subfolder", None)
  379. token = kwargs.pop("token", None)
  380. local_files_only = kwargs.pop("local_files_only", False)
  381. revision = kwargs.pop("revision", None)
  382. from_pipeline = kwargs.pop("_from_pipeline", None)
  383. from_auto_class = kwargs.pop("_from_auto", False)
  384. user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class}
  385. if from_pipeline is not None:
  386. user_agent["using_pipeline"] = from_pipeline
  387. if is_offline_mode() and not local_files_only:
  388. logger.info("Offline mode: forcing local_files_only=True")
  389. local_files_only = True
  390. pretrained_model_name_or_path = str(pretrained_model_name_or_path)
  391. is_local = os.path.isdir(pretrained_model_name_or_path)
  392. if os.path.isdir(pretrained_model_name_or_path):
  393. feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
  394. if os.path.isfile(pretrained_model_name_or_path):
  395. resolved_feature_extractor_file = pretrained_model_name_or_path
  396. resolved_processor_file = None
  397. is_local = True
  398. else:
  399. feature_extractor_file = FEATURE_EXTRACTOR_NAME
  400. try:
  401. # Load from local folder or from cache or download from model Hub and cache
  402. resolved_processor_file = cached_file(
  403. pretrained_model_name_or_path,
  404. filename=PROCESSOR_NAME,
  405. cache_dir=cache_dir,
  406. force_download=force_download,
  407. proxies=proxies,
  408. local_files_only=local_files_only,
  409. token=token,
  410. user_agent=user_agent,
  411. revision=revision,
  412. subfolder=subfolder,
  413. _raise_exceptions_for_missing_entries=False,
  414. )
  415. resolved_feature_extractor_file = cached_file(
  416. pretrained_model_name_or_path,
  417. filename=feature_extractor_file,
  418. cache_dir=cache_dir,
  419. force_download=force_download,
  420. proxies=proxies,
  421. local_files_only=local_files_only,
  422. token=token,
  423. user_agent=user_agent,
  424. revision=revision,
  425. subfolder=subfolder,
  426. _raise_exceptions_for_missing_entries=False,
  427. )
  428. except OSError:
  429. # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
  430. # the original exception.
  431. raise
  432. except Exception:
  433. # For any other exception, we throw a generic error.
  434. raise OSError(
  435. f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
  436. " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
  437. f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
  438. f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
  439. )
  440. # Load feature_extractor dict. Priority goes as (nested config if found -> image processor config)
  441. # We are downloading both configs because almost all models have a `processor_config.json` but
  442. # not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style
  443. feature_extractor_dict = None
  444. if resolved_processor_file is not None:
  445. processor_dict = safe_load_json_file(resolved_processor_file)
  446. if "feature_extractor" in processor_dict or "audio_processor" in processor_dict:
  447. feature_extractor_dict = processor_dict.get("feature_extractor", processor_dict.get("audio_processor"))
  448. if resolved_feature_extractor_file is not None and feature_extractor_dict is None:
  449. feature_extractor_dict = safe_load_json_file(resolved_feature_extractor_file)
  450. if feature_extractor_dict is None:
  451. raise OSError(
  452. f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
  453. " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
  454. f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
  455. f" directory containing a {feature_extractor_file} file"
  456. )
  457. if is_local:
  458. logger.info(f"loading configuration file {resolved_feature_extractor_file}")
  459. else:
  460. logger.info(
  461. f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
  462. )
  463. return feature_extractor_dict, kwargs
  464. @classmethod
  465. def from_dict(
  466. cls, feature_extractor_dict: dict[str, Any], **kwargs
  467. ) -> Union["FeatureExtractionMixin", tuple["FeatureExtractionMixin", dict[str, Any]]]:
  468. """
  469. Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of
  470. parameters.
  471. Args:
  472. feature_extractor_dict (`dict[str, Any]`):
  473. Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be
  474. retrieved from a pretrained checkpoint by leveraging the
  475. [`~feature_extraction_utils.FeatureExtractionMixin.to_dict`] method.
  476. kwargs (`dict[str, Any]`):
  477. Additional parameters from which to initialize the feature extractor object.
  478. Returns:
  479. [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature extractor object instantiated from those
  480. parameters.
  481. """
  482. return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
  483. # Update feature_extractor with kwargs if needed
  484. to_remove = []
  485. for key, value in kwargs.items():
  486. if key in feature_extractor_dict:
  487. feature_extractor_dict[key] = value
  488. to_remove.append(key)
  489. for key in to_remove:
  490. kwargs.pop(key, None)
  491. feature_extractor = cls(**feature_extractor_dict)
  492. logger.info(f"Feature extractor {feature_extractor}")
  493. if return_unused_kwargs:
  494. return feature_extractor, kwargs
  495. else:
  496. return feature_extractor
  497. def to_dict(self) -> dict[str, Any]:
  498. """
  499. Serializes this instance to a Python dictionary. Returns:
  500. `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
  501. """
  502. output = copy.deepcopy(self.__dict__)
  503. output["feature_extractor_type"] = self.__class__.__name__
  504. if "mel_filters" in output:
  505. del output["mel_filters"]
  506. if "window" in output:
  507. del output["window"]
  508. return output
  509. @classmethod
  510. def from_json_file(cls, json_file: str | os.PathLike) -> "FeatureExtractionMixin":
  511. """
  512. Instantiates a feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] from the path to
  513. a JSON file of parameters.
  514. Args:
  515. json_file (`str` or `os.PathLike`):
  516. Path to the JSON file containing the parameters.
  517. Returns:
  518. A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature_extractor
  519. object instantiated from that JSON file.
  520. """
  521. with open(json_file, encoding="utf-8") as reader:
  522. text = reader.read()
  523. feature_extractor_dict = json.loads(text)
  524. return cls(**feature_extractor_dict)
  525. def to_json_string(self) -> str:
  526. """
  527. Serializes this instance to a JSON string.
  528. Returns:
  529. `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
  530. """
  531. dictionary = self.to_dict()
  532. for key, value in dictionary.items():
  533. if isinstance(value, np.ndarray):
  534. dictionary[key] = value.tolist()
  535. return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
  536. def to_json_file(self, json_file_path: str | os.PathLike):
  537. """
  538. Save this instance to a JSON file.
  539. Args:
  540. json_file_path (`str` or `os.PathLike`):
  541. Path to the JSON file in which this feature_extractor instance's parameters will be saved.
  542. """
  543. with open(json_file_path, "w", encoding="utf-8") as writer:
  544. writer.write(self.to_json_string())
  545. def __repr__(self):
  546. return f"{self.__class__.__name__} {self.to_json_string()}"
  547. @classmethod
  548. def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
  549. """
  550. Register this class with a given auto class. This should only be used for custom feature extractors as the ones
  551. in the library are already mapped with `AutoFeatureExtractor`.
  552. Args:
  553. auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):
  554. The auto class to register this new feature extractor with.
  555. """
  556. if not isinstance(auto_class, str):
  557. auto_class = auto_class.__name__
  558. import transformers.models.auto as auto_module
  559. if not hasattr(auto_module, auto_class):
  560. raise ValueError(f"{auto_class} is not a valid auto class.")
  561. cls._auto_class = auto_class
  562. FeatureExtractionMixin.push_to_hub = copy_func(FeatureExtractionMixin.push_to_hub)
  563. if FeatureExtractionMixin.push_to_hub.__doc__ is not None:
  564. FeatureExtractionMixin.push_to_hub.__doc__ = FeatureExtractionMixin.push_to_hub.__doc__.format(
  565. object="feature extractor", object_class="AutoFeatureExtractor", object_files="feature extractor file"
  566. )