generate_utils.py
18.3 KB · 504 lines · python Raw
1 # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2 #
3 # NVIDIA CORPORATION and its licensors retain all intellectual property
4 # and proprietary rights in and to this software, related documentation
5 # and any modifications thereto. Any use, reproduction, disclosure or
6 # distribution of this software and related documentation without an express
7 # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
9 import torch
10 import torch.nn.functional as F
11 import torch.distributions as dists
12 from typing import Dict, Optional
13
14
15 def get_token_ids_from_config(config) -> Dict[str, int]:
16 """Extract all token IDs from the configuration object.
17
18 Args:
19 config: Configuration object (LocateAnythingConfig or similar)
20
21 Returns:
22 Dictionary containing all token IDs
23 """
24 token_ids = {}
25
26 # Get from main config
27 token_ids['box_start_token_id'] = getattr(config, 'box_start_token_id', 151668)
28 token_ids['box_end_token_id'] = getattr(config, 'box_end_token_id', 151669)
29 token_ids['coord_start_token_id'] = getattr(config, 'coord_start_token_id', 151677)
30 token_ids['coord_end_token_id'] = getattr(config, 'coord_end_token_id', 152677)
31 token_ids['ref_start_token_id'] = getattr(config, 'ref_start_token_id', 151672)
32 token_ids['ref_end_token_id'] = getattr(config, 'ref_end_token_id', 151673)
33 token_ids['none_token_id'] = getattr(config, 'none_token_id', 4064)
34
35 # Get from text_config
36 text_config = getattr(config, 'text_config', None)
37 if text_config is not None:
38 token_ids['null_token_id'] = getattr(text_config, 'null_token_id', 152678)
39 token_ids['im_end_token_id'] = getattr(text_config, 'eos_token_id', 151645)
40 token_ids['switch_token_id'] = getattr(text_config, 'switch_token_id', 152679)
41 token_ids['default_mask_token_id'] = getattr(text_config, 'text_mask_token_id', 151676)
42 else:
43 token_ids['null_token_id'] = 152678
44 token_ids['im_end_token_id'] = 151645
45 token_ids['switch_token_id'] = 152679
46 token_ids['default_mask_token_id'] = 151676
47
48 return token_ids
49
50
51 def top_p_logits(
52 logits: torch.Tensor,
53 top_p: float = None
54 ) -> torch.Tensor:
55 sorted_logits, sorted_indices = torch.sort(logits, descending=True)
56 cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
57 sorted_indices_to_remove = cumulative_probs > top_p
58 # Shift the indices to the right to keep the first token above the threshold
59 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
60 sorted_indices_to_remove[..., 0] = 0
61
62 mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
63 mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
64 logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
65 return logits
66
67
68 def top_k_logits(
69 logits: torch.Tensor,
70 top_k: int = None
71 ) -> torch.Tensor:
72 top_k = min(top_k, logits.size(-1)) # Safety check
73 # Remove all tokens with a probability less than the last token of the top-k
74 indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
75 logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
76 return logits
77
78
79 def apply_repetition_penalty(
80 logits: torch.Tensor,
81 input_ids: torch.Tensor,
82 repetition_penalty: float = 1.0
83 ) -> torch.Tensor:
84 """
85 Apply repetition penalty to logits.
86
87 Args:
88 logits: Shape [batch_size, seq_len, vocab_size] or [batch_size, vocab_size]
89 input_ids: Previously generated token ids, shape [batch_size, seq_len]
90 repetition_penalty: Penalty factor. > 1.0 penalizes repetition, < 1.0 encourages it.
91
92 Returns:
93 Modified logits with repetition penalty applied.
94 """
95 if repetition_penalty == 1.0:
96 return logits
97
98 # Convert to 3D for vectorized computation
99 if logits.dim() == 2:
100 logits = logits.unsqueeze(1) # [B, 1, V]
101 squeeze_back = True
102 else:
103 squeeze_back = False
104
105 batch_size, seq_len, vocab_size = logits.shape
106
107 # Construct [B, V] bool mask marking tokens that have appeared in each batch
108 device = logits.device
109 token_mask = torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
110 for b in range(batch_size):
111 # Apply penalty only based on tokens already generated in this batch
112 unique_tokens = input_ids[b].unique()
113 # Prevent out-of-bounds: only keep IDs within vocab range
114 valid_tokens = unique_tokens[(unique_tokens >= 0) & (unique_tokens < vocab_size)]
115 if valid_tokens.numel() > 0:
116 token_mask[b, valid_tokens] = True
117
118 # Expand to [B, L, V] to align with logits
119 token_mask = token_mask.unsqueeze(1).expand(-1, seq_len, -1)
120
121 # Divide positive values by penalty, multiply negative values by penalty
122 positive = logits > 0
123 negative = ~positive
124
125 # Apply penalty only at mask positions
126 logits = torch.where(token_mask & positive, logits / repetition_penalty, logits)
127 logits = torch.where(token_mask & negative, logits * repetition_penalty, logits)
128
129 if squeeze_back:
130 logits = logits.squeeze(1)
131
132 return logits
133
134
135 def sample_tokens(
136 logits: torch.Tensor,
137 generated: torch.Tensor,
138 token_ids: Dict[str, int],
139 **generate_kwargs,
140 ):
141 batch_size, seq_len, vocab_size = logits.shape
142
143 repetition_penalty = generate_kwargs.get('repetition_penalty', 1.0)
144 temperature = generate_kwargs.get('temperature', 0)
145 top_p = generate_kwargs.get('top_p', None)
146 top_k = generate_kwargs.get('top_k', None)
147
148 # Apply repetition penalty based on all previously generated tokens
149 if repetition_penalty != 1.0:
150 logits = apply_repetition_penalty(logits, generated, repetition_penalty)
151
152 if temperature > 0:
153 logits = logits / temperature
154 if top_p is not None and top_p < 1:
155 logits = top_p_logits(logits, top_p)
156 if top_k is not None:
157 logits = top_k_logits(logits, top_k)
158
159 probs = torch.softmax(logits, dim=-1)
160
161 if temperature > 0:
162 try:
163 x0 = dists.Categorical(probs=probs).sample()
164 confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
165 except Exception:
166 confidence, x0 = probs.max(dim=-1)
167 else:
168 confidence, x0 = probs.max(dim=-1)
169
170 if seq_len == 1:
171 return probs, confidence, x0, None
172
173 box_avg = []
174 fallback_box = torch.zeros(1, dtype=x0.dtype, device=x0.device)
175
176 for b in range(batch_size):
177 decoded_box = decode_bbox_avg(
178 logits[b], probs[b], token_ids, keep_k=generate_kwargs.get('keep_k_avg', 4),
179 generation_mode=generate_kwargs.get('generation_mode', 'hybrid'),
180 )
181 if decoded_box is not None:
182 box_avg.append(decoded_box)
183 else:
184 out_ref = decode_ref(logits[b], probs[b], token_ids)
185 if out_ref is not None:
186 box_avg.append(torch.tensor(out_ref, dtype=x0.dtype, device=x0.device))
187 else:
188 box_avg.append(fallback_box)
189
190 box_avg = torch.stack(box_avg)
191
192 return probs, confidence, x0, box_avg
193
194
195 def sample_tokens_ar(
196 logits: torch.Tensor,
197 generated: torch.Tensor,
198 token_ids: Dict[str, int],
199 **generate_kwargs,
200 ):
201 """
202 Lightweight sampling function for AR single-step sampling only.
203
204 Args:
205 logits: [batch_size, vocab_size] or [batch_size, 1, vocab_size]
206 generated: [batch_size, seq_len]
207 """
208 # Convert to 3D for reusing repetition penalty and clipping logic
209 if logits.dim() == 2:
210 logits = logits.unsqueeze(1) # [B, 1, V]
211 batch_size, seq_len, vocab_size = logits.shape
212 assert seq_len == 1, "sample_tokens_ar only supports single-step AR sampling (seq_len == 1)"
213
214 repetition_penalty = generate_kwargs.get('repetition_penalty', 1.0)
215 temperature = generate_kwargs.get('temperature', 0)
216 top_p = generate_kwargs.get('top_p', None)
217 top_k = generate_kwargs.get('top_k', None)
218
219 # Apply repetition penalty only based on historically generated tokens
220 if repetition_penalty != 1.0:
221 logits = apply_repetition_penalty(logits, generated, repetition_penalty)
222
223 if temperature > 0:
224 logits = logits / temperature
225 if top_p is not None and top_p < 1:
226 logits = top_p_logits(logits, top_p)
227 if top_k is not None:
228 logits = top_k_logits(logits, top_k)
229
230 probs = torch.softmax(logits, dim=-1)
231
232 if temperature > 0:
233 try:
234 x0 = dists.Categorical(probs=probs).sample()
235 confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
236 except Exception:
237 confidence, x0 = probs.max(dim=-1)
238 else:
239 # For greedy: directly take the token with maximum probability
240 confidence, x0 = probs.max(dim=-1)
241
242 # Keep interface consistent with sample_tokens: return [B, 1, V] / [B, 1] shape
243 return probs, confidence, x0, None, None
244
245
246 def is_valid_box_frame(
247 probs,
248 token_ids: Dict[str, int],
249 start_thresh=0.6,
250 end_thresh=0.2,
251 topk=5,
252 ):
253 box_start_token_id = token_ids['box_start_token_id']
254 box_end_token_id = token_ids['box_end_token_id']
255 null_token_id = token_ids['null_token_id']
256 im_end_token_id = token_ids['im_end_token_id']
257 none_token_id = token_ids['none_token_id'] # none
258
259 p_start = probs[0, box_start_token_id]
260 if p_start >= start_thresh:
261 if (probs[1, none_token_id] > 0.2 and
262 probs[2, box_end_token_id] > 0.2 and
263 probs[3, null_token_id] > 0.1 and
264 probs[4, null_token_id] > 0.1):
265 return 'empty_box'
266
267 end_target_ids = torch.tensor([box_end_token_id, null_token_id, im_end_token_id], device=probs.device)
268 end_score = probs[5, end_target_ids].sum()
269
270 if end_score >= end_thresh:
271 return 'legal_box'
272
273 return 'illegal_box'
274
275
276 def decode_bbox_avg(
277 logits,
278 probs,
279 token_ids: Dict[str, int],
280 keep_k=5,
281 start_thresh=0.7,
282 end_thresh=0.2,
283 generation_mode: str = 'hybrid',
284 ):
285 """
286 Decode bounding box coordinates using top-k weighted average.
287
288 Args:
289 logits: Logits of shape (6, vocab_size)
290 probs: Probability distribution of shape (6, vocab_size)
291 token_ids: Dictionary containing all token IDs
292 keep_k: Number of top-k candidate tokens to keep at each position
293 start_thresh: Confidence threshold for box start token
294 end_thresh: Confidence threshold for box end token
295
296 Returns:
297 Decoded bounding box coordinate list in format [box_start, x1, x2, y1, y2, box_end],
298 or None if decoding fails
299 """
300 coord_start_token_id = token_ids['coord_start_token_id']
301 coord_end_token_id = token_ids['coord_end_token_id']
302 box_start_token_id = token_ids['box_start_token_id']
303 box_end_token_id = token_ids['box_end_token_id']
304 none_token_id = token_ids['none_token_id']
305
306 device = logits.device
307
308 box_type = is_valid_box_frame(
309 probs,
310 token_ids,
311 start_thresh=start_thresh,
312 end_thresh=end_thresh,
313 topk=keep_k
314 )
315 if box_type == 'empty_box':
316 # Handle the <box>none</box> case first
317 return torch.tensor([
318 box_start_token_id,
319 none_token_id,
320 box_end_token_id,
321 token_ids['null_token_id'],
322 token_ids['null_token_id'],
323 token_ids['null_token_id']
324 ], dtype=torch.long, device=probs.device)
325 elif box_type == 'illegal_box':
326 return None
327
328 # Extract probabilities at positions 1-4 and compute Top-K for all 4 positions at once
329 pos_probs, pos_ids = torch.topk(probs[1:5], k=keep_k, dim=-1)
330 mask = (pos_ids >= coord_start_token_id) & (pos_ids <= coord_end_token_id)
331 has_valid = mask.any(dim=-1) # shape: [4]
332 if not has_valid.all():
333 return None # not a box, exit...
334
335 first_valid_idx = mask.long().argmax(dim=-1, keepdim=True) # [4, 1]
336 # Extract highest-probability valid_probs[0] and corresponding valid_ids[0]
337 first_valid_probs = pos_probs.gather(-1, first_valid_idx).squeeze(-1) # [4]
338 first_valid_ids = pos_ids.gather(-1, first_valid_idx).squeeze(-1) # [4]
339 if generation_mode == 'hybrid':
340 valid_counts = mask.sum(dim=-1) # [4]
341 # Compute max/min of valid ids: fill invalid positions with extreme values to avoid interfering with max/min
342 LARGE_NUM, SMALL_NUM = 999999, -999999
343 valid_ids_for_max = torch.where(mask, pos_ids, torch.tensor(SMALL_NUM, device=device))
344 valid_ids_for_min = torch.where(mask, pos_ids, torch.tensor(LARGE_NUM, device=device))
345
346 valid_max = valid_ids_for_max.max(dim=-1)[0]
347 valid_min = valid_ids_for_min.min(dim=-1)[0]
348
349 is_abnormal = (first_valid_probs < 0.9) & (valid_counts > 1) & ((valid_max - valid_min) > 60)
350 # is_abnormal = (first_valid_probs < 0.7) & (valid_counts > 1) & ((valid_max - valid_min) > 80)
351
352 # Normal positions take top-1 (first_valid_ids); abnormal positions are replaced with 0
353 final_coords = torch.where(is_abnormal, torch.tensor(0, device=pos_ids.device), first_valid_ids)
354 elif generation_mode == 'fast':
355 final_coords = first_valid_ids
356
357
358 start_t = torch.tensor([box_start_token_id], dtype=final_coords.dtype, device=device)
359 end_t = torch.tensor([box_end_token_id], dtype=final_coords.dtype, device=device)
360
361 return torch.cat([start_t, final_coords, end_t])
362
363
364 def decode_ref(
365 logits,
366 probs,
367 token_ids: Dict[str, int],
368 keep_k=5,
369 start_thresh=0.6,
370 ):
371 ref_start_token_id = token_ids.get('ref_start_token_id')
372 coord_start_token_id = token_ids['coord_start_token_id']
373 coord_end_token_id = token_ids['coord_end_token_id']
374 device = probs.device
375 L = probs.size(0)
376
377 # 1. Check if the first position is <ref> and its probability meets start_thresh
378 # Note: we directly use the probability of the ref token at position 0 for the check
379 if probs[0, ref_start_token_id] < start_thresh:
380 return None
381
382 # 2. Extract Top-K probabilities and token IDs for all subsequent positions
383 pos_probs, pos_ids = torch.topk(probs[1:], k=keep_k, dim=-1) # shape: [L-1, keep_k]
384
385 # 3. Build mask: identify coordinate tokens (<0> ~ <1000>)
386 is_coord = (pos_ids >= coord_start_token_id) & (pos_ids <= coord_end_token_id)
387 # Invert: valid tokens are non-coordinate tokens
388 is_valid = ~is_coord # shape: [L-1, keep_k]
389
390 # Ensure each position has at least one non-coordinate valid token in its Top-K
391 has_valid = is_valid.any(dim=-1) # shape: [L-1]
392 if not has_valid.all():
393 return None
394
395 # 4. Get the highest-probability valid token
396 # Since topk results are sorted in descending order of probability,
397 # argmax returns the first index where is_valid is True, i.e., the index of the most probable valid token
398 first_valid_idx = is_valid.long().argmax(dim=-1, keepdim=True) # shape: [L-1, 1]
399
400 # Extract the final token IDs
401 final_text_ids = pos_ids.gather(-1, first_valid_idx).squeeze(-1) # shape: [L-1]
402
403 start_t = torch.tensor([ref_start_token_id], dtype=final_text_ids.dtype, device=device)
404
405 return torch.cat([start_t, final_text_ids])
406
407
408 def handle_pattern(x0, token_ids: Dict[str, int], generation_mode: str = 'hybrid'):
409 """
410 Args:
411 x0: Token ID list of length 6
412 token_ids: Dictionary containing all token IDs
413 """
414 null_token_id = token_ids['null_token_id']
415 im_end_token_id = token_ids['im_end_token_id']
416 box_start_token_id = token_ids['box_start_token_id']
417 box_end_token_id = token_ids['box_end_token_id']
418 none_token_id = token_ids['none_token_id']
419 coord_start_token_id = token_ids['coord_start_token_id']
420 coord_end_token_id = token_ids['coord_end_token_id']
421 ref_end_token_id = token_ids['ref_end_token_id']
422
423 x0 = x0.tolist()
424
425 if x0[0] == null_token_id:
426 return {
427 "type": "im_end",
428 "tokens": [im_end_token_id],
429 "need_switch_to_ar": False,
430 "is_terminal": True,
431 }
432 elif x0[0] == im_end_token_id:
433 return {
434 "type": "im_end",
435 "tokens": [im_end_token_id],
436 "need_switch_to_ar": False,
437 "is_terminal": True,
438 }
439 elif x0[:2] == [box_start_token_id, none_token_id]:
440 return {
441 "type": "empty_box",
442 "tokens": [box_start_token_id, none_token_id, box_end_token_id],
443 "need_switch_to_ar": False,
444 "is_terminal": False,
445 }
446 elif x0[0] == box_start_token_id:
447 coord_ix = 1
448 for coord in x0[1:5]:
449 if coord_start_token_id <= coord <= coord_end_token_id:
450 coord_ix += 1
451 else:
452 break
453
454 # Standard 4-coordinate bbox: <box><x1><x2><y1><y2></box>
455 if coord_ix == 5 and x0[5] == box_end_token_id:
456 return {
457 "type": "coord_box",
458 "tokens": x0,
459 "need_switch_to_ar": False,
460 "is_terminal": False,
461 }
462 # Two-coordinate pointing: <box><x><y></box>
463 # Convention: the first two coordinates are valid coord tokens, the third token is box_end.
464 # Remaining positions (if any) are not part of the pattern; truncate at box_end.
465 elif coord_ix == 3 and x0[3] == box_end_token_id:
466 return {
467 "type": "point_box",
468 "tokens": x0[:4],
469 "need_switch_to_ar": False,
470 "is_terminal": False,
471 }
472 else:
473 if generation_mode == 'fast':
474 # fast mode: treat as coord_box, stay in MTP
475 return {
476 "type": "coord_box",
477 "tokens": x0,
478 "need_switch_to_ar": False,
479 "is_terminal": False,
480 }
481 else:
482 # hybrid mode: error_box, switch to AR
483 return {
484 "type": "error_box",
485 "tokens": x0[:coord_ix],
486 "need_switch_to_ar": True,
487 "is_terminal": False,
488 }
489
490 else:
491 for i, token in enumerate(x0):
492 if token == null_token_id:
493 x0 = x0[:i]
494 break
495
496 if len(x0) >= 2 and x0[-1] == x0[-2] == ref_end_token_id:
497 x0 = x0[:-1]
498
499 return {
500 "type": "ref_object",
501 "tokens": x0,
502 "need_switch_to_ar": False,
503 "is_terminal": False,
504 }