modeling_locateanything.py
| 1 | # -------------------------------------------------------- |
| 2 | # NVIDIA |
| 3 | # Copyright (c) 2025 NVIDIA |
| 4 | # Licensed under The MIT License [see LICENSE for details] |
| 5 | # -------------------------------------------------------- |
| 6 | |
| 7 | import time |
| 8 | from typing import List, Optional, Tuple, Union |
| 9 | |
| 10 | import numpy as np |
| 11 | import torch |
| 12 | from torch import nn |
| 13 | from torch.nn import CrossEntropyLoss |
| 14 | from transformers.generation import GenerationMixin |
| 15 | from transformers.modeling_outputs import CausalLMOutputWithPast |
| 16 | from transformers.modeling_utils import PreTrainedModel |
| 17 | from transformers.utils import add_start_docstrings, is_flash_attn_2_available, logging |
| 18 | from peft import LoraConfig, get_peft_model |
| 19 | |
| 20 | from .configuration_locateanything import LocateAnythingConfig |
| 21 | from .modeling_qwen2 import Qwen2ForCausalLM |
| 22 | from .modeling_vit import MoonVitPretrainedModel |
| 23 | from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM |
| 24 | from .mask_sdpa_utils import * |
| 25 | from .mask_magi_utils import * |
| 26 | from .configuration_qwen2 import Qwen2Config |
| 27 | |
| 28 | from .generate_utils import ( |
| 29 | sample_tokens, |
| 30 | handle_pattern, |
| 31 | get_token_ids_from_config, |
| 32 | ) |
| 33 | |
| 34 | logger = logging.get_logger(__name__) |
| 35 | |
| 36 | |
| 37 | LOCATEANYTHING_START_DOCSTRING = r""" |
| 38 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| 39 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| 40 | etc.) |
| 41 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| 42 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| 43 | and behavior. |
| 44 | Parameters: |
| 45 | config ([`LocateAnythingConfig`]): |
| 46 | Model configuration class with all the parameters of the model. Initializing with a config file does not |
| 47 | load the weights associated with the model, only the configuration. Check out the |
| 48 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| 49 | """ |
| 50 | |
| 51 | @add_start_docstrings( |
| 52 | "The bare LocateAnything Model outputting raw hidden-states without any specific head on top.", |
| 53 | LOCATEANYTHING_START_DOCSTRING, |
| 54 | ) |
| 55 | class LocateAnythingPreTrainedModel(PreTrainedModel): |
| 56 | config_class = LocateAnythingConfig |
| 57 | base_model_prefix = "model" |
| 58 | main_input_name = 'input_ids' |
| 59 | supports_gradient_checkpointing = True |
| 60 | _no_split_modules = ["Qwen2DecoderLayer"] |
| 61 | _skip_keys_device_placement = "past_key_values" |
| 62 | _supports_flash_attn_2 = True |
| 63 | _supports_cache_class = True |
| 64 | _supports_static_cache = True |
| 65 | _supports_quantized_cache = True |
| 66 | _supports_sdpa = True |
| 67 | |
| 68 | @classmethod |
| 69 | def _autoset_attn_implementation(cls, config, *args, **kwargs): |
| 70 | if getattr(config, '_attn_implementation', None) == 'magi': |
| 71 | return config |
| 72 | return super()._autoset_attn_implementation(config, *args, **kwargs) |
| 73 | |
| 74 | def _check_and_adjust_attn_implementation(self, attn_implementation, is_init_check=False): |
| 75 | if attn_implementation == "magi": |
| 76 | return "magi" |
| 77 | return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check) |
| 78 | |
| 79 | def _init_weights(self, module): |
| 80 | std = getattr(self.config, 'initializer_range', None) or self.config.text_config.initializer_range |
| 81 | if isinstance(module, (nn.Linear, nn.Conv2d)): |
| 82 | module.weight.data.normal_(mean=0.0, std=std) |
| 83 | if module.bias is not None: |
| 84 | module.bias.data.zero_() |
| 85 | elif isinstance(module, nn.Embedding): |
| 86 | module.weight.data.normal_(mean=0.0, std=std) |
| 87 | if module.padding_idx is not None: |
| 88 | module.weight.data[module.padding_idx].zero_() |
| 89 | |
| 90 | |
| 91 | class LocateAnythingForConditionalGeneration(LocateAnythingPreTrainedModel, GenerationMixin): |
| 92 | config_class = LocateAnythingConfig |
| 93 | def __init__(self, config: LocateAnythingConfig, vision_model=None, language_model=None): |
| 94 | super().__init__(config) |
| 95 | |
| 96 | self.template = config.template |
| 97 | self.mlp_checkpoint = config.mlp_checkpoint |
| 98 | |
| 99 | logger.info(f'mlp_checkpoint: {self.mlp_checkpoint}') |
| 100 | if vision_model is not None: |
| 101 | self.vision_model = vision_model |
| 102 | else: |
| 103 | if config.vision_config.model_type == 'moonvit': |
| 104 | vision_attn_impl = getattr(config.vision_config, '_attn_implementation', None) or 'flash_attention_2' |
| 105 | if vision_attn_impl == 'flash_attention_2' and not is_flash_attn_2_available(): |
| 106 | logger.warning_once( |
| 107 | "flash_attn is not available for MoonViT inference; falling back to sdpa." |
| 108 | ) |
| 109 | vision_attn_impl = 'sdpa' |
| 110 | config.vision_config._attn_implementation = vision_attn_impl |
| 111 | self.vision_model = MoonVitPretrainedModel(config.vision_config) |
| 112 | else: |
| 113 | raise ValueError(f'Unsupported vision model type: {config.vision_config.model_type}. Only moonvit is supported.') |
| 114 | |
| 115 | text_attn_impl = ( |
| 116 | getattr(config.text_config, '_attn_implementation', None) |
| 117 | or getattr(config, '_attn_implementation', None) |
| 118 | or 'magi' |
| 119 | ) |
| 120 | config.text_config._attn_implementation = text_attn_impl |
| 121 | |
| 122 | if language_model is not None: |
| 123 | self.language_model = language_model |
| 124 | else: |
| 125 | if config.text_config.architectures[0] == 'Qwen2ForCausalLM': |
| 126 | self.language_model = Qwen2ForCausalLM(config.text_config) |
| 127 | elif config.text_config.architectures[0] == 'Qwen3ForCausalLM': |
| 128 | self.language_model = Qwen3ForCausalLM(config.text_config) |
| 129 | else: |
| 130 | raise ValueError(f'Unsupported language model architecture: {config.text_config.architectures[0]}. Only Qwen2ForCausalLM and Qwen3ForCausalLM are supported.') |
| 131 | |
| 132 | vit_hidden_size = config.vision_config.hidden_size |
| 133 | llm_hidden_size = config.text_config.hidden_size |
| 134 | |
| 135 | # MLP for moonvit (without pixel_shuffle_back, direct mapping) |
| 136 | self.mlp1 = nn.Sequential( |
| 137 | nn.LayerNorm(vit_hidden_size*4), |
| 138 | nn.Linear(vit_hidden_size*4, llm_hidden_size), |
| 139 | nn.GELU(), |
| 140 | nn.Linear(llm_hidden_size, llm_hidden_size) |
| 141 | ) |
| 142 | self.image_token_index = config.image_token_index |
| 143 | self.neftune_alpha = None |
| 144 | |
| 145 | if config.use_backbone_lora: |
| 146 | self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora) |
| 147 | |
| 148 | self.use_llm_lora = config.use_llm_lora |
| 149 | if config.use_llm_lora: |
| 150 | self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora) |
| 151 | |
| 152 | self.token_ids = get_token_ids_from_config(config) |
| 153 | |
| 154 | # Set _no_split_modules dynamically based on the actual LLM architecture |
| 155 | arch = config.text_config.architectures[0] if hasattr(config.text_config, 'architectures') and config.text_config.architectures else 'Qwen2ForCausalLM' |
| 156 | if 'Qwen3' in arch: |
| 157 | self._no_split_modules = ["Qwen3DecoderLayer"] |
| 158 | else: |
| 159 | self._no_split_modules = ["Qwen2DecoderLayer"] |
| 160 | |
| 161 | |
| 162 | def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): |
| 163 | lora_config = LoraConfig( |
| 164 | r=r, |
| 165 | target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', |
| 166 | 'mlp.fc1', 'mlp.fc2'], |
| 167 | lora_alpha=lora_alpha, |
| 168 | lora_dropout=lora_dropout, |
| 169 | ) |
| 170 | self.vision_model = get_peft_model(self.vision_model, lora_config) |
| 171 | self.vision_model.print_trainable_parameters() |
| 172 | |
| 173 | def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): |
| 174 | lora_config = LoraConfig( |
| 175 | r=r, |
| 176 | target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj', |
| 177 | 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'], |
| 178 | lora_alpha=lora_alpha, |
| 179 | lora_dropout=lora_dropout, |
| 180 | task_type='CAUSAL_LM' |
| 181 | ) |
| 182 | self.language_model = get_peft_model(self.language_model, lora_config) |
| 183 | self.language_model.enable_input_require_grads() |
| 184 | self.language_model.print_trainable_parameters() |
| 185 | self.use_llm_lora = True |
| 186 | |
| 187 | |
| 188 | def forward( |
| 189 | self, |
| 190 | pixel_values: List[torch.FloatTensor], |
| 191 | input_ids: torch.LongTensor = None, |
| 192 | attention_mask: Optional[torch.Tensor] = None, |
| 193 | position_ids: Optional[torch.LongTensor] = None, |
| 194 | image_grid_hws: Optional[torch.Tensor] = None, |
| 195 | image_flags: Optional[torch.Tensor] = None, |
| 196 | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 197 | labels: Optional[torch.LongTensor] = None, |
| 198 | use_cache: Optional[bool] = None, |
| 199 | output_attentions: Optional[bool] = None, |
| 200 | output_hidden_states: Optional[bool] = None, |
| 201 | return_dict: Optional[bool] = None, |
| 202 | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| 203 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 204 | |
| 205 | input_embeds = self.language_model.get_input_embeddings()(input_ids) |
| 206 | |
| 207 | has_images = image_flags is not None and image_flags.sum() > 0 |
| 208 | |
| 209 | vit_embeds = self.extract_feature(pixel_values, image_grid_hws) |
| 210 | |
| 211 | B, N, C = input_embeds.shape |
| 212 | input_embeds = input_embeds.reshape(B * N, C) |
| 213 | |
| 214 | if has_images: |
| 215 | filtered_vit_embeds = [] |
| 216 | idx = 0 |
| 217 | for flag in image_flags: |
| 218 | flag_val = flag.item() |
| 219 | if flag_val != 0: |
| 220 | filtered_vit_embeds.extend(vit_embeds[idx:idx + flag_val]) |
| 221 | idx += flag_val |
| 222 | else: |
| 223 | idx += 1 |
| 224 | |
| 225 | vit_embeds = filtered_vit_embeds |
| 226 | vit_embeds = torch.cat(vit_embeds, dim=0) |
| 227 | |
| 228 | vit_embeds = self.mlp1(vit_embeds) |
| 229 | input_ids = input_ids.reshape(B * N) |
| 230 | selected = (input_ids == self.image_token_index) |
| 231 | |
| 232 | input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:selected.sum()] |
| 233 | else: |
| 234 | if vit_embeds: |
| 235 | vit_embeds = torch.cat(vit_embeds, dim=0) |
| 236 | vit_embeds = self.mlp1(vit_embeds) |
| 237 | input_ids = input_ids.reshape(B * N) |
| 238 | selected = (input_ids == self.image_token_index) |
| 239 | if selected.sum() > 0: |
| 240 | input_embeds[selected] = vit_embeds[:selected.sum()] |
| 241 | |
| 242 | input_embeds = input_embeds.reshape(B, N, C) |
| 243 | |
| 244 | outputs = self.language_model( |
| 245 | inputs_embeds=input_embeds, |
| 246 | attention_mask=attention_mask, |
| 247 | position_ids=position_ids, |
| 248 | past_key_values=past_key_values, |
| 249 | use_cache=use_cache, |
| 250 | output_attentions=output_attentions, |
| 251 | output_hidden_states=output_hidden_states, |
| 252 | ) |
| 253 | logits = outputs.logits |
| 254 | |
| 255 | loss = None |
| 256 | if labels is not None: |
| 257 | # Shift so that tokens < n predict n |
| 258 | shift_logits = logits[..., :-1, :].contiguous() |
| 259 | shift_labels = labels[..., 1:].contiguous() |
| 260 | # Flatten the tokens |
| 261 | loss_fct = CrossEntropyLoss() |
| 262 | shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) |
| 263 | shift_labels = shift_labels.view(-1) |
| 264 | # Enable model parallelism |
| 265 | shift_labels = shift_labels.to(shift_logits.device) |
| 266 | loss = loss_fct(shift_logits, shift_labels) |
| 267 | |
| 268 | if not return_dict: |
| 269 | output = (logits,) + outputs[1:] |
| 270 | return (loss,) + output if loss is not None else output |
| 271 | |
| 272 | return CausalLMOutputWithPast( |
| 273 | loss=loss, |
| 274 | logits=logits, |
| 275 | past_key_values=outputs.past_key_values, |
| 276 | hidden_states=outputs.hidden_states, |
| 277 | attentions=outputs.attentions, |
| 278 | ) |
| 279 | |
| 280 | |
| 281 | def extract_feature(self, pixel_values, image_grid_hws): |
| 282 | vit_embeds = self.vision_model(pixel_values=pixel_values, grid_hws=image_grid_hws) |
| 283 | |
| 284 | return vit_embeds |
| 285 | |
| 286 | def get_input_embeddings(self): |
| 287 | return self.language_model.get_input_embeddings() |
| 288 | |
| 289 | def set_input_embeddings(self, value): |
| 290 | self.language_model.set_input_embeddings(value) |
| 291 | |
| 292 | def get_output_embeddings(self): |
| 293 | return self.language_model.get_output_embeddings() |
| 294 | |
| 295 | def set_output_embeddings(self, new_embeddings): |
| 296 | self.language_model.set_output_embeddings(new_embeddings) |
| 297 | |
| 298 | def set_decoder(self, decoder): |
| 299 | self.language_model.set_decoder(decoder) |
| 300 | |
| 301 | def get_decoder(self): |
| 302 | return self.language_model.get_decoder() |
| 303 | |
| 304 | @torch.no_grad() |
| 305 | def generate( |
| 306 | self, |
| 307 | pixel_values: Optional[torch.FloatTensor] = None, |
| 308 | input_ids: Optional[torch.FloatTensor] = None, |
| 309 | attention_mask: Optional[torch.LongTensor] = None, |
| 310 | visual_features: Optional[torch.FloatTensor] = None, |
| 311 | image_grid_hws: Optional[torch.Tensor] = None, |
| 312 | tokenizer = None, |
| 313 | n_future_tokens: int = 6, |
| 314 | **generate_kwargs, |
| 315 | ) -> torch.LongTensor: |
| 316 | |
| 317 | verbose = generate_kwargs.pop('verbose', False) |
| 318 | start_time = time.time() |
| 319 | prefill_time = None |
| 320 | |
| 321 | pixel_values = pixel_values.to(self.language_model.dtype) |
| 322 | # Convert numpy array to tensor if needed |
| 323 | if isinstance(image_grid_hws, np.ndarray): |
| 324 | image_grid_hws = torch.from_numpy(image_grid_hws).to(pixel_values.device, dtype=torch.int32) |
| 325 | |
| 326 | batch_size, seq_len = input_ids.shape |
| 327 | assert batch_size == 1, 'only batch size = 1 is supported now' |
| 328 | assert generate_kwargs.get('use_cache', False), "Only use_cache=True is supported." |
| 329 | |
| 330 | generated = input_ids.clone() |
| 331 | total_gen_length = min(tokenizer.model_max_length, seq_len + generate_kwargs.get('max_new_tokens', 2048)) |
| 332 | iter_round = 0 |
| 333 | past_key_values = None |
| 334 | |
| 335 | # Extract visual features once before the loop |
| 336 | if visual_features is not None: |
| 337 | vit_embeds = visual_features |
| 338 | elif pixel_values is not None: |
| 339 | vit_embeds = self.extract_feature(pixel_values, image_grid_hws) |
| 340 | else: |
| 341 | vit_embeds = None |
| 342 | |
| 343 | if image_grid_hws is not None: |
| 344 | vit_embeds = torch.cat(vit_embeds, dim=0) |
| 345 | vit_embeds = self.mlp1(vit_embeds) |
| 346 | |
| 347 | # ==================== Generation Mode ==================== |
| 348 | # 'fast' : MTP only, never fall back to AR |
| 349 | # 'slow' : AR only, pure auto-regressive decoding |
| 350 | # 'hybrid' : MTP first, fall back to AR on error, switch back on box_end |
| 351 | generation_mode = generate_kwargs.get('generation_mode', 'hybrid') |
| 352 | assert generation_mode in ('fast', 'slow', 'hybrid'), \ |
| 353 | f"Unsupported generation_mode='{generation_mode}'. Use 'fast', 'slow', or 'hybrid'." |
| 354 | |
| 355 | sampling_history = [] |
| 356 | |
| 357 | |
| 358 | use_mtp = generation_mode in ('fast', 'hybrid') |
| 359 | switch_to_ar_count = 0 |
| 360 | |
| 361 | # Pre-allocate mask tokens and position ids |
| 362 | default_mask_token_id = self.token_ids['default_mask_token_id'] |
| 363 | pre_mask_tokens = torch.full( |
| 364 | (batch_size, n_future_tokens - 1), |
| 365 | default_mask_token_id, |
| 366 | dtype=generated.dtype, |
| 367 | device=generated.device |
| 368 | ) |
| 369 | max_possible_len = total_gen_length + n_future_tokens |
| 370 | full_position_ids = torch.arange(0, max_possible_len, device=generated.device).unsqueeze(0) |
| 371 | |
| 372 | |
| 373 | def _prepare_inputs_in_mtp(generated): |
| 374 | generated_with_mask = torch.cat( |
| 375 | ( |
| 376 | generated, |
| 377 | generated[:, -1].unsqueeze(1), |
| 378 | pre_mask_tokens |
| 379 | ), |
| 380 | dim=1 |
| 381 | ) # [batch_size, seq_len + 1 + n_future_tokens - 1] |
| 382 | |
| 383 | # Update pe for kvcache |
| 384 | start_idx = past_key_values[0][0].size(2) if past_key_values is not None else 0 |
| 385 | position_ids = full_position_ids[:, start_idx : generated_with_mask.size(1)].clone() |
| 386 | position_ids[0, -n_future_tokens:] -= 1 |
| 387 | |
| 388 | prepare_inputs = self.language_model.prepare_inputs_for_generation( |
| 389 | generated_with_mask, |
| 390 | past_key_values, |
| 391 | None, |
| 392 | inputs_embeds=None, |
| 393 | use_cache=True, |
| 394 | position_ids=position_ids |
| 395 | ) |
| 396 | return prepare_inputs |
| 397 | |
| 398 | |
| 399 | def _prepare_input_in_ar(generated): |
| 400 | start_idx = past_key_values[0][0].size(2) if past_key_values is not None else 0 |
| 401 | position_ids = full_position_ids[:, start_idx : generated.size(1)] |
| 402 | prepare_inputs = self.language_model.prepare_inputs_for_generation( |
| 403 | generated, |
| 404 | past_key_values, |
| 405 | None, |
| 406 | inputs_embeds=None, |
| 407 | use_cache=True, |
| 408 | position_ids=position_ids |
| 409 | ) |
| 410 | return prepare_inputs |
| 411 | |
| 412 | |
| 413 | def _sample_token_in_mtp(generated, outputs): |
| 414 | """Sample tokens using MTP (Multi-Token Prediction) mode.""" |
| 415 | next_token_logits = outputs.logits[:, -n_future_tokens:, :] |
| 416 | probs, confidence, x0, box_avg = sample_tokens( |
| 417 | next_token_logits, generated, self.token_ids, keep_k=5, **generate_kwargs |
| 418 | ) |
| 419 | |
| 420 | is_box_empty = (box_avg[0] == 0).all() |
| 421 | new_tokens = x0[0] if is_box_empty else box_avg[0] |
| 422 | |
| 423 | out_pattern = handle_pattern(new_tokens, self.token_ids, generation_mode) |
| 424 | out_type = out_pattern['type'] |
| 425 | out_token = torch.tensor(out_pattern['tokens'], dtype=x0.dtype, device=x0.device) |
| 426 | |
| 427 | return out_type, out_token |
| 428 | |
| 429 | |
| 430 | def _sample_token_in_ar(generated, outputs): |
| 431 | """Sample a single token using AR (Auto-Regressive) mode.""" |
| 432 | next_token_logits = outputs.logits[:, -1:, :] |
| 433 | probs, confidence, x0, _ = sample_tokens( |
| 434 | next_token_logits, generated, self.token_ids, **generate_kwargs |
| 435 | ) |
| 436 | |
| 437 | out_token = x0[0] |
| 438 | out_type = 'continue_ar' |
| 439 | token_val = out_token[0].item() |
| 440 | |
| 441 | box_end_token_id = self.token_ids['box_end_token_id'] |
| 442 | coord_start_token_id = self.token_ids['coord_start_token_id'] |
| 443 | coord_end_token_id = self.token_ids['coord_end_token_id'] |
| 444 | none_token_id = self.token_ids['none_token_id'] |
| 445 | im_end_token_id = self.token_ids['im_end_token_id'] |
| 446 | |
| 447 | if generation_mode == 'hybrid': |
| 448 | # Hybrid AR phase: detect box boundaries to switch back to MTP |
| 449 | if token_val == box_end_token_id: |
| 450 | out_type = 'box_end_ar' |
| 451 | elif coord_start_token_id <= token_val <= coord_end_token_id or token_val == none_token_id: |
| 452 | out_type = 'coord_ar' |
| 453 | else: |
| 454 | out_type = 'im_end' |
| 455 | else: |
| 456 | # Slow mode: pure AR, only stop on im_end |
| 457 | if token_val == im_end_token_id: |
| 458 | out_type = 'im_end' |
| 459 | |
| 460 | return out_type, out_token |
| 461 | |
| 462 | |
| 463 | # Generate loop |
| 464 | while generated.size(1) < total_gen_length: |
| 465 | iter_round += 1 |
| 466 | |
| 467 | # Step 1: Prepare inputs |
| 468 | if use_mtp: |
| 469 | prepare_inputs = _prepare_inputs_in_mtp(generated) |
| 470 | else: |
| 471 | prepare_inputs = _prepare_input_in_ar(generated) |
| 472 | |
| 473 | if iter_round == 1: |
| 474 | prepare_inputs.update({ |
| 475 | 'visual_features': vit_embeds, |
| 476 | 'image_token_index': self.config.image_token_index, |
| 477 | }) |
| 478 | |
| 479 | # Step 2: Model forward & update KV cache |
| 480 | with torch.no_grad(): |
| 481 | outputs = self.language_model(**prepare_inputs) |
| 482 | |
| 483 | past_key_values = tuple( |
| 484 | (kv[0][:, :, :generated.shape[1], :], kv[1][:, :, :generated.shape[1], :]) |
| 485 | for kv in outputs.past_key_values |
| 486 | ) |
| 487 | |
| 488 | # Step 3: Sample tokens |
| 489 | if use_mtp: |
| 490 | out_type, out_token = _sample_token_in_mtp(generated, outputs) |
| 491 | else: |
| 492 | out_type, out_token = _sample_token_in_ar(generated, outputs) |
| 493 | |
| 494 | if verbose: |
| 495 | sampling_history.append(('ar' if 'ar' in out_type else 'mtp', tokenizer.decode(out_token, skip_special_tokens=False))) |
| 496 | |
| 497 | generated = torch.cat([generated, out_token.unsqueeze(0)], dim=1) |
| 498 | |
| 499 | # Step 4: Mode switching & termination |
| 500 | if out_type == 'im_end': |
| 501 | break |
| 502 | |
| 503 | if generation_mode == 'hybrid': |
| 504 | if out_type == 'error_box': |
| 505 | use_mtp = False |
| 506 | switch_to_ar_count += 1 |
| 507 | elif out_type == 'box_end_ar': |
| 508 | use_mtp = True |
| 509 | # fast mode: use_mtp stays True always |
| 510 | # slow mode: use_mtp stays False always |
| 511 | |
| 512 | if prefill_time is None: |
| 513 | prefill_time = time.time() - start_time |
| 514 | |
| 515 | # Decode and return |
| 516 | generated_ids = generated[:, seq_len:] |
| 517 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=False) |
| 518 | |
| 519 | if verbose: |
| 520 | end_time = time.time() |
| 521 | num_tokens = generated_ids.size(1) |
| 522 | num_boxes = response[0].count("<box>") |
| 523 | total_time = end_time - start_time |
| 524 | |
| 525 | out_info = f"\nStatistic Info, num_tokens={num_tokens}; " + \ |
| 526 | f"generate_time(s)={total_time:.4f}; " + \ |
| 527 | f"tps={(num_tokens / total_time):.4f}; " + \ |
| 528 | f"forward_step={iter_round}; " + \ |
| 529 | f"num_boxes={num_boxes}; " + \ |
| 530 | f"bps={(num_boxes / total_time):.4f}; " + \ |
| 531 | f"prefill_time={(prefill_time):.4f}; " + \ |
| 532 | f"switch_to_ar={switch_to_ar_count}\n" |
| 533 | print(out_info) |
| 534 | |
| 535 | return response[0], sampling_history, out_info |
| 536 | |
| 537 | return response[0] |