utils.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # mypy: allow-untyped-defs
  2. import collections
  3. from itertools import repeat
  4. from typing import Any
  5. __all__ = ["consume_prefix_in_state_dict_if_present"]
  6. def _ntuple(n, name="parse"):
  7. def parse(x):
  8. if isinstance(x, collections.abc.Iterable):
  9. return tuple(x)
  10. return tuple(repeat(x, n))
  11. parse.__name__ = name
  12. return parse
  13. _single = _ntuple(1, "_single")
  14. _pair = _ntuple(2, "_pair")
  15. _triple = _ntuple(3, "_triple")
  16. _quadruple = _ntuple(4, "_quadruple")
  17. def _reverse_repeat_tuple(t, n):
  18. r"""Reverse the order of `t` and repeat each element for `n` times.
  19. This can be used to translate padding arg used by Conv and Pooling modules
  20. to the ones used by `F.pad`.
  21. """
  22. return tuple(x for x in reversed(t) for _ in range(n))
  23. def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]:
  24. import torch
  25. if isinstance(out_size, (int, torch.SymInt)):
  26. return out_size
  27. if len(defaults) <= len(out_size):
  28. raise ValueError(f"Input dimension should be at least {len(out_size) + 1}")
  29. return [
  30. v if v is not None else d
  31. for v, d in zip(out_size, defaults[-len(out_size) :], strict=False)
  32. ]
  33. def consume_prefix_in_state_dict_if_present(
  34. state_dict: dict[str, Any],
  35. prefix: str,
  36. ) -> None:
  37. r"""Strip the prefix in state_dict in place, if any.
  38. .. note::
  39. Given a `state_dict` from a DP/DDP model, a local model can load it by applying
  40. `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling
  41. :meth:`torch.nn.Module.load_state_dict`.
  42. Args:
  43. state_dict (OrderedDict): a state-dict to be loaded to the model.
  44. prefix (str): prefix.
  45. """
  46. keys = list(state_dict.keys())
  47. for key in keys:
  48. if key.startswith(prefix):
  49. newkey = key[len(prefix) :]
  50. state_dict[newkey] = state_dict.pop(key)
  51. # also strip the prefix in metadata if any.
  52. if hasattr(state_dict, "_metadata"):
  53. keys = list(state_dict._metadata.keys())
  54. for key in keys:
  55. # for the metadata dict, the key can be:
  56. # '': for the DDP module, which we want to remove.
  57. # 'module': for the actual model.
  58. # 'module.xx.xx': for the rest.
  59. if len(key) == 0:
  60. continue
  61. # handling both, 'module' case and 'module.' cases
  62. if key == prefix.replace(".", "") or key.startswith(prefix):
  63. newkey = key[len(prefix) :]
  64. state_dict._metadata[newkey] = state_dict._metadata.pop(key)