_serialization.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import pickle
  2. from dataclasses import dataclass
  3. from io import BufferedIOBase
  4. from typing import Any
  5. import torch
  6. import torch._weights_only_unpickler as _weights_only_unpickler
  7. from torch.serialization import _load, _save, DEFAULT_PROTOCOL, MAP_LOCATION
  8. __all__: list[str] = []
  9. @dataclass
  10. class _Entry:
  11. key: str
  12. is_storage: bool
  13. length: int
  14. _weights_only_unpickler._add_safe_globals([_Entry])
  15. class _PseudoZipFile:
  16. def __init__(self) -> None:
  17. self.records: dict[str, tuple[object, int]] = {}
  18. def write_record(self, key: str, data: object, length: int) -> None:
  19. self.records[key] = (data, length)
  20. def write_to(self, f: BufferedIOBase) -> None:
  21. entries = []
  22. for key, (data, length) in self.records.items():
  23. entries.append(
  24. _Entry(
  25. key=key,
  26. is_storage=isinstance(data, torch.UntypedStorage),
  27. length=length,
  28. )
  29. )
  30. pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL)
  31. for data, _ in self.records.values():
  32. if isinstance(data, bytes):
  33. f.write(data)
  34. elif isinstance(data, str):
  35. f.write(data.encode("utf-8"))
  36. elif isinstance(data, torch.UntypedStorage):
  37. data._write_file(f, False, False, 1)
  38. else:
  39. raise TypeError(f"unknown type: {type(data)}")
  40. def read_from(self, f: BufferedIOBase) -> None:
  41. entries = _weights_only_unpickler.load(f)
  42. for entry in entries:
  43. data = f.read(entry.length)
  44. if entry.is_storage:
  45. if entry.length == 0:
  46. storage = torch.UntypedStorage(0)
  47. else:
  48. storage = torch.frombuffer(
  49. data,
  50. dtype=torch.uint8,
  51. ).untyped_storage()
  52. self.records[entry.key] = (
  53. storage,
  54. entry.length,
  55. )
  56. else:
  57. self.records[entry.key] = (data, entry.length)
  58. def has_record(self, key: str) -> bool:
  59. return key in self.records
  60. def get_record(self, key: str) -> object:
  61. return self.records[key][0]
  62. def get_storage_from_record(
  63. self, key: str, _length: int, _type: int
  64. ) -> torch.Tensor:
  65. return torch.tensor(self.records[key][0], dtype=torch.uint8)
  66. def serialization_id(self) -> str:
  67. return "torchft"
  68. def _streaming_save(
  69. obj: object,
  70. f: BufferedIOBase,
  71. pickle_module: Any = pickle,
  72. pickle_protocol: int = DEFAULT_PROTOCOL,
  73. ) -> None:
  74. """
  75. Save the object to a file-like object in a streaming fashion compatible with
  76. network sockets.
  77. This behaves similarly to :func:`torch.save` with a few notable differences:
  78. * A non-seekable file like object can be used when loading.
  79. * No forwards/backwards compatibility is provided for the serialization
  80. format. This is only intended to be used with a single version of PyTorch
  81. with transient storage (i.e. sockets or temp files).
  82. * mmap is not supported
  83. See :func:`torch.save` for more details on specific arguments.
  84. """
  85. zip_file = _PseudoZipFile()
  86. _save(
  87. obj,
  88. zip_file=zip_file,
  89. pickle_module=pickle_module,
  90. pickle_protocol=pickle_protocol,
  91. _disable_byteorder_record=False,
  92. )
  93. zip_file.write_to(f)
  94. def _streaming_load(
  95. f: BufferedIOBase,
  96. map_location: MAP_LOCATION = None,
  97. pickle_module: Any = None,
  98. *,
  99. weights_only: bool = True,
  100. **pickle_load_args: Any,
  101. ) -> object:
  102. """
  103. Load the object from a file-like object in a streaming fashion compatible with
  104. network sockets.
  105. See :func:`_streaming_save` for more details about the streaming behavior.
  106. See :func:`torch.load` for more details on specific arguments.
  107. """
  108. if weights_only:
  109. if pickle_module is not None:
  110. raise RuntimeError(
  111. "Can not safely load weights when explicit pickle_module is specified"
  112. )
  113. pickle_module = _weights_only_unpickler
  114. else:
  115. if pickle_module is None:
  116. pickle_module = pickle
  117. if "encoding" not in pickle_load_args:
  118. pickle_load_args["encoding"] = "utf-8"
  119. zip_file = _PseudoZipFile()
  120. zip_file.read_from(f)
  121. return _load(
  122. zip_file=zip_file,
  123. map_location=map_location,
  124. pickle_module=pickle_module,
  125. **pickle_load_args,
  126. )