modeling_prismatic.py
| 1 | """ |
| 2 | modeling_prismatic.py |
| 3 | |
| 4 | Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting |
| 5 | from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the |
| 6 | logic in `prismatic.models.vlms.prismatic.py`. |
| 7 | |
| 8 | Note =>> for the time being, not adding the custom HF "docstring" formatting. |
| 9 | |
| 10 | References [LLaVa, IDEFICS-2]: |
| 11 | => https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py |
| 12 | => https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py |
| 13 | """ |
| 14 | |
| 15 | import logging |
| 16 | from dataclasses import dataclass |
| 17 | from functools import partial |
| 18 | from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union |
| 19 | |
| 20 | import numpy as np |
| 21 | import timm |
| 22 | import tokenizers |
| 23 | import torch |
| 24 | import torch.nn as nn |
| 25 | import transformers |
| 26 | from timm.models.vision_transformer import LayerScale |
| 27 | from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel |
| 28 | from transformers.modeling_outputs import ModelOutput |
| 29 | |
| 30 | from .configuration_prismatic import OpenVLAConfig, PrismaticConfig |
| 31 | |
| 32 | # Get Logger |
| 33 | logger = logging.getLogger(__name__) |
| 34 | |
| 35 | |
| 36 | # === PyTorch/HuggingFace Default IGNORE_INDEX (for CrossEntropyLoss labels) |
| 37 | IGNORE_INDEX = -100 |
| 38 | |
| 39 | |
| 40 | # === Utility Functions for Monkey-Patching === |
| 41 | def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: |
| 42 | def wrapper(*args: Any, **kwargs: Any) -> Any: |
| 43 | result = fn(*args, **kwargs) |
| 44 | return result[0] if isinstance(result, tuple) else result |
| 45 | |
| 46 | return wrapper |
| 47 | |
| 48 | |
| 49 | # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale. |
| 50 | # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109 |
| 51 | # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960 |
| 52 | def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor: |
| 53 | return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor |
| 54 | |
| 55 | |
| 56 | def ls_apply_patch(ls_module: LayerScale): |
| 57 | ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone()) |
| 58 | ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale) |
| 59 | del ls_module.gamma |
| 60 | |
| 61 | |
| 62 | # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) === |
| 63 | class PrismaticVisionBackbone(nn.Module): |
| 64 | def __init__( |
| 65 | self, |
| 66 | use_fused_vision_backbone: bool, |
| 67 | image_sizes: List[int], |
| 68 | timm_model_ids: List[str], |
| 69 | timm_override_act_layers: List[Optional[str]], |
| 70 | ) -> None: |
| 71 | super().__init__() |
| 72 | self.use_fused_vision_backbone = use_fused_vision_backbone |
| 73 | |
| 74 | # [Contract] Validate number of (fused) vision backbones, create "alpha" featurizer and Instantiate |
| 75 | # =>> Note :: Monkey-Patch the `forward()` function of the backbone to ensure FSDP-compatibility |
| 76 | # Hardcodes `get_intermediate_layers` to return the **SECOND-TO-LAST** layer patches! |
| 77 | assert len(timm_model_ids) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!" |
| 78 | self.featurizer = timm.create_model( |
| 79 | timm_model_ids[0], |
| 80 | pretrained=False, |
| 81 | num_classes=0, |
| 82 | img_size=image_sizes[0], |
| 83 | act_layer=timm_override_act_layers[0], |
| 84 | ) |
| 85 | self.featurizer.forward = unpack_tuple( |
| 86 | partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2}) |
| 87 | ) |
| 88 | self.embed_dim = self.featurizer.embed_dim |
| 89 | |
| 90 | # If `use_fused_vision_backbone` =>> create "beta" featurizer |
| 91 | if self.use_fused_vision_backbone: |
| 92 | self.fused_featurizer = timm.create_model( |
| 93 | timm_model_ids[1], |
| 94 | pretrained=False, |
| 95 | num_classes=0, |
| 96 | img_size=image_sizes[1], |
| 97 | act_layer=timm_override_act_layers[1], |
| 98 | ) |
| 99 | self.fused_featurizer.forward = unpack_tuple( |
| 100 | partial(self.fused_featurizer.get_intermediate_layers, n={len(self.fused_featurizer.blocks) - 2}) |
| 101 | ) |
| 102 | self.embed_dim += self.fused_featurizer.embed_dim |
| 103 | |
| 104 | # Patch `vision_backbone.featurizer` and `vision_backbone.fused_featurizer` with HF-Compatible LayerScale |
| 105 | for module in self.featurizer.modules(): |
| 106 | if isinstance(module, LayerScale): |
| 107 | ls_apply_patch(module) |
| 108 | |
| 109 | if self.use_fused_vision_backbone: |
| 110 | for module in self.fused_featurizer.modules(): |
| 111 | if isinstance(module, LayerScale): |
| 112 | ls_apply_patch(module) |
| 113 | |
| 114 | def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: |
| 115 | """Run image (`pixel_values`) through featurizer; if channel-stacked, then dispatch and sequence stack.""" |
| 116 | if not self.use_fused_vision_backbone: |
| 117 | return self.featurizer(pixel_values) |
| 118 | |
| 119 | # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack |
| 120 | img, img_fused = torch.split(pixel_values, [3, 3], dim=1) |
| 121 | patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused) |
| 122 | |
| 123 | return torch.cat([patches, patches_fused], dim=2) |
| 124 | |
| 125 | |
| 126 | # === Prismatic Projector (nn.Module) Definitions === |
| 127 | class PrismaticProjector(nn.Module): |
| 128 | def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None: |
| 129 | super().__init__() |
| 130 | self.use_fused_vision_backbone = use_fused_vision_backbone |
| 131 | self.vision_dim, self.llm_dim = vision_dim, llm_dim |
| 132 | |
| 133 | # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors! |
| 134 | if not self.use_fused_vision_backbone: |
| 135 | self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True) |
| 136 | self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) |
| 137 | self.act_fn1 = nn.GELU() |
| 138 | else: |
| 139 | initial_projection_dim = 4 * vision_dim |
| 140 | self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True) |
| 141 | self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True) |
| 142 | self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) |
| 143 | self.act_fn1 = nn.GELU() |
| 144 | self.act_fn2 = nn.GELU() |
| 145 | |
| 146 | def forward(self, img_patches: torch.Tensor) -> torch.Tensor: |
| 147 | if not self.use_fused_vision_backbone: |
| 148 | projected_features = self.fc1(img_patches) |
| 149 | projected_features = self.act_fn1(projected_features) |
| 150 | projected_features = self.fc2(projected_features) |
| 151 | else: |
| 152 | projected_features = self.fc1(img_patches) |
| 153 | projected_features = self.act_fn1(projected_features) |
| 154 | projected_features = self.fc2(projected_features) |
| 155 | projected_features = self.act_fn2(projected_features) |
| 156 | projected_features = self.fc3(projected_features) |
| 157 | |
| 158 | return projected_features |
| 159 | |
| 160 | |
| 161 | # === Main HF Class Definitions === |
| 162 | @dataclass |
| 163 | class PrismaticCausalLMOutputWithPast(ModelOutput): |
| 164 | """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features.""" |
| 165 | |
| 166 | loss: Optional[torch.FloatTensor] = None |
| 167 | logits: torch.FloatTensor = None |
| 168 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| 169 | hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
| 170 | attentions: Optional[Tuple[torch.FloatTensor]] = None |
| 171 | |
| 172 | # Additions for VLMs |
| 173 | projector_features: Optional[torch.FloatTensor] = None |
| 174 | |
| 175 | |
| 176 | class PrismaticPreTrainedModel(PreTrainedModel): |
| 177 | config_class: PretrainedConfig = PrismaticConfig |
| 178 | base_model_prefix: str = "model" |
| 179 | supports_gradient_checkpointing: bool = True |
| 180 | |
| 181 | _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"] |
| 182 | _skip_keys_device_placement: str = "past_key_values" |
| 183 | _supports_flash_attn_2: bool = True |
| 184 | |
| 185 | def _init_weights(self, module: nn.Module) -> None: |
| 186 | # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning! |
| 187 | # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at |
| 188 | # https://github.com/TRI-ML/prismatic-vlms |
| 189 | std = ( |
| 190 | self.config.initializer_range |
| 191 | if hasattr(self.config, "initializer_range") |
| 192 | else self.config.text_config.initializer_range |
| 193 | ) |
| 194 | |
| 195 | if hasattr(module, "class_embedding"): |
| 196 | module.class_embedding.data.normal_(mean=0.0, std=std) |
| 197 | |
| 198 | if isinstance(module, (nn.Linear, nn.Conv2d)): |
| 199 | module.weight.data.normal_(mean=0.0, std=std) |
| 200 | if module.bias is not None: |
| 201 | module.bias.data.zero_() |
| 202 | elif isinstance(module, nn.Embedding): |
| 203 | module.weight.data.normal_(mean=0.0, std=std) |
| 204 | if module.padding_idx is not None: |
| 205 | module.weight.data[module.padding_idx].zero_() |
| 206 | |
| 207 | @property |
| 208 | def _supports_sdpa(self) -> bool: |
| 209 | """Check LLM supports SDPA Attention""" |
| 210 | return self.language_model._supports_sdpa |
| 211 | |
| 212 | |
| 213 | class PrismaticForConditionalGeneration(PrismaticPreTrainedModel): |
| 214 | def __init__(self, config: PrismaticConfig) -> None: |
| 215 | super().__init__(config) |
| 216 | |
| 217 | # [Validation] Lightweight Validate on `config` Fields + Dependency Versions |
| 218 | if config.use_fused_vision_backbone is None: |
| 219 | raise ValueError("Missing config field `use_fused_vision_backbone`") |
| 220 | |
| 221 | if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}: |
| 222 | raise NotImplementedError( |
| 223 | "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue " |
| 224 | "if you urgently need support for latest TIMM versions." |
| 225 | ) |
| 226 | |
| 227 | if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"): |
| 228 | logger.warning( |
| 229 | f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got " |
| 230 | f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; " |
| 231 | f"there might be inference-time regressions due to dependency changes. If in doubt, please" |
| 232 | f"use the above versions." |
| 233 | ) |
| 234 | |
| 235 | # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone) |
| 236 | self.vision_backbone = PrismaticVisionBackbone( |
| 237 | config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers |
| 238 | ) |
| 239 | |
| 240 | # Create Multimodal Projector |
| 241 | self.projector = PrismaticProjector( |
| 242 | config.use_fused_vision_backbone, |
| 243 | vision_dim=self.vision_backbone.embed_dim, |
| 244 | llm_dim=config.text_config.hidden_size, |
| 245 | ) |
| 246 | |
| 247 | # Instantiate LLM Backbone |
| 248 | self.language_model = AutoModelForCausalLM.from_config( |
| 249 | config.text_config, attn_implementation=config._attn_implementation |
| 250 | ) |
| 251 | self.vocab_size = config.text_config.vocab_size |
| 252 | self.pad_token_id = config.pad_token_id |
| 253 | |
| 254 | # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing |
| 255 | self.post_init() |
| 256 | |
| 257 | # === `PreTrainedModel` Boilerplate === |
| 258 | def get_input_embeddings(self) -> nn.Module: |
| 259 | return self.language_model.get_input_embeddings() |
| 260 | |
| 261 | def set_input_embeddings(self, value: nn.Module) -> None: |
| 262 | self.language_model.set_input_embeddings(value) |
| 263 | |
| 264 | def get_output_embeddings(self) -> nn.Module: |
| 265 | return self.language_model.get_output_embeddings() |
| 266 | |
| 267 | def set_output_embeddings(self, new_embeddings: nn.Module) -> None: |
| 268 | self.language_model.set_output_embeddings(new_embeddings) |
| 269 | |
| 270 | def get_decoder(self) -> nn.Module: |
| 271 | return self.language_model.get_decoder() |
| 272 | |
| 273 | def set_decoder(self, decoder: nn.Module) -> None: |
| 274 | self.language_model.set_decoder(decoder) |
| 275 | |
| 276 | def tie_weights(self) -> None: |
| 277 | self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op) |
| 278 | |
| 279 | def resize_token_embeddings( |
| 280 | self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None |
| 281 | ) -> nn.Embedding: |
| 282 | updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) |
| 283 | |
| 284 | # Update config/instance variables |
| 285 | self.config.text_config.vocab_size = updated_embeddings.num_embeddings |
| 286 | self.vocab_size = updated_embeddings.num_embeddings |
| 287 | |
| 288 | return updated_embeddings |
| 289 | |
| 290 | # === Core Prismatic VLM `forward()` Logic === |
| 291 | def forward( |
| 292 | self, |
| 293 | input_ids: Optional[torch.LongTensor] = None, |
| 294 | attention_mask: Optional[torch.Tensor] = None, |
| 295 | pixel_values: Optional[torch.FloatTensor] = None, |
| 296 | labels: Optional[torch.LongTensor] = None, |
| 297 | inputs_embeds: Optional[torch.FloatTensor] = None, |
| 298 | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 299 | use_cache: Optional[bool] = None, |
| 300 | output_attentions: Optional[bool] = None, |
| 301 | output_hidden_states: Optional[bool] = None, |
| 302 | output_projector_features: Optional[bool] = None, |
| 303 | return_dict: Optional[bool] = None, |
| 304 | ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]: |
| 305 | """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance.""" |
| 306 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 307 | output_hidden_states = ( |
| 308 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 309 | ) |
| 310 | output_projector_features = output_projector_features if output_projector_features is not None else False |
| 311 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 312 | |
| 313 | # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off) |
| 314 | use_cache = use_cache and not self.training |
| 315 | |
| 316 | # Instantiate Placeholder for Projector Features |
| 317 | projected_patch_embeddings = None |
| 318 | |
| 319 | # Note :: We only support forward passes with the following cases: |
| 320 | # => Cached Generation :: (input_ids.shape[1] == 1) and (past_key_values is not None) |
| 321 | # => Unimodal Forward :: (pixel_values is None) |
| 322 | # => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0]) |
| 323 | |
| 324 | # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` === |
| 325 | if input_ids.shape[1] == 1: |
| 326 | assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!" |
| 327 | assert past_key_values is not None, "You must provide `past_key_values` during cached generation!" |
| 328 | assert labels is None, "Unexpected key `labels` provided during cached generation!" |
| 329 | |
| 330 | language_model_output = self.language_model( |
| 331 | input_ids=input_ids, |
| 332 | attention_mask=None, |
| 333 | position_ids=None, |
| 334 | past_key_values=past_key_values, |
| 335 | inputs_embeds=None, |
| 336 | labels=None, |
| 337 | use_cache=use_cache, |
| 338 | output_attentions=output_attentions, |
| 339 | output_hidden_states=output_hidden_states, |
| 340 | return_dict=return_dict, |
| 341 | ) |
| 342 | |
| 343 | # === Handle Unimodal Forward === |
| 344 | elif pixel_values is None: |
| 345 | assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!" |
| 346 | assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!" |
| 347 | |
| 348 | language_model_output = self.language_model( |
| 349 | input_ids=input_ids, |
| 350 | attention_mask=attention_mask, |
| 351 | position_ids=None, |
| 352 | past_key_values=None, |
| 353 | inputs_embeds=None, |
| 354 | labels=labels, |
| 355 | use_cache=use_cache, |
| 356 | output_attentions=output_attentions, |
| 357 | output_hidden_states=output_hidden_states, |
| 358 | return_dict=return_dict, |
| 359 | ) |
| 360 | |
| 361 | # === Handle Multimodal Forward === |
| 362 | elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]): |
| 363 | assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!" |
| 364 | |
| 365 | # Visual Feature Extraction |
| 366 | patch_features = self.vision_backbone(pixel_values) |
| 367 | |
| 368 | # Projection Logic =>> Update Attention Mask |
| 369 | projected_patch_embeddings = self.projector(patch_features) |
| 370 | projected_patch_attention_mask = None |
| 371 | if attention_mask is not None: |
| 372 | projected_patch_attention_mask = torch.full( |
| 373 | (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), |
| 374 | fill_value=True, |
| 375 | dtype=attention_mask.dtype, |
| 376 | device=attention_mask.device, |
| 377 | ) |
| 378 | |
| 379 | # Get Input Embeddings (from Language Model Embeddings) |
| 380 | input_embeddings = self.get_input_embeddings()(input_ids) |
| 381 | |
| 382 | # Build Multimodal Embeddings & Attention Mask =>> Prismatic defaults to inserting after <BOS> token (1:) |
| 383 | multimodal_embeddings = torch.cat( |
| 384 | [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1 |
| 385 | ) |
| 386 | multimodal_attention_mask = None |
| 387 | if attention_mask is not None: |
| 388 | multimodal_attention_mask = torch.cat( |
| 389 | [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1 |
| 390 | ) |
| 391 | |
| 392 | # Build Labels (if specified) =>> Ignore Labels for Patch Embeddings |
| 393 | multimodal_labels = None |
| 394 | if labels is not None: |
| 395 | projected_patch_labels = torch.full( |
| 396 | (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), |
| 397 | fill_value=IGNORE_INDEX, |
| 398 | dtype=labels.dtype, |
| 399 | device=labels.device, |
| 400 | ) |
| 401 | multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1) |
| 402 | |
| 403 | # Dispatch to Language Model |
| 404 | language_model_output = self.language_model( |
| 405 | input_ids=None, |
| 406 | attention_mask=multimodal_attention_mask, |
| 407 | position_ids=None, |
| 408 | past_key_values=None, |
| 409 | inputs_embeds=multimodal_embeddings, |
| 410 | labels=multimodal_labels, |
| 411 | use_cache=use_cache, |
| 412 | output_attentions=output_attentions, |
| 413 | output_hidden_states=output_hidden_states, |
| 414 | return_dict=return_dict, |
| 415 | ) |
| 416 | |
| 417 | # === Otherwise =>> Assume Invalid! === |
| 418 | elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]): |
| 419 | raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!") |
| 420 | |
| 421 | else: |
| 422 | raise ValueError( |
| 423 | "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n" |
| 424 | f"=> `input_ids` = {input_ids is not None}\n" |
| 425 | f"=> `attention_mask` = {attention_mask is not None}\n" |
| 426 | f"=> `pixel_values` = {pixel_values is not None}\n" |
| 427 | f"=> `labels` = {labels is not None}\n" |
| 428 | f"=> `input_embeds` = {inputs_embeds is not None}\n" |
| 429 | f"=> `past_key_values` = {past_key_values is not None}\n" |
| 430 | f"=> `use_cache` = {use_cache}" |
| 431 | ) |
| 432 | |
| 433 | # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`) |
| 434 | if not return_dict: |
| 435 | if output_projector_features and (projected_patch_embeddings is not None): |
| 436 | return *language_model_output, projected_patch_embeddings |
| 437 | |
| 438 | return language_model_output |
| 439 | |
| 440 | return PrismaticCausalLMOutputWithPast( |
| 441 | loss=language_model_output.loss, |
| 442 | logits=language_model_output.logits, |
| 443 | past_key_values=language_model_output.past_key_values, |
| 444 | hidden_states=language_model_output.hidden_states, |
| 445 | attentions=language_model_output.attentions, |
| 446 | projector_features=projected_patch_embeddings, |
| 447 | ) |
| 448 | |
| 449 | # === GenerationMixin Methods === |
| 450 | def prepare_inputs_for_generation( |
| 451 | self, |
| 452 | input_ids: Optional[torch.Tensor] = None, |
| 453 | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 454 | inputs_embeds: Optional[torch.FloatTensor] = None, |
| 455 | pixel_values: Optional[torch.FloatTensor] = None, |
| 456 | attention_mask: Optional[torch.Tensor] = None, |
| 457 | **kwargs: str, |
| 458 | ) -> Dict[str, torch.Tensor]: |
| 459 | """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic.""" |
| 460 | if ((input_ids is not None) and (input_ids.shape[0] > 1)) or ( |
| 461 | (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1) |
| 462 | ): |
| 463 | raise ValueError("Generation with batch size > 1 is not currently supported!") |
| 464 | |
| 465 | # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens |
| 466 | if past_key_values is not None: |
| 467 | input_ids = input_ids[:, -1:] |
| 468 | |
| 469 | # If `input_embeds` are passed, we only want to use them in the 1st generation step |
| 470 | if inputs_embeds is not None and past_key_values is None: |
| 471 | model_inputs = {"input_embeds": inputs_embeds} |
| 472 | else: |
| 473 | model_inputs = {"input_ids": input_ids} |
| 474 | |
| 475 | # Make sure `pixel_values` are preserved in `model_inputs` |
| 476 | model_inputs.update( |
| 477 | { |
| 478 | "attention_mask": attention_mask, |
| 479 | "pixel_values": pixel_values, |
| 480 | "past_key_values": past_key_values, |
| 481 | "use_cache": kwargs.get("use_cache"), |
| 482 | } |
| 483 | ) |
| 484 | |
| 485 | return model_inputs |
| 486 | |
| 487 | # Defer to Language Model (all handle this differently, with different return types) |
| 488 | def _reorder_cache(self, *args, **kwargs) -> Any: |
| 489 | return self.language_model._reorder_cache(*args, **kwargs) |
| 490 | |
| 491 | |
| 492 | class OpenVLAForActionPrediction(PrismaticForConditionalGeneration): |
| 493 | config_class: PretrainedConfig = OpenVLAConfig |
| 494 | |
| 495 | def __init__(self, config: OpenVLAConfig) -> None: |
| 496 | super().__init__(config) |
| 497 | self.norm_stats = config.norm_stats |
| 498 | |
| 499 | # Compute action bins |
| 500 | self.bins = np.linspace(-1, 1, config.n_action_bins) |
| 501 | self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 |
| 502 | |
| 503 | # Compute vocab size for de-tokenization -- revert added "multiple of" |
| 504 | self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of |
| 505 | |
| 506 | def predict_action( |
| 507 | self, input_ids: Optional[torch.LongTensor] = None, unnorm_key: Optional[str] = None, **kwargs: str |
| 508 | ) -> np.ndarray: |
| 509 | """Thin wrapper around .generate() that decodes predicted actions and unnormalizes them.""" |
| 510 | # If the special empty token ('') does not already appear after the colon (':') token in the prompt |
| 511 | # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time |
| 512 | if not torch.all(input_ids[:, -1] == 29871): |
| 513 | input_ids = torch.cat( |
| 514 | (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1 |
| 515 | ) |
| 516 | |
| 517 | # Run VLA inference |
| 518 | generated_ids = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs) |
| 519 | |
| 520 | # Extract predicted action tokens and translate into (normalized) continuous actions |
| 521 | predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy() |
| 522 | discretized_actions = self.vocab_size - predicted_action_token_ids |
| 523 | discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) |
| 524 | normalized_actions = self.bin_centers[discretized_actions] |
| 525 | |
| 526 | # Unnormalize actions |
| 527 | action_norm_stats = self.get_action_stats(unnorm_key) |
| 528 | mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool)) |
| 529 | action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"]) |
| 530 | actions = np.where( |
| 531 | mask, |
| 532 | 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low, |
| 533 | normalized_actions, |
| 534 | ) |
| 535 | |
| 536 | return actions |
| 537 | |
| 538 | @staticmethod |
| 539 | def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str: |
| 540 | if unnorm_key is None: |
| 541 | assert len(norm_stats) == 1, ( |
| 542 | f"Your model was trained on more than one dataset, " |
| 543 | f"please pass a `unnorm_key` from the following options to choose the statistics " |
| 544 | f"used for un-normalizing actions: {norm_stats.keys()}" |
| 545 | ) |
| 546 | unnorm_key = next(iter(norm_stats.keys())) |
| 547 | |
| 548 | assert unnorm_key in norm_stats, ( |
| 549 | f"The `unnorm_key` you chose is not in the set of available dataset statistics, " |
| 550 | f"please choose from: {norm_stats.keys()}" |
| 551 | ) |
| 552 | return unnorm_key |
| 553 | |
| 554 | def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: |
| 555 | """Get the dimensionality of the policy's action space.""" |
| 556 | unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
| 557 | return len(self.norm_stats[unnorm_key]["action"]["q01"]) |
| 558 | |
| 559 | def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]: |
| 560 | """Get all the logged statistics for the given dataset.""" |
| 561 | unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
| 562 | return self.norm_stats[unnorm_key]["action"] |
| 563 | |