immutable_collections.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from collections.abc import Iterable
  2. from typing import Any, NoReturn, TypeVar
  3. from typing_extensions import Self
  4. from torch.utils._pytree import (
  5. _dict_flatten,
  6. _dict_flatten_with_keys,
  7. _dict_unflatten,
  8. _list_flatten,
  9. _list_flatten_with_keys,
  10. _list_unflatten,
  11. Context,
  12. register_pytree_node,
  13. )
  14. from ._compatibility import compatibility
  15. __all__ = ["immutable_list", "immutable_dict"]
  16. _help_mutation = """
  17. If you are attempting to modify the kwargs or args of a torch.fx.Node object,
  18. instead create a new copy of it and assign the copy to the node:
  19. new_args = ... # copy and mutate args
  20. node.args = new_args
  21. """.strip()
  22. _T = TypeVar("_T")
  23. _KT = TypeVar("_KT")
  24. _VT = TypeVar("_VT")
  25. def _no_mutation(self: Any, *args: Any, **kwargs: Any) -> NoReturn:
  26. raise TypeError(
  27. f"{type(self).__name__!r} object does not support mutation. {_help_mutation}",
  28. )
  29. @compatibility(is_backward_compatible=True)
  30. class immutable_list(list[_T]):
  31. """An immutable version of :class:`list`."""
  32. __delitem__ = _no_mutation
  33. __iadd__ = _no_mutation
  34. __imul__ = _no_mutation
  35. __setitem__ = _no_mutation
  36. append = _no_mutation
  37. clear = _no_mutation
  38. extend = _no_mutation
  39. insert = _no_mutation
  40. pop = _no_mutation
  41. remove = _no_mutation
  42. reverse = _no_mutation
  43. sort = _no_mutation
  44. def __hash__(self) -> int: # type: ignore[override]
  45. return hash(tuple(self))
  46. def __reduce__(self) -> tuple[type[Self], tuple[tuple[_T, ...]]]:
  47. return (type(self), (tuple(self),))
  48. @compatibility(is_backward_compatible=True)
  49. class immutable_dict(dict[_KT, _VT]):
  50. """An immutable version of :class:`dict`."""
  51. __delitem__ = _no_mutation
  52. __ior__ = _no_mutation
  53. __setitem__ = _no_mutation
  54. clear = _no_mutation
  55. pop = _no_mutation
  56. popitem = _no_mutation
  57. setdefault = _no_mutation
  58. update = _no_mutation # type: ignore[assignment]
  59. def __hash__(self) -> int: # type: ignore[override]
  60. return hash(frozenset(self.items()))
  61. def __reduce__(self) -> tuple[type[Self], tuple[tuple[tuple[_KT, _VT], ...]]]:
  62. return (type(self), (tuple(self.items()),))
  63. # Register immutable collections for PyTree operations
  64. def _immutable_list_flatten(d: immutable_list[_T]) -> tuple[list[_T], Context]:
  65. return _list_flatten(d)
  66. def _immutable_list_unflatten(
  67. values: Iterable[_T],
  68. context: Context,
  69. ) -> immutable_list[_T]:
  70. return immutable_list(_list_unflatten(values, context))
  71. def _immutable_dict_flatten(d: immutable_dict[Any, _VT]) -> tuple[list[_VT], Context]:
  72. return _dict_flatten(d)
  73. def _immutable_dict_unflatten(
  74. values: Iterable[_VT],
  75. context: Context,
  76. ) -> immutable_dict[Any, _VT]:
  77. return immutable_dict(_dict_unflatten(values, context))
  78. register_pytree_node(
  79. immutable_list,
  80. _immutable_list_flatten,
  81. _immutable_list_unflatten,
  82. serialized_type_name="torch.fx.immutable_collections.immutable_list",
  83. flatten_with_keys_fn=_list_flatten_with_keys,
  84. )
  85. register_pytree_node(
  86. immutable_dict,
  87. _immutable_dict_flatten,
  88. _immutable_dict_unflatten,
  89. serialized_type_name="torch.fx.immutable_collections.immutable_dict",
  90. flatten_with_keys_fn=_dict_flatten_with_keys,
  91. )