serialization.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. """Module for serialization and deserialization of Albumentations transforms.
  2. This module provides functionality to serialize transforms to JSON or YAML format and
  3. deserialize them back. It implements the Serializable interface that allows transforms
  4. to be converted to and from dictionaries, which can then be saved to disk or transmitted
  5. over a network. This is particularly useful for saving augmentation pipelines and
  6. restoring them later with the exact same configuration.
  7. """
  8. from __future__ import annotations
  9. import importlib.util
  10. import json
  11. import warnings
  12. from abc import ABC, ABCMeta, abstractmethod
  13. from collections.abc import Mapping, Sequence
  14. from enum import Enum
  15. from pathlib import Path
  16. from typing import Any, Literal, TextIO
  17. from warnings import warn
  18. try:
  19. import yaml
  20. yaml_available = True
  21. except ImportError:
  22. yaml_available = False
  23. from albumentations import __version__
  24. __all__ = ["from_dict", "load", "save", "to_dict"]
  25. SERIALIZABLE_REGISTRY: dict[str, SerializableMeta] = {}
  26. NON_SERIALIZABLE_REGISTRY: dict[str, SerializableMeta] = {}
  27. def shorten_class_name(class_fullname: str) -> str:
  28. # Split the class_fullname once at the last '.' to separate the class name
  29. split_index = class_fullname.rfind(".")
  30. # If there's no '.' or the top module is not 'albumentations', return the full name
  31. if split_index == -1 or not class_fullname.startswith("albumentations."):
  32. return class_fullname
  33. # Extract the class name after the last '.'
  34. return class_fullname[split_index + 1 :]
  35. class SerializableMeta(ABCMeta):
  36. """A metaclass that is used to register classes in `SERIALIZABLE_REGISTRY` or `NON_SERIALIZABLE_REGISTRY`
  37. so they can be found later while deserializing transformation pipeline using classes full names.
  38. """
  39. def __new__(cls, name: str, bases: tuple[type, ...], *args: Any, **kwargs: Any) -> SerializableMeta:
  40. cls_obj = super().__new__(cls, name, bases, *args, **kwargs)
  41. if name != "Serializable" and ABC not in bases:
  42. if cls_obj.is_serializable():
  43. SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj
  44. else:
  45. NON_SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj
  46. return cls_obj
  47. @classmethod
  48. def is_serializable(cls) -> bool:
  49. return False
  50. @classmethod
  51. def get_class_fullname(cls) -> str:
  52. return get_shortest_class_fullname(cls)
  53. @classmethod
  54. def _to_dict(cls) -> dict[str, Any]:
  55. return {}
  56. class Serializable(metaclass=SerializableMeta):
  57. @classmethod
  58. @abstractmethod
  59. def is_serializable(cls) -> bool:
  60. raise NotImplementedError
  61. @classmethod
  62. @abstractmethod
  63. def get_class_fullname(cls) -> str:
  64. raise NotImplementedError
  65. @abstractmethod
  66. def to_dict_private(self) -> dict[str, Any]:
  67. raise NotImplementedError
  68. def to_dict(self, on_not_implemented_error: str = "raise") -> dict[str, Any]:
  69. """Take a transform pipeline and convert it to a serializable representation that uses only standard
  70. python data types: dictionaries, lists, strings, integers, and floats.
  71. Args:
  72. self (Serializable): A transform that should be serialized. If the transform doesn't implement the `to_dict`
  73. method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised.
  74. If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored
  75. but no transform parameters will be serialized.
  76. on_not_implemented_error (str): `raise` or `warn`.
  77. """
  78. if on_not_implemented_error not in {"raise", "warn"}:
  79. msg = f"Unknown on_not_implemented_error value: {on_not_implemented_error}. Supported values are: 'raise' "
  80. "and 'warn'"
  81. raise ValueError(msg)
  82. try:
  83. transform_dict = self.to_dict_private()
  84. except NotImplementedError:
  85. if on_not_implemented_error == "raise":
  86. raise
  87. transform_dict = {}
  88. warnings.warn(
  89. f"Got NotImplementedError while trying to serialize {self}. Object arguments are not preserved. "
  90. f"The transform class '{self.__class__.__name__}' needs to implement 'to_dict_private' or inherit from "
  91. f"BasicTransform to be properly serialized.",
  92. stacklevel=2,
  93. )
  94. return {"__version__": __version__, "transform": transform_dict}
  95. def to_dict(transform: Serializable, on_not_implemented_error: str = "raise") -> dict[str, Any]:
  96. """Take a transform pipeline and convert it to a serializable representation that uses only standard
  97. python data types: dictionaries, lists, strings, integers, and floats.
  98. Args:
  99. transform (Serializable): A transform that should be serialized. If the transform doesn't implement
  100. the `to_dict` method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised.
  101. If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored
  102. but no transform parameters will be serialized.
  103. on_not_implemented_error (str): `raise` or `warn`.
  104. """
  105. return transform.to_dict(on_not_implemented_error)
  106. def instantiate_nonserializable(
  107. transform: dict[str, Any],
  108. nonserializable: dict[str, Any] | None = None,
  109. ) -> Serializable | None:
  110. if transform.get("__class_fullname__") in NON_SERIALIZABLE_REGISTRY:
  111. name = transform["__name__"]
  112. if nonserializable is None:
  113. msg = f"To deserialize a non-serializable transform with name {name} you need to pass a dict with"
  114. "this transform as the `lambda_transforms` argument"
  115. raise ValueError(msg)
  116. result_transform = nonserializable.get(name)
  117. if transform is None:
  118. raise ValueError(f"Non-serializable transform with {name} was not found in `nonserializable`")
  119. return result_transform
  120. return None
  121. def from_dict(
  122. transform_dict: dict[str, Any],
  123. nonserializable: dict[str, Any] | None = None,
  124. ) -> Serializable | None:
  125. """Args:
  126. transform_dict: A dictionary with serialized transform pipeline.
  127. nonserializable (dict): A dictionary that contains non-serializable transforms.
  128. This dictionary is required when you are restoring a pipeline that contains non-serializable transforms.
  129. Keys in that dictionary should be named same as `name` arguments in respective transforms from
  130. a serialized pipeline.
  131. """
  132. register_additional_transforms()
  133. transform = transform_dict["transform"]
  134. lmbd = instantiate_nonserializable(transform, nonserializable)
  135. if lmbd:
  136. return lmbd
  137. name = transform["__class_fullname__"]
  138. args = {k: v for k, v in transform.items() if k != "__class_fullname__"}
  139. # Ensure 'p' is included, default to 0.5 if missing for backward compatibility
  140. if "p" not in args and name not in ("Compose", "Sequential"):
  141. warn(f"Transform {name} has no 'p' parameter in serialized data, defaulting to 0.5", stacklevel=2)
  142. args["p"] = 0.5
  143. cls = SERIALIZABLE_REGISTRY[shorten_class_name(name)]
  144. if "transforms" in args:
  145. args["transforms"] = [from_dict({"transform": t}, nonserializable=nonserializable) for t in args["transforms"]]
  146. return cls(**args)
  147. def check_data_format(data_format: Literal["json", "yaml"]) -> None:
  148. if data_format not in {"json", "yaml"}:
  149. raise ValueError(f"Unknown data_format {data_format}. Supported formats are: 'json' and 'yaml'")
  150. def serialize_enum(obj: Any) -> Any:
  151. """Recursively search for Enum objects and convert them to their value.
  152. Also handle any Mapping or Sequence types.
  153. """
  154. if isinstance(obj, Mapping):
  155. return {k: serialize_enum(v) for k, v in obj.items()}
  156. if isinstance(obj, Sequence) and not isinstance(obj, str): # exclude strings since they're also sequences
  157. return [serialize_enum(v) for v in obj]
  158. return obj.value if isinstance(obj, Enum) else obj
  159. def save(
  160. transform: Serializable,
  161. filepath_or_buffer: str | Path | TextIO,
  162. data_format: Literal["json", "yaml"] = "json",
  163. on_not_implemented_error: Literal["raise", "warn"] = "raise",
  164. ) -> None:
  165. """Serialize a transform pipeline and save it to either a file specified by a path or a file-like object
  166. in either JSON or YAML format.
  167. Args:
  168. transform (Serializable): The transform pipeline to serialize.
  169. filepath_or_buffer (Union[str, Path, TextIO]): The file path or file-like object to write the serialized
  170. data to.
  171. If a string is provided, it is interpreted as a path to a file. If a file-like object is provided,
  172. the serialized data will be written to it directly.
  173. data_format (str): The format to serialize the data in. Valid options are 'json' and 'yaml'.
  174. Defaults to 'json'.
  175. on_not_implemented_error (str): Determines the behavior if a transform does not implement the `to_dict` method.
  176. If set to 'raise', a `NotImplementedError` is raised. If set to 'warn', the exception is ignored, and
  177. no transform arguments are saved. Defaults to 'raise'.
  178. Raises:
  179. ValueError: If `data_format` is 'yaml' but PyYAML is not installed.
  180. """
  181. check_data_format(data_format)
  182. transform_dict = transform.to_dict(on_not_implemented_error=on_not_implemented_error)
  183. transform_dict = serialize_enum(transform_dict)
  184. # Determine whether to write to a file or a file-like object
  185. if isinstance(filepath_or_buffer, (str, Path)): # It's a filepath
  186. with Path(filepath_or_buffer).open("w") as f:
  187. if data_format == "yaml":
  188. if not yaml_available:
  189. msg = "You need to install PyYAML to save a pipeline in YAML format"
  190. raise ValueError(msg)
  191. yaml.safe_dump(transform_dict, f, default_flow_style=False)
  192. elif data_format == "json":
  193. json.dump(transform_dict, f)
  194. elif data_format == "yaml":
  195. if not yaml_available:
  196. msg = "You need to install PyYAML to save a pipeline in YAML format"
  197. raise ValueError(msg)
  198. yaml.safe_dump(transform_dict, filepath_or_buffer, default_flow_style=False)
  199. elif data_format == "json":
  200. json.dump(transform_dict, filepath_or_buffer, indent=2)
  201. def load(
  202. filepath_or_buffer: str | Path | TextIO,
  203. data_format: Literal["json", "yaml"] = "json",
  204. nonserializable: dict[str, Any] | None = None,
  205. ) -> object:
  206. """Load a serialized pipeline from a file or file-like object and construct a transform pipeline.
  207. Args:
  208. filepath_or_buffer (Union[str, Path, TextIO]): The file path or file-like object to read the serialized
  209. data from.
  210. If a string is provided, it is interpreted as a path to a file. If a file-like object is provided,
  211. the serialized data will be read from it directly.
  212. data_format (Literal["json", "yaml"]): The format of the serialized data.
  213. Defaults to 'json'.
  214. nonserializable (Optional[dict[str, Any]]): A dictionary that contains non-serializable transforms.
  215. This dictionary is required when restoring a pipeline that contains non-serializable transforms.
  216. Keys in the dictionary should be named the same as the `name` arguments in respective transforms
  217. from the serialized pipeline. Defaults to None.
  218. Returns:
  219. object: The deserialized transform pipeline.
  220. Raises:
  221. ValueError: If `data_format` is 'yaml' but PyYAML is not installed.
  222. """
  223. check_data_format(data_format)
  224. if isinstance(filepath_or_buffer, (str, Path)): # Assume it's a filepath
  225. with Path(filepath_or_buffer).open() as f:
  226. if data_format == "json":
  227. transform_dict = json.load(f)
  228. else:
  229. if not yaml_available:
  230. msg = "You need to install PyYAML to load a pipeline in yaml format"
  231. raise ValueError(msg)
  232. transform_dict = yaml.safe_load(f)
  233. elif data_format == "json":
  234. transform_dict = json.load(filepath_or_buffer)
  235. else:
  236. if not yaml_available:
  237. msg = "You need to install PyYAML to load a pipeline in yaml format"
  238. raise ValueError(msg)
  239. transform_dict = yaml.safe_load(filepath_or_buffer)
  240. return from_dict(transform_dict, nonserializable=nonserializable)
  241. def register_additional_transforms() -> None:
  242. """Register transforms that are not imported directly into the `albumentations` module by checking
  243. the availability of optional dependencies.
  244. """
  245. if importlib.util.find_spec("torch") is not None:
  246. try:
  247. # Import `albumentations.pytorch` only if `torch` is installed.
  248. import albumentations.pytorch
  249. # Use a dummy operation to acknowledge the use of the imported module and avoid linting errors.
  250. _ = albumentations.pytorch.ToTensorV2
  251. except ImportError:
  252. pass
  253. def get_shortest_class_fullname(cls: type[Any]) -> str:
  254. """The function `get_shortest_class_fullname` takes a class object as input and returns its shortened
  255. full name.
  256. :param cls: The parameter `cls` is of type `Type[BasicCompose]`, which means it expects a class that
  257. is a subclass of `BasicCompose`
  258. :type cls: Type[BasicCompose]
  259. :return: a string, which is the shortened version of the full class name.
  260. """
  261. class_fullname = f"{cls.__module__}.{cls.__name__}"
  262. return shorten_class_name(class_fullname)