configuration_pi0.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/pi0/modular_pi0.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_pi0.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from huggingface_hub.dataclasses import strict
  21. from ...configuration_utils import PreTrainedConfig
  22. from ...utils import auto_docstring
  23. from ..auto import CONFIG_MAPPING, AutoConfig
  24. @auto_docstring(checkpoint="lerobot/pi0_base")
  25. @strict
  26. class PI0Config(PreTrainedConfig):
  27. r"""
  28. vlm_config (`dict`, *optional*):
  29. Configuration for the vlm backbone (PaliGemmaModel).
  30. dit_config (`dict`, *optional*):
  31. Configuration for the DiT backbone. Defaults to a Gemma 300M variant.
  32. chunk_size (`int`, *optional*, defaults to 50):
  33. Number of action steps to predict per chunk.
  34. max_state_dim (`int`, *optional*, defaults to 32):
  35. Maximum state vector dimension (shorter vectors are zero-padded).
  36. max_action_dim (`int`, *optional*, defaults to 32):
  37. Maximum action vector dimension (shorter vectors are zero-padded).
  38. num_inference_steps (`int`, *optional*, defaults to 10):
  39. Number of denoising steps during inference.
  40. time_sampling_beta_alpha (`float`, *optional*, defaults to 1.5):
  41. Alpha parameter for Beta distribution used to sample diffusion time during training.
  42. time_sampling_beta_beta (`float`, *optional*, defaults to 1.0):
  43. Beta parameter for Beta distribution used to sample diffusion time during training.
  44. time_sampling_scale (`float`, *optional*, defaults to 0.999):
  45. Scale factor for sampled time values.
  46. time_sampling_offset (`float`, *optional*, defaults to 0.001):
  47. Offset added to sampled time values.
  48. min_period (`float`, *optional*, defaults to 0.004):
  49. Minimum period for sinusoidal time embedding.
  50. max_period (`float`, *optional*, defaults to 4.0):
  51. Maximum period for sinusoidal time embedding.
  52. loss_reduction (`str`, *optional*, defaults to `"mean"`):
  53. The reduction to use on MSE loss.
  54. Example:
  55. ```python
  56. >>> from transformers import PI0ForConditionalGeneration, PI0Config
  57. >>> config = PI0Config()
  58. >>> model = PI0ForConditionalGeneration(config)
  59. ```
  60. """
  61. model_type = "pi0"
  62. sub_configs = {"vlm_config": AutoConfig, "dit_config": AutoConfig}
  63. vlm_config: dict | PreTrainedConfig | None = None
  64. dit_config: dict | PreTrainedConfig | None = None
  65. chunk_size: int = 50
  66. max_state_dim: int = 32
  67. max_action_dim: int = 32
  68. num_inference_steps: int = 10
  69. time_sampling_beta_alpha: float = 1.5
  70. time_sampling_beta_beta: float = 1.0
  71. time_sampling_scale: float = 0.999
  72. time_sampling_offset: float = 0.001
  73. min_period: float = 4e-3
  74. max_period: float = 4.0
  75. loss_reduction: str = "mean"
  76. def __post_init__(self, **kwargs):
  77. if isinstance(self.vlm_config, dict):
  78. vlm_model_type = self.vlm_config.get("model_type", "paligemma")
  79. self.vlm_config = CONFIG_MAPPING[vlm_model_type](**self.vlm_config)
  80. elif self.vlm_config is None:
  81. self.vlm_config = CONFIG_MAPPING["paligemma"](
  82. text_config={
  83. "model_type": "gemma",
  84. "hidden_size": 2048,
  85. "num_hidden_layers": 18,
  86. "intermediate_size": 16384,
  87. "num_attention_heads": 8,
  88. "num_key_value_heads": 1,
  89. "vocab_size": 257152,
  90. },
  91. vision_config={
  92. "model_type": "siglip_vision_model",
  93. "intermediate_size": 4304,
  94. "hidden_size": 1152,
  95. "patch_size": 14,
  96. "image_size": 224,
  97. "num_hidden_layers": 27,
  98. "num_attention_heads": 16,
  99. "vocab_size": 257152,
  100. "vision_use_head": False,
  101. },
  102. projection_dim=2048,
  103. image_token_id=257152,
  104. )
  105. if isinstance(self.dit_config, dict):
  106. dit_model_type = self.dit_config.get("model_type", "gemma")
  107. self.dit_config = CONFIG_MAPPING[dit_model_type](**self.dit_config)
  108. elif self.dit_config is None:
  109. self.dit_config = CONFIG_MAPPING["gemma"](
  110. hidden_size=1024,
  111. num_hidden_layers=18,
  112. intermediate_size=4096,
  113. num_attention_heads=8,
  114. num_key_value_heads=1,
  115. head_dim=256,
  116. vocab_size=self.vlm_config.text_config.vocab_size,
  117. )
  118. # Force bidirectional attention
  119. self.dit_config.is_causal = False
  120. self.dit_config.use_bidirectional_attention = True
  121. self.vlm_config.text_config.use_bidirectional_attention = True
  122. super().__post_init__(**kwargs)
  123. def validate_architecture(self):
  124. """Part of `@strict`-powered validation. Validates the architecture of the config."""
  125. if self.dit_config.hidden_size % 2 != 0:
  126. raise ValueError(f"DiT hidden dim=({self.config.dit_config.hidden_size}) must be divisible by 2")
  127. __all__ = ["PI0Config"]