modular_pi0.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649
  1. # Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PI0 model: PaliGemma + Action Expert with flow matching for robot action prediction."""
  15. import math
  16. from collections.abc import Callable
  17. import numpy as np
  18. import torch
  19. import torch.nn.functional as F
  20. from huggingface_hub.dataclasses import strict
  21. from torch import nn
  22. from ... import initialization as init
  23. from ...cache_utils import Cache
  24. from ...configuration_utils import PreTrainedConfig
  25. from ...feature_extraction_utils import BatchFeature
  26. from ...image_utils import ImageInput, make_nested_list_of_images
  27. from ...masking_utils import create_bidirectional_mask
  28. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  29. from ...modeling_utils import PreTrainedModel
  30. from ...processing_utils import ProcessingKwargs, Unpack
  31. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  32. from ...utils import auto_docstring, can_return_tuple, logging
  33. from ...utils.generic import maybe_autocast
  34. from ...utils.import_utils import requires
  35. from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel
  36. from ..paligemma.processing_paligemma import PaligemmaProcessor
  37. from ..siglip.image_processing_siglip import SiglipImageProcessor
  38. logger = logging.get_logger(__name__)
  39. @auto_docstring
  40. class PI0ImageProcessor(SiglipImageProcessor):
  41. size = {"max_height": 224, "max_width": 224}
  42. pad_size = {"height": 224, "width": 224}
  43. do_pad = True
  44. class PI0ProcessorKwargs(ProcessingKwargs, total=False):
  45. _defaults = {
  46. "text_kwargs": {
  47. "padding": "max_length",
  48. "max_length": 48,
  49. "padding_side": "right",
  50. },
  51. "common_kwargs": {"return_tensors": "pt"},
  52. }
  53. @auto_docstring
  54. @requires(backends=("vision", "torch"))
  55. class PI0Processor(PaligemmaProcessor):
  56. def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
  57. self.height, self.width = image_processor.size["height"], image_processor.size["width"]
  58. state_mean = kwargs.get("state_mean", [-0.0419, 0.0354, 0.8257, 2.9083, -0.5562, -0.1665, 0.0283, -0.0286])
  59. state_std = kwargs.get("state_std", [0.1074, 0.1442, 0.2572, 0.3441, 1.2344, 0.3580, 0.0133, 0.0132])
  60. actions_mean = kwargs.get("actions_mean", [0.0182, 0.0586, -0.0559, 0.0046, 0.0029, -0.0077, -0.0916])
  61. actions_std = kwargs.get("actions_std", [0.2825, 0.3590, 0.3674, 0.0377, 0.0543, 0.0872, 0.9958])
  62. self.state_mean = torch.tensor(state_mean)
  63. self.state_std = torch.tensor(state_std)
  64. self.actions_mean = torch.tensor(actions_mean)
  65. self.actions_std = torch.tensor(actions_std)
  66. self.max_state_dim = kwargs.get("max_state_dim", 32)
  67. self.chunk_size = kwargs.get("chunk_size", 50)
  68. super().__init__(image_processor, tokenizer)
  69. def __call__(
  70. self,
  71. images: ImageInput | list[ImageInput] | list[list[ImageInput]] | None,
  72. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] | None = None,
  73. actions: list | np.ndarray | torch.Tensor | None = None,
  74. state: list | np.ndarray | torch.Tensor | None = None,
  75. **kwargs: Unpack[PI0ProcessorKwargs],
  76. ) -> BatchFeature:
  77. r"""
  78. actions (`list | np.ndarray | torch.Tensor`, *optional*):
  79. Actions to be predicted by the model. If provided, padding, mean and std normalization will be applied.
  80. state (`list | np.ndarray | torch.Tensor`, *optional*):
  81. Robotic states to be predicted by the model. If provided, padding, mean and std normalization will be applied.
  82. Returns:
  83. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  84. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix`
  85. is provided, the `input_ids` will also contain the suffix input ids.
  86. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  87. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  88. `None`).
  89. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  90. - **pixel_attention_mask** -- Pixel values padding mask to be fed to a model. Returned when `images` is not `None`.
  91. - **state** -- Robot state compatible with model if `state` is not None
  92. - **actions** -- Label-actions compatible with training if `actions` is not None
  93. """
  94. output_kwargs = self._merge_kwargs(
  95. PI0ProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs
  96. )
  97. if text is None:
  98. logger.warning_once("You are using PI0 without a text prefix. The processor will use an empty prompt.")
  99. text = ""
  100. if isinstance(text, str):
  101. text = [text]
  102. batched_images = make_nested_list_of_images(images)
  103. if len(batched_images) != len(text):
  104. raise ValueError(
  105. f"Received {len(batched_images)} image samples for {len(text)} prompts. "
  106. "Each prompt should be associated with one sample (with one or more camera images)."
  107. )
  108. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  109. output_kwargs["images_kwargs"].pop("return_tensors", None)
  110. prompt_strings = []
  111. for sample, image_list in zip(text, batched_images):
  112. sample = (
  113. f"{self.image_token * self.image_seq_length * len(image_list)}{self.tokenizer.bos_token}{sample}\n"
  114. )
  115. prompt_strings.append(sample)
  116. text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
  117. # Here is the diff from PaliGemma. Ideally we'd create a new ImageProcessor if it were a VLM
  118. max_num_cameras = max(len(sample_images) for sample_images in batched_images)
  119. pixel_attention_mask = torch.zeros((len(batched_images), max_num_cameras), dtype=torch.bool)
  120. padded_pixel_values = torch.zeros(len(batched_images), max_num_cameras, 3, self.height, self.width)
  121. for batch, sample_images in enumerate(batched_images):
  122. processed = self.image_processor(sample_images, return_tensors="pt", **output_kwargs["images_kwargs"])
  123. num_cameras = len(sample_images)
  124. pixel_attention_mask[batch, :num_cameras] = True
  125. padded_pixel_values[batch, :num_cameras] = processed["pixel_values"]
  126. return_data = {
  127. **text_inputs,
  128. "pixel_values": padded_pixel_values,
  129. "pixel_attention_mask": pixel_attention_mask,
  130. }
  131. if actions is not None:
  132. actions = (torch.tensor(actions) - self.actions_mean) / (self.actions_std + 1e-08)
  133. if actions.shape[-1] < self.max_state_dim:
  134. actions = F.pad(actions, (0, self.max_state_dim - actions.shape[-1]))
  135. return_data["actions"] = actions.view(-1, self.chunk_size, self.max_state_dim)
  136. if state is not None:
  137. state = (torch.tensor(state) - self.state_mean) / (self.state_std + 1e-08)
  138. if state.shape[-1] < self.max_state_dim:
  139. state = F.pad(state, (0, self.max_state_dim - state.shape[-1]))
  140. return_data["state"] = state.view(-1, self.max_state_dim)
  141. return BatchFeature(data=return_data, tensor_type=return_tensors)
  142. @property
  143. def model_input_names(self):
  144. return super().model_input_names + ["pixel_attention_mask"]
  145. @auto_docstring(checkpoint="lerobot/pi0_base")
  146. @strict
  147. class PI0Config(PreTrainedConfig):
  148. r"""
  149. vlm_config (`dict`, *optional*):
  150. Configuration for the vlm backbone (PaliGemmaModel).
  151. dit_config (`dict`, *optional*):
  152. Configuration for the DiT backbone. Defaults to a Gemma 300M variant.
  153. chunk_size (`int`, *optional*, defaults to 50):
  154. Number of action steps to predict per chunk.
  155. max_state_dim (`int`, *optional*, defaults to 32):
  156. Maximum state vector dimension (shorter vectors are zero-padded).
  157. max_action_dim (`int`, *optional*, defaults to 32):
  158. Maximum action vector dimension (shorter vectors are zero-padded).
  159. num_inference_steps (`int`, *optional*, defaults to 10):
  160. Number of denoising steps during inference.
  161. time_sampling_beta_alpha (`float`, *optional*, defaults to 1.5):
  162. Alpha parameter for Beta distribution used to sample diffusion time during training.
  163. time_sampling_beta_beta (`float`, *optional*, defaults to 1.0):
  164. Beta parameter for Beta distribution used to sample diffusion time during training.
  165. time_sampling_scale (`float`, *optional*, defaults to 0.999):
  166. Scale factor for sampled time values.
  167. time_sampling_offset (`float`, *optional*, defaults to 0.001):
  168. Offset added to sampled time values.
  169. min_period (`float`, *optional*, defaults to 0.004):
  170. Minimum period for sinusoidal time embedding.
  171. max_period (`float`, *optional*, defaults to 4.0):
  172. Maximum period for sinusoidal time embedding.
  173. loss_reduction (`str`, *optional*, defaults to `"mean"`):
  174. The reduction to use on MSE loss.
  175. Example:
  176. ```python
  177. >>> from transformers import PI0ForConditionalGeneration, PI0Config
  178. >>> config = PI0Config()
  179. >>> model = PI0ForConditionalGeneration(config)
  180. ```
  181. """
  182. model_type = "pi0"
  183. sub_configs = {"vlm_config": AutoConfig, "dit_config": AutoConfig}
  184. vlm_config: dict | PreTrainedConfig | None = None
  185. dit_config: dict | PreTrainedConfig | None = None
  186. chunk_size: int = 50
  187. max_state_dim: int = 32
  188. max_action_dim: int = 32
  189. num_inference_steps: int = 10
  190. time_sampling_beta_alpha: float = 1.5
  191. time_sampling_beta_beta: float = 1.0
  192. time_sampling_scale: float = 0.999
  193. time_sampling_offset: float = 0.001
  194. min_period: float = 4e-3
  195. max_period: float = 4.0
  196. loss_reduction: str = "mean"
  197. def __post_init__(self, **kwargs):
  198. if isinstance(self.vlm_config, dict):
  199. vlm_model_type = self.vlm_config.get("model_type", "paligemma")
  200. self.vlm_config = CONFIG_MAPPING[vlm_model_type](**self.vlm_config)
  201. elif self.vlm_config is None:
  202. self.vlm_config = CONFIG_MAPPING["paligemma"](
  203. text_config={
  204. "model_type": "gemma",
  205. "hidden_size": 2048,
  206. "num_hidden_layers": 18,
  207. "intermediate_size": 16384,
  208. "num_attention_heads": 8,
  209. "num_key_value_heads": 1,
  210. "vocab_size": 257152,
  211. },
  212. vision_config={
  213. "model_type": "siglip_vision_model",
  214. "intermediate_size": 4304,
  215. "hidden_size": 1152,
  216. "patch_size": 14,
  217. "image_size": 224,
  218. "num_hidden_layers": 27,
  219. "num_attention_heads": 16,
  220. "vocab_size": 257152,
  221. "vision_use_head": False,
  222. },
  223. projection_dim=2048,
  224. image_token_id=257152,
  225. )
  226. if isinstance(self.dit_config, dict):
  227. dit_model_type = self.dit_config.get("model_type", "gemma")
  228. self.dit_config = CONFIG_MAPPING[dit_model_type](**self.dit_config)
  229. elif self.dit_config is None:
  230. self.dit_config = CONFIG_MAPPING["gemma"](
  231. hidden_size=1024,
  232. num_hidden_layers=18,
  233. intermediate_size=4096,
  234. num_attention_heads=8,
  235. num_key_value_heads=1,
  236. head_dim=256,
  237. vocab_size=self.vlm_config.text_config.vocab_size,
  238. )
  239. # Force bidirectional attention
  240. self.dit_config.is_causal = False
  241. self.dit_config.use_bidirectional_attention = True
  242. self.vlm_config.text_config.use_bidirectional_attention = True
  243. super().__post_init__(**kwargs)
  244. def validate_architecture(self):
  245. """Part of `@strict`-powered validation. Validates the architecture of the config."""
  246. if self.dit_config.hidden_size % 2 != 0:
  247. raise ValueError(f"DiT hidden dim=({self.config.dit_config.hidden_size}) must be divisible by 2")
  248. def blockwise_bidirectional_mask(block_boundaries: torch.Tensor) -> Callable:
  249. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  250. q_block = torch.bucketize(q_idx, block_boundaries)
  251. kv_block = torch.bucketize(kv_idx, block_boundaries)
  252. return kv_block <= q_block
  253. return inner_mask
  254. class PI0TimestepEmbeddings(nn.Module):
  255. def __init__(self, config):
  256. super().__init__()
  257. self.config = config
  258. sinusoid_freq = self.compute_freqs(config)
  259. self.register_buffer("sinusoid_freq", sinusoid_freq, persistent=False)
  260. @staticmethod
  261. def compute_freqs(config):
  262. fraction = torch.linspace(0.0, 1.0, config.dit_config.hidden_size // 2, dtype=torch.float32)
  263. period = config.min_period * (config.max_period / config.min_period) ** fraction
  264. sinusoid_freq = 1.0 / period * 2 * math.pi
  265. return sinusoid_freq
  266. def forward(self, time):
  267. device_type = time.device.type if isinstance(time.device.type, str) and time.device.type != "mps" else "cpu"
  268. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  269. sinusoid_freq = self.sinusoid_freq[None, :]
  270. emb = sinusoid_freq * time[:, None]
  271. time_embeds = torch.cat([emb.sin(), emb.cos()], dim=1)
  272. return time_embeds
  273. class PI0ActionTimeEmbedding(nn.Module):
  274. def __init__(self, config):
  275. super().__init__()
  276. self.sinusoid_embeds = PI0TimestepEmbeddings(config)
  277. self.action_in_proj = nn.Linear(config.max_action_dim, config.dit_config.hidden_size)
  278. self.state_proj = nn.Linear(config.max_state_dim, config.dit_config.hidden_size)
  279. self.action_time_mlp_in = nn.Linear(2 * config.dit_config.hidden_size, config.dit_config.hidden_size)
  280. self.action_time_mlp_out = nn.Linear(config.dit_config.hidden_size, config.dit_config.hidden_size)
  281. def forward(self, state, noise, timestep):
  282. state_embeds = self.state_proj(state)
  283. action_embeds = self.action_in_proj(noise)
  284. time_embeds = self.sinusoid_embeds(timestep)
  285. time_embeds = time_embeds[:, None, :].expand_as(action_embeds).to(dtype=action_embeds.dtype)
  286. action_time_embeds = torch.cat([action_embeds, time_embeds], dim=2)
  287. action_time_embeds = self.action_time_mlp_out(F.silu(self.action_time_mlp_in(action_time_embeds)))
  288. action_embeds_merged = torch.cat([state_embeds[:, None, :], action_time_embeds], dim=1)
  289. return action_embeds_merged
  290. @auto_docstring
  291. class PI0PreTrainedModel(PreTrainedModel):
  292. config: PI0Config
  293. base_model_prefix = "model"
  294. main_input_name = "state"
  295. supports_gradient_checkpointing = True
  296. _skip_keys_device_placement = ["past_key_values"]
  297. _supports_flash_attn = True
  298. _supports_sdpa = True
  299. _supports_flex_attn = True
  300. _can_compile_fullgraph = True
  301. _supports_attention_backend = True
  302. input_modalities = ("image", "text")
  303. def _init_weights(self, module):
  304. super()._init_weights(module)
  305. if isinstance(module, PI0TimestepEmbeddings):
  306. init.copy_(module.sinusoid_freq, module.compute_freqs(module.config))
  307. @auto_docstring
  308. class PI0Model(PI0PreTrainedModel):
  309. def __init__(self, config: PI0Config):
  310. super().__init__(config)
  311. self.dit = AutoModel.from_config(config.dit_config)
  312. self.vlm = AutoModel.from_config(config.vlm_config)
  313. self.post_init()
  314. def get_input_embeddings(self):
  315. return self.vlm.get_input_embeddings()
  316. def set_input_embeddings(self, value):
  317. self.vlm.set_input_embeddings(value)
  318. def embed_prefix(self, input_ids, pixel_values, pixel_attention_mask, attention_mask=None):
  319. max_num_cameras = pixel_attention_mask.shape[1]
  320. pixel_values = pixel_values.flatten(0, 1)
  321. image_features = self.vlm.get_image_features(pixel_values).pooler_output
  322. image_features = image_features.reshape(-1, max_num_cameras, image_features.shape[1], image_features.shape[2])
  323. total_image_features = []
  324. for batch_idx, mask in enumerate(pixel_attention_mask):
  325. unpadded_image_features = image_features[batch_idx][mask]
  326. total_image_features.append(unpadded_image_features)
  327. total_image_features = torch.cat(total_image_features, dim=0)
  328. llm_input_ids = input_ids.clone()
  329. llm_input_ids[input_ids == self.config.vlm_config.image_token_id] = 0
  330. inputs_embeds = self.vlm.get_input_embeddings()(llm_input_ids)
  331. special_image_mask = (
  332. (input_ids == self.config.vlm_config.image_token_id)
  333. .unsqueeze(-1)
  334. .expand_as(inputs_embeds)
  335. .to(inputs_embeds.device)
  336. )
  337. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, total_image_features)
  338. return inputs_embeds
  339. @can_return_tuple
  340. @auto_docstring
  341. def forward(
  342. self,
  343. action_embeds: torch.Tensor, # aka `suffix_emb` (noise + state + timestep)
  344. input_ids: torch.Tensor | None = None,
  345. pixel_values: torch.Tensor | None = None,
  346. attention_mask: torch.Tensor | None = None,
  347. pixel_attention_mask: torch.Tensor | None = None,
  348. position_ids: torch.LongTensor | None = None,
  349. inputs_embeds: torch.Tensor | None = None, # aka `prefix_emb` or merged image+text emb
  350. past_key_values: Cache | None = None, # must-have for prefix tuning
  351. **kwargs,
  352. ) -> BaseModelOutputWithPast:
  353. r"""
  354. action_embeds (`torch.Tensor`, *optional*):
  355. The embeddings of input actions and robot states.
  356. pixel_attention_mask (`torch.Tensor`, *optional*):
  357. The mask indicating padded positions in the input image.
  358. """
  359. if pixel_values is not None and past_key_values is None:
  360. if attention_mask is not None and position_ids is None:
  361. position_ids = attention_mask.cumsum(-1) - 1
  362. if inputs_embeds is None:
  363. inputs_embeds = self.embed_prefix(input_ids, pixel_values, pixel_attention_mask)
  364. token_type_ids = torch.zeros_like(inputs_embeds)[:, :, 0]
  365. past_key_values = self.vlm(
  366. inputs_embeds=inputs_embeds,
  367. attention_mask=attention_mask,
  368. position_ids=position_ids,
  369. token_type_ids=token_type_ids,
  370. use_cache=True,
  371. ).past_key_values
  372. if attention_mask is not None and attention_mask.ndim != 2:
  373. raise ValueError("Only two-dimensional attention masks are accepted for now!")
  374. # Merge masks if needed, same for position ids
  375. dit_position_ids = dit_attention_mask = None
  376. if attention_mask is not None:
  377. noise_mask = torch.ones(
  378. action_embeds.shape[0],
  379. action_embeds.shape[1],
  380. dtype=attention_mask.dtype,
  381. device=attention_mask.device,
  382. )
  383. dit_attention_mask = torch.cat([attention_mask, noise_mask], dim=1)
  384. dit_position_ids = (torch.cumsum(dit_attention_mask, dim=1) - 1)[:, -action_embeds.shape[1] :]
  385. # We have three blocks: vlm-inputss, state and actions from which only 1 token is `state`
  386. # The mask should be bidirectional within each block and to prev blocks, but not to next blocks
  387. vlm_input_length = past_key_values.get_seq_length()
  388. block_sizes = torch.tensor([vlm_input_length + 1, action_embeds.shape[1] - 1], device=action_embeds.device)
  389. block_boundaries = torch.cumsum(block_sizes, dim=0) - 1
  390. bidirectional_mask = create_bidirectional_mask(
  391. config=self.config.dit_config,
  392. inputs_embeds=action_embeds,
  393. attention_mask=dit_attention_mask,
  394. past_key_values=past_key_values,
  395. and_mask_function=blockwise_bidirectional_mask(block_boundaries),
  396. )
  397. dit_output = self.dit(
  398. inputs_embeds=action_embeds,
  399. attention_mask=bidirectional_mask,
  400. position_ids=dit_position_ids,
  401. past_key_values=past_key_values,
  402. **kwargs,
  403. )
  404. return dit_output
  405. class PI0ForConditionalGeneration(PI0PreTrainedModel):
  406. """PI0 model with action projection heads and flow matching."""
  407. _tp_plan = {"action_out_proj": "colwise_gather_output"}
  408. def __init__(self, config: PI0Config):
  409. super().__init__(config)
  410. self.model = PI0Model(config)
  411. self.expert_hidden_size = config.dit_config.hidden_size
  412. self.embed_action_time = PI0ActionTimeEmbedding(config)
  413. self.action_out_proj = nn.Linear(self.expert_hidden_size, config.max_action_dim)
  414. self.post_init()
  415. @can_return_tuple
  416. @auto_docstring
  417. def forward(
  418. self,
  419. state: torch.FloatTensor,
  420. noise: torch.FloatTensor | None = None,
  421. timestep: torch.FloatTensor | None = None,
  422. input_ids: torch.Tensor | None = None,
  423. pixel_values: torch.Tensor | None = None,
  424. pixel_attention_mask: torch.BoolTensor | None = None,
  425. attention_mask: torch.Tensor | None = None,
  426. position_ids: torch.LongTensor | None = None,
  427. inputs_embeds: torch.Tensor | None = None,
  428. past_key_values: Cache | None = None,
  429. actions: torch.FloatTensor = None, # aka labels
  430. **kwargs,
  431. ) -> CausalLMOutputWithPast:
  432. r"""
  433. state (`torch.Tensor`, *optional*):
  434. Current robot state.
  435. noise (`torch.Tensor`, *optional*):
  436. Random noise at current timestep that needs to be denoised
  437. timestep (`torch.Tensor`, *optional*):
  438. Current denoising timestep.
  439. pixel_attention_mask (`torch.Tensor`, *optional*):
  440. The mask indicating padded positions in the input image.
  441. actions (`torch.Tensor`, *optional*):
  442. Input actions that need to be predicted. Used only when training to compiute loss.
  443. """
  444. batch_size = state.shape[0]
  445. # 1.Sample the timestep
  446. if timestep is None:
  447. alpha_t = torch.tensor(self.config.time_sampling_beta_alpha, dtype=torch.float32)
  448. beta_t = torch.tensor(self.config.time_sampling_beta_beta, dtype=torch.float32)
  449. dist = torch.distributions.Beta(alpha_t, beta_t)
  450. time_beta = dist.sample((batch_size,)).to(state.device)
  451. timestep = (time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset).float()
  452. # 2. Create random noise if not provided
  453. if noise is None:
  454. noise = torch.randn(
  455. batch_size,
  456. self.config.chunk_size,
  457. self.config.max_action_dim,
  458. device=state.device,
  459. dtype=state.dtype,
  460. )
  461. # 3. If training: merge noise with the ground truth actions (aka labels)
  462. # Target velocity is the label we want to predict and will compute loss upon
  463. if actions is not None:
  464. time_expanded = timestep[:, None, None]
  465. noisy_actions = (time_expanded * noise + (1 - time_expanded) * actions).to(actions.dtype)
  466. target_velocity = noise - actions
  467. else:
  468. noisy_actions = noise
  469. # 4. Embed 'state + noise + actions' for DiT blocks
  470. action_time_embeds = self.embed_action_time(state, noisy_actions, timestep)
  471. outputs = self.model(
  472. input_ids=input_ids,
  473. pixel_values=pixel_values,
  474. attention_mask=attention_mask,
  475. pixel_attention_mask=pixel_attention_mask,
  476. position_ids=position_ids,
  477. inputs_embeds=inputs_embeds,
  478. action_embeds=action_time_embeds,
  479. past_key_values=past_key_values,
  480. **kwargs,
  481. )
  482. last_hidden_states = outputs.last_hidden_state[:, -self.config.chunk_size :]
  483. predicted_velocity = self.action_out_proj(last_hidden_states)
  484. loss = None
  485. if actions is not None:
  486. # Let the users reduce loss themselves and return fine-grained per sample loss
  487. loss = F.mse_loss(target_velocity, predicted_velocity, reduction=self.config.loss_reduction)
  488. return CausalLMOutputWithPast(
  489. loss=loss,
  490. logits=predicted_velocity,
  491. past_key_values=outputs.past_key_values,
  492. hidden_states=outputs.hidden_states,
  493. attentions=outputs.attentions,
  494. )
  495. @torch.no_grad()
  496. def sample_actions(
  497. self,
  498. state: torch.FloatTensor,
  499. input_ids: torch.LongTensor,
  500. pixel_values: torch.FloatTensor,
  501. noise: torch.FloatTensor | None = None,
  502. attention_mask: torch.Tensor | None = None,
  503. pixel_attention_mask: torch.BoolTensor | None = None,
  504. num_steps: int | None = None,
  505. **kwargs,
  506. ) -> torch.FloatTensor:
  507. """Run flow matching inference to generate actions."""
  508. num_steps = num_steps or self.config.num_inference_steps
  509. batch_size = input_ids.shape[0]
  510. device = input_ids.device
  511. # 1. Sample random noise
  512. if noise is None:
  513. noise = torch.normal(
  514. mean=0.0,
  515. std=1.0,
  516. size=(
  517. batch_size,
  518. self.config.chunk_size,
  519. self.config.max_action_dim,
  520. ),
  521. dtype=pixel_values.dtype,
  522. device=device,
  523. )
  524. # 2. Run VLM once and obtain prefix cache. Must infer positions here!
  525. if attention_mask is not None:
  526. position_ids = attention_mask.cumsum(-1) - 1
  527. inputs_embeds = self.model.embed_prefix(input_ids, pixel_values, pixel_attention_mask)
  528. past_key_values = self.model.vlm(
  529. inputs_embeds=inputs_embeds,
  530. attention_mask=attention_mask,
  531. position_ids=position_ids,
  532. use_cache=True,
  533. return_dict=True,
  534. ).past_key_values
  535. prefix_length = past_key_values.get_seq_length()
  536. # 3. Denoise `num_steps` times
  537. dt = -1.0 / num_steps
  538. for step in range(num_steps):
  539. time = 1.0 + step * dt
  540. time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(batch_size)
  541. output = self(
  542. state=state,
  543. noise=noise,
  544. timestep=time_tensor,
  545. pixel_attention_mask=pixel_attention_mask,
  546. attention_mask=attention_mask,
  547. past_key_values=past_key_values,
  548. )
  549. # We need to keep only the "vlm-prefix", no attention to past denoising steps!
  550. past_key_values.crop(prefix_length)
  551. noise = noise + dt * output.logits
  552. return noise
  553. __all__ = [
  554. "PI0Config",
  555. "PI0PreTrainedModel",
  556. "PI0Model",
  557. "PI0ForConditionalGeneration",
  558. "PI0Processor",
  559. "PI0ImageProcessor",
  560. ]