transforms.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. """Transforms for text rendering and augmentation on images.
  2. This module provides transforms for adding and manipulating text on images,
  3. including text augmentation techniques like word insertion, deletion, and swapping.
  4. """
  5. from __future__ import annotations
  6. import re
  7. from pathlib import Path
  8. from typing import Annotated, Any, Literal
  9. import numpy as np
  10. from pydantic import AfterValidator
  11. import albumentations.augmentations.text.functional as ftext
  12. from albumentations.core.bbox_utils import check_bboxes, denormalize_bboxes
  13. from albumentations.core.pydantic import check_range_bounds, nondecreasing
  14. from albumentations.core.transforms_interface import BaseTransformInitSchema, ImageOnlyTransform
  15. __all__ = ["TextImage"]
  16. class TextImage(ImageOnlyTransform):
  17. """Apply text rendering transformations on images.
  18. This class supports rendering text directly onto images using a variety of configurations,
  19. such as custom fonts, font sizes, colors, and augmentation methods. The text can be placed
  20. inside specified bounding boxes.
  21. Args:
  22. font_path (str | Path): Path to the font file to use for rendering text.
  23. stopwords (list[str] | None): List of stopwords for text augmentation.
  24. augmentations (tuple[str | None, ...]): List of text augmentations to apply.
  25. None: text is printed as is
  26. "insertion": insert random stop words into the text.
  27. "swap": swap random words in the text.
  28. "deletion": delete random words from the text.
  29. fraction_range (tuple[float, float]): Range for selecting a fraction of bounding boxes to modify.
  30. font_size_fraction_range (tuple[float, float]): Range for selecting the font size as a fraction of
  31. bounding box height.
  32. font_color (tuple[float, ...]): Font color as RGB values (e.g., (0, 0, 0) for black).
  33. clear_bg (bool): Whether to clear the background before rendering text.
  34. metadata_key (str): Key to access metadata in the parameters.
  35. p (float): Probability of applying the transform.
  36. Targets:
  37. image, volume
  38. Image types:
  39. uint8, float32
  40. References:
  41. doc-augmentation: https://github.com/danaaubakirova/doc-augmentation
  42. Examples:
  43. >>> import albumentations as A
  44. >>> transform = A.Compose([
  45. A.TextImage(
  46. font_path=Path("/path/to/font.ttf"),
  47. stopwords=("the", "is", "in"),
  48. augmentations=("insertion", "deletion"),
  49. fraction_range=(0.5, 1.0),
  50. font_size_fraction_range=(0.5, 0.9),
  51. font_color=(255, 0, 0), # red in RGB
  52. metadata_key="text_metadata",
  53. p=0.5
  54. )
  55. ])
  56. >>> transformed = transform(image=my_image, text_metadata=my_metadata)
  57. >>> image = transformed['image']
  58. # This will render text on `my_image` based on the metadata provided in `my_metadata`.
  59. """
  60. class InitSchema(BaseTransformInitSchema):
  61. font_path: str | Path
  62. stopwords: tuple[str, ...]
  63. augmentations: tuple[str | None, ...]
  64. fraction_range: Annotated[
  65. tuple[float, float],
  66. AfterValidator(nondecreasing),
  67. AfterValidator(check_range_bounds(0, 1)),
  68. ]
  69. font_size_fraction_range: Annotated[
  70. tuple[float, float],
  71. AfterValidator(nondecreasing),
  72. AfterValidator(check_range_bounds(0, 1)),
  73. ]
  74. font_color: tuple[float, ...]
  75. clear_bg: bool
  76. metadata_key: str
  77. def __init__(
  78. self,
  79. font_path: str | Path,
  80. stopwords: tuple[str, ...] = ("the", "is", "in", "at", "of"),
  81. augmentations: tuple[Literal["insertion", "swap", "deletion"] | None, ...] = (None,),
  82. fraction_range: tuple[float, float] = (1.0, 1.0),
  83. font_size_fraction_range: tuple[float, float] = (0.8, 0.9),
  84. font_color: tuple[float, ...] = (0, 0, 0), # black in RGB
  85. clear_bg: bool = False,
  86. metadata_key: str = "textimage_metadata",
  87. p: float = 0.5,
  88. ) -> None:
  89. super().__init__(p=p)
  90. self.metadata_key = metadata_key
  91. self.font_path = font_path
  92. self.fraction_range = fraction_range
  93. self.stopwords = stopwords
  94. self.augmentations = list(augmentations)
  95. self.font_size_fraction_range = font_size_fraction_range
  96. self.font_color = font_color
  97. self.clear_bg = clear_bg
  98. @property
  99. def targets_as_params(self) -> list[str]:
  100. """Get list of targets that should be passed as parameters to transforms.
  101. Returns:
  102. list[str]: List containing the metadata key name
  103. """
  104. return [self.metadata_key]
  105. def random_aug(
  106. self,
  107. text: str,
  108. fraction: float,
  109. choice: Literal["insertion", "swap", "deletion"],
  110. ) -> str:
  111. """Apply a random text augmentation to the input text.
  112. Args:
  113. text (str): Original text to augment
  114. fraction (float): Fraction of words to modify
  115. choice (Literal["insertion", "swap", "deletion"]): Type of augmentation to apply
  116. Returns:
  117. str: Augmented text or empty string if no change was made
  118. Raises:
  119. ValueError: If an invalid choice is provided
  120. """
  121. words = [word for word in text.strip().split() if word]
  122. num_words = len(words)
  123. num_words_to_modify = max(1, int(fraction * num_words))
  124. if choice == "insertion":
  125. result_sentence = ftext.insert_random_stopwords(words, num_words_to_modify, self.stopwords, self.py_random)
  126. elif choice == "swap":
  127. result_sentence = ftext.swap_random_words(words, num_words_to_modify, self.py_random)
  128. elif choice == "deletion":
  129. result_sentence = ftext.delete_random_words(words, num_words_to_modify, self.py_random)
  130. else:
  131. raise ValueError("Invalid choice. Choose from 'insertion', 'swap', or 'deletion'.")
  132. result_sentence = re.sub(" +", " ", result_sentence).strip()
  133. return result_sentence if result_sentence != text else ""
  134. def preprocess_metadata(
  135. self,
  136. image: np.ndarray,
  137. bbox: tuple[float, float, float, float],
  138. text: str,
  139. bbox_index: int,
  140. ) -> dict[str, Any]:
  141. """Preprocess text metadata for a single bounding box.
  142. Args:
  143. image (np.ndarray): Input image
  144. bbox (tuple[float, float, float, float]): Normalized bounding box coordinates
  145. text (str): Text to render in the bounding box
  146. bbox_index (int): Index of the bounding box in the original metadata
  147. Returns:
  148. dict[str, Any]: Processed metadata including font, position, and text information
  149. Raises:
  150. ImportError: If PIL.ImageFont is not installed
  151. """
  152. try:
  153. from PIL import ImageFont
  154. except ImportError as err:
  155. raise ImportError(
  156. "ImageFont from PIL is required to use TextImage transform. Install it with `pip install Pillow`.",
  157. ) from err
  158. check_bboxes(np.array([bbox]))
  159. denormalized_bbox = denormalize_bboxes(np.array([bbox]), image.shape[:2])[0]
  160. x_min, y_min, x_max, y_max = (int(x) for x in denormalized_bbox[:4])
  161. bbox_height = y_max - y_min
  162. font_size_fraction = self.py_random.uniform(*self.font_size_fraction_range)
  163. font = ImageFont.truetype(str(self.font_path), int(font_size_fraction * bbox_height))
  164. if not self.augmentations or self.augmentations is None:
  165. augmented_text = text
  166. else:
  167. augmentation = self.py_random.choice(self.augmentations)
  168. augmented_text = text if augmentation is None else self.random_aug(text, 0.5, choice=augmentation)
  169. font_color = self.font_color
  170. return {
  171. "bbox_coords": (x_min, y_min, x_max, y_max),
  172. "bbox_index": bbox_index,
  173. "original_text": text,
  174. "text": augmented_text,
  175. "font": font,
  176. "font_color": font_color,
  177. }
  178. def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
  179. """Generate parameters based on input data.
  180. Args:
  181. params (dict[str, Any]): Dictionary of existing parameters
  182. data (dict[str, Any]): Dictionary containing input data with image and metadata
  183. Returns:
  184. dict[str, Any]: Dictionary containing the overlay data for text rendering
  185. """
  186. image = data["image"] if "image" in data else data["images"][0]
  187. metadata = data[self.metadata_key]
  188. if metadata == []:
  189. return {
  190. "overlay_data": [],
  191. }
  192. if isinstance(metadata, dict):
  193. metadata = [metadata]
  194. fraction = self.py_random.uniform(*self.fraction_range)
  195. num_bboxes_to_modify = int(len(metadata) * fraction)
  196. bbox_indices_to_update = self.py_random.sample(range(len(metadata)), num_bboxes_to_modify)
  197. overlay_data = [
  198. self.preprocess_metadata(image, metadata[bbox_index]["bbox"], metadata[bbox_index]["text"], bbox_index)
  199. for bbox_index in bbox_indices_to_update
  200. ]
  201. return {
  202. "overlay_data": overlay_data,
  203. }
  204. def apply(
  205. self,
  206. img: np.ndarray,
  207. overlay_data: list[dict[str, Any]],
  208. **params: Any,
  209. ) -> np.ndarray:
  210. """Apply text rendering to the input image.
  211. Args:
  212. img (np.ndarray): Input image
  213. overlay_data (list[dict[str, Any]]): List of dictionaries containing text rendering information
  214. **params (Any): Additional parameters
  215. Returns:
  216. np.ndarray: Image with rendered text
  217. """
  218. return ftext.render_text(img, overlay_data, clear_bg=self.clear_bg)
  219. def apply_with_params(self, params: dict[str, Any], *args: Any, **kwargs: Any) -> dict[str, Any]:
  220. """Apply the transform and include overlay data in the result.
  221. Args:
  222. params (dict[str, Any]): Parameters for the transform
  223. *args (Any): Additional positional arguments
  224. **kwargs (Any): Additional keyword arguments
  225. Returns:
  226. dict[str, Any]: Dictionary containing the transformed data and simplified overlay information
  227. """
  228. res = super().apply_with_params(params, *args, **kwargs)
  229. res["overlay_data"] = [
  230. {
  231. "bbox_coords": overlay["bbox_coords"],
  232. "text": overlay["text"],
  233. "original_text": overlay["original_text"],
  234. "bbox_index": overlay["bbox_index"],
  235. "font_color": overlay["font_color"],
  236. }
  237. for overlay in params["overlay_data"]
  238. ]
  239. return res