transform.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. """Transforms for spectrogram augmentation.
  2. This module provides transforms specifically designed for augmenting spectrograms
  3. in audio processing tasks. Includes time reversal, time masking, and frequency
  4. masking transforms commonly used in audio machine learning applications.
  5. """
  6. from __future__ import annotations
  7. from warnings import warn
  8. from pydantic import Field
  9. from albumentations.augmentations.dropout.xy_masking import XYMasking
  10. from albumentations.augmentations.geometric.flip import HorizontalFlip
  11. from albumentations.core.transforms_interface import BaseTransformInitSchema
  12. from albumentations.core.type_definitions import ALL_TARGETS
  13. __all__ = [
  14. "FrequencyMasking",
  15. "TimeMasking",
  16. "TimeReverse",
  17. ]
  18. class TimeReverse(HorizontalFlip):
  19. """Reverse the time axis of a spectrogram image, also known as time inversion.
  20. Time inversion of a spectrogram is analogous to the random flip of an image,
  21. an augmentation technique widely used in the visual domain. This can be relevant
  22. in the context of audio classification tasks when working with spectrograms.
  23. The technique was successfully applied in the AudioCLIP paper, which extended
  24. CLIP to handle image, text, and audio inputs.
  25. This transform is implemented as a subclass of HorizontalFlip since reversing
  26. time in a spectrogram is equivalent to flipping the image horizontally.
  27. Args:
  28. p (float): probability of applying the transform. Default: 0.5.
  29. Targets:
  30. image, mask, bboxes, keypoints, volume, mask3d
  31. Image types:
  32. uint8, float32
  33. Number of channels:
  34. Any
  35. Note:
  36. This transform is functionally identical to HorizontalFlip but provides
  37. a more semantically meaningful name when working with spectrograms and
  38. other time-series visualizations.
  39. References:
  40. - AudioCLIP paper: https://arxiv.org/abs/2106.13043
  41. - Audiomentations: https://iver56.github.io/audiomentations/waveform_transforms/reverse/
  42. """
  43. _targets = ALL_TARGETS
  44. class InitSchema(BaseTransformInitSchema):
  45. pass
  46. def __init__(
  47. self,
  48. p: float = 0.5,
  49. ):
  50. warn(
  51. "TimeReverse is an alias for HorizontalFlip transform. "
  52. "Consider using HorizontalFlip directly from albumentations.HorizontalFlip. ",
  53. UserWarning,
  54. stacklevel=2,
  55. )
  56. super().__init__(p=p)
  57. class TimeMasking(XYMasking):
  58. """Apply masking to a spectrogram in the time domain.
  59. This transform masks random segments along the time axis of a spectrogram,
  60. implementing the time masking technique proposed in the SpecAugment paper.
  61. Time masking helps in training models to be robust against temporal variations
  62. and missing information in audio signals.
  63. This is a specialized version of XYMasking configured for time masking only.
  64. For more advanced use cases (e.g., multiple masks, frequency masking, or custom
  65. fill values), consider using XYMasking directly.
  66. Args:
  67. time_mask_param (int): Maximum possible length of the mask in the time domain.
  68. Must be a positive integer. Length of the mask is uniformly sampled from (0, time_mask_param).
  69. p (float): probability of applying the transform. Default: 0.5.
  70. Targets:
  71. image, mask, bboxes, keypoints, volume, mask3d
  72. Image types:
  73. uint8, float32
  74. Number of channels:
  75. Any
  76. Note:
  77. This transform is implemented as a subset of XYMasking with fixed parameters:
  78. - Single horizontal mask (num_masks_x=1)
  79. - No vertical masks (num_masks_y=0)
  80. - Zero fill value
  81. - Random mask length up to time_mask_param
  82. For more flexibility, including:
  83. - Multiple masks
  84. - Custom fill values
  85. - Frequency masking
  86. - Combined time-frequency masking
  87. Consider using albumentations.XYMasking directly.
  88. References:
  89. - SpecAugment paper: https://arxiv.org/abs/1904.08779
  90. - Original implementation: https://pytorch.org/audio/stable/transforms.html#timemask
  91. """
  92. class InitSchema(BaseTransformInitSchema):
  93. time_mask_param: int = Field(gt=0)
  94. def __init__(
  95. self,
  96. time_mask_param: int = 40,
  97. p: float = 0.5,
  98. ):
  99. warn(
  100. "TimeMasking is a specialized version of XYMasking. "
  101. "For more flexibility (multiple masks, custom fill values, frequency masking), "
  102. "consider using XYMasking directly from albumentations.XYMasking.",
  103. UserWarning,
  104. stacklevel=2,
  105. )
  106. super().__init__(
  107. num_masks_x=1,
  108. num_masks_y=0,
  109. mask_x_length=(0, time_mask_param),
  110. fill=0,
  111. fill_mask=0,
  112. p=p,
  113. )
  114. self.time_mask_param = time_mask_param
  115. class FrequencyMasking(XYMasking):
  116. """Apply masking to a spectrogram in the frequency domain.
  117. This transform masks random segments along the frequency axis of a spectrogram,
  118. implementing the frequency masking technique proposed in the SpecAugment paper.
  119. Frequency masking helps in training models to be robust against frequency variations
  120. and missing spectral information in audio signals.
  121. This is a specialized version of XYMasking configured for frequency masking only.
  122. For more advanced use cases (e.g., multiple masks, time masking, or custom
  123. fill values), consider using XYMasking directly.
  124. Args:
  125. freq_mask_param (int): Maximum possible length of the mask in the frequency domain.
  126. Must be a positive integer. Length of the mask is uniformly sampled from (0, freq_mask_param).
  127. p (float): probability of applying the transform. Default: 0.5.
  128. Targets:
  129. image, mask, bboxes, keypoints, volume, mask3d
  130. Image types:
  131. uint8, float32
  132. Number of channels:
  133. Any
  134. Note:
  135. This transform is implemented as a subset of XYMasking with fixed parameters:
  136. - Single vertical mask (num_masks_y=1)
  137. - No horizontal masks (num_masks_x=0)
  138. - Zero fill value
  139. - Random mask length up to freq_mask_param
  140. For more flexibility, including:
  141. - Multiple masks
  142. - Custom fill values
  143. - Time masking
  144. - Combined time-frequency masking
  145. Consider using albumentations.XYMasking directly.
  146. References:
  147. - SpecAugment paper: https://arxiv.org/abs/1904.08779
  148. - Original implementation: https://pytorch.org/audio/stable/transforms.html#freqmask
  149. """
  150. class InitSchema(BaseTransformInitSchema):
  151. freq_mask_param: int = Field(gt=0)
  152. def __init__(
  153. self,
  154. freq_mask_param: int = 30,
  155. p: float = 0.5,
  156. ):
  157. warn(
  158. "FrequencyMasking is a specialized version of XYMasking. "
  159. "For more flexibility (multiple masks, custom fill values, time masking), "
  160. "consider using XYMasking directly from albumentations.XYMasking.",
  161. UserWarning,
  162. stacklevel=2,
  163. )
  164. super().__init__(
  165. p=p,
  166. fill=0,
  167. fill_mask=0,
  168. mask_y_length=(0, freq_mask_param),
  169. num_masks_x=0,
  170. num_masks_y=1,
  171. )
  172. self.freq_mask_param = freq_mask_param