wrap.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. #
  4. # This source code is licensed under the BSD license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import contextlib
  7. import copy
  8. from abc import ABC, abstractmethod
  9. from collections.abc import Callable, Generator, Iterable, Sequence
  10. from typing import Any, cast
  11. import torch.nn as nn
  12. __all__ = [
  13. "always_wrap_policy",
  14. "lambda_auto_wrap_policy",
  15. "transformer_auto_wrap_policy",
  16. "size_based_auto_wrap_policy",
  17. "enable_wrap",
  18. "wrap",
  19. "CustomPolicy",
  20. "ModuleWrapPolicy",
  21. ]
  22. # NOTE: We intentionally keep this function simple and isolate the complexity
  23. # to `fn` to enable using this function generically. We may move this to a
  24. # non-FSDP-specific folder and/or make it public in the future.
  25. def _post_order_apply(
  26. root_module: nn.Module,
  27. fn: Callable[[nn.Module], nn.Module | None],
  28. ):
  29. """
  30. This applies ``fn`` to every module in the module tree of ``root_module``
  31. following a post-order traversal. If ``fn`` returns an :class:`nn.Module`,
  32. then this replaces the original module with the newly returned one in the
  33. tree. Otherwise, ``fn`` should return ``None``, in which case the module is
  34. not changed.
  35. """
  36. # Track visited modules to avoid visiting shared modules multiple times
  37. visited_modules: set[nn.Module] = {root_module}
  38. def _post_order_apply_inner(
  39. module: nn.Module,
  40. module_name: str,
  41. parent_module: nn.Module | None,
  42. ):
  43. for child_module_name, child_module in module.named_children():
  44. if child_module not in visited_modules:
  45. visited_modules.add(child_module)
  46. _post_order_apply_inner(child_module, child_module_name, module)
  47. optional_module = fn(module)
  48. if optional_module is not None:
  49. if not isinstance(parent_module, nn.Module):
  50. raise AssertionError(
  51. "Non-root modules should have their parent module set but got "
  52. f"{parent_module} for {module}"
  53. )
  54. if not module_name:
  55. raise AssertionError(
  56. "Non-root modules should have their module name set but got "
  57. f"an empty module name for {module}"
  58. )
  59. if not isinstance(optional_module, nn.Module):
  60. raise AssertionError(
  61. f"fn should return None or an nn.Module but got {optional_module}"
  62. )
  63. setattr(parent_module, module_name, optional_module)
  64. _post_order_apply_inner(root_module, "", None)
  65. def _construct_wrap_fn(
  66. root_module: nn.Module,
  67. target_module_to_kwargs: dict[nn.Module, dict[str, Any]],
  68. fsdp_fn: Callable,
  69. ) -> Callable[[nn.Module], nn.Module | None]:
  70. """
  71. This constructs the "wrap" function to pass to :func:`_post_order_apply`
  72. based on ``target_module_to_kwargs``, which should be constructed from the
  73. wrapping policy.
  74. """
  75. def fn(module: nn.Module) -> nn.Module | None:
  76. # Explicitly avoid wrapping the root module since for FSDP, it is
  77. # handled by the caller
  78. if module in target_module_to_kwargs and module is not root_module:
  79. kwargs = target_module_to_kwargs[module]
  80. return fsdp_fn(module, **kwargs)
  81. return None
  82. return fn
  83. def _run_mixed_precision_override_policy(
  84. root_module: nn.Module,
  85. module_classes: Iterable[type[nn.Module]],
  86. ignored_modules: set[nn.Module],
  87. root_kwargs: dict[str, Any],
  88. target_module_to_kwargs: dict[nn.Module, dict[str, Any]],
  89. ):
  90. module_classes_tuple = tuple(set(module_classes))
  91. for module in root_module.modules():
  92. if module in ignored_modules:
  93. continue
  94. elif isinstance(module, module_classes_tuple):
  95. # This policy overrides any existing policy
  96. if module not in target_module_to_kwargs:
  97. # Only inherit from the root kwargs if not already specified
  98. target_module_to_kwargs[module] = root_kwargs
  99. target_module_to_kwargs[module]["mixed_precision"] = None
  100. return target_module_to_kwargs
  101. def always_wrap_policy(*args, **kwargs) -> bool:
  102. """
  103. A simple recursive wrap policy that always returns ``True``. This means
  104. that every submodule is wrapped by the wrapper class in
  105. :func:`_recursive_wrap`.
  106. """
  107. return True
  108. class _Policy(ABC):
  109. """
  110. This defines an abstract base class that represents a policy for applying
  111. a module-level API.
  112. """
  113. @abstractmethod
  114. def _run_policy(
  115. self,
  116. root_module: nn.Module,
  117. ignored_modules: set[nn.Module],
  118. root_kwargs: dict[str, Any],
  119. ) -> dict[nn.Module, dict[str, Any]]:
  120. """
  121. This should return a dict ``target_module_to_kwargs`` that maps from
  122. each target module to wrap to its kwargs.
  123. """
  124. ...
  125. def _module_wrap_policy(
  126. module: nn.Module,
  127. recurse: bool,
  128. nonwrapped_numel: int,
  129. module_classes: set[type[nn.Module]],
  130. ) -> bool:
  131. """
  132. This auto wrap policy wraps every module that is an instance of any type in
  133. ``module_classes`` as its own FSDP instance. The root module given by
  134. ``module`` is always wrapped as an FSDP instance regardless. Since the
  135. wrapping proceeds bottom up, each FSDP instance manages the parameters in
  136. its subtree excluding any already managed by a child FSDP instance.
  137. Args:
  138. module (nn.Module): Current module being considered.
  139. recurse (bool): If ``False``, then this function must decide whether
  140. ``module`` should be wrapped as an FSDP instance or not. If
  141. ``True``, then the function is still recursing down the module
  142. tree as a part of the DFS.
  143. nonwrapped_numel (int): Parameter numel not yet wrapped.
  144. module_classes (Set[Type[nn.Module]]): Set of module classes that are
  145. wrapped as FSDP instances.
  146. Returns:
  147. ``True`` if ``recurse=True``, and whether ``module`` should be wrapped
  148. if ``recurse=False``.
  149. """
  150. if recurse:
  151. return True # always recurse
  152. return isinstance(module, tuple(module_classes))
  153. class ModuleWrapPolicy(_Policy):
  154. """
  155. This policy applies to every module of the specified module classes,
  156. passing in the kwargs given to the root.
  157. """
  158. def __init__(self, module_classes: Iterable[type[nn.Module]]):
  159. module_classes_set = set(module_classes)
  160. self._module_classes = module_classes_set
  161. self._module_classes_str = str(module_classes_set)
  162. def _run_policy(
  163. self,
  164. root_module: nn.Module,
  165. ignored_modules: set[nn.Module],
  166. root_kwargs: dict[str, Any],
  167. ) -> dict[nn.Module, dict[str, Any]]:
  168. module_classes = tuple(self._module_classes)
  169. target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {}
  170. for module in root_module.modules():
  171. if module in ignored_modules:
  172. continue
  173. elif isinstance(module, module_classes):
  174. # Shallow copy to avoid coupling changes across modules
  175. target_module_to_kwargs[module] = copy.copy(root_kwargs)
  176. return target_module_to_kwargs
  177. def __call__(self, module, recurse, *args, **kwargs):
  178. # nonwrapped_numel is not used.
  179. return _module_wrap_policy(
  180. module, recurse, nonwrapped_numel=-1, module_classes=self._module_classes
  181. )
  182. def __repr__(self) -> str:
  183. return super().__repr__() + f"({self._module_classes_str})"
  184. class CustomPolicy(_Policy):
  185. """
  186. This policy takes in a lambda function that maps a given ``nn.Module`` to
  187. either ``False``, ``True``, or a kwarg dictionary.
  188. - If the function returns ``False`` or an empty dictionary, then the module
  189. does not have the API applied.
  190. - If the function returns ``True``, then the module has the API applied
  191. with the root's kwargs.
  192. - If the function returns a non-empty dictionary, then the module has the
  193. API applied, and the dictionary overrides the root's kwargs.
  194. Example::
  195. >>> # xdoctest: +SKIP("undefined variables")
  196. >>> model = init_transformer_model(...)
  197. >>> def lambda_fn(module: nn.Module):
  198. >>> if module is model.lm_head:
  199. >>> return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}
  200. >>> elif isinstance(module, TransformerBlock):
  201. >>> return True
  202. >>> return False
  203. >>> policy = CustomPolicy(lambda_fn)
  204. >>> fsdp_model = FSDP(model, auto_wrap_policy=policy)
  205. """
  206. def __init__(self, lambda_fn: Callable[[nn.Module], bool | dict[str, Any]]):
  207. self._lambda_fn = lambda_fn
  208. def _run_policy(
  209. self,
  210. root_module: nn.Module,
  211. ignored_modules: set[nn.Module],
  212. root_kwargs: dict[str, Any],
  213. ) -> dict[nn.Module, dict[str, Any]]:
  214. target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {}
  215. for module in root_module.modules():
  216. if module in ignored_modules:
  217. continue
  218. res = self._lambda_fn(module)
  219. if not isinstance(res, (dict, bool)):
  220. raise ValueError(
  221. "The lambda_fn passed to CustomPolicy should return "
  222. f"False/True or a kwarg dict, but it returned {res}"
  223. )
  224. if not res:
  225. continue
  226. kwargs = copy.copy(root_kwargs)
  227. if isinstance(res, dict):
  228. # Override the root kwargs with the ones specified by the
  229. # lambda function
  230. kwargs.update(res)
  231. target_module_to_kwargs[module] = kwargs
  232. return target_module_to_kwargs
  233. def lambda_auto_wrap_policy(
  234. module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable
  235. ) -> bool:
  236. """
  237. A convenient auto wrap policy to wrap submodules based on an arbitrary user
  238. function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
  239. a `wrapper_cls` unit.
  240. Return if a module should be wrapped during auto wrapping.
  241. The first three parameters are required by :func:`_recursive_wrap`.
  242. Args:
  243. module (nn.Module): Current module being considered.
  244. recurse (bool): If ``False``, then this function must decide whether
  245. ``module`` should be wrapped as an FSDP instance or not. If
  246. ``True``, then the function is still recursing down the module
  247. tree as a part of the DFS.
  248. nonwrapped_numel (int): Parameter numel not yet wrapped.
  249. lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
  250. this module will be wrapped.
  251. """
  252. if recurse:
  253. return True # always recurse
  254. return lambda_fn(module)
  255. def transformer_auto_wrap_policy(
  256. module: nn.Module,
  257. recurse: bool,
  258. nonwrapped_numel: int,
  259. transformer_layer_cls: set[type[nn.Module]],
  260. ) -> bool:
  261. """
  262. See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the
  263. same as ``module_classes``. Note that shared parameters must be wrapped in
  264. the same FSDP instance, so this auto wrap policy can help wrap shared
  265. embeddings into the same FSDP instance for transformer models.
  266. """
  267. return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls)
  268. def _wrap_module_cls_individually(
  269. module: nn.Module, module_classes: Sequence[type], recurse: bool, *args, **kwargs
  270. ):
  271. if recurse:
  272. # always recurse
  273. return True
  274. else:
  275. # if not recursing, decide whether we should wrap based on whether the type of module
  276. # is in `module_classes`.
  277. return isinstance(module, tuple(module_classes))
  278. def _or_policy(
  279. module: nn.Module,
  280. recurse: bool,
  281. nonwrapped_numel: int,
  282. policies,
  283. ) -> bool:
  284. """
  285. A policy that wraps ``module`` if any policy in the passed in iterable of
  286. ``policies`` returns ``True``.
  287. """
  288. return any(
  289. policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel)
  290. for policy in policies
  291. )
  292. def size_based_auto_wrap_policy(
  293. module: nn.Module,
  294. recurse: bool,
  295. nonwrapped_numel: int,
  296. # Additional custom arguments
  297. min_num_params: int = int(1e8),
  298. force_leaf_modules: set[type[nn.Module]] | None = None,
  299. exclude_wrap_modules: set[type[nn.Module]] | None = None,
  300. ) -> bool:
  301. """
  302. A size-based auto wrap policy.
  303. Args:
  304. module (nn.Module): Current module being considered.
  305. recurse (bool): If ``False``, then this function must decide whether
  306. ``module`` should be wrapped as an FSDP instance or not. If
  307. ``True``, then the function is still recursing down the module
  308. tree as a part of the DFS.
  309. nonwrapped_numel (int): Parameter numel not yet wrapped.
  310. min_num_params (int): Customizable policy input that controls the size
  311. threshold over which a module is ready to be wrapped. This is in
  312. units of numel.
  313. force_leaf_modules (Optional[set[type[nn.Module]]]): Set of module types to keep
  314. as leaves, i.e. their children will never be wrapped.
  315. exclude_wrap_modules (Optional[set[type[nn.Module]]]): Set of module types to be
  316. excluded in wrapping.
  317. Returns:
  318. Whether ``module`` should be wrapped.
  319. """
  320. force_leaf_modules = (
  321. size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined]
  322. if force_leaf_modules is None
  323. else force_leaf_modules
  324. )
  325. exclude_wrap_modules = (
  326. size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined]
  327. if exclude_wrap_modules is None
  328. else exclude_wrap_modules
  329. )
  330. # Keep the argument `min_num_params` for BC for now, but it represents the
  331. # minimum non-wrapped *numel* before triggering a wrapping
  332. min_nonwrapped_numel = min_num_params
  333. is_large = nonwrapped_numel >= min_nonwrapped_numel
  334. if recurse:
  335. # We should recurse if the module is big enough but not in force_leaf_modules list.
  336. return is_large and not isinstance(module, tuple(force_leaf_modules))
  337. else:
  338. # If we are not recursing, determine if we should wrap.
  339. return is_large and not isinstance(module, tuple(exclude_wrap_modules))
  340. # Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported.
  341. size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore[attr-defined]
  342. size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore[attr-defined]
  343. @contextlib.contextmanager
  344. def enable_wrap(
  345. *, wrapper_cls: Any, **wrapper_kwargs: Any
  346. ) -> Generator[None, None, None]:
  347. """
  348. Context manager to wrap modules using a wrapper.
  349. Useful for when you'd like to apply the same configuration arguments to all
  350. child modules that you wrap. A particularly important use case is wrapping
  351. large layers so that they get sharded (in-place) during initialization, to
  352. avoid running out of system memory. Large layers can indicate that they
  353. should be sharded via the ``wrap`` annotation and this context manager can
  354. provide the exact configuration for these nested instances.
  355. Usage::
  356. with enable_wrap(wrapper_cls, **params):
  357. # Wraps layer in FSDP by default if within context
  358. self.l1 = wrap(torch.nn.Linear(5, 5))
  359. Args:
  360. wrapper_cls:
  361. Class that `wrap` annotation will `wrap` modules with, such as
  362. `FullyShardedDataParallel`.
  363. **wrapper_kwargs:
  364. Configuration settings that will be passed to all ``wrap``
  365. instances inside the context
  366. """
  367. kwargs = {
  368. "wrapper_cls": wrapper_cls,
  369. **wrapper_kwargs,
  370. }
  371. with _ConfigAutoWrap(**kwargs):
  372. yield
  373. def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
  374. """
  375. Annotate that a module should be wrapped. Annotated modules will only be
  376. wrapped if inside of an :func:`enable_wrap` context manager. This allows
  377. a module to be initialized both with and without a wrapper without code
  378. change.
  379. The class that this function wraps the passed in ``nn.Module`` with is the
  380. passed in ``wrapper_cls`` argument into ``enable_wrap``. Both
  381. ``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct
  382. the ``wrapper_cls`` instance. In the case of duplicate kwargs in
  383. ``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be
  384. respected.
  385. Usage::
  386. with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
  387. # Wraps layer in FSDP by default if within context
  388. self.l1 = wrap(torch.nn.Linear(5, 5))
  389. Args:
  390. module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
  391. **wrap_overrides: configuration overrides that will take priority over
  392. the values provided by the :func:`enable_wrap` context
  393. """
  394. if _ConfigAutoWrap.in_autowrap_context:
  395. if _ConfigAutoWrap.wrapper_cls is None:
  396. raise AssertionError("Expected _ConfigAutoWrap.wrapper_cls to be set")
  397. wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides}
  398. return _wrap(
  399. module,
  400. _ConfigAutoWrap.wrapper_cls,
  401. **wrap_overrides,
  402. )
  403. return module
  404. def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
  405. if wrapper_cls is None:
  406. raise AssertionError("Expected wrapper_cls to be set")
  407. if hasattr(module, "_wrap_overrides"):
  408. # If module has a _wrap_overrides attribute, we force overriding the
  409. # FSDP config with these attributes for this module. Currently this
  410. # is only used to disable mixed precision for BatchNorm when
  411. # auto_wrapping.
  412. overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type, dict-item]
  413. return wrapper_cls(module, **overrides)
  414. return wrapper_cls(module, **kwargs)
  415. def _recursive_wrap(
  416. module: nn.Module,
  417. auto_wrap_policy: Callable,
  418. wrapper_cls: Callable,
  419. ignored_modules: set[nn.Module],
  420. ignored_params: set[nn.Parameter],
  421. only_wrap_children: bool = False,
  422. **kwargs: Any,
  423. ) -> tuple[nn.Module, int]:
  424. """
  425. Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns
  426. ``True`` with ``wrapper_cls``.
  427. Args:
  428. module (nn.Module): Module to recursively wrap.
  429. auto_wrap_policy (Callable): A callable representing a policy that
  430. determines which modules to recursively wrap with ``wrapper_cls``.
  431. ignored_modules (set[torch.nn.Module]): Modules to ignore when
  432. wrapping.
  433. ignored_params (set[torch.nn.Parameter]): Parameters to ignore when
  434. wrapping; these should be the parameters contained in the modules
  435. in ``ignored_modules``.
  436. Returns:
  437. (nn.Module, int):
  438. ``module`` after wrapping and the numel recursively wrapped.
  439. """
  440. if auto_wrap_policy is None:
  441. raise AssertionError("Must specify auto_wrap_policy.")
  442. if wrapper_cls is None:
  443. raise AssertionError("Must specify wrapper_cls")
  444. # Make sure no child is already wrapped.
  445. for _, child in module.named_modules():
  446. if child in ignored_modules:
  447. continue
  448. try:
  449. if isinstance(child, cast(type, wrapper_cls)):
  450. raise AssertionError(
  451. f"Child module {child} is already wrapped by {wrapper_cls}"
  452. )
  453. except TypeError:
  454. # wrapper_cls is a function as opposed to a class type, just bypass above check.
  455. pass
  456. # We count all params, assuming none of them are already wrapped.
  457. nonwrapped_numel = sum(
  458. p.numel() for p in module.parameters() if p not in ignored_params
  459. )
  460. if auto_wrap_policy is None:
  461. raise AssertionError("Expected auto_wrap_policy to be set")
  462. if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel):
  463. total_wrapped_numel = 0
  464. # Iterate through the children, recursively wrap if necessary
  465. for name, child in module.named_children():
  466. if child in ignored_modules:
  467. continue
  468. wrapped_child, num_wrapped_params = _recursive_wrap(
  469. module=child,
  470. auto_wrap_policy=auto_wrap_policy,
  471. wrapper_cls=wrapper_cls,
  472. ignored_modules=ignored_modules,
  473. ignored_params=ignored_params,
  474. **kwargs,
  475. )
  476. setattr(module, name, wrapped_child)
  477. # Keep track of how many parameters have been wrapped
  478. total_wrapped_numel += num_wrapped_params
  479. # decide if we need to wrap the current module,
  480. # since the left over parameters exceed the number of params to wrap
  481. remainder = nonwrapped_numel - total_wrapped_numel
  482. if not only_wrap_children and auto_wrap_policy(
  483. module=module, recurse=False, nonwrapped_numel=remainder
  484. ):
  485. # Leaf node or final wrapping of the remainder both happen here.
  486. return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
  487. else:
  488. return module, total_wrapped_numel
  489. return module, 0
  490. class _ConfigAutoWrap:
  491. """
  492. Helper class to wrap modules based on default config args via a context manager.
  493. See :func:`enable_wrap` for more information.
  494. """
  495. in_autowrap_context: bool = False # Context flag
  496. wrapper_cls: Callable | None = None # The wrapper class
  497. kwargs: dict[str, Any] = {} # Wrapper's args
  498. def __init__(self, **kwargs: dict[str, Any]):
  499. self.kwargs = kwargs
  500. @staticmethod
  501. def enable_autowrap_context(kwargs: Any) -> None:
  502. if _ConfigAutoWrap.in_autowrap_context:
  503. raise NotImplementedError(
  504. "You are already within an autowrap context and we currently do not supported nested autowrap."
  505. )
  506. _ConfigAutoWrap.in_autowrap_context = True
  507. # Get and save the wrapper cls for the context.
  508. if "wrapper_cls" not in kwargs:
  509. raise AssertionError(
  510. "Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
  511. )
  512. _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
  513. del kwargs["wrapper_cls"]
  514. # Save the rest.
  515. _ConfigAutoWrap.kwargs = kwargs
  516. @staticmethod
  517. def disable_autowrap_context() -> None:
  518. _ConfigAutoWrap.in_autowrap_context = False
  519. _ConfigAutoWrap.wrapper_cls = None
  520. _ConfigAutoWrap.kwargs = {}
  521. def __enter__(self) -> None:
  522. self.enable_autowrap_context(self.kwargs)
  523. def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
  524. self.disable_autowrap_context()