generate_utils.py
| 1 | # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. |
| 2 | # |
| 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property |
| 4 | # and proprietary rights in and to this software, related documentation |
| 5 | # and any modifications thereto. Any use, reproduction, disclosure or |
| 6 | # distribution of this software and related documentation without an express |
| 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. |
| 8 | |
| 9 | import torch |
| 10 | import torch.nn.functional as F |
| 11 | import torch.distributions as dists |
| 12 | from typing import Dict, Optional |
| 13 | |
| 14 | |
| 15 | def get_token_ids_from_config(config) -> Dict[str, int]: |
| 16 | """Extract all token IDs from the configuration object. |
| 17 | |
| 18 | Args: |
| 19 | config: Configuration object (LocateAnythingConfig or similar) |
| 20 | |
| 21 | Returns: |
| 22 | Dictionary containing all token IDs |
| 23 | """ |
| 24 | token_ids = {} |
| 25 | |
| 26 | # Get from main config |
| 27 | token_ids['box_start_token_id'] = getattr(config, 'box_start_token_id', 151668) |
| 28 | token_ids['box_end_token_id'] = getattr(config, 'box_end_token_id', 151669) |
| 29 | token_ids['coord_start_token_id'] = getattr(config, 'coord_start_token_id', 151677) |
| 30 | token_ids['coord_end_token_id'] = getattr(config, 'coord_end_token_id', 152677) |
| 31 | token_ids['ref_start_token_id'] = getattr(config, 'ref_start_token_id', 151672) |
| 32 | token_ids['ref_end_token_id'] = getattr(config, 'ref_end_token_id', 151673) |
| 33 | token_ids['none_token_id'] = getattr(config, 'none_token_id', 4064) |
| 34 | |
| 35 | # Get from text_config |
| 36 | text_config = getattr(config, 'text_config', None) |
| 37 | if text_config is not None: |
| 38 | token_ids['null_token_id'] = getattr(text_config, 'null_token_id', 152678) |
| 39 | token_ids['im_end_token_id'] = getattr(text_config, 'eos_token_id', 151645) |
| 40 | token_ids['switch_token_id'] = getattr(text_config, 'switch_token_id', 152679) |
| 41 | token_ids['default_mask_token_id'] = getattr(text_config, 'text_mask_token_id', 151676) |
| 42 | else: |
| 43 | token_ids['null_token_id'] = 152678 |
| 44 | token_ids['im_end_token_id'] = 151645 |
| 45 | token_ids['switch_token_id'] = 152679 |
| 46 | token_ids['default_mask_token_id'] = 151676 |
| 47 | |
| 48 | return token_ids |
| 49 | |
| 50 | |
| 51 | def top_p_logits( |
| 52 | logits: torch.Tensor, |
| 53 | top_p: float = None |
| 54 | ) -> torch.Tensor: |
| 55 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| 56 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| 57 | sorted_indices_to_remove = cumulative_probs > top_p |
| 58 | # Shift the indices to the right to keep the first token above the threshold |
| 59 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| 60 | sorted_indices_to_remove[..., 0] = 0 |
| 61 | |
| 62 | mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) |
| 63 | mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) |
| 64 | logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) |
| 65 | return logits |
| 66 | |
| 67 | |
| 68 | def top_k_logits( |
| 69 | logits: torch.Tensor, |
| 70 | top_k: int = None |
| 71 | ) -> torch.Tensor: |
| 72 | top_k = min(top_k, logits.size(-1)) # Safety check |
| 73 | # Remove all tokens with a probability less than the last token of the top-k |
| 74 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| 75 | logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) |
| 76 | return logits |
| 77 | |
| 78 | |
| 79 | def apply_repetition_penalty( |
| 80 | logits: torch.Tensor, |
| 81 | input_ids: torch.Tensor, |
| 82 | repetition_penalty: float = 1.0 |
| 83 | ) -> torch.Tensor: |
| 84 | """ |
| 85 | Apply repetition penalty to logits. |
| 86 | |
| 87 | Args: |
| 88 | logits: Shape [batch_size, seq_len, vocab_size] or [batch_size, vocab_size] |
| 89 | input_ids: Previously generated token ids, shape [batch_size, seq_len] |
| 90 | repetition_penalty: Penalty factor. > 1.0 penalizes repetition, < 1.0 encourages it. |
| 91 | |
| 92 | Returns: |
| 93 | Modified logits with repetition penalty applied. |
| 94 | """ |
| 95 | if repetition_penalty == 1.0: |
| 96 | return logits |
| 97 | |
| 98 | # Convert to 3D for vectorized computation |
| 99 | if logits.dim() == 2: |
| 100 | logits = logits.unsqueeze(1) # [B, 1, V] |
| 101 | squeeze_back = True |
| 102 | else: |
| 103 | squeeze_back = False |
| 104 | |
| 105 | batch_size, seq_len, vocab_size = logits.shape |
| 106 | |
| 107 | # Construct [B, V] bool mask marking tokens that have appeared in each batch |
| 108 | device = logits.device |
| 109 | token_mask = torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device) |
| 110 | for b in range(batch_size): |
| 111 | # Apply penalty only based on tokens already generated in this batch |
| 112 | unique_tokens = input_ids[b].unique() |
| 113 | # Prevent out-of-bounds: only keep IDs within vocab range |
| 114 | valid_tokens = unique_tokens[(unique_tokens >= 0) & (unique_tokens < vocab_size)] |
| 115 | if valid_tokens.numel() > 0: |
| 116 | token_mask[b, valid_tokens] = True |
| 117 | |
| 118 | # Expand to [B, L, V] to align with logits |
| 119 | token_mask = token_mask.unsqueeze(1).expand(-1, seq_len, -1) |
| 120 | |
| 121 | # Divide positive values by penalty, multiply negative values by penalty |
| 122 | positive = logits > 0 |
| 123 | negative = ~positive |
| 124 | |
| 125 | # Apply penalty only at mask positions |
| 126 | logits = torch.where(token_mask & positive, logits / repetition_penalty, logits) |
| 127 | logits = torch.where(token_mask & negative, logits * repetition_penalty, logits) |
| 128 | |
| 129 | if squeeze_back: |
| 130 | logits = logits.squeeze(1) |
| 131 | |
| 132 | return logits |
| 133 | |
| 134 | |
| 135 | def sample_tokens( |
| 136 | logits: torch.Tensor, |
| 137 | generated: torch.Tensor, |
| 138 | token_ids: Dict[str, int], |
| 139 | **generate_kwargs, |
| 140 | ): |
| 141 | batch_size, seq_len, vocab_size = logits.shape |
| 142 | |
| 143 | repetition_penalty = generate_kwargs.get('repetition_penalty', 1.0) |
| 144 | temperature = generate_kwargs.get('temperature', 0) |
| 145 | top_p = generate_kwargs.get('top_p', None) |
| 146 | top_k = generate_kwargs.get('top_k', None) |
| 147 | |
| 148 | # Apply repetition penalty based on all previously generated tokens |
| 149 | if repetition_penalty != 1.0: |
| 150 | logits = apply_repetition_penalty(logits, generated, repetition_penalty) |
| 151 | |
| 152 | if temperature > 0: |
| 153 | logits = logits / temperature |
| 154 | if top_p is not None and top_p < 1: |
| 155 | logits = top_p_logits(logits, top_p) |
| 156 | if top_k is not None: |
| 157 | logits = top_k_logits(logits, top_k) |
| 158 | |
| 159 | probs = torch.softmax(logits, dim=-1) |
| 160 | |
| 161 | if temperature > 0: |
| 162 | try: |
| 163 | x0 = dists.Categorical(probs=probs).sample() |
| 164 | confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) |
| 165 | except Exception: |
| 166 | confidence, x0 = probs.max(dim=-1) |
| 167 | else: |
| 168 | confidence, x0 = probs.max(dim=-1) |
| 169 | |
| 170 | if seq_len == 1: |
| 171 | return probs, confidence, x0, None |
| 172 | |
| 173 | box_avg = [] |
| 174 | fallback_box = torch.zeros(1, dtype=x0.dtype, device=x0.device) |
| 175 | |
| 176 | for b in range(batch_size): |
| 177 | decoded_box = decode_bbox_avg( |
| 178 | logits[b], probs[b], token_ids, keep_k=generate_kwargs.get('keep_k_avg', 4), |
| 179 | generation_mode=generate_kwargs.get('generation_mode', 'hybrid'), |
| 180 | ) |
| 181 | if decoded_box is not None: |
| 182 | box_avg.append(decoded_box) |
| 183 | else: |
| 184 | out_ref = decode_ref(logits[b], probs[b], token_ids) |
| 185 | if out_ref is not None: |
| 186 | box_avg.append(torch.tensor(out_ref, dtype=x0.dtype, device=x0.device)) |
| 187 | else: |
| 188 | box_avg.append(fallback_box) |
| 189 | |
| 190 | box_avg = torch.stack(box_avg) |
| 191 | |
| 192 | return probs, confidence, x0, box_avg |
| 193 | |
| 194 | |
| 195 | def sample_tokens_ar( |
| 196 | logits: torch.Tensor, |
| 197 | generated: torch.Tensor, |
| 198 | token_ids: Dict[str, int], |
| 199 | **generate_kwargs, |
| 200 | ): |
| 201 | """ |
| 202 | Lightweight sampling function for AR single-step sampling only. |
| 203 | |
| 204 | Args: |
| 205 | logits: [batch_size, vocab_size] or [batch_size, 1, vocab_size] |
| 206 | generated: [batch_size, seq_len] |
| 207 | """ |
| 208 | # Convert to 3D for reusing repetition penalty and clipping logic |
| 209 | if logits.dim() == 2: |
| 210 | logits = logits.unsqueeze(1) # [B, 1, V] |
| 211 | batch_size, seq_len, vocab_size = logits.shape |
| 212 | assert seq_len == 1, "sample_tokens_ar only supports single-step AR sampling (seq_len == 1)" |
| 213 | |
| 214 | repetition_penalty = generate_kwargs.get('repetition_penalty', 1.0) |
| 215 | temperature = generate_kwargs.get('temperature', 0) |
| 216 | top_p = generate_kwargs.get('top_p', None) |
| 217 | top_k = generate_kwargs.get('top_k', None) |
| 218 | |
| 219 | # Apply repetition penalty only based on historically generated tokens |
| 220 | if repetition_penalty != 1.0: |
| 221 | logits = apply_repetition_penalty(logits, generated, repetition_penalty) |
| 222 | |
| 223 | if temperature > 0: |
| 224 | logits = logits / temperature |
| 225 | if top_p is not None and top_p < 1: |
| 226 | logits = top_p_logits(logits, top_p) |
| 227 | if top_k is not None: |
| 228 | logits = top_k_logits(logits, top_k) |
| 229 | |
| 230 | probs = torch.softmax(logits, dim=-1) |
| 231 | |
| 232 | if temperature > 0: |
| 233 | try: |
| 234 | x0 = dists.Categorical(probs=probs).sample() |
| 235 | confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) |
| 236 | except Exception: |
| 237 | confidence, x0 = probs.max(dim=-1) |
| 238 | else: |
| 239 | # For greedy: directly take the token with maximum probability |
| 240 | confidence, x0 = probs.max(dim=-1) |
| 241 | |
| 242 | # Keep interface consistent with sample_tokens: return [B, 1, V] / [B, 1] shape |
| 243 | return probs, confidence, x0, None, None |
| 244 | |
| 245 | |
| 246 | def is_valid_box_frame( |
| 247 | probs, |
| 248 | token_ids: Dict[str, int], |
| 249 | start_thresh=0.6, |
| 250 | end_thresh=0.2, |
| 251 | topk=5, |
| 252 | ): |
| 253 | box_start_token_id = token_ids['box_start_token_id'] |
| 254 | box_end_token_id = token_ids['box_end_token_id'] |
| 255 | null_token_id = token_ids['null_token_id'] |
| 256 | im_end_token_id = token_ids['im_end_token_id'] |
| 257 | none_token_id = token_ids['none_token_id'] # none |
| 258 | |
| 259 | p_start = probs[0, box_start_token_id] |
| 260 | if p_start >= start_thresh: |
| 261 | if (probs[1, none_token_id] > 0.2 and |
| 262 | probs[2, box_end_token_id] > 0.2 and |
| 263 | probs[3, null_token_id] > 0.1 and |
| 264 | probs[4, null_token_id] > 0.1): |
| 265 | return 'empty_box' |
| 266 | |
| 267 | end_target_ids = torch.tensor([box_end_token_id, null_token_id, im_end_token_id], device=probs.device) |
| 268 | end_score = probs[5, end_target_ids].sum() |
| 269 | |
| 270 | if end_score >= end_thresh: |
| 271 | return 'legal_box' |
| 272 | |
| 273 | return 'illegal_box' |
| 274 | |
| 275 | |
| 276 | def decode_bbox_avg( |
| 277 | logits, |
| 278 | probs, |
| 279 | token_ids: Dict[str, int], |
| 280 | keep_k=5, |
| 281 | start_thresh=0.7, |
| 282 | end_thresh=0.2, |
| 283 | generation_mode: str = 'hybrid', |
| 284 | ): |
| 285 | """ |
| 286 | Decode bounding box coordinates using top-k weighted average. |
| 287 | |
| 288 | Args: |
| 289 | logits: Logits of shape (6, vocab_size) |
| 290 | probs: Probability distribution of shape (6, vocab_size) |
| 291 | token_ids: Dictionary containing all token IDs |
| 292 | keep_k: Number of top-k candidate tokens to keep at each position |
| 293 | start_thresh: Confidence threshold for box start token |
| 294 | end_thresh: Confidence threshold for box end token |
| 295 | |
| 296 | Returns: |
| 297 | Decoded bounding box coordinate list in format [box_start, x1, x2, y1, y2, box_end], |
| 298 | or None if decoding fails |
| 299 | """ |
| 300 | coord_start_token_id = token_ids['coord_start_token_id'] |
| 301 | coord_end_token_id = token_ids['coord_end_token_id'] |
| 302 | box_start_token_id = token_ids['box_start_token_id'] |
| 303 | box_end_token_id = token_ids['box_end_token_id'] |
| 304 | none_token_id = token_ids['none_token_id'] |
| 305 | |
| 306 | device = logits.device |
| 307 | |
| 308 | box_type = is_valid_box_frame( |
| 309 | probs, |
| 310 | token_ids, |
| 311 | start_thresh=start_thresh, |
| 312 | end_thresh=end_thresh, |
| 313 | topk=keep_k |
| 314 | ) |
| 315 | if box_type == 'empty_box': |
| 316 | # Handle the <box>none</box> case first |
| 317 | return torch.tensor([ |
| 318 | box_start_token_id, |
| 319 | none_token_id, |
| 320 | box_end_token_id, |
| 321 | token_ids['null_token_id'], |
| 322 | token_ids['null_token_id'], |
| 323 | token_ids['null_token_id'] |
| 324 | ], dtype=torch.long, device=probs.device) |
| 325 | elif box_type == 'illegal_box': |
| 326 | return None |
| 327 | |
| 328 | # Extract probabilities at positions 1-4 and compute Top-K for all 4 positions at once |
| 329 | pos_probs, pos_ids = torch.topk(probs[1:5], k=keep_k, dim=-1) |
| 330 | mask = (pos_ids >= coord_start_token_id) & (pos_ids <= coord_end_token_id) |
| 331 | has_valid = mask.any(dim=-1) # shape: [4] |
| 332 | if not has_valid.all(): |
| 333 | return None # not a box, exit... |
| 334 | |
| 335 | first_valid_idx = mask.long().argmax(dim=-1, keepdim=True) # [4, 1] |
| 336 | # Extract highest-probability valid_probs[0] and corresponding valid_ids[0] |
| 337 | first_valid_probs = pos_probs.gather(-1, first_valid_idx).squeeze(-1) # [4] |
| 338 | first_valid_ids = pos_ids.gather(-1, first_valid_idx).squeeze(-1) # [4] |
| 339 | if generation_mode == 'hybrid': |
| 340 | valid_counts = mask.sum(dim=-1) # [4] |
| 341 | # Compute max/min of valid ids: fill invalid positions with extreme values to avoid interfering with max/min |
| 342 | LARGE_NUM, SMALL_NUM = 999999, -999999 |
| 343 | valid_ids_for_max = torch.where(mask, pos_ids, torch.tensor(SMALL_NUM, device=device)) |
| 344 | valid_ids_for_min = torch.where(mask, pos_ids, torch.tensor(LARGE_NUM, device=device)) |
| 345 | |
| 346 | valid_max = valid_ids_for_max.max(dim=-1)[0] |
| 347 | valid_min = valid_ids_for_min.min(dim=-1)[0] |
| 348 | |
| 349 | is_abnormal = (first_valid_probs < 0.9) & (valid_counts > 1) & ((valid_max - valid_min) > 60) |
| 350 | # is_abnormal = (first_valid_probs < 0.7) & (valid_counts > 1) & ((valid_max - valid_min) > 80) |
| 351 | |
| 352 | # Normal positions take top-1 (first_valid_ids); abnormal positions are replaced with 0 |
| 353 | final_coords = torch.where(is_abnormal, torch.tensor(0, device=pos_ids.device), first_valid_ids) |
| 354 | elif generation_mode == 'fast': |
| 355 | final_coords = first_valid_ids |
| 356 | |
| 357 | |
| 358 | start_t = torch.tensor([box_start_token_id], dtype=final_coords.dtype, device=device) |
| 359 | end_t = torch.tensor([box_end_token_id], dtype=final_coords.dtype, device=device) |
| 360 | |
| 361 | return torch.cat([start_t, final_coords, end_t]) |
| 362 | |
| 363 | |
| 364 | def decode_ref( |
| 365 | logits, |
| 366 | probs, |
| 367 | token_ids: Dict[str, int], |
| 368 | keep_k=5, |
| 369 | start_thresh=0.6, |
| 370 | ): |
| 371 | ref_start_token_id = token_ids.get('ref_start_token_id') |
| 372 | coord_start_token_id = token_ids['coord_start_token_id'] |
| 373 | coord_end_token_id = token_ids['coord_end_token_id'] |
| 374 | device = probs.device |
| 375 | L = probs.size(0) |
| 376 | |
| 377 | # 1. Check if the first position is <ref> and its probability meets start_thresh |
| 378 | # Note: we directly use the probability of the ref token at position 0 for the check |
| 379 | if probs[0, ref_start_token_id] < start_thresh: |
| 380 | return None |
| 381 | |
| 382 | # 2. Extract Top-K probabilities and token IDs for all subsequent positions |
| 383 | pos_probs, pos_ids = torch.topk(probs[1:], k=keep_k, dim=-1) # shape: [L-1, keep_k] |
| 384 | |
| 385 | # 3. Build mask: identify coordinate tokens (<0> ~ <1000>) |
| 386 | is_coord = (pos_ids >= coord_start_token_id) & (pos_ids <= coord_end_token_id) |
| 387 | # Invert: valid tokens are non-coordinate tokens |
| 388 | is_valid = ~is_coord # shape: [L-1, keep_k] |
| 389 | |
| 390 | # Ensure each position has at least one non-coordinate valid token in its Top-K |
| 391 | has_valid = is_valid.any(dim=-1) # shape: [L-1] |
| 392 | if not has_valid.all(): |
| 393 | return None |
| 394 | |
| 395 | # 4. Get the highest-probability valid token |
| 396 | # Since topk results are sorted in descending order of probability, |
| 397 | # argmax returns the first index where is_valid is True, i.e., the index of the most probable valid token |
| 398 | first_valid_idx = is_valid.long().argmax(dim=-1, keepdim=True) # shape: [L-1, 1] |
| 399 | |
| 400 | # Extract the final token IDs |
| 401 | final_text_ids = pos_ids.gather(-1, first_valid_idx).squeeze(-1) # shape: [L-1] |
| 402 | |
| 403 | start_t = torch.tensor([ref_start_token_id], dtype=final_text_ids.dtype, device=device) |
| 404 | |
| 405 | return torch.cat([start_t, final_text_ids]) |
| 406 | |
| 407 | |
| 408 | def handle_pattern(x0, token_ids: Dict[str, int], generation_mode: str = 'hybrid'): |
| 409 | """ |
| 410 | Args: |
| 411 | x0: Token ID list of length 6 |
| 412 | token_ids: Dictionary containing all token IDs |
| 413 | """ |
| 414 | null_token_id = token_ids['null_token_id'] |
| 415 | im_end_token_id = token_ids['im_end_token_id'] |
| 416 | box_start_token_id = token_ids['box_start_token_id'] |
| 417 | box_end_token_id = token_ids['box_end_token_id'] |
| 418 | none_token_id = token_ids['none_token_id'] |
| 419 | coord_start_token_id = token_ids['coord_start_token_id'] |
| 420 | coord_end_token_id = token_ids['coord_end_token_id'] |
| 421 | ref_end_token_id = token_ids['ref_end_token_id'] |
| 422 | |
| 423 | x0 = x0.tolist() |
| 424 | |
| 425 | if x0[0] == null_token_id: |
| 426 | return { |
| 427 | "type": "im_end", |
| 428 | "tokens": [im_end_token_id], |
| 429 | "need_switch_to_ar": False, |
| 430 | "is_terminal": True, |
| 431 | } |
| 432 | elif x0[0] == im_end_token_id: |
| 433 | return { |
| 434 | "type": "im_end", |
| 435 | "tokens": [im_end_token_id], |
| 436 | "need_switch_to_ar": False, |
| 437 | "is_terminal": True, |
| 438 | } |
| 439 | elif x0[:2] == [box_start_token_id, none_token_id]: |
| 440 | return { |
| 441 | "type": "empty_box", |
| 442 | "tokens": [box_start_token_id, none_token_id, box_end_token_id], |
| 443 | "need_switch_to_ar": False, |
| 444 | "is_terminal": False, |
| 445 | } |
| 446 | elif x0[0] == box_start_token_id: |
| 447 | coord_ix = 1 |
| 448 | for coord in x0[1:5]: |
| 449 | if coord_start_token_id <= coord <= coord_end_token_id: |
| 450 | coord_ix += 1 |
| 451 | else: |
| 452 | break |
| 453 | |
| 454 | # Standard 4-coordinate bbox: <box><x1><x2><y1><y2></box> |
| 455 | if coord_ix == 5 and x0[5] == box_end_token_id: |
| 456 | return { |
| 457 | "type": "coord_box", |
| 458 | "tokens": x0, |
| 459 | "need_switch_to_ar": False, |
| 460 | "is_terminal": False, |
| 461 | } |
| 462 | # Two-coordinate pointing: <box><x><y></box> |
| 463 | # Convention: the first two coordinates are valid coord tokens, the third token is box_end. |
| 464 | # Remaining positions (if any) are not part of the pattern; truncate at box_end. |
| 465 | elif coord_ix == 3 and x0[3] == box_end_token_id: |
| 466 | return { |
| 467 | "type": "point_box", |
| 468 | "tokens": x0[:4], |
| 469 | "need_switch_to_ar": False, |
| 470 | "is_terminal": False, |
| 471 | } |
| 472 | else: |
| 473 | if generation_mode == 'fast': |
| 474 | # fast mode: treat as coord_box, stay in MTP |
| 475 | return { |
| 476 | "type": "coord_box", |
| 477 | "tokens": x0, |
| 478 | "need_switch_to_ar": False, |
| 479 | "is_terminal": False, |
| 480 | } |
| 481 | else: |
| 482 | # hybrid mode: error_box, switch to AR |
| 483 | return { |
| 484 | "type": "error_box", |
| 485 | "tokens": x0[:coord_ix], |
| 486 | "need_switch_to_ar": True, |
| 487 | "is_terminal": False, |
| 488 | } |
| 489 | |
| 490 | else: |
| 491 | for i, token in enumerate(x0): |
| 492 | if token == null_token_id: |
| 493 | x0 = x0[:i] |
| 494 | break |
| 495 | |
| 496 | if len(x0) >= 2 and x0[-1] == x0[-2] == ref_end_token_id: |
| 497 | x0 = x0[:-1] |
| 498 | |
| 499 | return { |
| 500 | "type": "ref_object", |
| 501 | "tokens": x0, |
| 502 | "need_switch_to_ar": False, |
| 503 | "is_terminal": False, |
| 504 | } |