functional.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. """Functional implementations for text manipulation and rendering.
  2. This module provides utility functions for manipulating text in strings and
  3. rendering text onto images. Includes functions for word manipulation, text drawing,
  4. and handling text regions in images.
  5. """
  6. from __future__ import annotations
  7. import random
  8. from typing import TYPE_CHECKING, Any
  9. import cv2
  10. import numpy as np
  11. from albucore import (
  12. MONO_CHANNEL_DIMENSIONS,
  13. NUM_MULTI_CHANNEL_DIMENSIONS,
  14. NUM_RGB_CHANNELS,
  15. preserve_channel_dim,
  16. uint8_io,
  17. )
  18. from albumentations.core.type_definitions import PAIR
  19. # Importing wordnet and other dependencies only for type checking
  20. if TYPE_CHECKING:
  21. from PIL import Image
  22. def delete_random_words(words: list[str], num_words: int, py_random: random.Random) -> str:
  23. """Delete a specified number of random words from a list.
  24. This function randomly removes words from the input list and joins the remaining
  25. words with spaces to form a new string.
  26. Args:
  27. words (list[str]): List of words to process.
  28. num_words (int): Number of words to delete.
  29. py_random (random.Random): Random number generator for reproducibility.
  30. Returns:
  31. str: New string with specified words removed. Returns empty string if
  32. num_words is greater than or equal to the length of words.
  33. """
  34. if num_words >= len(words):
  35. return ""
  36. indices_to_delete = py_random.sample(range(len(words)), num_words)
  37. new_words = [word for idx, word in enumerate(words) if idx not in indices_to_delete]
  38. return " ".join(new_words)
  39. def swap_random_words(words: list[str], num_words: int, py_random: random.Random) -> str:
  40. """Swap random pairs of words in a list of words.
  41. This function randomly selects pairs of words and swaps their positions
  42. a specified number of times.
  43. Args:
  44. words (list[str]): List of words to process.
  45. num_words (int): Number of swaps to perform.
  46. py_random (random.Random): Random number generator for reproducibility.
  47. Returns:
  48. str: New string with words swapped. If num_words is 0 or the list has fewer
  49. than 2 words, returns the original string.
  50. """
  51. if num_words == 0 or len(words) < PAIR:
  52. return " ".join(words)
  53. words = words.copy()
  54. for _ in range(num_words):
  55. idx1, idx2 = py_random.sample(range(len(words)), 2)
  56. words[idx1], words[idx2] = words[idx2], words[idx1]
  57. return " ".join(words)
  58. def insert_random_stopwords(
  59. words: list[str],
  60. num_insertions: int,
  61. stopwords: tuple[str, ...] | None,
  62. py_random: random.Random,
  63. ) -> str:
  64. """Insert random stopwords into a list of words.
  65. This function randomly inserts stopwords at random positions in the
  66. list of words a specified number of times.
  67. Args:
  68. words (list[str]): List of words to process.
  69. num_insertions (int): Number of stopwords to insert.
  70. stopwords (tuple[str, ...] | None): Tuple of stopwords to choose from.
  71. If None, default stopwords will be used.
  72. py_random (random.Random): Random number generator for reproducibility.
  73. Returns:
  74. str: New string with stopwords inserted.
  75. """
  76. if stopwords is None:
  77. stopwords = ("and", "the", "is", "in", "at", "of") # Default stopwords if none provided
  78. for _ in range(num_insertions):
  79. idx = py_random.randint(0, len(words))
  80. words.insert(idx, py_random.choice(stopwords))
  81. return " ".join(words)
  82. def convert_image_to_pil(image: np.ndarray) -> Image:
  83. """Convert a NumPy array image to a PIL image."""
  84. try:
  85. from PIL import Image
  86. except ImportError:
  87. raise ImportError("Pillow is not installed") from ImportError
  88. if image.ndim == MONO_CHANNEL_DIMENSIONS: # (height, width)
  89. return Image.fromarray(image)
  90. if image.ndim == NUM_MULTI_CHANNEL_DIMENSIONS and image.shape[2] == 1: # (height, width, 1)
  91. return Image.fromarray(image[:, :, 0], mode="L")
  92. if image.ndim == NUM_MULTI_CHANNEL_DIMENSIONS and image.shape[2] == NUM_RGB_CHANNELS: # (height, width, 3)
  93. return Image.fromarray(image)
  94. raise TypeError(f"Unsupported image shape: {image.shape}")
  95. def draw_text_on_pil_image(pil_image: Image, metadata_list: list[dict[str, Any]]) -> Image:
  96. """Draw text on a PIL image."""
  97. try:
  98. from PIL import ImageDraw
  99. except ImportError:
  100. raise ImportError("Pillow is not installed") from ImportError
  101. draw = ImageDraw.Draw(pil_image)
  102. for metadata in metadata_list:
  103. bbox_coords = metadata["bbox_coords"]
  104. text = metadata["text"]
  105. font = metadata["font"]
  106. font_color = metadata["font_color"]
  107. # Adapt font_color based on image mode
  108. if pil_image.mode == "L": # Grayscale
  109. # For grayscale images, use only the first value or average the RGB values
  110. if isinstance(font_color, tuple):
  111. if len(font_color) >= 3:
  112. # Average RGB values for grayscale
  113. font_color = int(sum(font_color[:3]) / 3)
  114. elif len(font_color) == 1:
  115. font_color = int(font_color[0])
  116. # For RGB and other modes, ensure font_color is a tuple of integers
  117. elif isinstance(font_color, tuple):
  118. font_color = tuple(int(c) for c in font_color)
  119. position = bbox_coords[:2]
  120. draw.text(position, text, font=font, fill=font_color)
  121. return pil_image
  122. def draw_text_on_multi_channel_image(image: np.ndarray, metadata_list: list[dict[str, Any]]) -> np.ndarray:
  123. """Draw text on a multi-channel image with more than three channels."""
  124. try:
  125. from PIL import Image, ImageDraw
  126. except ImportError:
  127. raise ImportError("Pillow is not installed") from ImportError
  128. channels = [Image.fromarray(image[:, :, i]) for i in range(image.shape[2])]
  129. pil_images = [ImageDraw.Draw(channel) for channel in channels]
  130. for metadata in metadata_list:
  131. bbox_coords = metadata["bbox_coords"]
  132. text = metadata["text"]
  133. font = metadata["font"]
  134. font_color = metadata["font_color"]
  135. # Handle font_color as tuple[float, ...]
  136. # Ensure we have enough color values for all channels
  137. if len(font_color) < image.shape[2]:
  138. # If fewer values than channels, pad with zeros
  139. font_color = tuple(list(font_color) + [0] * (image.shape[2] - len(font_color)))
  140. elif len(font_color) > image.shape[2]:
  141. # If more values than channels, truncate
  142. font_color = font_color[: image.shape[2]]
  143. # Convert to integers for PIL
  144. font_color = [int(c) for c in font_color]
  145. position = bbox_coords[:2]
  146. # For each channel, use the corresponding color value
  147. for channel_id, pil_image in enumerate(pil_images):
  148. # For single-channel PIL images, color must be an integer
  149. pil_image.text(position, text, font=font, fill=font_color[channel_id])
  150. return np.stack([np.array(channel) for channel in channels], axis=2)
  151. @uint8_io
  152. @preserve_channel_dim
  153. def render_text(image: np.ndarray, metadata_list: list[dict[str, Any]], clear_bg: bool) -> np.ndarray:
  154. """Render text onto an image based on provided metadata.
  155. This function draws text on an image using metadata that specifies text content,
  156. position, font, and color. It can optionally clear the background before rendering.
  157. The function handles different image types (grayscale, RGB, multi-channel).
  158. Args:
  159. image (np.ndarray): Image to draw text on.
  160. metadata_list (list[dict[str, Any]]): List of metadata dictionaries containing:
  161. - bbox_coords: Bounding box coordinates (x_min, y_min, x_max, y_max)
  162. - text: Text string to render
  163. - font: PIL ImageFont object
  164. - font_color: Color for the text
  165. clear_bg (bool): Whether to clear (inpaint) the background under the text.
  166. Returns:
  167. np.ndarray: Image with text rendered on it.
  168. """
  169. # First clean background under boxes using seamless clone if clear_bg is True
  170. if clear_bg:
  171. image = inpaint_text_background(image, metadata_list)
  172. if len(image.shape) == MONO_CHANNEL_DIMENSIONS or (
  173. len(image.shape) == NUM_MULTI_CHANNEL_DIMENSIONS and image.shape[2] in {1, NUM_RGB_CHANNELS}
  174. ):
  175. pil_image = convert_image_to_pil(image)
  176. pil_image = draw_text_on_pil_image(pil_image, metadata_list)
  177. return np.array(pil_image)
  178. return draw_text_on_multi_channel_image(image, metadata_list)
  179. def inpaint_text_background(
  180. image: np.ndarray,
  181. metadata_list: list[dict[str, Any]],
  182. method: int = cv2.INPAINT_TELEA,
  183. ) -> np.ndarray:
  184. """Inpaint (clear) regions in an image where text will be rendered.
  185. This function creates a clean background for text by inpainting rectangular
  186. regions specified in the metadata. It removes any existing content in those
  187. regions to provide a clean slate for rendering text.
  188. Args:
  189. image (np.ndarray): Image to inpaint.
  190. metadata_list (list[dict[str, Any]]): List of metadata dictionaries containing:
  191. - bbox_coords: Bounding box coordinates (x_min, y_min, x_max, y_max)
  192. method (int, optional): Inpainting method to use. Defaults to cv2.INPAINT_TELEA.
  193. Options include:
  194. - cv2.INPAINT_TELEA: Fast Marching Method
  195. - cv2.INPAINT_NS: Navier-Stokes method
  196. Returns:
  197. np.ndarray: Image with specified regions inpainted.
  198. """
  199. result_image = image.copy()
  200. mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
  201. for metadata in metadata_list:
  202. x_min, y_min, x_max, y_max = metadata["bbox_coords"]
  203. # Black out the region
  204. result_image[y_min:y_max, x_min:x_max] = 0
  205. # Update the mask to indicate the region to inpaint
  206. mask[y_min:y_max, x_min:x_max] = 255
  207. # Inpaint the blacked-out regions
  208. return cv2.inpaint(result_image, mask, inpaintRadius=3, flags=method)