modeling_pi0.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/pi0/modular_pi0.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_pi0.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable
  22. import torch
  23. import torch.nn.functional as F
  24. from torch import nn
  25. from ... import initialization as init
  26. from ...cache_utils import Cache
  27. from ...masking_utils import create_bidirectional_mask
  28. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  29. from ...modeling_utils import PreTrainedModel
  30. from ...utils import auto_docstring, can_return_tuple
  31. from ...utils.generic import maybe_autocast
  32. from ..auto import AutoModel
  33. from .configuration_pi0 import PI0Config
  34. class PI0TimestepEmbeddings(nn.Module):
  35. def __init__(self, config):
  36. super().__init__()
  37. self.config = config
  38. sinusoid_freq = self.compute_freqs(config)
  39. self.register_buffer("sinusoid_freq", sinusoid_freq, persistent=False)
  40. @staticmethod
  41. def compute_freqs(config):
  42. fraction = torch.linspace(0.0, 1.0, config.dit_config.hidden_size // 2, dtype=torch.float32)
  43. period = config.min_period * (config.max_period / config.min_period) ** fraction
  44. sinusoid_freq = 1.0 / period * 2 * math.pi
  45. return sinusoid_freq
  46. def forward(self, time):
  47. device_type = time.device.type if isinstance(time.device.type, str) and time.device.type != "mps" else "cpu"
  48. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  49. sinusoid_freq = self.sinusoid_freq[None, :]
  50. emb = sinusoid_freq * time[:, None]
  51. time_embeds = torch.cat([emb.sin(), emb.cos()], dim=1)
  52. return time_embeds
  53. class PI0ActionTimeEmbedding(nn.Module):
  54. def __init__(self, config):
  55. super().__init__()
  56. self.sinusoid_embeds = PI0TimestepEmbeddings(config)
  57. self.action_in_proj = nn.Linear(config.max_action_dim, config.dit_config.hidden_size)
  58. self.state_proj = nn.Linear(config.max_state_dim, config.dit_config.hidden_size)
  59. self.action_time_mlp_in = nn.Linear(2 * config.dit_config.hidden_size, config.dit_config.hidden_size)
  60. self.action_time_mlp_out = nn.Linear(config.dit_config.hidden_size, config.dit_config.hidden_size)
  61. def forward(self, state, noise, timestep):
  62. state_embeds = self.state_proj(state)
  63. action_embeds = self.action_in_proj(noise)
  64. time_embeds = self.sinusoid_embeds(timestep)
  65. time_embeds = time_embeds[:, None, :].expand_as(action_embeds).to(dtype=action_embeds.dtype)
  66. action_time_embeds = torch.cat([action_embeds, time_embeds], dim=2)
  67. action_time_embeds = self.action_time_mlp_out(F.silu(self.action_time_mlp_in(action_time_embeds)))
  68. action_embeds_merged = torch.cat([state_embeds[:, None, :], action_time_embeds], dim=1)
  69. return action_embeds_merged
  70. @auto_docstring
  71. class PI0PreTrainedModel(PreTrainedModel):
  72. config: PI0Config
  73. base_model_prefix = "model"
  74. main_input_name = "state"
  75. supports_gradient_checkpointing = True
  76. _skip_keys_device_placement = ["past_key_values"]
  77. _supports_flash_attn = True
  78. _supports_sdpa = True
  79. _supports_flex_attn = True
  80. _can_compile_fullgraph = True
  81. _supports_attention_backend = True
  82. input_modalities = ("image", "text")
  83. def _init_weights(self, module):
  84. super()._init_weights(module)
  85. if isinstance(module, PI0TimestepEmbeddings):
  86. init.copy_(module.sinusoid_freq, module.compute_freqs(module.config))
  87. def blockwise_bidirectional_mask(block_boundaries: torch.Tensor) -> Callable:
  88. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  89. q_block = torch.bucketize(q_idx, block_boundaries)
  90. kv_block = torch.bucketize(kv_idx, block_boundaries)
  91. return kv_block <= q_block
  92. return inner_mask
  93. @auto_docstring
  94. class PI0Model(PI0PreTrainedModel):
  95. def __init__(self, config: PI0Config):
  96. super().__init__(config)
  97. self.dit = AutoModel.from_config(config.dit_config)
  98. self.vlm = AutoModel.from_config(config.vlm_config)
  99. self.post_init()
  100. def get_input_embeddings(self):
  101. return self.vlm.get_input_embeddings()
  102. def set_input_embeddings(self, value):
  103. self.vlm.set_input_embeddings(value)
  104. def embed_prefix(self, input_ids, pixel_values, pixel_attention_mask, attention_mask=None):
  105. max_num_cameras = pixel_attention_mask.shape[1]
  106. pixel_values = pixel_values.flatten(0, 1)
  107. image_features = self.vlm.get_image_features(pixel_values).pooler_output
  108. image_features = image_features.reshape(-1, max_num_cameras, image_features.shape[1], image_features.shape[2])
  109. total_image_features = []
  110. for batch_idx, mask in enumerate(pixel_attention_mask):
  111. unpadded_image_features = image_features[batch_idx][mask]
  112. total_image_features.append(unpadded_image_features)
  113. total_image_features = torch.cat(total_image_features, dim=0)
  114. llm_input_ids = input_ids.clone()
  115. llm_input_ids[input_ids == self.config.vlm_config.image_token_id] = 0
  116. inputs_embeds = self.vlm.get_input_embeddings()(llm_input_ids)
  117. special_image_mask = (
  118. (input_ids == self.config.vlm_config.image_token_id)
  119. .unsqueeze(-1)
  120. .expand_as(inputs_embeds)
  121. .to(inputs_embeds.device)
  122. )
  123. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, total_image_features)
  124. return inputs_embeds
  125. @can_return_tuple
  126. @auto_docstring
  127. def forward(
  128. self,
  129. action_embeds: torch.Tensor, # aka `suffix_emb` (noise + state + timestep)
  130. input_ids: torch.Tensor | None = None,
  131. pixel_values: torch.Tensor | None = None,
  132. attention_mask: torch.Tensor | None = None,
  133. pixel_attention_mask: torch.Tensor | None = None,
  134. position_ids: torch.LongTensor | None = None,
  135. inputs_embeds: torch.Tensor | None = None, # aka `prefix_emb` or merged image+text emb
  136. past_key_values: Cache | None = None, # must-have for prefix tuning
  137. **kwargs,
  138. ) -> BaseModelOutputWithPast:
  139. r"""
  140. action_embeds (`torch.Tensor`, *optional*):
  141. The embeddings of input actions and robot states.
  142. pixel_attention_mask (`torch.Tensor`, *optional*):
  143. The mask indicating padded positions in the input image.
  144. """
  145. if pixel_values is not None and past_key_values is None:
  146. if attention_mask is not None and position_ids is None:
  147. position_ids = attention_mask.cumsum(-1) - 1
  148. if inputs_embeds is None:
  149. inputs_embeds = self.embed_prefix(input_ids, pixel_values, pixel_attention_mask)
  150. token_type_ids = torch.zeros_like(inputs_embeds)[:, :, 0]
  151. past_key_values = self.vlm(
  152. inputs_embeds=inputs_embeds,
  153. attention_mask=attention_mask,
  154. position_ids=position_ids,
  155. token_type_ids=token_type_ids,
  156. use_cache=True,
  157. ).past_key_values
  158. if attention_mask is not None and attention_mask.ndim != 2:
  159. raise ValueError("Only two-dimensional attention masks are accepted for now!")
  160. # Merge masks if needed, same for position ids
  161. dit_position_ids = dit_attention_mask = None
  162. if attention_mask is not None:
  163. noise_mask = torch.ones(
  164. action_embeds.shape[0],
  165. action_embeds.shape[1],
  166. dtype=attention_mask.dtype,
  167. device=attention_mask.device,
  168. )
  169. dit_attention_mask = torch.cat([attention_mask, noise_mask], dim=1)
  170. dit_position_ids = (torch.cumsum(dit_attention_mask, dim=1) - 1)[:, -action_embeds.shape[1] :]
  171. # We have three blocks: vlm-inputss, state and actions from which only 1 token is `state`
  172. # The mask should be bidirectional within each block and to prev blocks, but not to next blocks
  173. vlm_input_length = past_key_values.get_seq_length()
  174. block_sizes = torch.tensor([vlm_input_length + 1, action_embeds.shape[1] - 1], device=action_embeds.device)
  175. block_boundaries = torch.cumsum(block_sizes, dim=0) - 1
  176. bidirectional_mask = create_bidirectional_mask(
  177. config=self.config.dit_config,
  178. inputs_embeds=action_embeds,
  179. attention_mask=dit_attention_mask,
  180. past_key_values=past_key_values,
  181. and_mask_function=blockwise_bidirectional_mask(block_boundaries),
  182. )
  183. dit_output = self.dit(
  184. inputs_embeds=action_embeds,
  185. attention_mask=bidirectional_mask,
  186. position_ids=dit_position_ids,
  187. past_key_values=past_key_values,
  188. **kwargs,
  189. )
  190. return dit_output
  191. class PI0ForConditionalGeneration(PI0PreTrainedModel):
  192. """PI0 model with action projection heads and flow matching."""
  193. _tp_plan = {"action_out_proj": "colwise_gather_output"}
  194. def __init__(self, config: PI0Config):
  195. super().__init__(config)
  196. self.model = PI0Model(config)
  197. self.expert_hidden_size = config.dit_config.hidden_size
  198. self.embed_action_time = PI0ActionTimeEmbedding(config)
  199. self.action_out_proj = nn.Linear(self.expert_hidden_size, config.max_action_dim)
  200. self.post_init()
  201. @can_return_tuple
  202. @auto_docstring
  203. def forward(
  204. self,
  205. state: torch.FloatTensor,
  206. noise: torch.FloatTensor | None = None,
  207. timestep: torch.FloatTensor | None = None,
  208. input_ids: torch.Tensor | None = None,
  209. pixel_values: torch.Tensor | None = None,
  210. pixel_attention_mask: torch.BoolTensor | None = None,
  211. attention_mask: torch.Tensor | None = None,
  212. position_ids: torch.LongTensor | None = None,
  213. inputs_embeds: torch.Tensor | None = None,
  214. past_key_values: Cache | None = None,
  215. actions: torch.FloatTensor = None, # aka labels
  216. **kwargs,
  217. ) -> CausalLMOutputWithPast:
  218. r"""
  219. state (`torch.Tensor`, *optional*):
  220. Current robot state.
  221. noise (`torch.Tensor`, *optional*):
  222. Random noise at current timestep that needs to be denoised
  223. timestep (`torch.Tensor`, *optional*):
  224. Current denoising timestep.
  225. pixel_attention_mask (`torch.Tensor`, *optional*):
  226. The mask indicating padded positions in the input image.
  227. actions (`torch.Tensor`, *optional*):
  228. Input actions that need to be predicted. Used only when training to compiute loss.
  229. """
  230. batch_size = state.shape[0]
  231. # 1.Sample the timestep
  232. if timestep is None:
  233. alpha_t = torch.tensor(self.config.time_sampling_beta_alpha, dtype=torch.float32)
  234. beta_t = torch.tensor(self.config.time_sampling_beta_beta, dtype=torch.float32)
  235. dist = torch.distributions.Beta(alpha_t, beta_t)
  236. time_beta = dist.sample((batch_size,)).to(state.device)
  237. timestep = (time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset).float()
  238. # 2. Create random noise if not provided
  239. if noise is None:
  240. noise = torch.randn(
  241. batch_size,
  242. self.config.chunk_size,
  243. self.config.max_action_dim,
  244. device=state.device,
  245. dtype=state.dtype,
  246. )
  247. # 3. If training: merge noise with the ground truth actions (aka labels)
  248. # Target velocity is the label we want to predict and will compute loss upon
  249. if actions is not None:
  250. time_expanded = timestep[:, None, None]
  251. noisy_actions = (time_expanded * noise + (1 - time_expanded) * actions).to(actions.dtype)
  252. target_velocity = noise - actions
  253. else:
  254. noisy_actions = noise
  255. # 4. Embed 'state + noise + actions' for DiT blocks
  256. action_time_embeds = self.embed_action_time(state, noisy_actions, timestep)
  257. outputs = self.model(
  258. input_ids=input_ids,
  259. pixel_values=pixel_values,
  260. attention_mask=attention_mask,
  261. pixel_attention_mask=pixel_attention_mask,
  262. position_ids=position_ids,
  263. inputs_embeds=inputs_embeds,
  264. action_embeds=action_time_embeds,
  265. past_key_values=past_key_values,
  266. **kwargs,
  267. )
  268. last_hidden_states = outputs.last_hidden_state[:, -self.config.chunk_size :]
  269. predicted_velocity = self.action_out_proj(last_hidden_states)
  270. loss = None
  271. if actions is not None:
  272. # Let the users reduce loss themselves and return fine-grained per sample loss
  273. loss = F.mse_loss(target_velocity, predicted_velocity, reduction=self.config.loss_reduction)
  274. return CausalLMOutputWithPast(
  275. loss=loss,
  276. logits=predicted_velocity,
  277. past_key_values=outputs.past_key_values,
  278. hidden_states=outputs.hidden_states,
  279. attentions=outputs.attentions,
  280. )
  281. @torch.no_grad()
  282. def sample_actions(
  283. self,
  284. state: torch.FloatTensor,
  285. input_ids: torch.LongTensor,
  286. pixel_values: torch.FloatTensor,
  287. noise: torch.FloatTensor | None = None,
  288. attention_mask: torch.Tensor | None = None,
  289. pixel_attention_mask: torch.BoolTensor | None = None,
  290. num_steps: int | None = None,
  291. **kwargs,
  292. ) -> torch.FloatTensor:
  293. """Run flow matching inference to generate actions."""
  294. num_steps = num_steps or self.config.num_inference_steps
  295. batch_size = input_ids.shape[0]
  296. device = input_ids.device
  297. # 1. Sample random noise
  298. if noise is None:
  299. noise = torch.normal(
  300. mean=0.0,
  301. std=1.0,
  302. size=(
  303. batch_size,
  304. self.config.chunk_size,
  305. self.config.max_action_dim,
  306. ),
  307. dtype=pixel_values.dtype,
  308. device=device,
  309. )
  310. # 2. Run VLM once and obtain prefix cache. Must infer positions here!
  311. if attention_mask is not None:
  312. position_ids = attention_mask.cumsum(-1) - 1
  313. inputs_embeds = self.model.embed_prefix(input_ids, pixel_values, pixel_attention_mask)
  314. past_key_values = self.model.vlm(
  315. inputs_embeds=inputs_embeds,
  316. attention_mask=attention_mask,
  317. position_ids=position_ids,
  318. use_cache=True,
  319. return_dict=True,
  320. ).past_key_values
  321. prefix_length = past_key_values.get_seq_length()
  322. # 3. Denoise `num_steps` times
  323. dt = -1.0 / num_steps
  324. for step in range(num_steps):
  325. time = 1.0 + step * dt
  326. time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(batch_size)
  327. output = self(
  328. state=state,
  329. noise=noise,
  330. timestep=time_tensor,
  331. pixel_attention_mask=pixel_attention_mask,
  332. attention_mask=attention_mask,
  333. past_key_values=past_key_values,
  334. )
  335. # We need to keep only the "vlm-prefix", no attention to past denoising steps!
  336. past_key_values.crop(prefix_length)
  337. noise = noise + dt * output.logits
  338. return noise
  339. __all__ = ["PI0PreTrainedModel", "PI0Model", "PI0ForConditionalGeneration"]