modeling_prismatic.py
25.5 KB · 563 lines · python Raw
1 """
2 modeling_prismatic.py
3
4 Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting
5 from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the
6 logic in `prismatic.models.vlms.prismatic.py`.
7
8 Note =>> for the time being, not adding the custom HF "docstring" formatting.
9
10 References [LLaVa, IDEFICS-2]:
11 => https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py
12 => https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py
13 """
14
15 import logging
16 from dataclasses import dataclass
17 from functools import partial
18 from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
19
20 import numpy as np
21 import timm
22 import tokenizers
23 import torch
24 import torch.nn as nn
25 import transformers
26 from timm.models.vision_transformer import LayerScale
27 from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
28 from transformers.modeling_outputs import ModelOutput
29
30 from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
31
32 # Get Logger
33 logger = logging.getLogger(__name__)
34
35
36 # === PyTorch/HuggingFace Default IGNORE_INDEX (for CrossEntropyLoss labels)
37 IGNORE_INDEX = -100
38
39
40 # === Utility Functions for Monkey-Patching ===
41 def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
42 def wrapper(*args: Any, **kwargs: Any) -> Any:
43 result = fn(*args, **kwargs)
44 return result[0] if isinstance(result, tuple) else result
45
46 return wrapper
47
48
49 # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
50 # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
51 # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
52 def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
53 return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
54
55
56 def ls_apply_patch(ls_module: LayerScale):
57 ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
58 ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
59 del ls_module.gamma
60
61
62 # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
63 class PrismaticVisionBackbone(nn.Module):
64 def __init__(
65 self,
66 use_fused_vision_backbone: bool,
67 image_sizes: List[int],
68 timm_model_ids: List[str],
69 timm_override_act_layers: List[Optional[str]],
70 ) -> None:
71 super().__init__()
72 self.use_fused_vision_backbone = use_fused_vision_backbone
73
74 # [Contract] Validate number of (fused) vision backbones, create "alpha" featurizer and Instantiate
75 # =>> Note :: Monkey-Patch the `forward()` function of the backbone to ensure FSDP-compatibility
76 # Hardcodes `get_intermediate_layers` to return the **SECOND-TO-LAST** layer patches!
77 assert len(timm_model_ids) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!"
78 self.featurizer = timm.create_model(
79 timm_model_ids[0],
80 pretrained=False,
81 num_classes=0,
82 img_size=image_sizes[0],
83 act_layer=timm_override_act_layers[0],
84 )
85 self.featurizer.forward = unpack_tuple(
86 partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2})
87 )
88 self.embed_dim = self.featurizer.embed_dim
89
90 # If `use_fused_vision_backbone` =>> create "beta" featurizer
91 if self.use_fused_vision_backbone:
92 self.fused_featurizer = timm.create_model(
93 timm_model_ids[1],
94 pretrained=False,
95 num_classes=0,
96 img_size=image_sizes[1],
97 act_layer=timm_override_act_layers[1],
98 )
99 self.fused_featurizer.forward = unpack_tuple(
100 partial(self.fused_featurizer.get_intermediate_layers, n={len(self.fused_featurizer.blocks) - 2})
101 )
102 self.embed_dim += self.fused_featurizer.embed_dim
103
104 # Patch `vision_backbone.featurizer` and `vision_backbone.fused_featurizer` with HF-Compatible LayerScale
105 for module in self.featurizer.modules():
106 if isinstance(module, LayerScale):
107 ls_apply_patch(module)
108
109 if self.use_fused_vision_backbone:
110 for module in self.fused_featurizer.modules():
111 if isinstance(module, LayerScale):
112 ls_apply_patch(module)
113
114 def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
115 """Run image (`pixel_values`) through featurizer; if channel-stacked, then dispatch and sequence stack."""
116 if not self.use_fused_vision_backbone:
117 return self.featurizer(pixel_values)
118
119 # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
120 img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
121 patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
122
123 return torch.cat([patches, patches_fused], dim=2)
124
125
126 # === Prismatic Projector (nn.Module) Definitions ===
127 class PrismaticProjector(nn.Module):
128 def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
129 super().__init__()
130 self.use_fused_vision_backbone = use_fused_vision_backbone
131 self.vision_dim, self.llm_dim = vision_dim, llm_dim
132
133 # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
134 if not self.use_fused_vision_backbone:
135 self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
136 self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
137 self.act_fn1 = nn.GELU()
138 else:
139 initial_projection_dim = 4 * vision_dim
140 self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
141 self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
142 self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
143 self.act_fn1 = nn.GELU()
144 self.act_fn2 = nn.GELU()
145
146 def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
147 if not self.use_fused_vision_backbone:
148 projected_features = self.fc1(img_patches)
149 projected_features = self.act_fn1(projected_features)
150 projected_features = self.fc2(projected_features)
151 else:
152 projected_features = self.fc1(img_patches)
153 projected_features = self.act_fn1(projected_features)
154 projected_features = self.fc2(projected_features)
155 projected_features = self.act_fn2(projected_features)
156 projected_features = self.fc3(projected_features)
157
158 return projected_features
159
160
161 # === Main HF Class Definitions ===
162 @dataclass
163 class PrismaticCausalLMOutputWithPast(ModelOutput):
164 """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
165
166 loss: Optional[torch.FloatTensor] = None
167 logits: torch.FloatTensor = None
168 past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
169 hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
170 attentions: Optional[Tuple[torch.FloatTensor]] = None
171
172 # Additions for VLMs
173 projector_features: Optional[torch.FloatTensor] = None
174
175
176 class PrismaticPreTrainedModel(PreTrainedModel):
177 config_class: PretrainedConfig = PrismaticConfig
178 base_model_prefix: str = "model"
179 supports_gradient_checkpointing: bool = True
180
181 _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
182 _skip_keys_device_placement: str = "past_key_values"
183 _supports_flash_attn_2: bool = True
184
185 def _init_weights(self, module: nn.Module) -> None:
186 # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
187 # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
188 # https://github.com/TRI-ML/prismatic-vlms
189 std = (
190 self.config.initializer_range
191 if hasattr(self.config, "initializer_range")
192 else self.config.text_config.initializer_range
193 )
194
195 if hasattr(module, "class_embedding"):
196 module.class_embedding.data.normal_(mean=0.0, std=std)
197
198 if isinstance(module, (nn.Linear, nn.Conv2d)):
199 module.weight.data.normal_(mean=0.0, std=std)
200 if module.bias is not None:
201 module.bias.data.zero_()
202 elif isinstance(module, nn.Embedding):
203 module.weight.data.normal_(mean=0.0, std=std)
204 if module.padding_idx is not None:
205 module.weight.data[module.padding_idx].zero_()
206
207 @property
208 def _supports_sdpa(self) -> bool:
209 """Check LLM supports SDPA Attention"""
210 return self.language_model._supports_sdpa
211
212
213 class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
214 def __init__(self, config: PrismaticConfig) -> None:
215 super().__init__(config)
216
217 # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
218 if config.use_fused_vision_backbone is None:
219 raise ValueError("Missing config field `use_fused_vision_backbone`")
220
221 if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
222 raise NotImplementedError(
223 "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
224 "if you urgently need support for latest TIMM versions."
225 )
226
227 if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
228 logger.warning(
229 f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
230 f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
231 f"there might be inference-time regressions due to dependency changes. If in doubt, please"
232 f"use the above versions."
233 )
234
235 # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
236 self.vision_backbone = PrismaticVisionBackbone(
237 config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
238 )
239
240 # Create Multimodal Projector
241 self.projector = PrismaticProjector(
242 config.use_fused_vision_backbone,
243 vision_dim=self.vision_backbone.embed_dim,
244 llm_dim=config.text_config.hidden_size,
245 )
246
247 # Instantiate LLM Backbone
248 self.language_model = AutoModelForCausalLM.from_config(
249 config.text_config, attn_implementation=config._attn_implementation
250 )
251 self.vocab_size = config.text_config.vocab_size
252 self.pad_token_id = config.pad_token_id
253
254 # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
255 self.post_init()
256
257 # === `PreTrainedModel` Boilerplate ===
258 def get_input_embeddings(self) -> nn.Module:
259 return self.language_model.get_input_embeddings()
260
261 def set_input_embeddings(self, value: nn.Module) -> None:
262 self.language_model.set_input_embeddings(value)
263
264 def get_output_embeddings(self) -> nn.Module:
265 return self.language_model.get_output_embeddings()
266
267 def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
268 self.language_model.set_output_embeddings(new_embeddings)
269
270 def get_decoder(self) -> nn.Module:
271 return self.language_model.get_decoder()
272
273 def set_decoder(self, decoder: nn.Module) -> None:
274 self.language_model.set_decoder(decoder)
275
276 def tie_weights(self) -> None:
277 self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
278
279 def resize_token_embeddings(
280 self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
281 ) -> nn.Embedding:
282 updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
283
284 # Update config/instance variables
285 self.config.text_config.vocab_size = updated_embeddings.num_embeddings
286 self.vocab_size = updated_embeddings.num_embeddings
287
288 return updated_embeddings
289
290 # === Core Prismatic VLM `forward()` Logic ===
291 def forward(
292 self,
293 input_ids: Optional[torch.LongTensor] = None,
294 attention_mask: Optional[torch.Tensor] = None,
295 pixel_values: Optional[torch.FloatTensor] = None,
296 labels: Optional[torch.LongTensor] = None,
297 inputs_embeds: Optional[torch.FloatTensor] = None,
298 past_key_values: Optional[List[torch.FloatTensor]] = None,
299 use_cache: Optional[bool] = None,
300 output_attentions: Optional[bool] = None,
301 output_hidden_states: Optional[bool] = None,
302 output_projector_features: Optional[bool] = None,
303 return_dict: Optional[bool] = None,
304 ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
305 """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
306 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
307 output_hidden_states = (
308 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
309 )
310 output_projector_features = output_projector_features if output_projector_features is not None else False
311 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
312
313 # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
314 use_cache = use_cache and not self.training
315
316 # Instantiate Placeholder for Projector Features
317 projected_patch_embeddings = None
318
319 # Note :: We only support forward passes with the following cases:
320 # => Cached Generation :: (input_ids.shape[1] == 1) and (past_key_values is not None)
321 # => Unimodal Forward :: (pixel_values is None)
322 # => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
323
324 # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
325 if input_ids.shape[1] == 1:
326 assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
327 assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
328 assert labels is None, "Unexpected key `labels` provided during cached generation!"
329
330 language_model_output = self.language_model(
331 input_ids=input_ids,
332 attention_mask=None,
333 position_ids=None,
334 past_key_values=past_key_values,
335 inputs_embeds=None,
336 labels=None,
337 use_cache=use_cache,
338 output_attentions=output_attentions,
339 output_hidden_states=output_hidden_states,
340 return_dict=return_dict,
341 )
342
343 # === Handle Unimodal Forward ===
344 elif pixel_values is None:
345 assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
346 assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
347
348 language_model_output = self.language_model(
349 input_ids=input_ids,
350 attention_mask=attention_mask,
351 position_ids=None,
352 past_key_values=None,
353 inputs_embeds=None,
354 labels=labels,
355 use_cache=use_cache,
356 output_attentions=output_attentions,
357 output_hidden_states=output_hidden_states,
358 return_dict=return_dict,
359 )
360
361 # === Handle Multimodal Forward ===
362 elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
363 assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
364
365 # Visual Feature Extraction
366 patch_features = self.vision_backbone(pixel_values)
367
368 # Projection Logic =>> Update Attention Mask
369 projected_patch_embeddings = self.projector(patch_features)
370 projected_patch_attention_mask = None
371 if attention_mask is not None:
372 projected_patch_attention_mask = torch.full(
373 (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
374 fill_value=True,
375 dtype=attention_mask.dtype,
376 device=attention_mask.device,
377 )
378
379 # Get Input Embeddings (from Language Model Embeddings)
380 input_embeddings = self.get_input_embeddings()(input_ids)
381
382 # Build Multimodal Embeddings & Attention Mask =>> Prismatic defaults to inserting after <BOS> token (1:)
383 multimodal_embeddings = torch.cat(
384 [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
385 )
386 multimodal_attention_mask = None
387 if attention_mask is not None:
388 multimodal_attention_mask = torch.cat(
389 [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
390 )
391
392 # Build Labels (if specified) =>> Ignore Labels for Patch Embeddings
393 multimodal_labels = None
394 if labels is not None:
395 projected_patch_labels = torch.full(
396 (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
397 fill_value=IGNORE_INDEX,
398 dtype=labels.dtype,
399 device=labels.device,
400 )
401 multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
402
403 # Dispatch to Language Model
404 language_model_output = self.language_model(
405 input_ids=None,
406 attention_mask=multimodal_attention_mask,
407 position_ids=None,
408 past_key_values=None,
409 inputs_embeds=multimodal_embeddings,
410 labels=multimodal_labels,
411 use_cache=use_cache,
412 output_attentions=output_attentions,
413 output_hidden_states=output_hidden_states,
414 return_dict=return_dict,
415 )
416
417 # === Otherwise =>> Assume Invalid! ===
418 elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
419 raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
420
421 else:
422 raise ValueError(
423 "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
424 f"=> `input_ids` = {input_ids is not None}\n"
425 f"=> `attention_mask` = {attention_mask is not None}\n"
426 f"=> `pixel_values` = {pixel_values is not None}\n"
427 f"=> `labels` = {labels is not None}\n"
428 f"=> `input_embeds` = {inputs_embeds is not None}\n"
429 f"=> `past_key_values` = {past_key_values is not None}\n"
430 f"=> `use_cache` = {use_cache}"
431 )
432
433 # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
434 if not return_dict:
435 if output_projector_features and (projected_patch_embeddings is not None):
436 return *language_model_output, projected_patch_embeddings
437
438 return language_model_output
439
440 return PrismaticCausalLMOutputWithPast(
441 loss=language_model_output.loss,
442 logits=language_model_output.logits,
443 past_key_values=language_model_output.past_key_values,
444 hidden_states=language_model_output.hidden_states,
445 attentions=language_model_output.attentions,
446 projector_features=projected_patch_embeddings,
447 )
448
449 # === GenerationMixin Methods ===
450 def prepare_inputs_for_generation(
451 self,
452 input_ids: Optional[torch.Tensor] = None,
453 past_key_values: Optional[List[torch.FloatTensor]] = None,
454 inputs_embeds: Optional[torch.FloatTensor] = None,
455 pixel_values: Optional[torch.FloatTensor] = None,
456 attention_mask: Optional[torch.Tensor] = None,
457 **kwargs: str,
458 ) -> Dict[str, torch.Tensor]:
459 """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
460 if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
461 (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
462 ):
463 raise ValueError("Generation with batch size > 1 is not currently supported!")
464
465 # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
466 if past_key_values is not None:
467 input_ids = input_ids[:, -1:]
468
469 # If `input_embeds` are passed, we only want to use them in the 1st generation step
470 if inputs_embeds is not None and past_key_values is None:
471 model_inputs = {"input_embeds": inputs_embeds}
472 else:
473 model_inputs = {"input_ids": input_ids}
474
475 # Make sure `pixel_values` are preserved in `model_inputs`
476 model_inputs.update(
477 {
478 "attention_mask": attention_mask,
479 "pixel_values": pixel_values,
480 "past_key_values": past_key_values,
481 "use_cache": kwargs.get("use_cache"),
482 }
483 )
484
485 return model_inputs
486
487 # Defer to Language Model (all handle this differently, with different return types)
488 def _reorder_cache(self, *args, **kwargs) -> Any:
489 return self.language_model._reorder_cache(*args, **kwargs)
490
491
492 class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
493 config_class: PretrainedConfig = OpenVLAConfig
494
495 def __init__(self, config: OpenVLAConfig) -> None:
496 super().__init__(config)
497 self.norm_stats = config.norm_stats
498
499 # Compute action bins
500 self.bins = np.linspace(-1, 1, config.n_action_bins)
501 self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
502
503 # Compute vocab size for de-tokenization -- revert added "multiple of"
504 self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
505
506 def predict_action(
507 self, input_ids: Optional[torch.LongTensor] = None, unnorm_key: Optional[str] = None, **kwargs: str
508 ) -> np.ndarray:
509 """Thin wrapper around .generate() that decodes predicted actions and unnormalizes them."""
510 # If the special empty token ('') does not already appear after the colon (':') token in the prompt
511 # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
512 if not torch.all(input_ids[:, -1] == 29871):
513 input_ids = torch.cat(
514 (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
515 )
516
517 # Run VLA inference
518 generated_ids = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs)
519
520 # Extract predicted action tokens and translate into (normalized) continuous actions
521 predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy()
522 discretized_actions = self.vocab_size - predicted_action_token_ids
523 discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
524 normalized_actions = self.bin_centers[discretized_actions]
525
526 # Unnormalize actions
527 action_norm_stats = self.get_action_stats(unnorm_key)
528 mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
529 action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
530 actions = np.where(
531 mask,
532 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low,
533 normalized_actions,
534 )
535
536 return actions
537
538 @staticmethod
539 def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
540 if unnorm_key is None:
541 assert len(norm_stats) == 1, (
542 f"Your model was trained on more than one dataset, "
543 f"please pass a `unnorm_key` from the following options to choose the statistics "
544 f"used for un-normalizing actions: {norm_stats.keys()}"
545 )
546 unnorm_key = next(iter(norm_stats.keys()))
547
548 assert unnorm_key in norm_stats, (
549 f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
550 f"please choose from: {norm_stats.keys()}"
551 )
552 return unnorm_key
553
554 def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
555 """Get the dimensionality of the policy's action space."""
556 unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
557 return len(self.norm_stats[unnorm_key]["action"]["q01"])
558
559 def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
560 """Get all the logged statistics for the given dataset."""
561 unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
562 return self.norm_stats[unnorm_key]["action"]
563