modeling_locateanything.py
21.7 KB · 537 lines · python Raw
1 # --------------------------------------------------------
2 # NVIDIA
3 # Copyright (c) 2025 NVIDIA
4 # Licensed under The MIT License [see LICENSE for details]
5 # --------------------------------------------------------
6
7 import time
8 from typing import List, Optional, Tuple, Union
9
10 import numpy as np
11 import torch
12 from torch import nn
13 from torch.nn import CrossEntropyLoss
14 from transformers.generation import GenerationMixin
15 from transformers.modeling_outputs import CausalLMOutputWithPast
16 from transformers.modeling_utils import PreTrainedModel
17 from transformers.utils import add_start_docstrings, is_flash_attn_2_available, logging
18 from peft import LoraConfig, get_peft_model
19
20 from .configuration_locateanything import LocateAnythingConfig
21 from .modeling_qwen2 import Qwen2ForCausalLM
22 from .modeling_vit import MoonVitPretrainedModel
23 from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
24 from .mask_sdpa_utils import *
25 from .mask_magi_utils import *
26 from .configuration_qwen2 import Qwen2Config
27
28 from .generate_utils import (
29 sample_tokens,
30 handle_pattern,
31 get_token_ids_from_config,
32 )
33
34 logger = logging.get_logger(__name__)
35
36
37 LOCATEANYTHING_START_DOCSTRING = r"""
38 This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
39 library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
40 etc.)
41 This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
42 Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
43 and behavior.
44 Parameters:
45 config ([`LocateAnythingConfig`]):
46 Model configuration class with all the parameters of the model. Initializing with a config file does not
47 load the weights associated with the model, only the configuration. Check out the
48 [`~PreTrainedModel.from_pretrained`] method to load the model weights.
49 """
50
51 @add_start_docstrings(
52 "The bare LocateAnything Model outputting raw hidden-states without any specific head on top.",
53 LOCATEANYTHING_START_DOCSTRING,
54 )
55 class LocateAnythingPreTrainedModel(PreTrainedModel):
56 config_class = LocateAnythingConfig
57 base_model_prefix = "model"
58 main_input_name = 'input_ids'
59 supports_gradient_checkpointing = True
60 _no_split_modules = ["Qwen2DecoderLayer"]
61 _skip_keys_device_placement = "past_key_values"
62 _supports_flash_attn_2 = True
63 _supports_cache_class = True
64 _supports_static_cache = True
65 _supports_quantized_cache = True
66 _supports_sdpa = True
67
68 @classmethod
69 def _autoset_attn_implementation(cls, config, *args, **kwargs):
70 if getattr(config, '_attn_implementation', None) == 'magi':
71 return config
72 return super()._autoset_attn_implementation(config, *args, **kwargs)
73
74 def _check_and_adjust_attn_implementation(self, attn_implementation, is_init_check=False):
75 if attn_implementation == "magi":
76 return "magi"
77 return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check)
78
79 def _init_weights(self, module):
80 std = getattr(self.config, 'initializer_range', None) or self.config.text_config.initializer_range
81 if isinstance(module, (nn.Linear, nn.Conv2d)):
82 module.weight.data.normal_(mean=0.0, std=std)
83 if module.bias is not None:
84 module.bias.data.zero_()
85 elif isinstance(module, nn.Embedding):
86 module.weight.data.normal_(mean=0.0, std=std)
87 if module.padding_idx is not None:
88 module.weight.data[module.padding_idx].zero_()
89
90
91 class LocateAnythingForConditionalGeneration(LocateAnythingPreTrainedModel, GenerationMixin):
92 config_class = LocateAnythingConfig
93 def __init__(self, config: LocateAnythingConfig, vision_model=None, language_model=None):
94 super().__init__(config)
95
96 self.template = config.template
97 self.mlp_checkpoint = config.mlp_checkpoint
98
99 logger.info(f'mlp_checkpoint: {self.mlp_checkpoint}')
100 if vision_model is not None:
101 self.vision_model = vision_model
102 else:
103 if config.vision_config.model_type == 'moonvit':
104 vision_attn_impl = getattr(config.vision_config, '_attn_implementation', None) or 'flash_attention_2'
105 if vision_attn_impl == 'flash_attention_2' and not is_flash_attn_2_available():
106 logger.warning_once(
107 "flash_attn is not available for MoonViT inference; falling back to sdpa."
108 )
109 vision_attn_impl = 'sdpa'
110 config.vision_config._attn_implementation = vision_attn_impl
111 self.vision_model = MoonVitPretrainedModel(config.vision_config)
112 else:
113 raise ValueError(f'Unsupported vision model type: {config.vision_config.model_type}. Only moonvit is supported.')
114
115 text_attn_impl = (
116 getattr(config.text_config, '_attn_implementation', None)
117 or getattr(config, '_attn_implementation', None)
118 or 'magi'
119 )
120 config.text_config._attn_implementation = text_attn_impl
121
122 if language_model is not None:
123 self.language_model = language_model
124 else:
125 if config.text_config.architectures[0] == 'Qwen2ForCausalLM':
126 self.language_model = Qwen2ForCausalLM(config.text_config)
127 elif config.text_config.architectures[0] == 'Qwen3ForCausalLM':
128 self.language_model = Qwen3ForCausalLM(config.text_config)
129 else:
130 raise ValueError(f'Unsupported language model architecture: {config.text_config.architectures[0]}. Only Qwen2ForCausalLM and Qwen3ForCausalLM are supported.')
131
132 vit_hidden_size = config.vision_config.hidden_size
133 llm_hidden_size = config.text_config.hidden_size
134
135 # MLP for moonvit (without pixel_shuffle_back, direct mapping)
136 self.mlp1 = nn.Sequential(
137 nn.LayerNorm(vit_hidden_size*4),
138 nn.Linear(vit_hidden_size*4, llm_hidden_size),
139 nn.GELU(),
140 nn.Linear(llm_hidden_size, llm_hidden_size)
141 )
142 self.image_token_index = config.image_token_index
143 self.neftune_alpha = None
144
145 if config.use_backbone_lora:
146 self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
147
148 self.use_llm_lora = config.use_llm_lora
149 if config.use_llm_lora:
150 self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
151
152 self.token_ids = get_token_ids_from_config(config)
153
154 # Set _no_split_modules dynamically based on the actual LLM architecture
155 arch = config.text_config.architectures[0] if hasattr(config.text_config, 'architectures') and config.text_config.architectures else 'Qwen2ForCausalLM'
156 if 'Qwen3' in arch:
157 self._no_split_modules = ["Qwen3DecoderLayer"]
158 else:
159 self._no_split_modules = ["Qwen2DecoderLayer"]
160
161
162 def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
163 lora_config = LoraConfig(
164 r=r,
165 target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj',
166 'mlp.fc1', 'mlp.fc2'],
167 lora_alpha=lora_alpha,
168 lora_dropout=lora_dropout,
169 )
170 self.vision_model = get_peft_model(self.vision_model, lora_config)
171 self.vision_model.print_trainable_parameters()
172
173 def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
174 lora_config = LoraConfig(
175 r=r,
176 target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
177 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'],
178 lora_alpha=lora_alpha,
179 lora_dropout=lora_dropout,
180 task_type='CAUSAL_LM'
181 )
182 self.language_model = get_peft_model(self.language_model, lora_config)
183 self.language_model.enable_input_require_grads()
184 self.language_model.print_trainable_parameters()
185 self.use_llm_lora = True
186
187
188 def forward(
189 self,
190 pixel_values: List[torch.FloatTensor],
191 input_ids: torch.LongTensor = None,
192 attention_mask: Optional[torch.Tensor] = None,
193 position_ids: Optional[torch.LongTensor] = None,
194 image_grid_hws: Optional[torch.Tensor] = None,
195 image_flags: Optional[torch.Tensor] = None,
196 past_key_values: Optional[List[torch.FloatTensor]] = None,
197 labels: Optional[torch.LongTensor] = None,
198 use_cache: Optional[bool] = None,
199 output_attentions: Optional[bool] = None,
200 output_hidden_states: Optional[bool] = None,
201 return_dict: Optional[bool] = None,
202 ) -> Union[Tuple, CausalLMOutputWithPast]:
203 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
204
205 input_embeds = self.language_model.get_input_embeddings()(input_ids)
206
207 has_images = image_flags is not None and image_flags.sum() > 0
208
209 vit_embeds = self.extract_feature(pixel_values, image_grid_hws)
210
211 B, N, C = input_embeds.shape
212 input_embeds = input_embeds.reshape(B * N, C)
213
214 if has_images:
215 filtered_vit_embeds = []
216 idx = 0
217 for flag in image_flags:
218 flag_val = flag.item()
219 if flag_val != 0:
220 filtered_vit_embeds.extend(vit_embeds[idx:idx + flag_val])
221 idx += flag_val
222 else:
223 idx += 1
224
225 vit_embeds = filtered_vit_embeds
226 vit_embeds = torch.cat(vit_embeds, dim=0)
227
228 vit_embeds = self.mlp1(vit_embeds)
229 input_ids = input_ids.reshape(B * N)
230 selected = (input_ids == self.image_token_index)
231
232 input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:selected.sum()]
233 else:
234 if vit_embeds:
235 vit_embeds = torch.cat(vit_embeds, dim=0)
236 vit_embeds = self.mlp1(vit_embeds)
237 input_ids = input_ids.reshape(B * N)
238 selected = (input_ids == self.image_token_index)
239 if selected.sum() > 0:
240 input_embeds[selected] = vit_embeds[:selected.sum()]
241
242 input_embeds = input_embeds.reshape(B, N, C)
243
244 outputs = self.language_model(
245 inputs_embeds=input_embeds,
246 attention_mask=attention_mask,
247 position_ids=position_ids,
248 past_key_values=past_key_values,
249 use_cache=use_cache,
250 output_attentions=output_attentions,
251 output_hidden_states=output_hidden_states,
252 )
253 logits = outputs.logits
254
255 loss = None
256 if labels is not None:
257 # Shift so that tokens < n predict n
258 shift_logits = logits[..., :-1, :].contiguous()
259 shift_labels = labels[..., 1:].contiguous()
260 # Flatten the tokens
261 loss_fct = CrossEntropyLoss()
262 shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
263 shift_labels = shift_labels.view(-1)
264 # Enable model parallelism
265 shift_labels = shift_labels.to(shift_logits.device)
266 loss = loss_fct(shift_logits, shift_labels)
267
268 if not return_dict:
269 output = (logits,) + outputs[1:]
270 return (loss,) + output if loss is not None else output
271
272 return CausalLMOutputWithPast(
273 loss=loss,
274 logits=logits,
275 past_key_values=outputs.past_key_values,
276 hidden_states=outputs.hidden_states,
277 attentions=outputs.attentions,
278 )
279
280
281 def extract_feature(self, pixel_values, image_grid_hws):
282 vit_embeds = self.vision_model(pixel_values=pixel_values, grid_hws=image_grid_hws)
283
284 return vit_embeds
285
286 def get_input_embeddings(self):
287 return self.language_model.get_input_embeddings()
288
289 def set_input_embeddings(self, value):
290 self.language_model.set_input_embeddings(value)
291
292 def get_output_embeddings(self):
293 return self.language_model.get_output_embeddings()
294
295 def set_output_embeddings(self, new_embeddings):
296 self.language_model.set_output_embeddings(new_embeddings)
297
298 def set_decoder(self, decoder):
299 self.language_model.set_decoder(decoder)
300
301 def get_decoder(self):
302 return self.language_model.get_decoder()
303
304 @torch.no_grad()
305 def generate(
306 self,
307 pixel_values: Optional[torch.FloatTensor] = None,
308 input_ids: Optional[torch.FloatTensor] = None,
309 attention_mask: Optional[torch.LongTensor] = None,
310 visual_features: Optional[torch.FloatTensor] = None,
311 image_grid_hws: Optional[torch.Tensor] = None,
312 tokenizer = None,
313 n_future_tokens: int = 6,
314 **generate_kwargs,
315 ) -> torch.LongTensor:
316
317 verbose = generate_kwargs.pop('verbose', False)
318 start_time = time.time()
319 prefill_time = None
320
321 pixel_values = pixel_values.to(self.language_model.dtype)
322 # Convert numpy array to tensor if needed
323 if isinstance(image_grid_hws, np.ndarray):
324 image_grid_hws = torch.from_numpy(image_grid_hws).to(pixel_values.device, dtype=torch.int32)
325
326 batch_size, seq_len = input_ids.shape
327 assert batch_size == 1, 'only batch size = 1 is supported now'
328 assert generate_kwargs.get('use_cache', False), "Only use_cache=True is supported."
329
330 generated = input_ids.clone()
331 total_gen_length = min(tokenizer.model_max_length, seq_len + generate_kwargs.get('max_new_tokens', 2048))
332 iter_round = 0
333 past_key_values = None
334
335 # Extract visual features once before the loop
336 if visual_features is not None:
337 vit_embeds = visual_features
338 elif pixel_values is not None:
339 vit_embeds = self.extract_feature(pixel_values, image_grid_hws)
340 else:
341 vit_embeds = None
342
343 if image_grid_hws is not None:
344 vit_embeds = torch.cat(vit_embeds, dim=0)
345 vit_embeds = self.mlp1(vit_embeds)
346
347 # ==================== Generation Mode ====================
348 # 'fast' : MTP only, never fall back to AR
349 # 'slow' : AR only, pure auto-regressive decoding
350 # 'hybrid' : MTP first, fall back to AR on error, switch back on box_end
351 generation_mode = generate_kwargs.get('generation_mode', 'hybrid')
352 assert generation_mode in ('fast', 'slow', 'hybrid'), \
353 f"Unsupported generation_mode='{generation_mode}'. Use 'fast', 'slow', or 'hybrid'."
354
355 sampling_history = []
356
357
358 use_mtp = generation_mode in ('fast', 'hybrid')
359 switch_to_ar_count = 0
360
361 # Pre-allocate mask tokens and position ids
362 default_mask_token_id = self.token_ids['default_mask_token_id']
363 pre_mask_tokens = torch.full(
364 (batch_size, n_future_tokens - 1),
365 default_mask_token_id,
366 dtype=generated.dtype,
367 device=generated.device
368 )
369 max_possible_len = total_gen_length + n_future_tokens
370 full_position_ids = torch.arange(0, max_possible_len, device=generated.device).unsqueeze(0)
371
372
373 def _prepare_inputs_in_mtp(generated):
374 generated_with_mask = torch.cat(
375 (
376 generated,
377 generated[:, -1].unsqueeze(1),
378 pre_mask_tokens
379 ),
380 dim=1
381 ) # [batch_size, seq_len + 1 + n_future_tokens - 1]
382
383 # Update pe for kvcache
384 start_idx = past_key_values[0][0].size(2) if past_key_values is not None else 0
385 position_ids = full_position_ids[:, start_idx : generated_with_mask.size(1)].clone()
386 position_ids[0, -n_future_tokens:] -= 1
387
388 prepare_inputs = self.language_model.prepare_inputs_for_generation(
389 generated_with_mask,
390 past_key_values,
391 None,
392 inputs_embeds=None,
393 use_cache=True,
394 position_ids=position_ids
395 )
396 return prepare_inputs
397
398
399 def _prepare_input_in_ar(generated):
400 start_idx = past_key_values[0][0].size(2) if past_key_values is not None else 0
401 position_ids = full_position_ids[:, start_idx : generated.size(1)]
402 prepare_inputs = self.language_model.prepare_inputs_for_generation(
403 generated,
404 past_key_values,
405 None,
406 inputs_embeds=None,
407 use_cache=True,
408 position_ids=position_ids
409 )
410 return prepare_inputs
411
412
413 def _sample_token_in_mtp(generated, outputs):
414 """Sample tokens using MTP (Multi-Token Prediction) mode."""
415 next_token_logits = outputs.logits[:, -n_future_tokens:, :]
416 probs, confidence, x0, box_avg = sample_tokens(
417 next_token_logits, generated, self.token_ids, keep_k=5, **generate_kwargs
418 )
419
420 is_box_empty = (box_avg[0] == 0).all()
421 new_tokens = x0[0] if is_box_empty else box_avg[0]
422
423 out_pattern = handle_pattern(new_tokens, self.token_ids, generation_mode)
424 out_type = out_pattern['type']
425 out_token = torch.tensor(out_pattern['tokens'], dtype=x0.dtype, device=x0.device)
426
427 return out_type, out_token
428
429
430 def _sample_token_in_ar(generated, outputs):
431 """Sample a single token using AR (Auto-Regressive) mode."""
432 next_token_logits = outputs.logits[:, -1:, :]
433 probs, confidence, x0, _ = sample_tokens(
434 next_token_logits, generated, self.token_ids, **generate_kwargs
435 )
436
437 out_token = x0[0]
438 out_type = 'continue_ar'
439 token_val = out_token[0].item()
440
441 box_end_token_id = self.token_ids['box_end_token_id']
442 coord_start_token_id = self.token_ids['coord_start_token_id']
443 coord_end_token_id = self.token_ids['coord_end_token_id']
444 none_token_id = self.token_ids['none_token_id']
445 im_end_token_id = self.token_ids['im_end_token_id']
446
447 if generation_mode == 'hybrid':
448 # Hybrid AR phase: detect box boundaries to switch back to MTP
449 if token_val == box_end_token_id:
450 out_type = 'box_end_ar'
451 elif coord_start_token_id <= token_val <= coord_end_token_id or token_val == none_token_id:
452 out_type = 'coord_ar'
453 else:
454 out_type = 'im_end'
455 else:
456 # Slow mode: pure AR, only stop on im_end
457 if token_val == im_end_token_id:
458 out_type = 'im_end'
459
460 return out_type, out_token
461
462
463 # Generate loop
464 while generated.size(1) < total_gen_length:
465 iter_round += 1
466
467 # Step 1: Prepare inputs
468 if use_mtp:
469 prepare_inputs = _prepare_inputs_in_mtp(generated)
470 else:
471 prepare_inputs = _prepare_input_in_ar(generated)
472
473 if iter_round == 1:
474 prepare_inputs.update({
475 'visual_features': vit_embeds,
476 'image_token_index': self.config.image_token_index,
477 })
478
479 # Step 2: Model forward & update KV cache
480 with torch.no_grad():
481 outputs = self.language_model(**prepare_inputs)
482
483 past_key_values = tuple(
484 (kv[0][:, :, :generated.shape[1], :], kv[1][:, :, :generated.shape[1], :])
485 for kv in outputs.past_key_values
486 )
487
488 # Step 3: Sample tokens
489 if use_mtp:
490 out_type, out_token = _sample_token_in_mtp(generated, outputs)
491 else:
492 out_type, out_token = _sample_token_in_ar(generated, outputs)
493
494 if verbose:
495 sampling_history.append(('ar' if 'ar' in out_type else 'mtp', tokenizer.decode(out_token, skip_special_tokens=False)))
496
497 generated = torch.cat([generated, out_token.unsqueeze(0)], dim=1)
498
499 # Step 4: Mode switching & termination
500 if out_type == 'im_end':
501 break
502
503 if generation_mode == 'hybrid':
504 if out_type == 'error_box':
505 use_mtp = False
506 switch_to_ar_count += 1
507 elif out_type == 'box_end_ar':
508 use_mtp = True
509 # fast mode: use_mtp stays True always
510 # slow mode: use_mtp stays False always
511
512 if prefill_time is None:
513 prefill_time = time.time() - start_time
514
515 # Decode and return
516 generated_ids = generated[:, seq_len:]
517 response = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
518
519 if verbose:
520 end_time = time.time()
521 num_tokens = generated_ids.size(1)
522 num_boxes = response[0].count("<box>")
523 total_time = end_time - start_time
524
525 out_info = f"\nStatistic Info, num_tokens={num_tokens}; " + \
526 f"generate_time(s)={total_time:.4f}; " + \
527 f"tps={(num_tokens / total_time):.4f}; " + \
528 f"forward_step={iter_round}; " + \
529 f"num_boxes={num_boxes}; " + \
530 f"bps={(num_boxes / total_time):.4f}; " + \
531 f"prefill_time={(prefill_time):.4f}; " + \
532 f"switch_to_ar={switch_to_ar_count}\n"
533 print(out_info)
534
535 return response[0], sampling_history, out_info
536
537 return response[0]