modeling_prismatic.py
47.8 KB · 1086 lines · python Raw
1 """
2 modeling_prismatic.py
3
4 Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions.
5 Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained,
6 but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`.
7 """
8
9 import logging
10 from dataclasses import dataclass
11 from functools import partial
12 from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
13
14 import numpy as np
15 import timm
16 import tokenizers
17 import torch
18 import torch.nn as nn
19 import transformers
20 from timm.models.vision_transformer import LayerScale
21 from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
22 from transformers.modeling_outputs import ModelOutput
23
24 from prismatic.training.train_utils import (
25 get_current_action_mask,
26 get_next_actions_mask,
27 )
28 from prismatic.vla.constants import (
29 ACTION_DIM,
30 ACTION_PROPRIO_NORMALIZATION_TYPE,
31 ACTION_TOKEN_BEGIN_IDX,
32 IGNORE_INDEX,
33 NUM_ACTIONS_CHUNK,
34 STOP_INDEX,
35 NormalizationType,
36 )
37
38 from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
39
40 # Set up logger
41 logger = logging.getLogger(__name__)
42
43
44 # === Utility Functions for Monkey-Patching ===
45 def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
46 def wrapper(*args: Any, **kwargs: Any) -> Any:
47 result = fn(*args, **kwargs)
48 return result[0] if isinstance(result, tuple) else result
49
50 return wrapper
51
52
53 # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
54 # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
55 # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
56 def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
57 return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
58
59
60 def ls_apply_patch(ls_module: LayerScale):
61 ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
62 ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
63 del ls_module.gamma
64
65
66 # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
67 class PrismaticVisionBackbone(nn.Module):
68 """
69 Vision backbone for Prismatic models that handles image feature extraction.
70
71 Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations.
72 For fused backbones, features from both models are concatenated along the feature dimension.
73 """
74
75 def __init__(
76 self,
77 use_fused_vision_backbone: bool,
78 image_sizes: List[int],
79 timm_model_ids: List[str],
80 timm_override_act_layers: List[Optional[str]],
81 ) -> None:
82 """
83 Initialize the vision backbone.
84
85 Args:
86 use_fused_vision_backbone: Whether to use two backbones and fuse their features
87 image_sizes: List of image sizes for each backbone
88 timm_model_ids: List of TIMM model IDs to use for each backbone
89 timm_override_act_layers: List of activation layer overrides for each backbone
90 """
91 super().__init__()
92 self.use_fused_vision_backbone = use_fused_vision_backbone
93 self.num_images_in_input = 1 # Default value, can be overridden later
94
95 # Validate number of (fused) vision backbones
96 if len(timm_model_ids) > 2:
97 raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!")
98
99 # Create primary featurizer
100 self.featurizer = self._create_featurizer(
101 model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0]
102 )
103 self.embed_dim = self.featurizer.embed_dim
104
105 # Create secondary featurizer if using fused backbone
106 if self.use_fused_vision_backbone:
107 self.fused_featurizer = self._create_featurizer(
108 model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1]
109 )
110 self.embed_dim += self.fused_featurizer.embed_dim
111
112 # Patch LayerScale modules for HF compatibility
113 self._patch_layer_scales()
114
115 def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module:
116 """
117 Create a TIMM-based featurizer model with appropriate configurations.
118
119 Args:
120 model_id: The TIMM model ID to load
121 img_size: Input image size for the model
122 act_layer: Override for the activation layer type
123
124 Returns:
125 A configured featurizer model
126 """
127 featurizer = timm.create_model(
128 model_id,
129 pretrained=False,
130 num_classes=0,
131 img_size=img_size,
132 act_layer=act_layer,
133 )
134
135 # Monkey-patch the forward function to extract the second-to-last layer features
136 num_blocks = len(featurizer.blocks)
137 featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2}))
138
139 return featurizer
140
141 def _patch_layer_scales(self) -> None:
142 """
143 Patch all LayerScale modules to be compatible with HF's parameter naming.
144
145 HF Transformers overwrites parameters with names containing 'gamma',
146 so we need to rename and modify the forward method.
147 """
148 # Patch primary featurizer
149 for module in self.featurizer.modules():
150 if isinstance(module, LayerScale):
151 ls_apply_patch(module)
152
153 # Patch secondary featurizer if it exists
154 if self.use_fused_vision_backbone:
155 for module in self.fused_featurizer.modules():
156 if isinstance(module, LayerScale):
157 ls_apply_patch(module)
158
159 def get_num_patches(self) -> int:
160 """
161 Returns the number of vision patches output by the vision backbone.
162
163 Returns:
164 Number of patches per image
165 """
166 return self.featurizer.patch_embed.num_patches
167
168 def get_num_images_in_input(self) -> int:
169 """
170 Returns the number of input images for the vision backbone.
171
172 Returns:
173 Number of images expected in the input
174 """
175 return self.num_images_in_input
176
177 def set_num_images_in_input(self, num_images_in_input: int) -> None:
178 """
179 Sets the number of input images for the vision backbone.
180
181 Args:
182 num_images_in_input: Number of images to expect in the input
183 """
184 self.num_images_in_input = num_images_in_input
185
186 def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
187 """
188 Implements the forward pass for the vision backbone.
189
190 If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features
191 (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone).
192
193 Args:
194 pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W).
195 """
196 if self.num_images_in_input == 1:
197 if not self.use_fused_vision_backbone:
198 return self.featurizer(pixel_values)
199
200 # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
201 img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
202 patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
203
204 return torch.cat([patches, patches_fused], dim=2)
205
206 else:
207 assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!"
208
209 # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2)
210 images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1)
211
212 # Process each image and collect patches
213 all_patches = []
214 for img in images:
215 # Split each image further into two stacks of channels (each with 3 channels)
216 img_regular, img_fused = torch.split(img, [3, 3], dim=1)
217
218 # Get patches from both SigLIP and DINOv2 vision transformers
219 patches = self.featurizer(img_regular)
220 patches_fused = self.fused_featurizer(img_fused)
221
222 # Concatenate SigLIP and DINOv2 patches along the hidden dimension
223 combined_patches = torch.cat([patches, patches_fused], dim=2)
224 all_patches.append(combined_patches)
225
226 # Concatenate all patches along the patch dimension
227 return torch.cat(all_patches, dim=1)
228
229
230 # === Prismatic Projector (nn.Module) Definitions ===
231 class PrismaticProjector(nn.Module):
232 def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
233 super().__init__()
234 self.use_fused_vision_backbone = use_fused_vision_backbone
235 self.vision_dim, self.llm_dim = vision_dim, llm_dim
236
237 # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
238 if not self.use_fused_vision_backbone:
239 self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
240 self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
241 self.act_fn1 = nn.GELU()
242 else:
243 initial_projection_dim = 4 * vision_dim
244 self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
245 self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
246 self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
247 self.act_fn1 = nn.GELU()
248 self.act_fn2 = nn.GELU()
249
250 def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
251 if not self.use_fused_vision_backbone:
252 projected_features = self.fc1(img_patches)
253 projected_features = self.act_fn1(projected_features)
254 projected_features = self.fc2(projected_features)
255 else:
256 projected_features = self.fc1(img_patches)
257 projected_features = self.act_fn1(projected_features)
258 projected_features = self.fc2(projected_features)
259 projected_features = self.act_fn2(projected_features)
260 projected_features = self.fc3(projected_features)
261
262 return projected_features
263
264
265 # === Main HF Class Definitions ===
266 @dataclass
267 class PrismaticCausalLMOutputWithPast(ModelOutput):
268 """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
269
270 loss: Optional[torch.FloatTensor] = None
271 logits: torch.FloatTensor = None
272 past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
273 hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
274 attentions: Optional[Tuple[torch.FloatTensor]] = None
275
276 # Additions for VLMs
277 projector_features: Optional[torch.FloatTensor] = None
278
279
280 class PrismaticPreTrainedModel(PreTrainedModel):
281 config_class: PretrainedConfig = PrismaticConfig
282 base_model_prefix: str = "model"
283 supports_gradient_checkpointing: bool = True
284
285 _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
286 _skip_keys_device_placement: str = "past_key_values"
287 _supports_flash_attn_2: bool = True
288
289 def _init_weights(self, module: nn.Module) -> None:
290 # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
291 # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
292 # https://github.com/TRI-ML/prismatic-vlms
293 std = (
294 self.config.initializer_range
295 if hasattr(self.config, "initializer_range")
296 else self.config.text_config.initializer_range
297 )
298
299 if hasattr(module, "class_embedding"):
300 module.class_embedding.data.normal_(mean=0.0, std=std)
301
302 if isinstance(module, (nn.Linear, nn.Conv2d)):
303 module.weight.data.normal_(mean=0.0, std=std)
304 if module.bias is not None:
305 module.bias.data.zero_()
306 elif isinstance(module, nn.Embedding):
307 module.weight.data.normal_(mean=0.0, std=std)
308 if module.padding_idx is not None:
309 module.weight.data[module.padding_idx].zero_()
310
311 @property
312 def _supports_sdpa(self) -> bool:
313 """Check LLM supports SDPA Attention"""
314 return self.language_model._supports_sdpa
315
316
317 class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
318 def __init__(self, config: PrismaticConfig) -> None:
319 super().__init__(config)
320
321 # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
322 if config.use_fused_vision_backbone is None:
323 raise ValueError("Missing config field `use_fused_vision_backbone`")
324
325 if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
326 raise NotImplementedError(
327 "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
328 "if you urgently need support for latest TIMM versions."
329 )
330
331 if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
332 logger.warning(
333 f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
334 f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
335 f"there might be inference-time regressions due to dependency changes. If in doubt, please"
336 f"use the above versions."
337 )
338
339 # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
340 self.vision_backbone = PrismaticVisionBackbone(
341 config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
342 )
343
344 # Create Multimodal Projector
345 self.projector = PrismaticProjector(
346 config.use_fused_vision_backbone,
347 vision_dim=self.vision_backbone.embed_dim,
348 llm_dim=config.text_config.hidden_size,
349 )
350
351 # Instantiate LLM Backbone
352 self.language_model = AutoModelForCausalLM.from_config(
353 config.text_config, attn_implementation=config._attn_implementation
354 )
355 self.vocab_size = config.text_config.vocab_size
356 self.pad_token_id = config.pad_token_id
357 self.llm_dim = config.text_config.hidden_size
358
359 # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
360 self.post_init()
361
362 # === `PreTrainedModel` Boilerplate ===
363 def get_input_embeddings(self) -> nn.Module:
364 return self.language_model.get_input_embeddings()
365
366 def set_input_embeddings(self, value: nn.Module) -> None:
367 self.language_model.set_input_embeddings(value)
368
369 def get_output_embeddings(self) -> nn.Module:
370 return self.language_model.get_output_embeddings()
371
372 def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
373 self.language_model.set_output_embeddings(new_embeddings)
374
375 def get_decoder(self) -> nn.Module:
376 return self.language_model.get_decoder()
377
378 def set_decoder(self, decoder: nn.Module) -> None:
379 self.language_model.set_decoder(decoder)
380
381 def tie_weights(self) -> None:
382 self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
383
384 def resize_token_embeddings(
385 self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
386 ) -> nn.Embedding:
387 updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
388
389 # Update config/instance variables
390 self.config.text_config.vocab_size = updated_embeddings.num_embeddings
391 self.vocab_size = updated_embeddings.num_embeddings
392
393 return updated_embeddings
394
395 def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features):
396 """
397 Replace embeddings in input_embeddings at positions where all_actions_mask is True
398 with embeddings from noisy_action_features, using vectorized operations.
399
400 Args:
401 input_embeddings: Tensor of shape (B, S, D)
402 all_actions_mask: Boolean tensor of shape (B, S)
403 noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample
404
405 Returns:
406 Modified input_embeddings tensor
407 """
408 # Clone input to avoid modifying the original tensor
409 new_input_embeddings = input_embeddings.clone()
410
411 # Create a tensor with the same shape of input_embeddings to hold the noisy action features
412 repositioned_noisy_action_features = torch.zeros_like(input_embeddings)
413
414 # Create batch indices for splicing
415 batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device)
416 batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1])
417
418 # Get indices where mask is True for each sample
419 masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask])
420
421 # Move the noisy action features into their correct positions
422 repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
423
424 # Combine original input embeddings and noisy action embeddings using the mask
425 new_input_embeddings = torch.where(
426 all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings
427 )
428
429 return new_input_embeddings
430
431 def _process_action_masks(self, labels):
432 """Helper to get action masks from labels"""
433 current_action_mask = get_current_action_mask(labels)
434 next_actions_mask = get_next_actions_mask(labels)
435 all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
436 return all_actions_mask
437
438 def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False):
439 """Process vision features with optional FiLM conditioning"""
440 if use_film:
441 # FiLM: Infuse language inputs into visual features
442 patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D)
443 else:
444 patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
445
446 # Project patch embeddings into language embedding space
447 return self.projector(patch_features)
448
449 def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector):
450 """Process proprioceptive features and append to vision features"""
451 if proprio_projector is not None and proprio is not None:
452 # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim)
453 # proprio: (bsz, proprio_dim) or (propro_dim,)
454 proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim)
455 proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
456 proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
457 # For simplicity, just append proprio token to the end of projected vision patch tokens
458 return torch.cat((projected_patch_embeddings, proprio_features), dim=1)
459 return projected_patch_embeddings
460
461 def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
462 """Build multimodal embeddings and attention mask"""
463 # Update attention mask
464 projected_patch_attention_mask = None
465 if attention_mask is not None:
466 projected_patch_attention_mask = torch.full(
467 (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
468 fill_value=True,
469 dtype=attention_mask.dtype,
470 device=attention_mask.device,
471 )
472
473 # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
474 multimodal_embeddings = torch.cat(
475 [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
476 )
477
478 multimodal_attention_mask = None
479 if attention_mask is not None:
480 multimodal_attention_mask = torch.cat(
481 [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
482 )
483
484 return multimodal_embeddings, multimodal_attention_mask
485
486 def _build_multimodal_labels(self, labels, projected_patch_embeddings):
487 """Build multimodal labels with IGNORE_INDEX for patch embeddings"""
488 if labels is not None:
489 projected_patch_labels = torch.full(
490 (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
491 fill_value=IGNORE_INDEX,
492 dtype=labels.dtype,
493 device=labels.device,
494 )
495 return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
496 return None
497
498 # === Core Prismatic VLM `forward()` Logic ===
499 def forward(
500 self,
501 input_ids: Optional[torch.LongTensor] = None,
502 attention_mask: Optional[torch.Tensor] = None,
503 pixel_values: Optional[torch.FloatTensor] = None,
504 labels: Optional[torch.LongTensor] = None,
505 inputs_embeds: Optional[torch.FloatTensor] = None,
506 past_key_values: Optional[List[torch.FloatTensor]] = None,
507 use_cache: Optional[bool] = None,
508 output_attentions: Optional[bool] = None,
509 output_hidden_states: Optional[bool] = None,
510 output_projector_features: Optional[bool] = None,
511 return_dict: Optional[bool] = None,
512 proprio=None,
513 proprio_projector=None,
514 noisy_actions=None,
515 noisy_action_projector=None,
516 diffusion_timestep_embeddings=None,
517 use_film: bool = False,
518 ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
519 """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
520 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
521 output_hidden_states = (
522 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
523 )
524 output_projector_features = output_projector_features if output_projector_features is not None else False
525 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
526
527 # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
528 use_cache = use_cache and not self.training
529
530 # Instantiate Placeholder for Projector Features
531 projected_patch_embeddings = None
532
533 # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
534 if input_ids.shape[1] == 1:
535 assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
536 assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
537 assert labels is None, "Unexpected key `labels` provided during cached generation!"
538
539 language_model_output = self.language_model(
540 input_ids=input_ids,
541 attention_mask=None,
542 position_ids=None,
543 past_key_values=past_key_values,
544 inputs_embeds=None,
545 labels=None,
546 use_cache=use_cache,
547 output_attentions=output_attentions,
548 output_hidden_states=output_hidden_states,
549 return_dict=return_dict,
550 )
551
552 # === Handle Unimodal Forward ===
553 elif pixel_values is None:
554 assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
555 assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
556
557 language_model_output = self.language_model(
558 input_ids=input_ids,
559 attention_mask=attention_mask,
560 position_ids=None,
561 past_key_values=None,
562 inputs_embeds=None,
563 labels=labels,
564 use_cache=use_cache,
565 output_attentions=output_attentions,
566 output_hidden_states=output_hidden_states,
567 return_dict=return_dict,
568 )
569
570 # === Handle Multimodal Forward ===
571 elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
572 assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
573
574 # Get input embeddings (from language model embeddings)
575 input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
576
577 # Extract action masks
578 all_actions_mask = self._process_action_masks(labels)
579
580 # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
581 language_embeddings = input_embeddings[~all_actions_mask].reshape(
582 input_embeddings.shape[0], -1, input_embeddings.shape[2]
583 ) # (B, lang_seq_len, llm_dim)
584
585 # Get visual features
586 projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
587
588 # Add proprioceptive state if provided
589 projected_patch_embeddings = self._process_proprio_features(
590 projected_patch_embeddings, proprio, proprio_projector
591 )
592
593 # [Diffusion] Add diffusion timestep embedding if provided
594 if diffusion_timestep_embeddings is not None:
595 # For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens
596 projected_patch_embeddings = torch.cat(
597 (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
598 )
599
600 # Process action embeddings
601 if noisy_actions is not None:
602 # Get mask corresponding to all action tokens
603 all_actions_mask = self._process_action_masks(labels)
604
605 # Reshape noisy actions into individual action tokens
606 # noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1)
607 B = noisy_actions.shape[0]
608 noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1)
609
610 # Project noisy action tokens into language model embedding space
611 noisy_action_features = noisy_action_projector(noisy_actions) # (B, chunk_len * action_dim, llm_dim)
612
613 # Replace embeddings of the action tokens with noisy action embeddings
614 input_embeddings = self._replace_input_embeddings(
615 input_embeddings, all_actions_mask, noisy_action_features
616 )
617 else:
618 # Replace the embeddings of the action tokens with zeros
619 # (Later on, the positional embeddings will be added to them)
620 all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
621 input_embeddings = input_embeddings * ~all_actions_mask
622
623 # Build multimodal embeddings & attention mask
624 multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
625 input_embeddings, projected_patch_embeddings, attention_mask
626 )
627
628 # Build labels for multimodal sequence if needed
629 multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
630
631 # Dispatch to language model
632 language_model_output = self.language_model(
633 input_ids=None,
634 attention_mask=multimodal_attention_mask,
635 position_ids=None,
636 past_key_values=None,
637 inputs_embeds=multimodal_embeddings,
638 labels=multimodal_labels,
639 use_cache=use_cache,
640 output_attentions=output_attentions,
641 output_hidden_states=output_hidden_states,
642 return_dict=return_dict,
643 )
644
645 # === Otherwise =>> Assume Invalid! ===
646 elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
647 raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
648
649 else:
650 raise ValueError(
651 "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
652 f"=> `input_ids` = {input_ids is not None}\n"
653 f"=> `attention_mask` = {attention_mask is not None}\n"
654 f"=> `pixel_values` = {pixel_values is not None}\n"
655 f"=> `labels` = {labels is not None}\n"
656 f"=> `input_embeds` = {inputs_embeds is not None}\n"
657 f"=> `past_key_values` = {past_key_values is not None}\n"
658 f"=> `use_cache` = {use_cache}"
659 )
660
661 # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
662 if not return_dict:
663 if output_projector_features and (projected_patch_embeddings is not None):
664 return *language_model_output, projected_patch_embeddings
665
666 return language_model_output
667
668 return PrismaticCausalLMOutputWithPast(
669 loss=language_model_output.loss,
670 logits=language_model_output.logits,
671 past_key_values=language_model_output.past_key_values,
672 hidden_states=language_model_output.hidden_states,
673 attentions=language_model_output.attentions,
674 projector_features=projected_patch_embeddings,
675 )
676
677 # === GenerationMixin Methods ===
678 def prepare_inputs_for_generation(
679 self,
680 input_ids: Optional[torch.Tensor] = None,
681 past_key_values: Optional[List[torch.FloatTensor]] = None,
682 inputs_embeds: Optional[torch.FloatTensor] = None,
683 pixel_values: Optional[torch.FloatTensor] = None,
684 attention_mask: Optional[torch.Tensor] = None,
685 **kwargs: str,
686 ) -> Dict[str, torch.Tensor]:
687 """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
688 if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
689 (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
690 ):
691 raise ValueError("Generation with batch size > 1 is not currently supported!")
692
693 # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
694 if past_key_values is not None:
695 input_ids = input_ids[:, -1:]
696
697 # If `input_embeds` are passed, we only want to use them in the 1st generation step
698 if inputs_embeds is not None and past_key_values is None:
699 model_inputs = {"input_embeds": inputs_embeds}
700 else:
701 model_inputs = {"input_ids": input_ids}
702
703 # Make sure `pixel_values` are preserved in `model_inputs`
704 model_inputs.update(
705 {
706 "attention_mask": attention_mask,
707 "pixel_values": pixel_values,
708 "past_key_values": past_key_values,
709 "use_cache": kwargs.get("use_cache"),
710 }
711 )
712
713 return model_inputs
714
715 # Defer to Language Model (all handle this differently, with different return types)
716 def _reorder_cache(self, *args, **kwargs) -> Any:
717 return self.language_model._reorder_cache(*args, **kwargs)
718
719
720 class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
721 config_class: PretrainedConfig = OpenVLAConfig
722
723 def __init__(self, config: OpenVLAConfig) -> None:
724 super().__init__(config)
725 self.norm_stats = config.norm_stats
726
727 # Compute action bins
728 self.bins = np.linspace(-1, 1, config.n_action_bins)
729 self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
730
731 # Compute vocab size for de-tokenization -- revert added "multiple of"
732 self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
733
734 def _prepare_input_for_action_prediction(self, input_ids, attention_mask):
735 """Prepares input for action prediction by adding necessary tokens"""
736 # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
737 placeholder_action_token_ids = (
738 torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
739 )
740 input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
741
742 # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
743 stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
744 input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
745
746 # Extend the attention mask to fit the new shape of input
747 # Note: Only batch size == 1 supported right now
748 mask_extension = (
749 torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
750 .to(attention_mask.device)
751 .to(attention_mask.dtype)
752 )
753 attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
754
755 return input_ids, attention_mask
756
757 def _prepare_labels_for_action_prediction(self, labels, input_ids):
758 """Creates labels tensor for action prediction if not provided"""
759 # Extend labels tensor with fake action labels
760 ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
761 labels_extension = (
762 torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
763 * ARBITRARY_ACTION_TOKEN_IDX
764 )
765 labels = torch.cat([labels, labels_extension], dim=-1)
766
767 # Replace last label token with stop token
768 labels[:, -1] = STOP_INDEX
769
770 return labels
771
772 def _unnormalize_actions(self, normalized_actions, unnorm_key=None):
773 """Unnormalize actions using dataset statistics"""
774 action_norm_stats = self.get_action_stats(unnorm_key)
775
776 if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
777 mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool))
778 action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"])
779 elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
780 mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
781 action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
782 else:
783 raise ValueError("Unsupported action/proprio normalization type detected!")
784
785 actions = np.where(
786 mask,
787 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,
788 normalized_actions,
789 )
790
791 return actions
792
793 def _run_diffusion_prediction(
794 self,
795 input_embeddings,
796 all_actions_mask,
797 noise,
798 action_head,
799 projected_patch_embeddings,
800 labels,
801 attention_mask,
802 NUM_PATCHES,
803 NUM_PROMPT_TOKENS,
804 noisy_action_projector,
805 ):
806 """Run diffusion-based action prediction"""
807 # Clone embedding for reuse in each timestep
808 orig_projected_patch_embeddings = projected_patch_embeddings.clone()
809 curr_noisy_actions = noise
810
811 # Reverse diffusion: Iteratively denoise to generate action prediction
812 for t in action_head.noise_scheduler.timesteps:
813 # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
814 # embedding, and diffusion timestep embedding)
815 timesteps = torch.Tensor([t]).to(labels.device)
816 diffusion_timestep_embeddings = (
817 action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
818 ) # (B, llm_dim)
819 diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
820
821 # [Diffusion] Replace the embeddings of the action tokens with noisy actions
822 # (Later on, the positional embeddings will be added to them)
823
824 # For simplicity, append diffusion timestep embedding to the end of projected vision tokens
825 projected_patch_embeddings = torch.cat(
826 (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
827 )
828
829 # Reshape and project noisy actions into language embedding space
830 B = curr_noisy_actions.shape[0]
831 orig_curr_noisy_actions_shape = curr_noisy_actions.shape
832 curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
833 noisy_action_features = noisy_action_projector(curr_noisy_actions)
834 curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
835
836 # Replace action token embeddings with noisy action embeddings
837 input_embeddings = self._replace_input_embeddings(
838 input_embeddings.clone(), all_actions_mask, noisy_action_features
839 )
840
841 # Build multimodal embeddings and attention mask
842 multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
843 input_embeddings, projected_patch_embeddings, attention_mask
844 )
845
846 # Forward pass through language model
847 language_model_output = self.language_model(
848 input_ids=None,
849 attention_mask=multimodal_attention_mask,
850 position_ids=None,
851 past_key_values=None,
852 inputs_embeds=multimodal_embeddings,
853 labels=None,
854 use_cache=None,
855 output_attentions=False,
856 output_hidden_states=True,
857 return_dict=True,
858 )
859
860 # Extract hidden states for action portion of response
861 last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
862 actions_hidden_states = last_hidden_states[
863 :,
864 NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
865 :,
866 ] # (B, act_chunk_len, D)
867
868 # Predict noise and update noisy actions: x_t -> x_{t-1}
869 noise_pred = action_head.predict_noise(actions_hidden_states)
870 curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
871
872 curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
873
874 # Return final actions
875 return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
876
877 def _regression_or_discrete_prediction(
878 self,
879 input_embeddings,
880 all_actions_mask,
881 projected_patch_embeddings,
882 attention_mask,
883 labels,
884 NUM_PATCHES,
885 NUM_PROMPT_TOKENS,
886 action_head=None,
887 ):
888 """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
889 # Zero out action token embeddings
890 all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
891 input_embeddings = input_embeddings * ~all_actions_mask
892
893 # Build multimodal embeddings and attention mask
894 multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
895 input_embeddings, projected_patch_embeddings, attention_mask
896 )
897
898 # Forward pass through language model
899 language_model_output = self.language_model(
900 input_ids=None,
901 attention_mask=multimodal_attention_mask,
902 position_ids=None,
903 past_key_values=None,
904 inputs_embeds=multimodal_embeddings,
905 labels=None,
906 use_cache=None,
907 output_attentions=False,
908 output_hidden_states=True,
909 return_dict=True,
910 )
911
912 # Extract hidden states for action tokens
913 last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
914 actions_hidden_states = last_hidden_states[
915 :,
916 NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
917 :,
918 ] # (B, act_chunk_len, D)
919
920 # Handle different prediction methods
921 if action_head is not None:
922 # L1 regression prediction
923 normalized_actions = action_head.predict_action(actions_hidden_states)
924 normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
925 normalized_actions = normalized_actions.float().cpu().detach().numpy()
926 else:
927 # Discrete token-based prediction
928 predicted_action_token_ids = (
929 language_model_output.logits[
930 :,
931 NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
932 ]
933 .argmax(dim=2)
934 .cpu()
935 .numpy()
936 )
937 discretized_actions = self.vocab_size - predicted_action_token_ids
938 discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
939 normalized_actions = self.bin_centers[discretized_actions]
940 normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
941
942 return normalized_actions, actions_hidden_states
943
944 def predict_action(
945 self,
946 input_ids: Optional[torch.LongTensor] = None,
947 unnorm_key: Optional[str] = None,
948 proprio=None,
949 proprio_projector=None,
950 action_head=None,
951 noisy_action_projector=None,
952 use_film: bool = False,
953 **kwargs: str,
954 ) -> np.ndarray:
955 """Predict actions from input sequence, with options for different prediction methods.
956
957 Args:
958 input_ids: Input token ids
959 unnorm_key: Key for unnormalization statistics
960 proprio: Proprioceptive features
961 proprio_projector: Projector for proprioceptive features
962 action_head: Optional head for L1 regression or diffusion-based prediction
963 noisy_action_projector: Projector for noisy actions in diffusion-based prediction
964 use_film: Whether to use FiLM conditioning
965 **kwargs: Additional arguments including pixel_values and attention_mask
966
967 Returns:
968 Tuple of (unnormalized_actions, action_hidden_states)
969 """
970 # If the special empty token ('') does not already appear after the colon (':') token in the prompt
971 # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
972 if not torch.all(input_ids[:, -1] == 29871):
973 input_ids = torch.cat(
974 (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
975 )
976
977 pixel_values = kwargs["pixel_values"]
978 attention_mask = kwargs["attention_mask"]
979
980 # Create fake labels tensor (needed for action mask)
981 labels = input_ids.clone()
982 labels[:] = IGNORE_INDEX
983
984 # Get number of tokens in prompt (excluding the start token)
985 NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
986
987 # Prepare inputs by adding necessary tokens
988 input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
989
990 # Update labels tensor for action mask computation later
991 labels = self._prepare_labels_for_action_prediction(labels, input_ids)
992
993 # Get input embeddings and action masks
994 input_embeddings = self.get_input_embeddings()(input_ids)
995 all_actions_mask = self._process_action_masks(labels)
996
997 # Extract language embeddings
998 language_embeddings = input_embeddings[~all_actions_mask].reshape(
999 input_embeddings.shape[0], -1, input_embeddings.shape[2]
1000 )
1001
1002 # Process vision features
1003 projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
1004
1005 # Add proprioceptive features if provided
1006 use_proprio = proprio_projector is not None and proprio is not None
1007 if use_proprio:
1008 proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1009 projected_patch_embeddings = self._process_proprio_features(
1010 projected_patch_embeddings, proprio, proprio_projector
1011 )
1012
1013 # Use diffusion if provided, otherwise use regression or discrete prediction
1014 use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1015
1016 # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1017 NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1018 if use_proprio:
1019 NUM_PATCHES += 1
1020 if use_diffusion:
1021 NUM_PATCHES += 1
1022
1023 if use_diffusion:
1024 # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion
1025 noise = torch.randn(
1026 size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
1027 )
1028
1029 # Run diffusion-based prediction
1030 normalized_actions, actions_hidden_states = self._run_diffusion_prediction(
1031 input_embeddings,
1032 all_actions_mask,
1033 noise,
1034 action_head,
1035 projected_patch_embeddings,
1036 labels,
1037 attention_mask,
1038 NUM_PATCHES,
1039 NUM_PROMPT_TOKENS,
1040 noisy_action_projector,
1041 )
1042 else:
1043 # Run regression or discrete token-based prediction
1044 normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
1045 input_embeddings,
1046 all_actions_mask,
1047 projected_patch_embeddings,
1048 attention_mask,
1049 labels,
1050 NUM_PATCHES,
1051 NUM_PROMPT_TOKENS,
1052 action_head,
1053 )
1054
1055 # Unnormalize predicted actions
1056 actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1057
1058 return actions, actions_hidden_states
1059
1060 @staticmethod
1061 def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
1062 """Validate and resolve the unnormalization key for action statistics"""
1063 if unnorm_key is None:
1064 assert len(norm_stats) == 1, (
1065 f"Your model was trained on more than one dataset, "
1066 f"please pass a `unnorm_key` from the following options to choose the statistics "
1067 f"used for un-normalizing actions: {norm_stats.keys()}"
1068 )
1069 unnorm_key = next(iter(norm_stats.keys()))
1070
1071 assert unnorm_key in norm_stats, (
1072 f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
1073 f"please choose from: {norm_stats.keys()}"
1074 )
1075 return unnorm_key
1076
1077 def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
1078 """Get the dimensionality of the policy's action space."""
1079 unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1080 return len(self.norm_stats[unnorm_key]["action"]["min"])
1081
1082 def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
1083 """Get all the logged statistics for the given dataset."""
1084 unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1085 return self.norm_stats[unnorm_key]["action"]
1086