batch_utils/engine_hybrid.py
51.2 KB · 1358 lines · python Raw
1 """Batched hybrid-mode generation for LocateAnything-3B.
2
3 This module keeps the stock hybrid state machine:
4
5 MTP -> error_box -> AR
6 AR -> box_end_ar -> MTP
7
8 Rows in a batch may be in different modes. The decode loop therefore stores
9 per-row KV caches, packs rows with the same mode for one forward call, then
10 unpacks the clean KV back per row.
11 """
12 import copy
13 import importlib
14 import os
15
16 import torch
17
18
19 from .hybrid_runtime import (
20 ATTN_MODE,
21 AR_BATCH_SAN,
22 BATCH_SAN,
23 DEV,
24 N_FUTURE,
25 _encode_images,
26 _helpers,
27 _pad_generated,
28 _set_llm_mode,
29 _tokenize,
30 _tokenize_cached_image,
31 build_magi_scheduler_ranges,
32 language_model_forward,
33 load,
34 sample_next_tokens_batched,
35 sample_tokens_batched,
36 )
37
38
39 README_MAX_NEW_TOKENS = 2048
40 README_TEMPERATURE = 0.7
41 README_TOP_P = 0.9
42 README_REPETITION_PENALTY = 1.1
43
44 _LAST_HYBRID_STATS = None
45
46
47 def _row_len(kv):
48 return kv[0][0].shape[2]
49
50
51 def _pack_stock_kv_rows(kv_rows, rows, dev):
52 """Left-pad per-row real-token KV caches for stock-style decoding."""
53 lengths = [0 if kv_rows[r] is None else _row_len(kv_rows[r]) for r in rows]
54 kmax = max(lengths) if lengths else 0
55 if kmax == 0:
56 return None, torch.zeros((len(rows), 0), dtype=torch.long, device=dev), lengths, 0
57
58 ref = next(kv_rows[r] for r in rows if kv_rows[r] is not None)
59 packed = []
60 for layer in range(len(ref)):
61 ref_k, ref_v = ref[layer]
62 ks, vs = [], []
63 for r, length in zip(rows, lengths):
64 if length == 0:
65 k = ref_k.new_zeros((1, ref_k.shape[1], kmax, ref_k.shape[3]))
66 v = ref_v.new_zeros((1, ref_v.shape[1], kmax, ref_v.shape[3]))
67 else:
68 k, v = kv_rows[r][layer]
69 if length < kmax:
70 pad_shape = (1, k.shape[1], kmax - length, k.shape[3])
71 k = torch.cat([k.new_zeros(pad_shape), k], dim=2)
72 v = torch.cat([v.new_zeros(pad_shape), v], dim=2)
73 ks.append(k)
74 vs.append(v)
75 packed.append((torch.cat(ks, dim=0), torch.cat(vs, dim=0)))
76
77 kvalid = torch.zeros((len(rows), kmax), dtype=torch.long, device=dev)
78 for i, length in enumerate(lengths):
79 if length:
80 kvalid[i, kmax - length :] = 1
81 return tuple(packed), kvalid, lengths, kmax
82
83
84 def _unpack_stock_after_forward(out_kv, local_row, old_len, uncached_len, kmax, umax):
85 """Keep old real KV plus the right-aligned uncached real tokens; drop pads/window."""
86 out = []
87 u0 = kmax + (umax - uncached_len)
88 u1 = kmax + umax
89 for k, v in out_kv:
90 parts_k, parts_v = [], []
91 if old_len:
92 parts_k.append(k[local_row : local_row + 1, :, kmax - old_len : kmax, :])
93 parts_v.append(v[local_row : local_row + 1, :, kmax - old_len : kmax, :])
94 if uncached_len:
95 parts_k.append(k[local_row : local_row + 1, :, u0:u1, :])
96 parts_v.append(v[local_row : local_row + 1, :, u0:u1, :])
97 out.append((torch.cat(parts_k, dim=2).contiguous(),
98 torch.cat(parts_v, dim=2).contiguous()))
99 return tuple(out)
100
101
102 def _mk_generate_kwargs(temperature, top_p, top_k, repetition_penalty, row_temp=None):
103 t = temperature if row_temp is None else row_temp
104 gk = {"repetition_penalty": repetition_penalty, "generation_mode": "hybrid"}
105 if t and t > 0:
106 gk["temperature"] = t
107 if top_p is not None:
108 gk["top_p"] = top_p
109 if top_k is not None:
110 gk["top_k"] = top_k
111 return gk
112
113
114 def _classify_ar_token(token_val, tids):
115 if token_val == tids["box_end_token_id"]:
116 return "box_end_ar"
117 if tids["coord_start_token_id"] <= token_val <= tids["coord_end_token_id"]:
118 return "coord_ar"
119 if token_val == tids["none_token_id"]:
120 return "coord_ar"
121 return "im_end"
122
123
124 def _env_flag(name, default=False):
125 val = os.environ.get(name)
126 if val is None:
127 return default
128 return val.lower() not in {"0", "false", "no", "off", ""}
129
130
131 def _env_int(name, default):
132 val = os.environ.get(name)
133 if val is None or val == "":
134 return default
135 return int(val)
136
137
138 def _kv_pack_token_budget():
139 return max(0, _env_int("LA_FLASH_KV_PACK_TOKEN_BUDGET", 0))
140
141
142 def _debug_enabled(debug):
143 return _env_flag("LA_FLASH_DEBUG", False) if debug is None else bool(debug)
144
145
146 def _new_hybrid_stats(total_rows, scheduler, group_size, hold_max_steps, adaptive_hold_mtp_max=0):
147 return {
148 "scheduler": scheduler,
149 "requested_group_size": int(group_size or 0),
150 "hold_max_steps": int(hold_max_steps),
151 "adaptive_hold_mtp_max": int(adaptive_hold_mtp_max),
152 "input_batches": 1,
153 "input_rows": int(total_rows),
154 "groups": 0,
155 "group_sizes": [],
156 "decode_loops": 0,
157 "mixed_mode_cycles": 0,
158 "eager_mtp_then_ar_cycles": 0,
159 "ar_first_cycles": 0,
160 "pipeline_ar_after_mtp_cycles": 0,
161 "adaptive_hold_cycles": 0,
162 "adaptive_ar_first_cycles": 0,
163 "hold_ar_steps": 0,
164 "hold_ar_held_mtp_rows": 0,
165 "hold_ar_limit_mtp_forwards": 0,
166 "mtp_forwards": 0,
167 "ar_forwards": 0,
168 "mtp_forward_rows": 0,
169 "ar_forward_rows": 0,
170 "mtp_forward_query_tokens": 0,
171 "ar_forward_query_tokens": 0,
172 "max_mtp_forward_rows": 0,
173 "max_ar_forward_rows": 0,
174 "mtp_max_uncached_len": 0,
175 "ar_max_uncached_len": 0,
176 "mtp_forward_row_hist": {},
177 "ar_forward_row_hist": {},
178 "prompt_prefill_mode": _hybrid_prefill_mode(),
179 "prompt_prefill_forwards": 0,
180 "prompt_prefill_forward_rows": 0,
181 "prompt_prefill_forward_query_tokens": 0,
182 "prompt_prefill_real_tokens": 0,
183 "prompt_prefill_shared_groups": 0,
184 "prompt_prefill_shared_rows": 0,
185 "prompt_prefill_shared_saved_tokens": 0,
186 "kv_bucket_splits": 0,
187 "kv_bucket_groups": 0,
188 "kv_bucket_max_packed_tokens": 0,
189 }
190
191
192 def _set_last_hybrid_stats(stats):
193 global _LAST_HYBRID_STATS
194 _LAST_HYBRID_STATS = copy.deepcopy(stats) if stats is not None else None
195
196
197 def get_last_hybrid_stats():
198 """Return scheduler/forward statistics from the most recent hybrid batch."""
199 return copy.deepcopy(_LAST_HYBRID_STATS)
200
201
202 def _record_group_stats(stats, bsz):
203 if stats is None:
204 return
205 stats["groups"] += 1
206 stats["group_sizes"].append(int(bsz))
207
208
209 def _bump_hist(hist, val):
210 key = str(int(val))
211 hist[key] = int(hist.get(key, 0)) + 1
212
213
214 def _record_forward_stats(stats, kind, rows, q_len, uncached_lens):
215 if stats is None:
216 return
217 prefix = "mtp" if kind == "mtp" else "ar"
218 nrows = int(len(rows))
219 q_len = int(q_len)
220 stats[f"{prefix}_forwards"] += 1
221 stats[f"{prefix}_forward_rows"] += nrows
222 stats[f"{prefix}_forward_query_tokens"] += nrows * q_len
223 stats[f"max_{prefix}_forward_rows"] = max(stats[f"max_{prefix}_forward_rows"], nrows)
224 stats[f"{prefix}_max_uncached_len"] = max(
225 stats[f"{prefix}_max_uncached_len"],
226 max((int(x) for x in uncached_lens), default=0),
227 )
228 _bump_hist(stats[f"{prefix}_forward_row_hist"], nrows)
229
230
231 def _record_prefill_stats(stats, rows, q_len, real_tokens, shared_groups=0, shared_rows=0, saved_tokens=0):
232 if stats is None:
233 return
234 nrows = int(rows)
235 stats["prompt_prefill_forwards"] += 1
236 stats["prompt_prefill_forward_rows"] += nrows
237 stats["prompt_prefill_forward_query_tokens"] += nrows * int(q_len)
238 stats["prompt_prefill_real_tokens"] += int(real_tokens)
239 stats["prompt_prefill_shared_groups"] += int(shared_groups)
240 stats["prompt_prefill_shared_rows"] += int(shared_rows)
241 stats["prompt_prefill_shared_saved_tokens"] += int(saved_tokens)
242
243
244 def _split_rows_by_kv_budget(rows, kv_rows):
245 """Keep dense left-padded KV packs bounded when a few rows become long tails."""
246 budget = _kv_pack_token_budget()
247 if budget <= 0 or len(rows) <= 1:
248 return [rows]
249 lengths = [0 if kv_rows[r] is None else _row_len(kv_rows[r]) for r in rows]
250 if not lengths or max(lengths) * len(rows) <= budget:
251 return [rows]
252
253 groups = []
254 current = []
255 current_max = 0
256 for row, length in sorted(zip(rows, lengths), key=lambda item: item[1]):
257 next_max = max(current_max, int(length))
258 if current and next_max * (len(current) + 1) > budget:
259 groups.append(current)
260 current = [row]
261 current_max = int(length)
262 else:
263 current.append(row)
264 current_max = next_max
265 if current:
266 groups.append(current)
267 return groups or [rows]
268
269
270 def _record_kv_bucket_stats(stats, groups, kv_rows):
271 if stats is None:
272 return
273 max_packed = 0
274 for group in groups:
275 if not group:
276 continue
277 kmax = max((0 if kv_rows[r] is None else _row_len(kv_rows[r])) for r in group)
278 max_packed = max(max_packed, int(kmax) * len(group))
279 stats["kv_bucket_max_packed_tokens"] = max(stats["kv_bucket_max_packed_tokens"], max_packed)
280 if len(groups) > 1:
281 stats["kv_bucket_splits"] += 1
282 stats["kv_bucket_groups"] += len(groups)
283
284
285 def _hybrid_scheduler(scheduler):
286 val = os.environ.get("LA_FLASH_HYBRID_SCHEDULER", "eager") if scheduler is None else scheduler
287 val = str(val).strip().lower()
288 aliases = {
289 "": "eager",
290 "default": "eager",
291 "normal": "eager",
292 "hold": "hold_ar",
293 "hold-ar": "hold_ar",
294 "hold_mtp": "hold_ar",
295 "hold-mtp": "hold_ar",
296 "repair_first": "ar_first",
297 "repair-first": "ar_first",
298 "ar-first": "ar_first",
299 }
300 val = aliases.get(val, val)
301 if val not in {"eager", "hold_ar", "ar_first", "pipeline", "adaptive"}:
302 raise ValueError("scheduler must be one of: eager, hold_ar, ar_first, pipeline, adaptive")
303 return val
304
305
306 def _hybrid_group_size(group_size):
307 if group_size is None:
308 return max(0, _env_int("LA_FLASH_HYBRID_GROUP_SIZE", 0))
309 return max(0, int(group_size))
310
311
312 def _hybrid_prefill_mode():
313 val = os.environ.get("LA_FLASH_HYBRID_PREFILL", "shared").strip().lower()
314 aliases = {
315 "0": "none",
316 "false": "none",
317 "off": "none",
318 "legacy": "none",
319 "1": "per_row",
320 "true": "per_row",
321 "on": "per_row",
322 "single": "per_row",
323 "row": "per_row",
324 "rows": "per_row",
325 "batched": "batch",
326 "prefix": "shared",
327 "shared_prefix": "shared",
328 "shared-image": "shared",
329 "shared_image": "shared",
330 "vision": "shared",
331 }
332 val = aliases.get(val, val)
333 if val not in {"none", "per_row", "batch", "shared"}:
334 raise ValueError("LA_FLASH_HYBRID_PREFILL must be one of none, per_row, batch, shared")
335 return val
336
337
338 def _tolist(t):
339 return t.detach().cpu().tolist()
340
341
342 def _safe_decode_rows(tok, input_ids):
343 rows = []
344 for row in _tolist(input_ids):
345 try:
346 rows.append(tok.decode(torch.tensor(row), skip_special_tokens=False))
347 except Exception:
348 rows.append("<decode failed>")
349 return rows
350
351
352 def _safe_decode_row(tok, row):
353 try:
354 return tok.decode(torch.tensor(row), skip_special_tokens=False)
355 except Exception:
356 return "<decode failed>"
357
358
359 def _effective_allowed_mask(mask2d, q_len, past_len, mtp_window=False):
360 """Readable 1/0 q-by-k mask derived from the 2D key-valid mask.
361
362 This mirrors the model path at a high level:
363 causal + padding columns, then the MTP window update
364 attn[-block:, -block:] = visible and attn[-block:, -block-1] = masked.
365 """
366 rows = []
367 key_valid = mask2d.detach().cpu().bool()
368 total_len = int(key_valid.numel())
369 for qi in range(q_len):
370 q_abs = past_len + qi
371 row = []
372 for ki in range(total_len):
373 row.append(1 if bool(key_valid[ki]) and ki <= q_abs else 0)
374 rows.append(row)
375
376 if mtp_window and q_len >= N_FUTURE and total_len >= N_FUTURE:
377 q0 = q_len - N_FUTURE
378 k0 = total_len - N_FUTURE
379 for qi in range(q0, q_len):
380 for ki in range(k0, total_len):
381 rows[qi][ki] = 1
382 if k0 - 1 >= 0:
383 rows[qi][k0 - 1] = 0
384 return rows
385
386
387 def _tail_matrix(mat, rows=None, cols=None):
388 if rows is not None:
389 mat = mat[-rows:]
390 if cols is not None:
391 mat = [row[-cols:] for row in mat]
392 return mat
393
394
395 def _format_01_matrix(mat):
396 return "\n".join(" " + " ".join(str(int(v)) for v in row) for row in mat)
397
398
399 def _safe_sdpa_mask_enabled():
400 return _env_flag("LA_FLASH_SDPA_SAFE_4D_MASK", True)
401
402
403 def _build_safe_sdpa_visible_mask(attention_mask_2d, input_ids, past_len, mtp_window=False):
404 """Build a 4D 1/0 visible mask, with harmless visibility for all-masked pad queries.
405
406 The remote Qwen2 SDPA path uses a 2D key-valid mask and can create fully
407 masked query rows for left-padded, no-cache prefill. Those rows can produce
408 NaNs inside SDPA and later contaminate real tokens through masked K columns.
409 This 4D mask keeps real-token visibility identical, and only gives otherwise
410 all-masked query rows one valid fallback key so their activations stay finite.
411 """
412 bsz, q_len = int(input_ids.shape[0]), int(input_ids.shape[1])
413 key_len = int(attention_mask_2d.shape[1])
414 dev = input_ids.device
415 key_valid = attention_mask_2d.to(dtype=torch.bool, device=dev)
416 key_idx = torch.arange(key_len, device=dev).view(1, 1, key_len)
417 q_abs = (past_len + torch.arange(q_len, device=dev)).view(1, q_len, 1)
418 visible = key_valid[:, None, :] & (key_idx <= q_abs)
419
420 if mtp_window and q_len >= N_FUTURE and key_len >= N_FUTURE:
421 k0 = key_len - N_FUTURE
422 visible[:, -N_FUTURE:, k0:key_len] = key_valid[:, None, k0:key_len]
423 blocked_k = k0 - 1
424 if blocked_k >= 0:
425 visible[:, -N_FUTURE:, blocked_k] = False
426
427 row_has_key = visible.any(dim=-1)
428 fallback_rows = int((~row_has_key).sum().item())
429 if fallback_rows:
430 for b in range(bsz):
431 valid = torch.nonzero(key_valid[b], as_tuple=False).flatten()
432 fallback = int(valid[0].item()) if valid.numel() else 0
433 missing = torch.nonzero(~row_has_key[b], as_tuple=False).flatten()
434 if missing.numel():
435 visible[b, missing, fallback] = True
436
437 mask = visible[:, None, :, :].to(dtype=torch.bfloat16)
438 try:
439 mask._la_flash_visible_mask = True
440 except Exception:
441 pass
442 return mask, fallback_rows
443
444
445 def _mask_desc(mask):
446 if mask is None:
447 return "none"
448 if isinstance(mask, dict):
449 return "magi_ranges"
450 if hasattr(mask, "dim"):
451 return "4d_safe_sdpa" if mask.dim() == 4 else "2d_key_valid"
452 return type(mask).__name__
453
454
455 def _forward_attention_mask(model, input_ids, attention_mask_2d, past_len, mtp_window=False, range_plan=False):
456 llm = model.language_model.model
457 if getattr(model, "_la_flash_requested_attn", ATTN_MODE) in {"magi", "la_flash"}:
458 range_plan = build_magi_scheduler_ranges(
459 model, attention_mask_2d, input_ids, past_len, mtp_window=mtp_window)
460 if range_plan is not None:
461 return range_plan, 0
462 needs_safe_pad = (
463 past_len == 0
464 and attention_mask_2d is not None
465 and attention_mask_2d.dim() == 2
466 and input_ids.shape[0] > 1
467 )
468 if (
469 getattr(llm, "_attn_implementation", None) == "sdpa"
470 and _safe_sdpa_mask_enabled()
471 and needs_safe_pad
472 and attention_mask_2d is not None
473 and attention_mask_2d.dim() == 2
474 ):
475 return _build_safe_sdpa_visible_mask(attention_mask_2d, input_ids, past_len, mtp_window)
476 return attention_mask_2d, 0
477
478
479 def _actual_sdpa_allowed_masks(model, input_ids, attention_mask, past_len):
480 """Recreate the remote Qwen2 SDPA 4D additive mask and return a 0/1 view."""
481 llm = model.language_model.model
482 mod = importlib.import_module(type(llm).__module__)
483 bsz, q_len = int(input_ids.shape[0]), int(input_ids.shape[1])
484 dummy = torch.empty(
485 (bsz, q_len, 1),
486 dtype=torch.bfloat16,
487 device=input_ids.device,
488 )
489 mask4 = mod._prepare_4d_causal_attention_mask(
490 attention_mask,
491 (bsz, q_len),
492 dummy,
493 past_len,
494 sliding_window=getattr(llm.config, "sliding_window", None),
495 )
496 remote_ar_decode = q_len == 1 or (
497 input_ids is not None and int(input_ids[0, -1].item()) != int(llm.text_mask_token_id)
498 )
499 if not remote_ar_decode and mask4 is not None and mask4.dim() == 4:
500 rows = []
501 for b in range(bsz):
502 rows.append(
503 mod.update_causal_mask_for_one_gen_window_2d(
504 input_ids[b],
505 mask4[b][0].clone(),
506 block_size=int(llm.block_size),
507 use_cache=True,
508 causal_attn=bool(getattr(llm, "causal_attn", False)),
509 ).unsqueeze(0)
510 )
511 mask4 = torch.stack(rows, dim=0)
512 allowed = (mask4[:, 0] >= 0).to(torch.int8).detach().cpu().tolist()
513 return allowed, tuple(mask4.shape), remote_ar_decode
514
515
516 def _debug_magi_ranges(q_len, past_len, mtp_window=False):
517 kv_len = past_len + q_len
518 ar_decode = not mtp_window
519 if ar_decode:
520 return {
521 "q_ranges": [[0, q_len]],
522 "k_ranges": [[0, kv_len]],
523 "attn_type_map": ["CAUSAL"],
524 }
525
526 block = N_FUTURE
527 if not (0 < block <= q_len <= kv_len):
528 return {"error": f"invalid magi MTP shape: block={block}, q_len={q_len}, kv_len={kv_len}"}
529
530 prefix_len = kv_len - block
531 blocked_k = prefix_len - 1
532 q_ranges, k_ranges, attn_types = [], [], []
533 if q_len == kv_len:
534 if prefix_len > 0:
535 q_ranges.append([0, prefix_len])
536 k_ranges.append([0, prefix_len])
537 attn_types.append("CAUSAL")
538 if prefix_len > 0 and blocked_k > 0:
539 q_ranges.append([prefix_len, kv_len])
540 k_ranges.append([0, blocked_k])
541 attn_types.append("FULL")
542 q_ranges.append([prefix_len, kv_len])
543 k_ranges.append([prefix_len, kv_len])
544 attn_types.append("FULL")
545 else:
546 recompute = q_len - block
547 q_global_start = kv_len - q_len
548 for i in range(recompute):
549 g = q_global_start + i
550 q_ranges.append([i, i + 1])
551 k_ranges.append([0, g + 1])
552 attn_types.append("FULL")
553 q_win = [recompute, q_len]
554 if blocked_k > 0:
555 q_ranges.append(q_win)
556 k_ranges.append([0, blocked_k])
557 attn_types.append("FULL")
558 q_ranges.append(q_win)
559 k_ranges.append([prefix_len, kv_len])
560 attn_types.append("FULL")
561
562 return {"q_ranges": q_ranges, "k_ranges": k_ranges, "attn_type_map": attn_types}
563
564
565 def _print_debug_forward(label, model, tok, input_ids, attention_mask, position_ids,
566 past_len, mtp_window=False, extra=None, attention_impl="sdpa"):
567 print(f"\n========== LA Flash DEBUG {label} ==========", flush=True)
568 if extra:
569 for k, v in extra.items():
570 print(f"{k}: {v}", flush=True)
571 tail = int(os.environ.get("LA_FLASH_DEBUG_TAIL", "15"))
572 bsz, q_len = int(input_ids.shape[0]), int(input_ids.shape[1])
573 key_len = int(attention_mask.shape[1])
574 q_tail, k_tail = min(tail, q_len), min(tail, key_len)
575 print(
576 "shapes: "
577 f"input_ids={tuple(input_ids.shape)} "
578 f"position_ids={tuple(position_ids.shape)} "
579 f"attention_mask_key_valid={tuple(attention_mask.shape)} "
580 f"mask_2d_q_by_k=({bsz}, {q_len}, {key_len}) "
581 f"mask_2d_tail=({bsz}, {q_tail}, {k_tail}) "
582 f"past_len={past_len} q_len={q_len} "
583 f"mtp_window={mtp_window} ar_decode={not mtp_window}",
584 flush=True,
585 )
586 print(f"dtypes/devices: input_ids={input_ids.dtype}@{input_ids.device} position_ids={position_ids.dtype}@{position_ids.device} attention_mask={attention_mask.dtype}@{attention_mask.device}", flush=True)
587 print(f"attention_impl={attention_impl}", flush=True)
588 input_rows = _tolist(input_ids)
589 pos_rows = _tolist(position_ids)
590 print(f"tail_window_last={tail}", flush=True)
591 print(f"input_ids_tail.shape=({bsz}, {q_tail})", flush=True)
592 print(f"position_ids_tail.shape=({bsz}, {q_tail})", flush=True)
593 actual_sdpa = None
594 if attention_impl in {"sdpa", "eager", "la_flash"}:
595 try:
596 actual_sdpa = _actual_sdpa_allowed_masks(model, input_ids, attention_mask, past_len)
597 print(
598 f"actual_sdpa_4d_mask_shape={actual_sdpa[1]} "
599 f"remote_ar_decode={actual_sdpa[2]}",
600 flush=True,
601 )
602 except Exception as e:
603 print(f"actual_sdpa_4d_mask_debug_failed={type(e).__name__}: {e}", flush=True)
604
605 for b in range(input_ids.shape[0]):
606 ids_tail = input_rows[b][-tail:]
607 pos_tail = pos_rows[b][-tail:]
608 allowed = _effective_allowed_mask(attention_mask[b], input_ids.shape[1], past_len, mtp_window)
609 q_tail = min(tail, len(allowed))
610 k_tail = min(tail, len(allowed[0]) if allowed else 0)
611 allowed_tail = _tail_matrix(allowed, rows=q_tail, cols=k_tail)
612 print(f"batch_row={b} ar_decode={not mtp_window}", flush=True)
613 print(f"input_ids_tail[-{tail}:]: {ids_tail}", flush=True)
614 print(f"decoded_tail[-{tail}:]: {_safe_decode_row(tok, ids_tail)}", flush=True)
615 print(f"position_ids_tail[-{tail}:]: {pos_tail}", flush=True)
616 print(f"expected_mask_2d_tail[-{q_tail}:,-{k_tail}:].shape=({q_tail}, {k_tail})", flush=True)
617 print(_format_01_matrix(allowed_tail), flush=True)
618 if actual_sdpa is not None:
619 actual = actual_sdpa[0][b]
620 actual_tail = _tail_matrix(actual, rows=q_tail, cols=k_tail)
621 mismatch = sum(
622 int(allowed[qi][ki] != actual[qi][ki])
623 for qi in range(len(allowed))
624 for ki in range(len(allowed[qi]))
625 )
626 print(
627 f"actual_sdpa_mask_2d_tail[-{q_tail}:,-{k_tail}:].shape=({q_tail}, {k_tail})",
628 flush=True,
629 )
630 print(_format_01_matrix(actual_tail), flush=True)
631 print(f"expected_vs_actual_sdpa_mismatch_count={mismatch}", flush=True)
632
633 if _env_flag("LA_FLASH_DEBUG_FULL_MASK", False):
634 masks = [
635 _effective_allowed_mask(attention_mask[b], input_ids.shape[1], past_len, mtp_window)
636 for b in range(input_ids.shape[0])
637 ]
638 print("effective_allowed_mask_q_by_k_FULL:", masks, flush=True)
639 if attention_impl == "magi":
640 if bsz == 1:
641 print(
642 "magi_ranges:",
643 _debug_magi_ranges(input_ids.shape[1], past_len, mtp_window),
644 flush=True,
645 )
646 else:
647 print(
648 "magi_ranges: built once per forward from the batched scheduler mask",
649 flush=True,
650 )
651 print(
652 "magi_ranges_single_row_template:",
653 _debug_magi_ranges(input_ids.shape[1], past_len, mtp_window),
654 flush=True,
655 )
656
657
658 def _common_prefix_len(prompt_ids, rows):
659 if not rows:
660 return 0
661 first = prompt_ids[rows[0]]
662 max_len = min(int(prompt_ids[r].numel()) for r in rows)
663 prefix_len = 0
664 for idx in range(max_len):
665 val = int(first[idx].item())
666 if all(int(prompt_ids[r][idx].item()) == val for r in rows[1:]):
667 prefix_len += 1
668 else:
669 break
670 return prefix_len
671
672
673 def _prefill_shared_prefix_kv_rows(model, prompt_ids, vit_list, img_tok, pad, dev, stats=None, debug=False):
674 """Cache one common prompt prefix per image-feature group.
675
676 Multi-category split repeats the same image feature tensor for each
677 category prompt. Token ids are identical through the image tokens and the
678 fixed prompt prefix, so we prefill that shared prefix once and let each
679 category row forward only its text suffix.
680 """
681 bsz = len(prompt_ids)
682 kv_rows = [None] * bsz
683 cached_lens = [0] * bsz
684 groups = {}
685 for row, vit in enumerate(vit_list):
686 groups.setdefault(id(vit), []).append(row)
687
688 items = []
689 min_prefix_len = max(1, _env_int("LA_FLASH_SHARED_PREFILL_MIN_PREFIX", 64))
690 for rows in groups.values():
691 if len(rows) < 2:
692 continue
693 prefix_len = _common_prefix_len(prompt_ids, rows)
694 if prefix_len < min_prefix_len:
695 continue
696 prefix_ids = prompt_ids[rows[0]][:prefix_len]
697 image_token_count = int((prefix_ids == img_tok).sum().item())
698 if image_token_count != int(vit_list[rows[0]].shape[0]):
699 if debug:
700 print(
701 "LA Flash shared prefill skip group: "
702 f"rows={rows} prefix_len={prefix_len} "
703 f"image_tokens={image_token_count} visual_rows={int(vit_list[rows[0]].shape[0])}",
704 flush=True,
705 )
706 continue
707 items.append((rows, prefix_ids, vit_list[rows[0]]))
708
709 if not items:
710 return kv_rows, cached_lens
711
712 lengths = [int(ids.numel()) for _rows, ids, _vit in items]
713 pmax = max(lengths)
714 input_ids = torch.full((len(items), pmax), pad, dtype=torch.long, device=dev)
715 amask = torch.zeros((len(items), pmax), dtype=torch.long, device=dev)
716 pos = torch.ones((len(items), pmax), dtype=torch.long, device=dev)
717 for item_idx, (_rows, ids, _vit) in enumerate(items):
718 length = lengths[item_idx]
719 left = pmax - length
720 input_ids[item_idx, left:] = ids.to(dev)
721 amask[item_idx, left:] = 1
722 pos[item_idx, left:] = torch.arange(length, dtype=torch.long, device=dev)
723
724 visual_features = torch.cat([vit for _rows, _ids, vit in items], dim=0)
725 assert int((input_ids == img_tok).sum().item()) == visual_features.shape[0], \
726 "shared-prefix image-token count != supplied visual_features rows"
727
728 if debug:
729 group_sizes = [len(rows) for rows, _ids, _vit in items]
730 print(
731 "LA Flash hybrid shared prompt prefill "
732 f"groups={len(items)} group_sizes={group_sizes} prefix_lens={lengths}",
733 flush=True,
734 )
735
736 forward_mask, fallback_rows = _forward_attention_mask(
737 model, input_ids, amask, 0, mtp_window=False)
738 if debug and fallback_rows:
739 print(
740 "LA Flash hybrid shared prefill safe SDPA fallback "
741 f"query_rows={fallback_rows}",
742 flush=True,
743 )
744 forward_kwargs = dict(
745 input_ids=input_ids,
746 visual_features=visual_features,
747 image_token_index=img_tok,
748 attention_mask=forward_mask,
749 position_ids=pos,
750 past_key_values=None,
751 use_cache=True,
752 )
753 if isinstance(forward_mask, dict):
754 out = language_model_forward(model, **forward_kwargs, return_logits=False)
755 else:
756 out = model.language_model.model(**forward_kwargs)
757
758 real_tokens = sum(lengths)
759 shared_rows = sum(len(rows) for rows, _ids, _vit in items)
760 saved_tokens = sum((len(rows) - 1) * length for (rows, _ids, _vit), length in zip(items, lengths))
761 _record_prefill_stats(
762 stats,
763 rows=len(items),
764 q_len=pmax,
765 real_tokens=real_tokens,
766 shared_groups=len(items),
767 shared_rows=shared_rows,
768 saved_tokens=saved_tokens,
769 )
770
771 for item_idx, (rows, _ids, _vit) in enumerate(items):
772 prefix_len = lengths[item_idx]
773 prefix_kv = _unpack_stock_after_forward(out.past_key_values, item_idx, 0, prefix_len, 0, pmax)
774 for row in rows:
775 kv_rows[row] = prefix_kv
776 cached_lens[row] = prefix_len
777
778 return kv_rows, cached_lens
779
780
781 @torch.no_grad()
782 def _prefill_prompt_kv_rows(model, prompt_ids, vit_list, img_tok, pad, dev, mode, debug=False, stats=None):
783 """Return per-row prompt KV caches and cached lengths.
784
785 ``mode='none'`` preserves the legacy stock-like first MTP forward where the
786 whole prompt and the 6-token MTP window are forwarded together. The split
787 prefill modes keep prompt KV clean before the scheduler batches only short
788 suffix/window forwards, which avoids ragged prompt+window masking in the
789 first decode step.
790 """
791 bsz = len(prompt_ids)
792 lengths = [int(p.numel()) for p in prompt_ids]
793 if mode == "none":
794 return [None] * bsz, [0] * bsz
795
796 base = model.language_model.model
797 if debug:
798 print(f"LA Flash hybrid prompt prefill mode={mode} rows={bsz} lengths={lengths}", flush=True)
799
800 if mode == "shared":
801 return _prefill_shared_prefix_kv_rows(
802 model, prompt_ids, vit_list, img_tok, pad, dev, stats=stats, debug=debug)
803
804 if mode == "per_row":
805 kv_rows = []
806 for b, ids in enumerate(prompt_ids):
807 ids = ids.to(dev).unsqueeze(0)
808 pos = torch.arange(ids.shape[1], dtype=torch.long, device=dev).unsqueeze(0)
809 out = base(
810 input_ids=ids,
811 visual_features=vit_list[b],
812 image_token_index=img_tok,
813 attention_mask=None,
814 position_ids=pos,
815 past_key_values=None,
816 use_cache=True,
817 )
818 kv_rows.append(out.past_key_values)
819 _record_prefill_stats(stats, rows=1, q_len=ids.shape[1], real_tokens=ids.shape[1])
820 return kv_rows, lengths
821
822 pmax = max(lengths)
823 input_ids = torch.full((bsz, pmax), pad, dtype=torch.long, device=dev)
824 amask = torch.zeros((bsz, pmax), dtype=torch.long, device=dev)
825 pos = torch.ones((bsz, pmax), dtype=torch.long, device=dev)
826 for b, ids in enumerate(prompt_ids):
827 left = pmax - lengths[b]
828 input_ids[b, left:] = ids.to(dev)
829 amask[b, left:] = 1
830 pos[b, left:] = torch.arange(lengths[b], dtype=torch.long, device=dev)
831
832 visual_features = torch.cat(vit_list, dim=0)
833 assert int((input_ids == img_tok).sum().item()) == visual_features.shape[0], \
834 "image-token count != supplied visual_features rows"
835 forward_mask, fallback_rows = _forward_attention_mask(
836 model, input_ids, amask, 0, mtp_window=False)
837 if debug and fallback_rows:
838 print(
839 "LA Flash hybrid batch prefill safe SDPA fallback "
840 f"query_rows={fallback_rows}",
841 flush=True,
842 )
843 forward_kwargs = dict(
844 input_ids=input_ids,
845 visual_features=visual_features,
846 image_token_index=img_tok,
847 attention_mask=forward_mask,
848 position_ids=pos,
849 past_key_values=None,
850 use_cache=True,
851 )
852 if isinstance(forward_mask, dict):
853 out = language_model_forward(model, **forward_kwargs, return_logits=False)
854 else:
855 out = base(**forward_kwargs)
856 _record_prefill_stats(stats, rows=bsz, q_len=pmax, real_tokens=sum(lengths))
857 kv_rows = [
858 _unpack_stock_after_forward(out.past_key_values, b, 0, lengths[b], 0, pmax)
859 for b in range(bsz)
860 ]
861 return kv_rows, lengths
862
863
864 @torch.no_grad()
865 def generate_batch_hybrid(pairs, temperature=README_TEMPERATURE, top_p=README_TOP_P, top_k=None,
866 repetition_penalty=README_REPETITION_PENALTY,
867 max_new_tokens=README_MAX_NEW_TOKENS, temps=None,
868 debug=None, scheduler=None, group_size=None,
869 vision_features=None, _stats=None):
870 """Batched stock-style LocateAnything-3B hybrid generation.
871
872 This mirrors ``model.generate(..., generation_mode='hybrid')``: each row
873 owns a full ``generated`` token stream plus a KV cache truncated to real
874 generated tokens before sampling. MTP forwards
875 ``generated[cached_len:] + duplicate-last + mask*5``; AR forwards
876 ``generated[cached_len:]``.
877 """
878 tok, _, model = load()
879 san, hpat = _helpers()
880 tids = model.token_ids
881 img_tok = model.config.image_token_index
882 mask_tok = tids["default_mask_token_id"]
883 im_end = tids["im_end_token_id"]
884 pad = tok.pad_token_id if tok.pad_token_id is not None else im_end
885 dev = DEV
886
887 if not pairs:
888 return []
889 if temps is not None and len(temps) != len(pairs):
890 raise ValueError("temps must have the same length as pairs")
891 if vision_features is not None and len(vision_features) != len(pairs):
892 raise ValueError("vision_features must have the same length as pairs")
893 debug = _debug_enabled(debug)
894 scheduler = _hybrid_scheduler(scheduler)
895 group_size = _hybrid_group_size(group_size)
896 requested_attn = getattr(model, "_la_flash_requested_attn", ATTN_MODE)
897 use_magi = requested_attn == "magi"
898 prefill_mode = _hybrid_prefill_mode()
899 hold_max_steps = max(0, _env_int("LA_FLASH_HYBRID_HOLD_MAX_STEPS", 5))
900 adaptive_hold_mtp_max = max(0, _env_int("LA_FLASH_HYBRID_ADAPTIVE_HOLD_MTP_MAX", 3))
901 top_level_stats = _stats is None
902 if top_level_stats:
903 _stats = _new_hybrid_stats(
904 len(pairs), scheduler, group_size, hold_max_steps, adaptive_hold_mtp_max)
905 if os.environ.get("LA_FLASH_PLAN_STATS", "0") == "1":
906 model._la_flash_sparse_plan_stats = None
907 if group_size and len(pairs) > group_size:
908 outs = []
909 if debug:
910 print(
911 f"LA Flash hybrid grouped scheduling: total_rows={len(pairs)} "
912 f"group_size={group_size} scheduler={scheduler} hold_max_steps={hold_max_steps} "
913 f"adaptive_hold_mtp_max={adaptive_hold_mtp_max}",
914 flush=True,
915 )
916 for start in range(0, len(pairs), group_size):
917 end = min(start + group_size, len(pairs))
918 chunk_temps = temps[start:end] if temps is not None else None
919 chunk_vision_features = (
920 vision_features[start:end] if vision_features is not None else None
921 )
922 if debug:
923 print(f"LA Flash hybrid group rows=[{start}:{end}]", flush=True)
924 outs.extend(generate_batch_hybrid(
925 pairs[start:end],
926 temperature=temperature,
927 top_p=top_p,
928 top_k=top_k,
929 repetition_penalty=repetition_penalty,
930 max_new_tokens=max_new_tokens,
931 temps=chunk_temps,
932 debug=debug,
933 scheduler=scheduler,
934 group_size=0,
935 vision_features=chunk_vision_features,
936 _stats=_stats,
937 ))
938 if top_level_stats:
939 _set_last_hybrid_stats(_stats)
940 return outs
941
942 use_cached_tokenize = (
943 vision_features is not None
944 and os.environ.get("LA_FLASH_CACHE_TOKENIZE", "1") != "0"
945 )
946 if use_cached_tokenize:
947 try:
948 prompt_ids = [
949 _tokenize_cached_image(q, int(v.shape[0]), im=im)
950 for (im, q), v in zip(pairs, vision_features)
951 ]
952 except Exception as exc:
953 if os.environ.get("LA_FLASH_CACHE_TOKENIZE_STRICT", "0") == "1":
954 raise
955 if debug:
956 print(f"LA Flash cached tokenize fallback: {exc}", flush=True)
957 prompt_ids = [_tokenize(im, q) for im, q in pairs]
958 else:
959 prompt_ids = [_tokenize(im, q) for im, q in pairs]
960 vit_list = (
961 list(vision_features)
962 if vision_features is not None
963 else _encode_images([im for im, _ in pairs])
964 )
965 lengths = [int(p.numel()) for p in prompt_ids]
966 bsz = len(pairs)
967 _record_group_stats(_stats, bsz)
968
969 _set_llm_mode(model, requested_attn)
970
971 modes = ["mtp"] * bsz
972 finished = [False] * bsz
973 gen_ids = [[] for _ in range(bsz)]
974 full_ids = [list(ids.detach().cpu().tolist()) for ids in prompt_ids]
975 kv_rows, cached_lens = _prefill_prompt_kv_rows(
976 model, prompt_ids, vit_list, img_tok, pad, dev, prefill_mode, debug=debug, stats=_stats)
977 total_limits = [lengths[b] + max_new_tokens for b in range(bsz)]
978
979 row_temps = [float(temperature or 0.0)] * bsz if temps is None else [float(t or 0.0) for t in temps]
980
981 def run_ar(ar_rows, step_idx):
982 row_groups = _split_rows_by_kv_budget(ar_rows, kv_rows)
983 _record_kv_bucket_stats(_stats, row_groups, kv_rows)
984 for row_group in row_groups:
985 _step_stock_ar_rows(
986 model, san, tids, prompt_ids, kv_rows, row_group,
987 cached_lens, full_ids, gen_ids, modes, finished, total_limits,
988 pad, img_tok, row_temps, temperature, top_p, top_k,
989 repetition_penalty, dev, tok, debug, step_idx, use_magi, _stats,
990 )
991
992 def run_mtp(mtp_rows, step_idx):
993 if any(cached_lens[r] == 0 for r in mtp_rows) and any(cached_lens[r] > 0 for r in mtp_rows):
994 first_rows = [r for r in mtp_rows if cached_lens[r] == 0]
995 cached_rows = [r for r in mtp_rows if cached_lens[r] > 0]
996 if first_rows:
997 run_mtp(first_rows, step_idx)
998 if cached_rows:
999 run_mtp(cached_rows, step_idx)
1000 return
1001 row_groups = _split_rows_by_kv_budget(mtp_rows, kv_rows)
1002 _record_kv_bucket_stats(_stats, row_groups, kv_rows)
1003 if len(row_groups) > 1:
1004 for row_group in row_groups:
1005 run_mtp(row_group, step_idx)
1006 return
1007 _step_stock_mtp_rows(
1008 model, san, hpat, tids, prompt_ids, kv_rows, mtp_rows,
1009 cached_lens, full_ids, gen_ids, modes, finished, total_limits,
1010 vit_list, pad, mask_tok, img_tok, row_temps, top_p, top_k,
1011 repetition_penalty, dev, tok, debug, step_idx, use_magi, _stats,
1012 )
1013
1014 def live_rows(mode):
1015 return [b for b in range(bsz) if not finished[b] and modes[b] == mode]
1016
1017 step = 0
1018 hold_steps = 0
1019 while not all(finished) and step <= max_new_tokens:
1020 step += 1
1021 if _stats is not None:
1022 _stats["decode_loops"] += 1
1023 if scheduler == "hold_ar" and hold_max_steps > 0:
1024 ar_rows = live_rows("ar")
1025 mtp_rows = live_rows("mtp")
1026 if ar_rows and mtp_rows and _stats is not None:
1027 _stats["mixed_mode_cycles"] += 1
1028 if ar_rows and (hold_steps < hold_max_steps or not mtp_rows):
1029 if mtp_rows and _stats is not None:
1030 _stats["hold_ar_steps"] += 1
1031 _stats["hold_ar_held_mtp_rows"] += len(mtp_rows)
1032 run_ar(ar_rows, step)
1033 hold_steps += 1
1034 continue
1035 if mtp_rows:
1036 if ar_rows and _stats is not None:
1037 _stats["hold_ar_limit_mtp_forwards"] += 1
1038 run_mtp(mtp_rows, step)
1039 hold_steps = 0
1040 continue
1041
1042 if scheduler in {"ar_first", "pipeline", "adaptive"}:
1043 ar_rows_at_loop_start = live_rows("ar")
1044 mtp_rows_at_loop_start = live_rows("mtp")
1045 mixed = bool(ar_rows_at_loop_start and mtp_rows_at_loop_start)
1046 if mixed and _stats is not None:
1047 _stats["mixed_mode_cycles"] += 1
1048
1049 if scheduler == "adaptive" and mixed and hold_max_steps > 0:
1050 should_hold = len(mtp_rows_at_loop_start) <= adaptive_hold_mtp_max
1051 if should_hold and hold_steps < hold_max_steps:
1052 if _stats is not None:
1053 _stats["adaptive_hold_cycles"] += 1
1054 _stats["hold_ar_steps"] += 1
1055 _stats["hold_ar_held_mtp_rows"] += len(mtp_rows_at_loop_start)
1056 run_ar(ar_rows_at_loop_start, step)
1057 hold_steps += 1
1058 continue
1059
1060 if ar_rows_at_loop_start:
1061 if mixed and _stats is not None:
1062 if scheduler == "adaptive":
1063 _stats["adaptive_ar_first_cycles"] += 1
1064 else:
1065 _stats["ar_first_cycles"] += 1
1066 run_ar(ar_rows_at_loop_start, step)
1067
1068 mtp_rows = live_rows("mtp")
1069 if mtp_rows:
1070 run_mtp(mtp_rows, step)
1071 hold_steps = 0
1072
1073 if scheduler == "pipeline" and mtp_rows:
1074 old_ar = set(ar_rows_at_loop_start)
1075 new_ar_rows = [b for b in live_rows("ar") if b not in old_ar]
1076 if new_ar_rows:
1077 if _stats is not None:
1078 _stats["pipeline_ar_after_mtp_cycles"] += 1
1079 run_ar(new_ar_rows, step)
1080 continue
1081
1082 mtp_rows = live_rows("mtp")
1083 ar_rows_at_loop_start = live_rows("ar")
1084 if mtp_rows and ar_rows_at_loop_start and _stats is not None:
1085 _stats["mixed_mode_cycles"] += 1
1086 if mtp_rows:
1087 run_mtp(mtp_rows, step)
1088
1089 ar_rows = [b for b in range(bsz) if not finished[b] and modes[b] == "ar"]
1090 if mtp_rows and ar_rows and _stats is not None:
1091 _stats["eager_mtp_then_ar_cycles"] += 1
1092 if ar_rows:
1093 run_ar(ar_rows, step)
1094
1095 outs = [
1096 tok.decode(torch.tensor(gen_ids[b], dtype=torch.long, device=dev),
1097 skip_special_tokens=False) if gen_ids[b] else ""
1098 for b in range(bsz)
1099 ]
1100 if top_level_stats:
1101 if os.environ.get("LA_FLASH_PLAN_STATS", "0") == "1":
1102 _stats["sparse_plan_stats"] = copy.deepcopy(
1103 getattr(model, "_la_flash_sparse_plan_stats", None) or {}
1104 )
1105 _set_last_hybrid_stats(_stats)
1106 return outs
1107
1108
1109 @torch.no_grad()
1110 def _step_stock_mtp_rows(model, san, hpat, tids, prompt_ids, kv_rows, rows,
1111 cached_lens, full_ids, gen_ids, modes, finished, total_limits,
1112 vit_list, pad, mask_tok, img_tok, row_temps, top_p, top_k,
1113 repetition_penalty, dev, tok, debug, step_idx, use_magi, stats=None):
1114 kv, kvalid, old_lens, kmax = _pack_stock_kv_rows(kv_rows, rows, dev)
1115 uncached_lens = [len(full_ids[r]) - cached_lens[r] for r in rows]
1116 umax = max(uncached_lens)
1117 seq_len = umax + N_FUTURE
1118 _record_forward_stats(stats, "mtp", rows, seq_len, uncached_lens)
1119
1120 suf_ids = torch.full((len(rows), seq_len), pad, dtype=torch.long, device=dev)
1121 suf_pos = torch.ones((len(rows), seq_len), dtype=torch.long, device=dev)
1122 q_valid = torch.zeros((len(rows), seq_len), dtype=torch.long, device=dev)
1123
1124 for i, r in enumerate(rows):
1125 uncached = full_ids[r][cached_lens[r] :]
1126 left = umax - len(uncached)
1127 if uncached:
1128 suf_ids[i, left : left + len(uncached)] = torch.tensor(uncached, dtype=torch.long, device=dev)
1129 suf_pos[i, left : left + len(uncached)] = torch.arange(
1130 cached_lens[r], len(full_ids[r]), dtype=torch.long, device=dev)
1131 q_valid[i, left : left + len(uncached)] = 1
1132
1133 rep = full_ids[r][-1]
1134 cur_len = len(full_ids[r])
1135 suf_ids[i, umax] = rep
1136 suf_pos[i, umax] = cur_len - 1
1137 q_valid[i, umax] = 1
1138 for j in range(1, N_FUTURE):
1139 suf_ids[i, umax + j] = mask_tok
1140 suf_pos[i, umax + j] = cur_len + (j - 1)
1141 q_valid[i, umax + j] = 1
1142
1143 full_mask = torch.cat([kvalid, q_valid], dim=1)
1144
1145 if debug:
1146 forward_mask, fallback_rows = _forward_attention_mask(
1147 model, suf_ids, full_mask, kmax, mtp_window=True, range_plan=True)
1148 _print_debug_forward(
1149 f"MTP step={step_idx}",
1150 model,
1151 tok,
1152 suf_ids,
1153 full_mask,
1154 suf_pos,
1155 past_len=kmax,
1156 mtp_window=True,
1157 extra={
1158 "global_rows": rows,
1159 "old_kv_lens": old_lens,
1160 "cached_lens": [cached_lens[r] for r in rows],
1161 "full_lens": [len(full_ids[r]) for r in rows],
1162 "uncached_lens": uncached_lens,
1163 "forward_attention_mask": _mask_desc(forward_mask),
1164 "safe_sdpa_fallback_query_rows": fallback_rows,
1165 },
1166 attention_impl="magi" if use_magi else ATTN_MODE,
1167 )
1168 else:
1169 forward_mask, _ = _forward_attention_mask(
1170 model, suf_ids, full_mask, kmax, mtp_window=True, range_plan=True)
1171 first_rows = [r for r in rows if cached_lens[r] == 0]
1172 visual_features = None
1173 if first_rows:
1174 if first_rows != rows:
1175 raise RuntimeError("mixed first/non-first MTP rows are not supported")
1176 visual_features = torch.cat([vit_list[r] for r in rows], dim=0)
1177 assert int((suf_ids == img_tok).sum().item()) == visual_features.shape[0], \
1178 "image-token count != supplied visual_features rows"
1179 out = language_model_forward(
1180 model, input_ids=suf_ids, attention_mask=forward_mask,
1181 position_ids=suf_pos, past_key_values=kv, use_cache=True,
1182 visual_features=visual_features,
1183 image_token_index=img_tok if visual_features is not None else None,
1184 logits_slice=slice(-N_FUTURE, None))
1185
1186 for i, r in enumerate(rows):
1187 kv_rows[r] = _unpack_stock_after_forward(
1188 out.past_key_values, i, old_lens[i], uncached_lens[i], kmax, umax)
1189 cached_lens[r] = len(full_ids[r])
1190
1191 wlogits = out.logits[:, -N_FUTURE:, :]
1192 local_prompts = [prompt_ids[r] for r in rows]
1193 local_gen = [gen_ids[r] for r in rows]
1194 gen_pad = _pad_generated(local_prompts, local_gen, img_tok, dev)
1195 per_row_temp = torch.tensor([row_temps[r] for r in rows], dtype=torch.float32, device=dev)
1196
1197 if BATCH_SAN:
1198 x0_all, boxes_all = sample_tokens_batched(
1199 wlogits, gen_pad, tids, per_row_temp,
1200 repetition_penalty=repetition_penalty, top_p=top_p, top_k=top_k,
1201 keep_k_avg=4, generation_mode="hybrid")
1202
1203 for i, r in enumerate(rows):
1204 if finished[r]:
1205 continue
1206 if BATCH_SAN:
1207 x0b, boxb = x0_all[i], boxes_all[i]
1208 else:
1209 gk = _mk_generate_kwargs(row_temps[r], top_p, top_k, repetition_penalty)
1210 _, _, x0, box_avg = san(wlogits[i : i + 1], gen_pad[i : i + 1], tids, keep_k=5, **gk)
1211 x0b, boxb = x0[0], box_avg[0]
1212 nt = x0b if bool((boxb == 0).all()) else boxb
1213 op = hpat(nt, tids, "hybrid")
1214
1215 toks = [int(t) for t in op["tokens"]]
1216 for t in toks:
1217 gen_ids[r].append(t)
1218 full_ids[r].append(t)
1219
1220 if op["type"] == "im_end":
1221 finished[r] = True
1222 elif op["type"] == "error_box":
1223 modes[r] = "ar"
1224 if len(full_ids[r]) >= total_limits[r]:
1225 finished[r] = True
1226
1227
1228 @torch.no_grad()
1229 def _step_stock_ar_rows(model, san, tids, prompt_ids, kv_rows, rows,
1230 cached_lens, full_ids, gen_ids, modes, finished, total_limits,
1231 pad, img_tok, row_temps, temperature, top_p, top_k,
1232 repetition_penalty, dev, tok, debug, step_idx, use_magi, stats=None):
1233 kv, kvalid, old_lens, kmax = _pack_stock_kv_rows(kv_rows, rows, dev)
1234 uncached_lens = [len(full_ids[r]) - cached_lens[r] for r in rows]
1235 if any(n <= 0 for n in uncached_lens):
1236 raise RuntimeError(f"AR rows have no uncached tokens: {rows}")
1237 umax = max(uncached_lens)
1238 _record_forward_stats(stats, "ar", rows, umax, uncached_lens)
1239
1240 suf_ids = torch.full((len(rows), umax), pad, dtype=torch.long, device=dev)
1241 suf_pos = torch.ones((len(rows), umax), dtype=torch.long, device=dev)
1242 q_valid = torch.zeros((len(rows), umax), dtype=torch.long, device=dev)
1243
1244 for i, r in enumerate(rows):
1245 uncached = full_ids[r][cached_lens[r] :]
1246 left = umax - len(uncached)
1247 suf_ids[i, left:] = torch.tensor(uncached, dtype=torch.long, device=dev)
1248 suf_pos[i, left:] = torch.arange(cached_lens[r], len(full_ids[r]), dtype=torch.long, device=dev)
1249 q_valid[i, left:] = 1
1250
1251 full_mask = torch.cat([kvalid, q_valid], dim=1)
1252 if debug:
1253 forward_mask, fallback_rows = _forward_attention_mask(
1254 model, suf_ids, full_mask, kmax, mtp_window=False, range_plan=True)
1255 _print_debug_forward(
1256 f"AR step={step_idx}",
1257 model,
1258 tok,
1259 suf_ids,
1260 full_mask,
1261 suf_pos,
1262 past_len=kmax,
1263 mtp_window=False,
1264 extra={
1265 "global_rows": rows,
1266 "old_kv_lens": old_lens,
1267 "cached_lens": [cached_lens[r] for r in rows],
1268 "full_lens": [len(full_ids[r]) for r in rows],
1269 "uncached_lens": uncached_lens,
1270 "forward_attention_mask": _mask_desc(forward_mask),
1271 "safe_sdpa_fallback_query_rows": fallback_rows,
1272 },
1273 attention_impl="magi" if use_magi else ATTN_MODE,
1274 )
1275 else:
1276 forward_mask, _ = _forward_attention_mask(
1277 model, suf_ids, full_mask, kmax, mtp_window=False, range_plan=True)
1278
1279 out = language_model_forward(
1280 model, input_ids=suf_ids, attention_mask=forward_mask,
1281 position_ids=suf_pos, past_key_values=kv, use_cache=True,
1282 logits_slice=slice(-1, None))
1283
1284 for i, r in enumerate(rows):
1285 kv_rows[r] = _unpack_stock_after_forward(
1286 out.past_key_values, i, old_lens[i], uncached_lens[i], kmax, umax)
1287 cached_lens[r] = len(full_ids[r])
1288
1289 if AR_BATCH_SAN:
1290 local_prompts = [prompt_ids[r] for r in rows]
1291 local_gen = [gen_ids[r] for r in rows]
1292 gen_pad = _pad_generated(local_prompts, local_gen, img_tok, dev)
1293 per_row_temp = torch.tensor([row_temps[r] for r in rows], dtype=torch.float32, device=dev)
1294 x0_all = sample_next_tokens_batched(
1295 out.logits[:, -1:, :],
1296 gen_pad,
1297 per_row_temp,
1298 repetition_penalty=repetition_penalty,
1299 top_p=top_p,
1300 top_k=top_k,
1301 )
1302
1303 for i, r in enumerate(rows):
1304 if AR_BATCH_SAN:
1305 token_val = int(x0_all[i, 0].item())
1306 else:
1307 logits = out.logits[i : i + 1, -1:, :]
1308 gen_pad = _pad_generated([prompt_ids[r]], [gen_ids[r]], img_tok, dev)
1309 gk = _mk_generate_kwargs(temperature, top_p, top_k, repetition_penalty, row_temp=row_temps[r])
1310 _, _, x0, _ = san(logits, gen_pad, tids, **gk)
1311 token_val = int(x0[0, 0].item())
1312 out_type = _classify_ar_token(token_val, tids)
1313
1314 gen_ids[r].append(token_val)
1315 full_ids[r].append(token_val)
1316
1317 if out_type == "im_end":
1318 finished[r] = True
1319 elif out_type == "box_end_ar":
1320 modes[r] = "mtp"
1321
1322 if len(full_ids[r]) >= total_limits[r]:
1323 finished[r] = True
1324
1325
1326 def generate_batch_grouped_hybrid(groups, temperature=README_TEMPERATURE, top_p=README_TOP_P,
1327 top_k=None, repetition_penalty=README_REPETITION_PENALTY,
1328 max_new_tokens=README_MAX_NEW_TOKENS, temps=None,
1329 debug=None, scheduler=None, group_size=None,
1330 vision_features=None):
1331 """Hybrid grouped API shape.
1332
1333 This preserves grouped return shape, but intentionally uses the generic
1334 hybrid decoder rather than the fast engine's shared-prefix optimization.
1335 """
1336 flat = []
1337 flat_vision_features = [] if vision_features is not None else None
1338 counts = []
1339 for group_idx, (im, queries) in enumerate(groups):
1340 counts.append(len(queries))
1341 flat.extend((im, q) for q in queries)
1342 if flat_vision_features is not None:
1343 flat_vision_features.extend([vision_features[group_idx]] * len(queries))
1344
1345 outs = generate_batch_hybrid(
1346 flat, temperature=temperature, top_p=top_p, top_k=top_k,
1347 repetition_penalty=repetition_penalty, max_new_tokens=max_new_tokens,
1348 temps=temps, debug=debug, scheduler=scheduler, group_size=group_size,
1349 vision_features=flat_vision_features)
1350 res, offset = [], 0
1351 for n in counts:
1352 res.append(outs[offset : offset + n])
1353 offset += n
1354 return res
1355
1356
1357 __all__ = ["generate_batch_hybrid", "generate_batch_grouped_hybrid", "get_last_hybrid_stats"]
1358