| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/pi0/modular_pi0.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_pi0.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from collections.abc import Callable
- import torch
- import torch.nn.functional as F
- from torch import nn
- from ... import initialization as init
- from ...cache_utils import Cache
- from ...masking_utils import create_bidirectional_mask
- from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
- from ...modeling_utils import PreTrainedModel
- from ...utils import auto_docstring, can_return_tuple
- from ...utils.generic import maybe_autocast
- from ..auto import AutoModel
- from .configuration_pi0 import PI0Config
- class PI0TimestepEmbeddings(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- sinusoid_freq = self.compute_freqs(config)
- self.register_buffer("sinusoid_freq", sinusoid_freq, persistent=False)
- @staticmethod
- def compute_freqs(config):
- fraction = torch.linspace(0.0, 1.0, config.dit_config.hidden_size // 2, dtype=torch.float32)
- period = config.min_period * (config.max_period / config.min_period) ** fraction
- sinusoid_freq = 1.0 / period * 2 * math.pi
- return sinusoid_freq
- def forward(self, time):
- device_type = time.device.type if isinstance(time.device.type, str) and time.device.type != "mps" else "cpu"
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
- sinusoid_freq = self.sinusoid_freq[None, :]
- emb = sinusoid_freq * time[:, None]
- time_embeds = torch.cat([emb.sin(), emb.cos()], dim=1)
- return time_embeds
- class PI0ActionTimeEmbedding(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.sinusoid_embeds = PI0TimestepEmbeddings(config)
- self.action_in_proj = nn.Linear(config.max_action_dim, config.dit_config.hidden_size)
- self.state_proj = nn.Linear(config.max_state_dim, config.dit_config.hidden_size)
- self.action_time_mlp_in = nn.Linear(2 * config.dit_config.hidden_size, config.dit_config.hidden_size)
- self.action_time_mlp_out = nn.Linear(config.dit_config.hidden_size, config.dit_config.hidden_size)
- def forward(self, state, noise, timestep):
- state_embeds = self.state_proj(state)
- action_embeds = self.action_in_proj(noise)
- time_embeds = self.sinusoid_embeds(timestep)
- time_embeds = time_embeds[:, None, :].expand_as(action_embeds).to(dtype=action_embeds.dtype)
- action_time_embeds = torch.cat([action_embeds, time_embeds], dim=2)
- action_time_embeds = self.action_time_mlp_out(F.silu(self.action_time_mlp_in(action_time_embeds)))
- action_embeds_merged = torch.cat([state_embeds[:, None, :], action_time_embeds], dim=1)
- return action_embeds_merged
- @auto_docstring
- class PI0PreTrainedModel(PreTrainedModel):
- config: PI0Config
- base_model_prefix = "model"
- main_input_name = "state"
- supports_gradient_checkpointing = True
- _skip_keys_device_placement = ["past_key_values"]
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- _can_compile_fullgraph = True
- _supports_attention_backend = True
- input_modalities = ("image", "text")
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, PI0TimestepEmbeddings):
- init.copy_(module.sinusoid_freq, module.compute_freqs(module.config))
- def blockwise_bidirectional_mask(block_boundaries: torch.Tensor) -> Callable:
- def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
- q_block = torch.bucketize(q_idx, block_boundaries)
- kv_block = torch.bucketize(kv_idx, block_boundaries)
- return kv_block <= q_block
- return inner_mask
- @auto_docstring
- class PI0Model(PI0PreTrainedModel):
- def __init__(self, config: PI0Config):
- super().__init__(config)
- self.dit = AutoModel.from_config(config.dit_config)
- self.vlm = AutoModel.from_config(config.vlm_config)
- self.post_init()
- def get_input_embeddings(self):
- return self.vlm.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.vlm.set_input_embeddings(value)
- def embed_prefix(self, input_ids, pixel_values, pixel_attention_mask, attention_mask=None):
- max_num_cameras = pixel_attention_mask.shape[1]
- pixel_values = pixel_values.flatten(0, 1)
- image_features = self.vlm.get_image_features(pixel_values).pooler_output
- image_features = image_features.reshape(-1, max_num_cameras, image_features.shape[1], image_features.shape[2])
- total_image_features = []
- for batch_idx, mask in enumerate(pixel_attention_mask):
- unpadded_image_features = image_features[batch_idx][mask]
- total_image_features.append(unpadded_image_features)
- total_image_features = torch.cat(total_image_features, dim=0)
- llm_input_ids = input_ids.clone()
- llm_input_ids[input_ids == self.config.vlm_config.image_token_id] = 0
- inputs_embeds = self.vlm.get_input_embeddings()(llm_input_ids)
- special_image_mask = (
- (input_ids == self.config.vlm_config.image_token_id)
- .unsqueeze(-1)
- .expand_as(inputs_embeds)
- .to(inputs_embeds.device)
- )
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, total_image_features)
- return inputs_embeds
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- action_embeds: torch.Tensor, # aka `suffix_emb` (noise + state + timestep)
- input_ids: torch.Tensor | None = None,
- pixel_values: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- pixel_attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.Tensor | None = None, # aka `prefix_emb` or merged image+text emb
- past_key_values: Cache | None = None, # must-have for prefix tuning
- **kwargs,
- ) -> BaseModelOutputWithPast:
- r"""
- action_embeds (`torch.Tensor`, *optional*):
- The embeddings of input actions and robot states.
- pixel_attention_mask (`torch.Tensor`, *optional*):
- The mask indicating padded positions in the input image.
- """
- if pixel_values is not None and past_key_values is None:
- if attention_mask is not None and position_ids is None:
- position_ids = attention_mask.cumsum(-1) - 1
- if inputs_embeds is None:
- inputs_embeds = self.embed_prefix(input_ids, pixel_values, pixel_attention_mask)
- token_type_ids = torch.zeros_like(inputs_embeds)[:, :, 0]
- past_key_values = self.vlm(
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- position_ids=position_ids,
- token_type_ids=token_type_ids,
- use_cache=True,
- ).past_key_values
- if attention_mask is not None and attention_mask.ndim != 2:
- raise ValueError("Only two-dimensional attention masks are accepted for now!")
- # Merge masks if needed, same for position ids
- dit_position_ids = dit_attention_mask = None
- if attention_mask is not None:
- noise_mask = torch.ones(
- action_embeds.shape[0],
- action_embeds.shape[1],
- dtype=attention_mask.dtype,
- device=attention_mask.device,
- )
- dit_attention_mask = torch.cat([attention_mask, noise_mask], dim=1)
- dit_position_ids = (torch.cumsum(dit_attention_mask, dim=1) - 1)[:, -action_embeds.shape[1] :]
- # We have three blocks: vlm-inputss, state and actions from which only 1 token is `state`
- # The mask should be bidirectional within each block and to prev blocks, but not to next blocks
- vlm_input_length = past_key_values.get_seq_length()
- block_sizes = torch.tensor([vlm_input_length + 1, action_embeds.shape[1] - 1], device=action_embeds.device)
- block_boundaries = torch.cumsum(block_sizes, dim=0) - 1
- bidirectional_mask = create_bidirectional_mask(
- config=self.config.dit_config,
- inputs_embeds=action_embeds,
- attention_mask=dit_attention_mask,
- past_key_values=past_key_values,
- and_mask_function=blockwise_bidirectional_mask(block_boundaries),
- )
- dit_output = self.dit(
- inputs_embeds=action_embeds,
- attention_mask=bidirectional_mask,
- position_ids=dit_position_ids,
- past_key_values=past_key_values,
- **kwargs,
- )
- return dit_output
- class PI0ForConditionalGeneration(PI0PreTrainedModel):
- """PI0 model with action projection heads and flow matching."""
- _tp_plan = {"action_out_proj": "colwise_gather_output"}
- def __init__(self, config: PI0Config):
- super().__init__(config)
- self.model = PI0Model(config)
- self.expert_hidden_size = config.dit_config.hidden_size
- self.embed_action_time = PI0ActionTimeEmbedding(config)
- self.action_out_proj = nn.Linear(self.expert_hidden_size, config.max_action_dim)
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- state: torch.FloatTensor,
- noise: torch.FloatTensor | None = None,
- timestep: torch.FloatTensor | None = None,
- input_ids: torch.Tensor | None = None,
- pixel_values: torch.Tensor | None = None,
- pixel_attention_mask: torch.BoolTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- past_key_values: Cache | None = None,
- actions: torch.FloatTensor = None, # aka labels
- **kwargs,
- ) -> CausalLMOutputWithPast:
- r"""
- state (`torch.Tensor`, *optional*):
- Current robot state.
- noise (`torch.Tensor`, *optional*):
- Random noise at current timestep that needs to be denoised
- timestep (`torch.Tensor`, *optional*):
- Current denoising timestep.
- pixel_attention_mask (`torch.Tensor`, *optional*):
- The mask indicating padded positions in the input image.
- actions (`torch.Tensor`, *optional*):
- Input actions that need to be predicted. Used only when training to compiute loss.
- """
- batch_size = state.shape[0]
- # 1.Sample the timestep
- if timestep is None:
- alpha_t = torch.tensor(self.config.time_sampling_beta_alpha, dtype=torch.float32)
- beta_t = torch.tensor(self.config.time_sampling_beta_beta, dtype=torch.float32)
- dist = torch.distributions.Beta(alpha_t, beta_t)
- time_beta = dist.sample((batch_size,)).to(state.device)
- timestep = (time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset).float()
- # 2. Create random noise if not provided
- if noise is None:
- noise = torch.randn(
- batch_size,
- self.config.chunk_size,
- self.config.max_action_dim,
- device=state.device,
- dtype=state.dtype,
- )
- # 3. If training: merge noise with the ground truth actions (aka labels)
- # Target velocity is the label we want to predict and will compute loss upon
- if actions is not None:
- time_expanded = timestep[:, None, None]
- noisy_actions = (time_expanded * noise + (1 - time_expanded) * actions).to(actions.dtype)
- target_velocity = noise - actions
- else:
- noisy_actions = noise
- # 4. Embed 'state + noise + actions' for DiT blocks
- action_time_embeds = self.embed_action_time(state, noisy_actions, timestep)
- outputs = self.model(
- input_ids=input_ids,
- pixel_values=pixel_values,
- attention_mask=attention_mask,
- pixel_attention_mask=pixel_attention_mask,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- action_embeds=action_time_embeds,
- past_key_values=past_key_values,
- **kwargs,
- )
- last_hidden_states = outputs.last_hidden_state[:, -self.config.chunk_size :]
- predicted_velocity = self.action_out_proj(last_hidden_states)
- loss = None
- if actions is not None:
- # Let the users reduce loss themselves and return fine-grained per sample loss
- loss = F.mse_loss(target_velocity, predicted_velocity, reduction=self.config.loss_reduction)
- return CausalLMOutputWithPast(
- loss=loss,
- logits=predicted_velocity,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @torch.no_grad()
- def sample_actions(
- self,
- state: torch.FloatTensor,
- input_ids: torch.LongTensor,
- pixel_values: torch.FloatTensor,
- noise: torch.FloatTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- pixel_attention_mask: torch.BoolTensor | None = None,
- num_steps: int | None = None,
- **kwargs,
- ) -> torch.FloatTensor:
- """Run flow matching inference to generate actions."""
- num_steps = num_steps or self.config.num_inference_steps
- batch_size = input_ids.shape[0]
- device = input_ids.device
- # 1. Sample random noise
- if noise is None:
- noise = torch.normal(
- mean=0.0,
- std=1.0,
- size=(
- batch_size,
- self.config.chunk_size,
- self.config.max_action_dim,
- ),
- dtype=pixel_values.dtype,
- device=device,
- )
- # 2. Run VLM once and obtain prefix cache. Must infer positions here!
- if attention_mask is not None:
- position_ids = attention_mask.cumsum(-1) - 1
- inputs_embeds = self.model.embed_prefix(input_ids, pixel_values, pixel_attention_mask)
- past_key_values = self.model.vlm(
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- position_ids=position_ids,
- use_cache=True,
- return_dict=True,
- ).past_key_values
- prefix_length = past_key_values.get_seq_length()
- # 3. Denoise `num_steps` times
- dt = -1.0 / num_steps
- for step in range(num_steps):
- time = 1.0 + step * dt
- time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(batch_size)
- output = self(
- state=state,
- noise=noise,
- timestep=time_tensor,
- pixel_attention_mask=pixel_attention_mask,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- )
- # We need to keep only the "vlm-prefix", no attention to past denoising steps!
- past_key_values.crop(prefix_length)
- noise = noise + dt * output.logits
- return noise
- __all__ = ["PI0PreTrainedModel", "PI0Model", "PI0ForConditionalGeneration"]
|