configuration_sam3.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. # Copyright 2025 Meta AI and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """SAM3 model configuration"""
  15. from huggingface_hub.dataclasses import strict
  16. from transformers import CLIPTextConfig
  17. from ...configuration_utils import PreTrainedConfig
  18. from ...utils import auto_docstring
  19. from ..auto import CONFIG_MAPPING, AutoConfig
  20. @auto_docstring(checkpoint="facebook/sam3")
  21. @strict
  22. class Sam3ViTConfig(PreTrainedConfig):
  23. r"""
  24. rope_theta (`float`, *optional*, defaults to 10000.0):
  25. Base frequency for RoPE.
  26. window_size (`int`, *optional*, defaults to 24):
  27. Window size for windowed attention.
  28. global_attn_indexes (`list[int]`, *optional*, defaults to `[7, 15, 23, 31]`):
  29. Indexes of layers with global attention.
  30. pretrain_image_size (`int`, *optional*, defaults to 336):
  31. Pretrained model image size for position embedding initialization.
  32. hidden_dropout (`float`, *optional*, defaults to 0.0):
  33. Dropout probability for hidden states.
  34. """
  35. base_config_key = "backbone_config"
  36. model_type = "sam3_vit_model"
  37. hidden_size: int = 1024
  38. intermediate_size: int = 4736
  39. num_hidden_layers: int = 32
  40. num_attention_heads: int = 16
  41. num_channels: int = 3
  42. image_size: int | list[int] | tuple[int, int] = 1008
  43. patch_size: int | list[int] | tuple[int, int] = 14
  44. hidden_act: str = "gelu"
  45. layer_norm_eps: float = 1e-6
  46. attention_dropout: float | int = 0.0
  47. rope_theta: float = 10000.0
  48. window_size: int = 24
  49. global_attn_indexes: list[int] | None = None
  50. layer_scale_init_value: float | None = None
  51. pretrain_image_size: int | list[int] | tuple[int, int] = 336
  52. hidden_dropout: float | int = 0.0
  53. initializer_range: float = 0.02
  54. def __post_init__(self, **kwargs):
  55. super().__post_init__(**kwargs)
  56. if self.global_attn_indexes is None:
  57. self.global_attn_indexes = [7, 15, 23, 31]
  58. @auto_docstring(checkpoint="facebook/sam3")
  59. @strict
  60. class Sam3VisionConfig(PreTrainedConfig):
  61. r"""
  62. fpn_hidden_size (`int`, *optional*, defaults to 256):
  63. The hidden dimension of the FPN.
  64. backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[288, 288], [144, 144], [72, 72]]`):
  65. The spatial sizes (height, width) of the feature maps from the backbone at different scales.
  66. scale_factors (`list[float]`, *optional*, defaults to `[4.0, 2.0, 1.0, 0.5]`):
  67. Scale factors for FPN multi-scale features. List of scaling factors for each FPN level.
  68. """
  69. base_config_key = "vision_config"
  70. model_type = "sam3_vision_model"
  71. sub_configs = {
  72. "backbone_config": AutoConfig,
  73. }
  74. backbone_config: dict | PreTrainedConfig | None = None
  75. fpn_hidden_size: int = 256
  76. backbone_feature_sizes: list | None = None
  77. scale_factors: list[float] | None = None
  78. hidden_act: str = "gelu"
  79. layer_norm_eps: float = 1e-6
  80. initializer_range: float = 0.02
  81. def __post_init__(self, **kwargs):
  82. self.scale_factors = [4.0, 2.0, 1.0, 0.5] if self.scale_factors is None else self.scale_factors
  83. if self.backbone_feature_sizes is None:
  84. self.backbone_feature_sizes = [[288, 288], [144, 144], [72, 72]]
  85. if isinstance(self.backbone_config, dict):
  86. self.backbone_config["model_type"] = self.backbone_config.get("model_type", "sam3_vit_model")
  87. self.backbone_config = CONFIG_MAPPING[self.backbone_config["model_type"]](**self.backbone_config)
  88. elif self.backbone_config is None:
  89. self.backbone_config = CONFIG_MAPPING["sam3_vit_model"]()
  90. super().__post_init__(**kwargs)
  91. @property
  92. def image_size(self):
  93. """Image size for the vision encoder."""
  94. return self.backbone_config.image_size
  95. @image_size.setter
  96. def image_size(self, value):
  97. """Set the image size and propagate to backbone."""
  98. self.backbone_config.image_size = value
  99. @auto_docstring(checkpoint="facebook/sam3")
  100. @strict
  101. class Sam3GeometryEncoderConfig(PreTrainedConfig):
  102. r"""
  103. roi_size (`int`, *optional*, defaults to 7):
  104. ROI size for box pooling operations.
  105. """
  106. model_type = "sam3_geometry_encoder"
  107. hidden_size: int = 256
  108. num_layers: int = 3
  109. num_attention_heads: int = 8
  110. intermediate_size: int = 2048
  111. dropout: float | int = 0.1
  112. hidden_act: str = "relu"
  113. hidden_dropout: float | int = 0.0
  114. layer_norm_eps: float = 1e-6
  115. roi_size: int = 7
  116. initializer_range: float = 0.02
  117. @auto_docstring(checkpoint="facebook/sam3")
  118. @strict
  119. class Sam3DETREncoderConfig(PreTrainedConfig):
  120. r"""
  121. hidden_dropout (`float`, *optional*, defaults to 0.0):
  122. Dropout probability for hidden states.
  123. """
  124. model_type = "sam3_detr_encoder"
  125. hidden_size: int = 256
  126. num_layers: int = 6
  127. num_attention_heads: int = 8
  128. intermediate_size: int = 2048
  129. dropout: float | int = 0.1
  130. hidden_act: str = "relu"
  131. hidden_dropout: float | int = 0.0
  132. layer_norm_eps: float = 1e-6
  133. initializer_range: float = 0.02
  134. @auto_docstring(checkpoint="facebook/sam3")
  135. @strict
  136. class Sam3DETRDecoderConfig(PreTrainedConfig):
  137. r"""
  138. num_queries (`int`, *optional*, defaults to 200):
  139. Number of object queries.
  140. """
  141. model_type = "sam3_detr_decoder"
  142. hidden_size: int = 256
  143. num_layers: int = 6
  144. num_queries: int = 200
  145. num_attention_heads: int = 8
  146. intermediate_size: int = 2048
  147. dropout: float | int = 0.1
  148. hidden_act: str = "relu"
  149. hidden_dropout: float | int = 0.0
  150. layer_norm_eps: float = 1e-6
  151. initializer_range: float = 0.02
  152. @auto_docstring(checkpoint="facebook/sam3")
  153. @strict
  154. class Sam3MaskDecoderConfig(PreTrainedConfig):
  155. r"""
  156. num_upsampling_stages (`int`, *optional*, defaults to 3):
  157. Number of upsampling stages in the pixel decoder (FPN).
  158. """
  159. model_type = "sam3_mask_decoder"
  160. hidden_size: int = 256
  161. num_upsampling_stages: int = 3
  162. layer_norm_eps: float = 1e-6
  163. dropout: float | int = 0.0
  164. num_attention_heads: int = 8
  165. initializer_range: float = 0.02
  166. @auto_docstring(checkpoint="facebook/sam3")
  167. @strict
  168. class Sam3Config(PreTrainedConfig):
  169. r"""
  170. geometry_encoder_config (`dict` or `Sam3GeometryEncoderConfig`, *optional*):
  171. Configuration for the geometry encoder.
  172. detr_encoder_config (`dict` or `Sam3DETREncoderConfig`, *optional*):
  173. Configuration for the DETR encoder.
  174. detr_decoder_config (`dict` or `Sam3DETRDecoderConfig`, *optional*):
  175. Configuration for the DETR decoder.
  176. mask_decoder_config (`dict` or `Sam3MaskDecoderConfig`, *optional*):
  177. Configuration for the mask decoder.
  178. Example:
  179. ```python
  180. >>> from transformers import Sam3Config, Sam3Model
  181. >>> # Initializing a SAM3 configuration
  182. >>> configuration = Sam3Config()
  183. >>> # Initializing a model from the configuration
  184. >>> model = Sam3Model(configuration)
  185. >>> # Accessing the model configuration
  186. >>> configuration = model.config
  187. ```
  188. """
  189. model_type = "sam3"
  190. is_composition = True
  191. sub_configs = {
  192. "vision_config": Sam3VisionConfig,
  193. "text_config": CLIPTextConfig,
  194. "geometry_encoder_config": Sam3GeometryEncoderConfig,
  195. "detr_encoder_config": Sam3DETREncoderConfig,
  196. "detr_decoder_config": Sam3DETRDecoderConfig,
  197. "mask_decoder_config": Sam3MaskDecoderConfig,
  198. }
  199. vision_config: dict | PreTrainedConfig | None = None
  200. text_config: dict | PreTrainedConfig | None = None
  201. geometry_encoder_config: dict | PreTrainedConfig | None = None
  202. detr_encoder_config: dict | PreTrainedConfig | None = None
  203. detr_decoder_config: dict | PreTrainedConfig | None = None
  204. mask_decoder_config: dict | PreTrainedConfig | None = None
  205. initializer_range: float = 0.02
  206. def __post_init__(self, **kwargs):
  207. if self.vision_config is None:
  208. self.vision_config = Sam3VisionConfig()
  209. if isinstance(self.vision_config, dict):
  210. self.vision_config = Sam3VisionConfig(**self.vision_config)
  211. if self.text_config is None:
  212. self.text_config = CLIPTextConfig(
  213. **{
  214. "vocab_size": 49408,
  215. "hidden_size": 1024,
  216. "intermediate_size": 4096, # hidden_size * mlp_ratio (1024 * 4)
  217. "projection_dim": 512, # CLIP's internal projection dimension
  218. "num_hidden_layers": 24,
  219. "num_attention_heads": 16,
  220. "max_position_embeddings": 32,
  221. "hidden_act": "gelu",
  222. }
  223. )
  224. if isinstance(self.text_config, dict):
  225. self.text_config = CLIPTextConfig(**self.text_config)
  226. if self.geometry_encoder_config is None:
  227. self.geometry_encoder_config = Sam3GeometryEncoderConfig()
  228. if isinstance(self.geometry_encoder_config, dict):
  229. self.geometry_encoder_config = Sam3GeometryEncoderConfig(**self.geometry_encoder_config)
  230. if self.detr_encoder_config is None:
  231. self.detr_encoder_config = Sam3DETREncoderConfig()
  232. if isinstance(self.detr_encoder_config, dict):
  233. self.detr_encoder_config = Sam3DETREncoderConfig(**self.detr_encoder_config)
  234. if self.detr_decoder_config is None:
  235. self.detr_decoder_config = Sam3DETRDecoderConfig()
  236. if isinstance(self.detr_decoder_config, dict):
  237. self.detr_decoder_config = Sam3DETRDecoderConfig(**self.detr_decoder_config)
  238. if self.mask_decoder_config is None:
  239. self.mask_decoder_config = Sam3MaskDecoderConfig()
  240. if isinstance(self.mask_decoder_config, dict):
  241. self.mask_decoder_config = Sam3MaskDecoderConfig(**self.mask_decoder_config)
  242. super().__post_init__(**kwargs)
  243. @property
  244. def image_size(self):
  245. """Image size for the SAM3 model."""
  246. return self.vision_config.image_size
  247. @image_size.setter
  248. def image_size(self, value):
  249. """Set the image size and propagate to vision config."""
  250. self.vision_config.image_size = value
  251. __all__ = [
  252. "Sam3Config",
  253. "Sam3ViTConfig",
  254. "Sam3VisionConfig",
  255. "Sam3GeometryEncoderConfig",
  256. "Sam3DETREncoderConfig",
  257. "Sam3DETRDecoderConfig",
  258. "Sam3MaskDecoderConfig",
  259. ]