batch_utils/hybrid_runtime.py
79.2 KB · 1843 lines · python Raw
1 """Internal runtime support for the LocateAnything-3B hybrid batch decoder.
2
3 This file keeps only the model-loading, tokenization, image-encoding, stock
4 processor, and sample-token helpers that ``engine_hybrid.py`` needs.
5
6 Important env knobs:
7 LA_FLASH_MODEL HF repo id / local path of the model (default nvidia/LocateAnything-3B)
8 HF_HUB_OFFLINE=1 read the local HF cache only (no network); unset -> download on first use
9 LA_FLASH_ATTN sdpa, eager, magi, or la_flash; la_flash uses FlashAttention sparse ranges
10 LA_FLASH_STRICT_ATTN 1 -> fail if the requested backend is unavailable;
11 default 0 falls back to sdpa
12 LA_FLASH_VISION_ATTN auto, flash_attention_2, sdpa, or eager (default auto)
13 LA_FLASH_HYBRID_PREFILL shared, none, per_row, or batch prompt KV prefill (default shared)
14 MTP_BATCH_VISION 0 -> per-image vision encode (default 1: batched when flash is present)
15 LA_FLASH_VISION_ENCODE_BATCH_SIZE
16 max images per MoonViT encode micro-batch (default 8; <=0 disables limit)
17 MTP_BATCH_SAN 0 -> per-row logits/sample pipeline (default 1: batched over [B,6,V])
18 AR_BATCH_SAN 0 -> per-row AR sample pipeline (default 1: batched over [B,1,V])
19 """
20 import inspect
21 import os, warnings, importlib, torch
22 from types import SimpleNamespace
23 import numpy as np
24 from transformers import AutoModel, AutoTokenizer, AutoProcessor
25
26
27 # By default let transformers fetch the model on first use; set HF_HUB_OFFLINE=1 yourself
28 # to read the local HF cache only (e.g. air-gapped / already-downloaded runs).
29 MODEL = os.environ.get("LA_FLASH_MODEL", "nvidia/LocateAnything-3B")
30
31
32 LLM_ATTN_MODES = ("sdpa", "eager", "magi", "la_flash")
33 VISION_ATTN_MODES = ("auto", "flash_attention_2", "sdpa", "eager")
34
35
36 def _normalize_attn_mode(value):
37 mode = (value or "sdpa").strip().lower().replace("-", "_")
38 aliases = {
39 "": "sdpa",
40 "manual": "eager",
41 "torch": "eager",
42 "torch_eager": "eager",
43 "torch_sdpa": "sdpa",
44 "scaled_dot_product_attention": "sdpa",
45 "flash": "la_flash",
46 "la_flash": "la_flash",
47 "kernel": "la_flash",
48 "cuda": "la_flash",
49 "range": "la_flash",
50 "range_attention": "la_flash",
51 "flex_flash": "magi",
52 "flex_flash_attention": "magi",
53 "flex_flash_attn": "magi",
54 }
55 mode = aliases.get(mode, mode)
56 if mode not in LLM_ATTN_MODES:
57 raise ValueError(
58 f"LA_FLASH_ATTN must be one of {', '.join(LLM_ATTN_MODES)}; got {value!r}"
59 )
60 return mode
61
62
63 def _normalize_vision_attn_mode(value):
64 mode = (value or "auto").strip().lower().replace("-", "_")
65 aliases = {
66 "": "auto",
67 "flash": "flash_attention_2",
68 "flash_attention2": "flash_attention_2",
69 "fa2": "flash_attention_2",
70 "manual": "eager",
71 }
72 mode = aliases.get(mode, mode)
73 if mode not in VISION_ATTN_MODES:
74 raise ValueError(
75 f"LA_FLASH_VISION_ATTN must be one of {', '.join(VISION_ATTN_MODES)}; got {value!r}"
76 )
77 return mode
78
79
80 ATTN_MODE = _normalize_attn_mode(os.environ.get("LA_FLASH_ATTN", "sdpa"))
81 REMOTE_ATTN_MODE = "sdpa" if ATTN_MODE in {"la_flash", "magi"} else ATTN_MODE
82 VISION_ATTN_MODE = _normalize_vision_attn_mode(os.environ.get("LA_FLASH_VISION_ATTN", "auto"))
83 MAX_DIM = 1024
84 DEV, DT = "cuda", torch.bfloat16
85 N_FUTURE = 6 # = config.block_size (MTP window)
86 _PROMPT = "Locate all the instances that matches the following description: "
87
88
89 def _env_flag(name, default=False):
90 val = os.environ.get(name)
91 if val is None:
92 return default
93 return val.strip().lower() not in {"0", "false", "no", "off"}
94
95
96 def _env_int(name):
97 val = os.environ.get(name)
98 if val is None or val.strip() == "":
99 return None
100 return int(val)
101
102
103 def _strict_attn():
104 return _env_flag("LA_FLASH_STRICT_ATTN", False)
105
106
107 def _fallback_to_sdpa(model, requested, reason):
108 if requested == "sdpa":
109 raise RuntimeError(f"LA_FLASH_ATTN=sdpa failed: {reason}") from reason
110 message = f"LA_FLASH_ATTN={requested} is unavailable; falling back to sdpa. Reason: {reason}"
111 if _strict_attn():
112 raise RuntimeError(message) from reason
113 warnings.warn(message)
114 _set_llm_mode(model, "sdpa")
115 model._la_flash_requested_attn_original = requested
116 model._la_flash_attn_fallback_reason = str(reason)
117 return "sdpa"
118
119
120 # Optional compile for the shared Qwen2 core. This is off by default because the
121 # hybrid scheduler already varies query/cache shapes and first-call compile cost is high.
122 MTP_COMPILE = os.environ.get("MTP_COMPILE", "0") == "1"
123
124 # Batch the MoonViT vision encode across a micro-batch's images: pack N images into ONE
125 # extract_feature. With flash present, MoonViT's varlen cu_seqlens path is block-diagonal per
126 # image and equivalent to per-image encode.
127 # Without flash, sdpa builds a dense [1,S,S] mask -> O(S^2) N^2 -> per-image fallback (auto, see
128 # _vision_is_flash). Default ON; set MTP_BATCH_VISION=0 to force per-image.
129 BATCH_VISION = os.environ.get("MTP_BATCH_VISION", "1") == "1"
130 _vision_encode_batch_size = _env_int("LA_FLASH_VISION_ENCODE_BATCH_SIZE")
131 VISION_ENCODE_BATCH_SIZE = 8 if _vision_encode_batch_size is None else max(0, _vision_encode_batch_size)
132
133 # Batch the per-row box-decode (sample_tokens): run the row-independent logits pipeline
134 # (rep-penalty / per-row temperature / top_p / top_k / softmax / sample) ONCE over the whole
135 # [B,6,V] step instead of B times on [1,6,V]; only the variable-length box assembly stays per-row.
136 # Greedy is BIT-IDENTICAL to the per-row san (argmax, no RNG). Default ON; MTP_BATCH_SAN=0 -> per-row.
137 BATCH_SAN = os.environ.get("MTP_BATCH_SAN", "1") == "1"
138
139 # Batch the AR repair sampler over [B,1,V]. This shares the exact filtering
140 # helpers with MTP batching but skips box/ref decoding, so it only replaces the
141 # repeated stock one-token sample calls. Sampling itself stays row-ordered by
142 # default to preserve the stock RNG consumption pattern for AR repair.
143 AR_BATCH_SAN = os.environ.get("AR_BATCH_SAN", "1") == "1"
144
145 _tok = _proc = _model = None
146
147 def _magi_diag():
148 lines = []
149 try:
150 import magi_attention
151 lines.append(f"magi_attention: OK file={getattr(magi_attention, '__file__', None)}")
152 lines.append(f"magi_attention.__version__={getattr(magi_attention, '__version__', '<missing>')}")
153 except Exception as e:
154 lines.append(f"magi_attention: FAIL {type(e).__name__}: {e}")
155 return "\n".join(lines)
156 try:
157 from magi_attention.functional.flex_flash_attn import flex_flash_attn_func
158 lines.append(f"magi_attention.functional.flex_flash_attn: OK func={flex_flash_attn_func}")
159 except Exception as e:
160 lines.append(f"magi_attention.functional.flex_flash_attn: FAIL {type(e).__name__}: {e}")
161 return "\n".join(lines)
162
163 def _remote_magi_diag(model=None):
164 lines = []
165 try:
166 if model is not None:
167 mod = importlib.import_module(type(model.language_model.model).__module__)
168 else:
169 # Best effort: if the dynamic module is not imported yet this may fail;
170 # the post-load diagnostic below will still work.
171 mod = importlib.import_module("transformers_modules.LocateAnything-3B.modeling_qwen2")
172 lines.append(f"remote_qwen2_module={getattr(mod, '__file__', None)}")
173 lines.append(f"remote_qwen2._MAGI_AVAILABLE={getattr(mod, '_MAGI_AVAILABLE', '<missing>')!r}")
174 lines.append(f"remote_qwen2.flex_flash_attn_func={getattr(mod, 'flex_flash_attn_func', '<missing>')}")
175 except Exception as e:
176 lines.append(f"remote_qwen2: diagnostic failed {type(e).__name__}: {e}")
177 return "\n".join(lines)
178
179 def _attn_class_diag(model):
180 try:
181 llm = model.language_model.model
182 classes = [type(layer.self_attn).__name__ for layer in llm.layers[:4]]
183 return (
184 f"llm._attn_implementation={getattr(llm, '_attn_implementation', None)!r}\n"
185 f"config._attn_implementation={getattr(llm.config, '_attn_implementation', None)!r}\n"
186 f"first_attn_classes={classes}"
187 )
188 except Exception as e:
189 return f"attention class diagnostic failed {type(e).__name__}: {e}"
190
191
192 def _set_vision_attention_mode(model):
193 """Match HF's MoonViT policy: prefer flash_attention_2, then sdpa, then eager."""
194 vm = getattr(model, "vision_model", None)
195 if vm is None:
196 return None
197 mod = importlib.import_module(type(vm).__module__)
198 funcs = getattr(mod, "VL_VISION_ATTENTION_FUNCTIONS", {})
199 has_flash = getattr(mod, "flash_attn_varlen_func", None) is not None
200 requested = VISION_ATTN_MODE
201
202 if requested == "auto":
203 candidates = ("flash_attention_2", "sdpa", "eager")
204 else:
205 candidates = (requested, "flash_attention_2", "sdpa", "eager")
206
207 chosen = None
208 for candidate in candidates:
209 if candidate == "flash_attention_2" and not has_flash:
210 continue
211 if candidate in funcs:
212 chosen = candidate
213 break
214 if chosen is None:
215 raise RuntimeError("MoonViT has no supported attention implementation.")
216
217 if requested == "flash_attention_2" and chosen != "flash_attention_2":
218 warnings.warn("LA_FLASH_VISION_ATTN=flash_attention_2 requested but flash-attn is unavailable; "
219 f"using {chosen}.")
220 elif requested not in {"auto", chosen}:
221 warnings.warn(f"LA_FLASH_VISION_ATTN={requested} is unavailable; using {chosen}.")
222
223 if hasattr(model.config, "vision_config"):
224 model.config.vision_config._attn_implementation = chosen
225 try:
226 vm.config._attn_implementation = chosen
227 except Exception:
228 pass
229 try:
230 for block in vm.encoder.blocks:
231 block.attn_implementation = chosen
232 except Exception as exc:
233 raise RuntimeError("Failed to configure MoonViT attention implementation.") from exc
234 model._la_flash_vision_attn = chosen
235 return chosen
236
237
238 def load():
239 """Lazy model load with HF remote-code semantics plus release backends.
240
241 The text decoder is pinned to one of sdpa/eager/magi/la_flash. MoonViT is
242 configured independently and follows the HF policy: flash_attention_2 when
243 flash-attn is importable, otherwise sdpa, otherwise eager.
244 """
245 global _tok, _proc, _model
246 if _model is None:
247 _tok = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
248 _proc = AutoProcessor.from_pretrained(MODEL, trust_remote_code=True)
249 attn_impl = REMOTE_ATTN_MODE
250 if ATTN_MODE == "magi" and os.environ.get("LA_FLASH_DEBUG", "0") != "0":
251 print("LA Flash magi pre-load diagnostic:", flush=True)
252 print(_magi_diag(), flush=True)
253 _model = AutoModel.from_pretrained(MODEL, torch_dtype=DT, trust_remote_code=True,
254 attn_implementation=attn_impl).to(DEV).eval()
255 _set_vision_attention_mode(_model)
256 actual_attn = getattr(_model.language_model.model, "_attn_implementation", None)
257 if ATTN_MODE == "magi" and os.environ.get("LA_FLASH_DEBUG", "0") != "0":
258 print("LA Flash magi post-load diagnostic:", flush=True)
259 print(_remote_magi_diag(_model), flush=True)
260 print(_attn_class_diag(_model), flush=True)
261 if ATTN_MODE == "magi":
262 try:
263 qwen2_mod = importlib.import_module(type(_model.language_model.model).__module__)
264 if not getattr(qwen2_mod, "_MAGI_AVAILABLE", False):
265 raise RuntimeError(
266 "remote module reports _MAGI_AVAILABLE=False.\n"
267 f"{_remote_magi_diag(_model)}\n{_magi_diag()}"
268 )
269 first_attn = type(_model.language_model.model.layers[0].self_attn).__name__
270 if actual_attn != "sdpa" or first_attn != "_BatchedMagiAttention":
271 _set_llm_mode(_model, "magi")
272 actual_attn = getattr(_model.language_model.model, "_attn_implementation", None)
273 first_attn = type(_model.language_model.model.layers[0].self_attn).__name__
274 if os.environ.get("LA_FLASH_DEBUG", "0") != "0":
275 print("LA Flash magi post-swap diagnostic:", flush=True)
276 print(_attn_class_diag(_model), flush=True)
277 if actual_attn != "sdpa" or first_attn != "_BatchedMagiAttention":
278 raise RuntimeError(
279 "batched magi attention did not activate. "
280 f"actual_attn={actual_attn!r}; first_attn={first_attn!r}; "
281 f"{_remote_magi_diag(_model)}; {_attn_class_diag(_model)}"
282 )
283 _model._la_flash_requested_attn = "magi"
284 except Exception as exc:
285 _fallback_to_sdpa(_model, "magi", exc)
286 else:
287 try:
288 _set_llm_mode(_model, ATTN_MODE) # decode-safe mask plumbing for sdpa/eager/la_flash
289 except Exception as exc:
290 _fallback_to_sdpa(_model, ATTN_MODE, exc)
291 if MTP_COMPILE:
292 _maybe_compile(_model)
293 return _tok, _proc, _model
294
295
296 def _maybe_compile(model):
297 """Compile the shared Qwen2Model core (base.forward). It backs BOTH prefill (called directly)
298 and decode (language_model.forward -> self.model). lm_head + MoonViT left eager. dynamic=True
299 so the varying decode S/kvlen don't trigger a recompile storm. No-op + warning if triton is
300 missing (inductor needs it on GPU). First call pays the compile cost (~42s warm / ~187s cold)."""
301 try:
302 import triton # noqa: F401
303 except Exception:
304 warnings.warn("MTP_COMPILE set but triton is unavailable; running without torch.compile.")
305 return
306 import torch._dynamo as _dyn
307 _dyn.config.cache_size_limit = max(_dyn.config.cache_size_limit, 64)
308 base = model.language_model.model
309 if not getattr(base, "_mtp_compiled", False):
310 base.forward = torch.compile(base.forward, dynamic=True)
311 base._mtp_compiled = True
312
313
314 def build_batched_magi_attention_class(mod):
315 """Build a Qwen2 attention subclass backed by Magi's flex_flash_attn.
316
317 The official LocateAnything ``Qwen2MagiAttention`` asserts ``bsz == 1`` and
318 relies on ``Qwen2Model._attn_implementation == "magi"`` to build a single
319 sample range plan. For release batch inference the hybrid scheduler passes
320 a batched Magi range plan directly to this layer; a 4D-mask conversion path
321 remains as a compatibility fallback.
322 """
323 flex_flash_attn_func = getattr(mod, "flex_flash_attn_func", None)
324 if flex_flash_attn_func is None:
325 try:
326 from magi_attention.functional.flex_flash_attn import flex_flash_attn_func
327 except Exception as exc:
328 raise RuntimeError(
329 "LA_FLASH_ATTN=magi requires "
330 "magi_attention.functional.flex_flash_attn.flex_flash_attn_func."
331 ) from exc
332
333 FULL, CAUSAL = 0, 1
334 causal_plan_cache = {}
335 try:
336 magi_params = set(inspect.signature(flex_flash_attn_func).parameters)
337 except (TypeError, ValueError):
338 magi_params = set()
339 supports_disable_fwd_atomic = "disable_fwd_atomic_reduction" in magi_params
340
341 def _disjoint_q_ranges(q_ranges):
342 seen = set()
343 for start, end in q_ranges:
344 key = (int(start), int(end))
345 if key in seen:
346 return False
347 seen.add(key)
348 return True
349
350 def _plan_disjoint_q_ranges(plan):
351 cached = plan.get("_la_flash_disjoint_q_ranges")
352 if cached is not None:
353 return bool(cached)
354 q_ranges = plan["q_ranges"].detach().to(device="cpu", dtype=torch.int32).tolist()
355 disjoint = _disjoint_q_ranges(q_ranges)
356 try:
357 plan["_la_flash_disjoint_q_ranges"] = disjoint
358 except Exception:
359 pass
360 return disjoint
361
362 def _tensor_plan(q_ranges, k_ranges, types, device):
363 return {
364 "q_ranges": torch.tensor(q_ranges, dtype=torch.int32, device=device).contiguous(),
365 "k_ranges": torch.tensor(k_ranges, dtype=torch.int32, device=device).contiguous(),
366 "attn_type_map": torch.tensor(types, dtype=torch.int32, device=device).contiguous(),
367 "_la_flash_disjoint_q_ranges": _disjoint_q_ranges(q_ranges),
368 }
369
370 def _offset_plan(plan, q_offset, k_offset):
371 return (
372 (plan["q_ranges"] + int(q_offset)).tolist(),
373 (plan["k_ranges"] + int(k_offset)).tolist(),
374 plan["attn_type_map"].tolist(),
375 )
376
377 def _causal_plan(bsz, q_len, kv_seq_len, device):
378 key = (int(bsz), int(q_len), int(kv_seq_len), device.type, device.index)
379 cached = causal_plan_cache.get(key)
380 if cached is not None:
381 return cached
382 q_ranges, k_ranges, types = [], [], []
383 for b in range(int(bsz)):
384 q_base = b * int(q_len)
385 k_base = b * int(kv_seq_len)
386 q_ranges.append([q_base, q_base + int(q_len)])
387 k_ranges.append([k_base, k_base + int(kv_seq_len)])
388 types.append(CAUSAL)
389 plan = _tensor_plan(q_ranges, k_ranges, types, device)
390 plan.update(
391 {
392 "flash_cu_seqlens_q": torch.arange(
393 0,
394 (int(bsz) + 1) * int(q_len),
395 int(q_len),
396 dtype=torch.int32,
397 device=device,
398 ),
399 "flash_cu_seqlens_k": torch.arange(
400 0,
401 (int(bsz) + 1) * int(kv_seq_len),
402 int(kv_seq_len),
403 dtype=torch.int32,
404 device=device,
405 ),
406 "flash_causal": True,
407 }
408 )
409 causal_plan_cache[key] = plan
410 return plan
411
412 def _row_segments(row):
413 idx = np.flatnonzero(row)
414 if idx.size == 0:
415 return ((0, 1),)
416 split = np.flatnonzero(np.diff(idx) > 1) + 1
417 starts = np.concatenate((idx[:1], idx[split]))
418 ends = np.concatenate((idx[split - 1], idx[-1:])) + 1
419 return tuple((int(s), int(e)) for s, e in zip(starts, ends))
420
421 def _visible_from_4d_mask(attention_mask, kv_seq_len):
422 mask = attention_mask[:, :, :, :kv_seq_len]
423 if mask.dtype == torch.bool:
424 return mask[:, 0].detach().to(device="cpu", dtype=torch.bool).contiguous()
425 mask_cpu = mask[:, 0].detach().to(device="cpu").contiguous()
426 if getattr(attention_mask, "_la_flash_visible_mask", False):
427 return (mask_cpu > 0).to(dtype=torch.bool)
428
429 max_value = float(mask_cpu.max().item()) if mask_cpu.numel() else 0.0
430 min_value = float(mask_cpu.min().item()) if mask_cpu.numel() else 0.0
431 if max_value > 0.0 and min_value >= 0.0:
432 return (mask_cpu > 0).to(dtype=torch.bool)
433 return (mask_cpu >= 0).to(dtype=torch.bool)
434
435 def _plan_from_visible_mask(attention_mask, bsz, q_len, kv_seq_len, device):
436 cache_key = (int(bsz), int(q_len), int(kv_seq_len), device.type, device.index)
437 cached = getattr(attention_mask, "_la_flash_magi_plan", None)
438 if cached is not None and cached[0] == cache_key:
439 return cached[1]
440
441 visible = _visible_from_4d_mask(attention_mask, int(kv_seq_len)).numpy()
442 q_ranges, k_ranges, types = [], [], []
443 for b in range(int(bsz)):
444 q_base = b * int(q_len)
445 k_base = b * int(kv_seq_len)
446 run_start = 0
447 run_segments = _row_segments(visible[b, 0])
448 for q in range(1, int(q_len)):
449 segments = _row_segments(visible[b, q])
450 if segments == run_segments:
451 continue
452 for start, end in run_segments:
453 q_ranges.append([q_base + run_start, q_base + q])
454 k_ranges.append([k_base + start, k_base + end])
455 types.append(FULL)
456 run_start = q
457 run_segments = segments
458 for start, end in run_segments:
459 q_ranges.append([q_base + run_start, q_base + int(q_len)])
460 k_ranges.append([k_base + start, k_base + end])
461 types.append(FULL)
462
463 plan = _tensor_plan(q_ranges, k_ranges, types, device)
464 try:
465 attention_mask._la_flash_magi_plan = (cache_key, plan)
466 except Exception:
467 pass
468 return plan
469
470 def _plan_from_magi_dict(attention_mask, bsz, q_len, kv_seq_len, device):
471 if int(bsz) == 1:
472 return attention_mask
473 q_ranges, k_ranges, types = [], [], []
474 for b in range(int(bsz)):
475 qs, ks, ts = _offset_plan(
476 attention_mask,
477 q_offset=b * int(q_len),
478 k_offset=b * int(kv_seq_len),
479 )
480 q_ranges.extend(qs)
481 k_ranges.extend(ks)
482 types.extend(ts)
483 return _tensor_plan(q_ranges, k_ranges, types, device)
484
485 def _magi_plan(attention_mask, bsz, q_len, kv_seq_len, device):
486 if isinstance(attention_mask, dict):
487 if attention_mask.get("_la_flash_batched", False):
488 return attention_mask
489 return _plan_from_magi_dict(attention_mask, bsz, q_len, kv_seq_len, device)
490 if attention_mask is None:
491 return _causal_plan(bsz, q_len, kv_seq_len, device)
492 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
493 raise ValueError(
494 f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
495 f"but is {attention_mask.size()}"
496 )
497 return _plan_from_visible_mask(attention_mask, bsz, q_len, kv_seq_len, device)
498
499 class _BatchedMagiAttention(mod.Qwen2Attention):
500 """MagiAttention path with true batch inference via packed token ranges."""
501
502 def forward(
503 self,
504 hidden_states: torch.Tensor,
505 attention_mask=None,
506 position_ids=None,
507 past_key_value=None,
508 output_attentions=False,
509 use_cache=False,
510 **kwargs,
511 ):
512 if output_attentions:
513 raise NotImplementedError("MagiAttention does not support output_attentions=True")
514
515 bsz, q_len, _ = hidden_states.size()
516 query_states = self.q_proj(hidden_states)
517 key_states = self.k_proj(hidden_states)
518 value_states = self.v_proj(hidden_states)
519
520 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
521 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
522 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
523
524 kv_seq_len = key_states.shape[-2]
525 if past_key_value is not None:
526 if self.layer_idx is None:
527 raise ValueError(
528 f"The cache structure has changed since version v4.36. If you are using "
529 f"{self.__class__.__name__} for auto-regressive decoding with k/v caching, "
530 "please initialize the attention class with a layer index."
531 )
532 kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
533
534 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
535 query_states, key_states = mod.apply_rotary_pos_emb(
536 query_states, key_states, cos, sin, position_ids)
537
538 if past_key_value is not None:
539 cache_kwargs = {"sin": sin, "cos": cos}
540 key_states, value_states = past_key_value.update(
541 key_states, value_states, self.layer_idx, cache_kwargs)
542
543 kv_seq_len = key_states.shape[-2]
544 plan = _magi_plan(attention_mask, bsz, q_len, kv_seq_len, query_states.device)
545 magi_extra_kwargs = {}
546 if supports_disable_fwd_atomic:
547 magi_extra_kwargs["disable_fwd_atomic_reduction"] = (
548 (not self.training) and _plan_disjoint_q_ranges(plan)
549 )
550
551 query_states = query_states.transpose(1, 2).reshape(
552 bsz * q_len, self.num_heads, self.head_dim).contiguous()
553 key_states = key_states.transpose(1, 2).reshape(
554 bsz * kv_seq_len, self.num_key_value_heads, self.head_dim).contiguous()
555 value_states = value_states.transpose(1, 2).reshape(
556 bsz * kv_seq_len, self.num_key_value_heads, self.head_dim).contiguous()
557
558 attn_output, _ = flex_flash_attn_func(
559 query_states,
560 key_states,
561 value_states,
562 q_ranges=plan["q_ranges"],
563 k_ranges=plan["k_ranges"],
564 attn_type_map=plan["attn_type_map"],
565 softmax_scale=getattr(self, "softmax_scale", self.head_dim ** -0.5),
566 softcap=0.0,
567 deterministic=False,
568 **magi_extra_kwargs,
569 )
570 attn_output = attn_output.view(bsz, q_len, self.hidden_size)
571 attn_output = self.o_proj(attn_output)
572 return attn_output, None, past_key_value
573
574 return _BatchedMagiAttention
575
576
577 def build_la_flash_attention_class(mod):
578 """Build a Qwen2 attention subclass backed by LA Flash sparse ranges."""
579 try:
580 from kernel_utils import is_available, range_attention
581 except Exception as exc:
582 raise RuntimeError(
583 "LA_FLASH_ATTN=la_flash requires kernel_utils and FlashAttention."
584 ) from exc
585 if not is_available():
586 raise RuntimeError(
587 "LA_FLASH_ATTN=la_flash requires flash_attn.flash_attn_varlen_func."
588 )
589
590 FULL, CAUSAL = 0, 1
591 causal_plan_cache = {}
592
593 def _tensor_plan(q_ranges, k_ranges, types, device):
594 max_q_len = max((int(end) - int(start) for start, end in q_ranges), default=0)
595 max_k_len = max((int(end) - int(start) for start, end in k_ranges), default=0)
596 plan = {
597 "q_ranges": torch.tensor(q_ranges, dtype=torch.int32, device=device).contiguous(),
598 "k_ranges": torch.tensor(k_ranges, dtype=torch.int32, device=device).contiguous(),
599 "attn_type_map": torch.tensor(types, dtype=torch.int32, device=device).contiguous(),
600 "max_q_len": max_q_len,
601 "max_k_len": max_k_len,
602 }
603 plan.update(_la_flash_group_plan_tensors(q_ranges, types, device))
604 return plan
605
606 def _offset_plan(plan, q_offset, k_offset):
607 return (
608 (plan["q_ranges"] + int(q_offset)).tolist(),
609 (plan["k_ranges"] + int(k_offset)).tolist(),
610 plan["attn_type_map"].tolist(),
611 )
612
613 def _causal_plan(bsz, q_len, kv_seq_len, device):
614 key = (int(bsz), int(q_len), int(kv_seq_len), device.type, device.index)
615 cached = causal_plan_cache.get(key)
616 if cached is not None:
617 return cached
618 q_ranges, k_ranges, types = [], [], []
619 for b in range(int(bsz)):
620 q_base = b * int(q_len)
621 k_base = b * int(kv_seq_len)
622 q_ranges.append([q_base, q_base + int(q_len)])
623 k_ranges.append([k_base, k_base + int(kv_seq_len)])
624 types.append(CAUSAL)
625 plan = _tensor_plan(q_ranges, k_ranges, types, device)
626 plan.update(
627 {
628 "flash_cu_seqlens_q": torch.arange(
629 0,
630 (int(bsz) + 1) * int(q_len),
631 int(q_len),
632 dtype=torch.int32,
633 device=device,
634 ),
635 "flash_cu_seqlens_k": torch.arange(
636 0,
637 (int(bsz) + 1) * int(kv_seq_len),
638 int(kv_seq_len),
639 dtype=torch.int32,
640 device=device,
641 ),
642 "flash_causal": True,
643 }
644 )
645 causal_plan_cache[key] = plan
646 return plan
647
648 def _row_segments(row):
649 idx = np.flatnonzero(row)
650 if idx.size == 0:
651 return ((0, 1),)
652 split = np.flatnonzero(np.diff(idx) > 1) + 1
653 starts = np.concatenate((idx[:1], idx[split]))
654 ends = np.concatenate((idx[split - 1], idx[-1:])) + 1
655 return tuple((int(s), int(e)) for s, e in zip(starts, ends))
656
657 def _visible_from_4d_mask(attention_mask, kv_seq_len):
658 mask = attention_mask[:, :, :, :kv_seq_len]
659 if mask.dtype == torch.bool:
660 return mask[:, 0].detach().to(device="cpu", dtype=torch.bool).contiguous()
661 mask_cpu = mask[:, 0].detach().to(device="cpu").contiguous()
662 if getattr(attention_mask, "_la_flash_visible_mask", False):
663 return (mask_cpu > 0).to(dtype=torch.bool)
664
665 max_value = float(mask_cpu.max().item()) if mask_cpu.numel() else 0.0
666 min_value = float(mask_cpu.min().item()) if mask_cpu.numel() else 0.0
667 if max_value > 0.0 and min_value >= 0.0:
668 return (mask_cpu > 0).to(dtype=torch.bool)
669 return (mask_cpu >= 0).to(dtype=torch.bool)
670
671 def _prefix_len(row):
672 idx = np.flatnonzero(row)
673 if idx.size == 0:
674 return None
675 end = int(idx[-1]) + 1
676 if not bool(row[:end].all()) or bool(row[end:].any()):
677 return None
678 return end
679
680 def _causal_plan_from_visible(visible, bsz, q_len, kv_seq_len, device):
681 q_ranges, k_ranges, types = [], [], []
682 packed_flash = True
683 for b in range(int(bsz)):
684 first_len = _prefix_len(visible[b, 0])
685 if first_len is None:
686 return None
687 valid_len = int(first_len) + int(q_len) - 1
688 if valid_len < int(q_len) or valid_len > int(kv_seq_len):
689 return None
690 for q in range(int(q_len)):
691 row_len = _prefix_len(visible[b, q])
692 expected = valid_len - int(q_len) + q + 1
693 if row_len != expected:
694 return None
695 q_base = b * int(q_len)
696 k_base = b * int(kv_seq_len)
697 q_ranges.append([q_base, q_base + int(q_len)])
698 k_ranges.append([k_base, k_base + valid_len])
699 types.append(CAUSAL)
700 packed_flash = packed_flash and valid_len == int(kv_seq_len)
701
702 plan = _tensor_plan(q_ranges, k_ranges, types, device)
703 plan["_la_flash_disjoint_q_ranges"] = True
704 if packed_flash:
705 plan.update(
706 {
707 "flash_cu_seqlens_q": torch.arange(
708 0,
709 (int(bsz) + 1) * int(q_len),
710 int(q_len),
711 dtype=torch.int32,
712 device=device,
713 ),
714 "flash_cu_seqlens_k": torch.arange(
715 0,
716 (int(bsz) + 1) * int(kv_seq_len),
717 int(kv_seq_len),
718 dtype=torch.int32,
719 device=device,
720 ),
721 "flash_causal": True,
722 }
723 )
724 return plan
725
726 def _plan_from_visible_mask(attention_mask, bsz, q_len, kv_seq_len, device):
727 cache_key = (int(bsz), int(q_len), int(kv_seq_len), device.type, device.index, "la_flash")
728 cached = getattr(attention_mask, "_la_flash_range_plan", None)
729 if cached is not None and cached[0] == cache_key:
730 return cached[1]
731
732 visible = _visible_from_4d_mask(attention_mask, int(kv_seq_len)).numpy()
733 plan = _causal_plan_from_visible(visible, bsz, q_len, kv_seq_len, device)
734 if plan is not None:
735 try:
736 attention_mask._la_flash_range_plan = (cache_key, plan)
737 except Exception:
738 pass
739 return plan
740
741 q_ranges, k_ranges, types = [], [], []
742 for b in range(int(bsz)):
743 q_base = b * int(q_len)
744 k_base = b * int(kv_seq_len)
745 run_start = 0
746 run_segments = _row_segments(visible[b, 0])
747 for q in range(1, int(q_len)):
748 segments = _row_segments(visible[b, q])
749 if segments == run_segments:
750 continue
751 for start, end in run_segments:
752 q_ranges.append([q_base + run_start, q_base + q])
753 k_ranges.append([k_base + start, k_base + end])
754 types.append(FULL)
755 run_start = q
756 run_segments = segments
757 for start, end in run_segments:
758 q_ranges.append([q_base + run_start, q_base + int(q_len)])
759 k_ranges.append([k_base + start, k_base + end])
760 types.append(FULL)
761
762 plan = _tensor_plan(q_ranges, k_ranges, types, device)
763 try:
764 attention_mask._la_flash_range_plan = (cache_key, plan)
765 except Exception:
766 pass
767 return plan
768
769 def _plan_from_magi_dict(attention_mask, bsz, q_len, kv_seq_len, device):
770 if int(bsz) == 1:
771 return attention_mask
772 q_ranges, k_ranges, types = [], [], []
773 for b in range(int(bsz)):
774 qs, ks, ts = _offset_plan(
775 attention_mask,
776 q_offset=b * int(q_len),
777 k_offset=b * int(kv_seq_len),
778 )
779 q_ranges.extend(qs)
780 k_ranges.extend(ks)
781 types.extend(ts)
782 return _tensor_plan(q_ranges, k_ranges, types, device)
783
784 def _range_plan(attention_mask, bsz, q_len, kv_seq_len, device):
785 if isinstance(attention_mask, dict):
786 if attention_mask.get("_la_flash_batched", False):
787 return attention_mask
788 return _plan_from_magi_dict(attention_mask, bsz, q_len, kv_seq_len, device)
789 if attention_mask is None:
790 return _causal_plan(bsz, q_len, kv_seq_len, device)
791 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
792 raise ValueError(
793 f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
794 f"but is {attention_mask.size()}"
795 )
796 return _plan_from_visible_mask(attention_mask, bsz, q_len, kv_seq_len, device)
797
798 class _LaFlashAttention(mod.Qwen2Attention):
799 """Range-plan attention path backed by FlashAttention sparse ranges."""
800
801 def forward(
802 self,
803 hidden_states: torch.Tensor,
804 attention_mask=None,
805 position_ids=None,
806 past_key_value=None,
807 output_attentions=False,
808 use_cache=False,
809 **kwargs,
810 ):
811 if output_attentions:
812 raise NotImplementedError("LA Flash attention does not support output_attentions=True")
813
814 bsz, q_len, _ = hidden_states.size()
815 query_states = self.q_proj(hidden_states)
816 key_states = self.k_proj(hidden_states)
817 value_states = self.v_proj(hidden_states)
818
819 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
820 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
821 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
822
823 kv_seq_len = key_states.shape[-2]
824 if past_key_value is not None:
825 if self.layer_idx is None:
826 raise ValueError(
827 f"The cache structure has changed since version v4.36. If you are using "
828 f"{self.__class__.__name__} for auto-regressive decoding with k/v caching, "
829 "please initialize the attention class with a layer index."
830 )
831 kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
832
833 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
834 query_states, key_states = mod.apply_rotary_pos_emb(
835 query_states, key_states, cos, sin, position_ids)
836
837 if past_key_value is not None:
838 cache_kwargs = {"sin": sin, "cos": cos}
839 key_states, value_states = past_key_value.update(
840 key_states, value_states, self.layer_idx, cache_kwargs)
841
842 kv_seq_len = key_states.shape[-2]
843 dense_backend = os.environ.get("LA_FLASH_DENSE_BACKEND", "sdpa").strip().lower()
844 if dense_backend == "sdpa" and not isinstance(attention_mask, dict):
845 dense_key_states = mod.repeat_kv(key_states, self.num_key_value_groups)
846 dense_value_states = mod.repeat_kv(value_states, self.num_key_value_groups)
847 if attention_mask is not None:
848 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
849 raise ValueError(
850 f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
851 f"but is {attention_mask.size()}"
852 )
853 query_for_sdpa = query_states.contiguous()
854 key_for_sdpa = dense_key_states.contiguous()
855 value_for_sdpa = dense_value_states.contiguous()
856 is_causal = False
857 elif past_key_value is None:
858 query_for_sdpa = query_states
859 key_for_sdpa = dense_key_states
860 value_for_sdpa = dense_value_states
861 is_causal = bool(self.is_causal and q_len > 1)
862 else:
863 query_for_sdpa = key_for_sdpa = value_for_sdpa = None
864 is_causal = False
865 if query_for_sdpa is not None:
866 attn_output = torch.nn.functional.scaled_dot_product_attention(
867 query_for_sdpa,
868 key_for_sdpa,
869 value_for_sdpa,
870 attn_mask=attention_mask,
871 dropout_p=self.attention_dropout if self.training else 0.0,
872 is_causal=is_causal,
873 )
874 attn_output = attn_output.transpose(1, 2).contiguous()
875 attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
876 attn_output = self.o_proj(attn_output)
877 return attn_output, None, past_key_value
878
879 plan = _range_plan(attention_mask, bsz, q_len, kv_seq_len, query_states.device)
880
881 query_states = query_states.transpose(1, 2).reshape(
882 bsz * q_len, self.num_heads, self.head_dim).contiguous()
883 key_states = key_states.transpose(1, 2).reshape(
884 bsz * kv_seq_len, self.num_key_value_heads, self.head_dim).contiguous()
885 value_states = value_states.transpose(1, 2).reshape(
886 bsz * kv_seq_len, self.num_key_value_heads, self.head_dim).contiguous()
887
888 attn_output = range_attention(
889 query_states,
890 key_states,
891 value_states,
892 plan["q_ranges"],
893 plan["k_ranges"],
894 plan["attn_type_map"],
895 getattr(self, "softmax_scale", self.head_dim ** -0.5),
896 segment_offsets=plan.get("segment_offsets"),
897 group_q_ranges=plan.get("group_q_ranges"),
898 group_attn_type_map=plan.get("group_attn_type_map"),
899 max_q_len=plan.get("max_q_len"),
900 max_k_len=plan.get("max_k_len"),
901 flash_cu_seqlens_q=plan.get("flash_cu_seqlens_q"),
902 flash_cu_seqlens_k=plan.get("flash_cu_seqlens_k"),
903 flash_causal=plan.get("flash_causal"),
904 disjoint_q_ranges=plan.get("_la_flash_disjoint_q_ranges"),
905 )
906 attn_output = attn_output.view(bsz, q_len, self.hidden_size)
907 attn_output = self.o_proj(attn_output)
908 return attn_output, None, past_key_value
909
910 return _LaFlashAttention
911
912
913 def _is_magi_plan(obj):
914 return isinstance(obj, dict) and {
915 "q_ranges",
916 "k_ranges",
917 "attn_type_map",
918 }.issubset(obj.keys())
919
920
921 def _la_flash_group_plan_tensors(q_ranges, types, device):
922 """Group consecutive Magi range entries that share the same query span.
923
924 Magi-style plans may represent one query span with multiple disjoint key
925 spans. LA Flash consumes those as one FlashAttention-backed softmax group.
926 """
927 if not q_ranges:
928 return {
929 "group_q_ranges": torch.empty((0, 2), dtype=torch.int32, device=device),
930 "segment_offsets": torch.zeros((1,), dtype=torch.int32, device=device),
931 "group_attn_type_map": torch.empty((0,), dtype=torch.int32, device=device),
932 }
933
934 grouped_q, grouped_types, offsets = [], [], [0]
935 last_q = None
936 last_type = None
937 for idx, (q_range, attn_type) in enumerate(zip(q_ranges, types)):
938 key = (int(q_range[0]), int(q_range[1]))
939 attn_type = int(attn_type)
940 if last_q is None:
941 grouped_q.append([key[0], key[1]])
942 grouped_types.append(attn_type)
943 last_q = key
944 last_type = attn_type
945 continue
946 if key == last_q and attn_type == last_type:
947 continue
948 offsets.append(idx)
949 grouped_q.append([key[0], key[1]])
950 grouped_types.append(attn_type)
951 last_q = key
952 last_type = attn_type
953 offsets.append(len(q_ranges))
954
955 return {
956 "group_q_ranges": torch.tensor(grouped_q, dtype=torch.int32, device=device).contiguous(),
957 "segment_offsets": torch.tensor(offsets, dtype=torch.int32, device=device).contiguous(),
958 "group_attn_type_map": torch.tensor(grouped_types, dtype=torch.int32, device=device).contiguous(),
959 "max_q_len": max((end - start for start, end in grouped_q), default=0),
960 }
961
962
963 def _record_sparse_plan_stats(model, q_ranges, k_ranges, types):
964 if os.environ.get("LA_FLASH_PLAN_STATS", "0") != "1":
965 return
966 stats = getattr(model, "_la_flash_sparse_plan_stats", None)
967 if stats is None:
968 stats = {
969 "calls": 0,
970 "ranges": 0,
971 "q_tokens": 0,
972 "k_tokens": 0,
973 "max_q_len": 0,
974 "max_k_len": 0,
975 "full_ranges": 0,
976 "causal_ranges": 0,
977 "other_ranges": 0,
978 }
979 model._la_flash_sparse_plan_stats = stats
980 stats["calls"] += 1
981 stats["ranges"] += len(q_ranges)
982 for (q_start, q_end), (k_start, k_end), attn_type in zip(q_ranges, k_ranges, types):
983 q_len = int(q_end) - int(q_start)
984 k_len = int(k_end) - int(k_start)
985 stats["q_tokens"] += q_len
986 stats["k_tokens"] += k_len
987 stats["max_q_len"] = max(stats["max_q_len"], q_len)
988 stats["max_k_len"] = max(stats["max_k_len"], k_len)
989 attn_type = int(attn_type)
990 if attn_type == 0:
991 stats["full_ranges"] += 1
992 elif attn_type == 1:
993 stats["causal_ranges"] += 1
994 else:
995 stats["other_ranges"] += 1
996
997
998 def build_magi_scheduler_ranges(model, attention_mask_2d, input_ids, past_len, mtp_window=False):
999 """Build batched Magi ranges directly from the hybrid scheduler mask.
1000
1001 The official Qwen2 SDPA dispatcher may optimize an all-valid 2D mask to
1002 ``None`` before decoder layers see it. That is correct for plain causal
1003 attention but loses LocateAnything's MTP generation-window rule. Building
1004 ranges here keeps Magi batch inference exact and avoids per-layer dense
1005 mask conversion.
1006 """
1007 requested_attn = getattr(model, "_la_flash_requested_attn", ATTN_MODE)
1008 if requested_attn not in {"magi", "la_flash"}:
1009 return None
1010 if attention_mask_2d is None or not hasattr(attention_mask_2d, "dim") or attention_mask_2d.dim() != 2:
1011 return None
1012
1013 bsz, q_len = int(input_ids.shape[0]), int(input_ids.shape[1])
1014 key_len = int(attention_mask_2d.shape[1])
1015 dev = input_ids.device
1016 llm = model.language_model.model
1017 block = int(getattr(llm, "block_size", N_FUTURE))
1018 causal_attn = bool(getattr(llm, "causal_attn", False))
1019 use_mtp_window = bool(mtp_window and q_len >= block and key_len >= block)
1020 q0 = max(0, q_len - block)
1021 k0 = max(0, key_len - block)
1022 blocked_k = k0 - 1
1023 past_len = int(past_len)
1024
1025 key_valid = attention_mask_2d.detach().to(device="cpu", dtype=torch.bool).contiguous().numpy()
1026 key_idx = np.arange(key_len)
1027 q_ranges, k_ranges, types = [], [], []
1028 if not use_mtp_window:
1029 causal_q_ranges, causal_k_ranges, causal_types = [], [], []
1030 causal_fast_path = True
1031 packed_flash = True
1032 for b in range(bsz):
1033 valid = np.flatnonzero(key_valid[b])
1034 if valid.size == 0:
1035 causal_fast_path = False
1036 break
1037 valid_len = int(valid[-1]) + 1
1038 if valid_len < q_len or not bool(key_valid[b, :valid_len].all()) or bool(key_valid[b, valid_len:].any()):
1039 causal_fast_path = False
1040 break
1041 packed_flash = packed_flash and valid_len == key_len
1042 q_base = b * q_len
1043 k_base = b * key_len
1044 causal_q_ranges.append([q_base, q_base + q_len])
1045 causal_k_ranges.append([k_base, k_base + valid_len])
1046 causal_types.append(1)
1047 if causal_fast_path:
1048 plan = {
1049 "q_ranges": torch.tensor(causal_q_ranges, dtype=torch.int32, device=dev).contiguous(),
1050 "k_ranges": torch.tensor(causal_k_ranges, dtype=torch.int32, device=dev).contiguous(),
1051 "attn_type_map": torch.tensor(causal_types, dtype=torch.int32, device=dev).contiguous(),
1052 "max_q_len": q_len,
1053 "max_k_len": max((end - start for start, end in causal_k_ranges), default=0),
1054 "_la_flash_batched": True,
1055 "_la_flash_disjoint_q_ranges": True,
1056 }
1057 if packed_flash:
1058 plan.update(
1059 {
1060 "flash_cu_seqlens_q": torch.arange(
1061 0,
1062 (bsz + 1) * q_len,
1063 q_len,
1064 dtype=torch.int32,
1065 device=dev,
1066 ),
1067 "flash_cu_seqlens_k": torch.arange(
1068 0,
1069 (bsz + 1) * key_len,
1070 key_len,
1071 dtype=torch.int32,
1072 device=dev,
1073 ),
1074 "flash_causal": True,
1075 }
1076 )
1077 plan.update(_la_flash_group_plan_tensors(causal_q_ranges, causal_types, dev))
1078 _record_sparse_plan_stats(model, causal_q_ranges, causal_k_ranges, causal_types)
1079 return plan
1080
1081 def row_segments(row):
1082 idx = np.flatnonzero(row)
1083 if idx.size == 0:
1084 return ((0, 1),)
1085 split = np.flatnonzero(np.diff(idx) > 1) + 1
1086 starts = np.concatenate((idx[:1], idx[split]))
1087 ends = np.concatenate((idx[split - 1], idx[-1:])) + 1
1088 return tuple((int(s), int(e)) for s, e in zip(starts, ends))
1089
1090 for b in range(bsz):
1091 q_base = b * q_len
1092 k_base = b * key_len
1093 run_start = 0
1094 run_segments = None
1095 if use_mtp_window and not causal_attn:
1096 prefix_q_len = q0
1097 prefix_k_end = past_len + prefix_q_len
1098 prefix_ok = (
1099 prefix_q_len > 0
1100 and prefix_k_end <= key_len
1101 and bool(key_valid[b, :prefix_k_end].all())
1102 )
1103 window_prefix_ok = blocked_k <= 0 or bool(key_valid[b, :blocked_k].all())
1104 window_ok = bool(key_valid[b, k0:key_len].all())
1105 if prefix_ok:
1106 q_ranges.append([q_base, q_base + prefix_q_len])
1107 k_ranges.append([k_base, k_base + prefix_k_end])
1108 types.append(1)
1109 run_start = prefix_q_len
1110 if run_start == prefix_q_len and prefix_q_len < q_len and window_prefix_ok and window_ok:
1111 if blocked_k > 0:
1112 q_ranges.append([q_base + prefix_q_len, q_base + q_len])
1113 k_ranges.append([k_base, k_base + blocked_k])
1114 types.append(0)
1115 q_ranges.append([q_base + prefix_q_len, q_base + q_len])
1116 k_ranges.append([k_base + k0, k_base + key_len])
1117 types.append(0)
1118 continue
1119
1120 for q in range(run_start, q_len):
1121 visible = key_valid[b] & (key_idx <= q + past_len)
1122 if use_mtp_window and q >= q0:
1123 if not causal_attn:
1124 visible = visible.copy()
1125 visible[k0:key_len] = key_valid[b, k0:key_len]
1126 if blocked_k >= 0:
1127 if visible.base is None:
1128 visible[blocked_k] = False
1129 else:
1130 visible = visible.copy()
1131 visible[blocked_k] = False
1132 segments = row_segments(visible)
1133 if run_segments is None:
1134 run_segments = segments
1135 continue
1136 if segments == run_segments:
1137 continue
1138 for start, end in run_segments:
1139 q_ranges.append([q_base + run_start, q_base + q])
1140 k_ranges.append([k_base + start, k_base + end])
1141 types.append(0)
1142 run_start = q
1143 run_segments = segments
1144 for start, end in run_segments:
1145 q_ranges.append([q_base + run_start, q_base + q_len])
1146 k_ranges.append([k_base + start, k_base + end])
1147 types.append(0)
1148
1149 seen_q_ranges = set()
1150 disjoint_q_ranges = True
1151 for start, end in q_ranges:
1152 key = (int(start), int(end))
1153 if key in seen_q_ranges:
1154 disjoint_q_ranges = False
1155 break
1156 seen_q_ranges.add(key)
1157
1158 plan = {
1159 "q_ranges": torch.tensor(q_ranges, dtype=torch.int32, device=dev).contiguous(),
1160 "k_ranges": torch.tensor(k_ranges, dtype=torch.int32, device=dev).contiguous(),
1161 "attn_type_map": torch.tensor(types, dtype=torch.int32, device=dev).contiguous(),
1162 "max_q_len": max((end - start for start, end in q_ranges), default=0),
1163 "max_k_len": max((end - start for start, end in k_ranges), default=0),
1164 "_la_flash_batched": True,
1165 "_la_flash_disjoint_q_ranges": disjoint_q_ranges,
1166 }
1167 plan.update(_la_flash_group_plan_tensors(q_ranges, types, dev))
1168 _record_sparse_plan_stats(model, q_ranges, k_ranges, types)
1169 return plan
1170
1171
1172 def _direct_base_forward(
1173 base,
1174 input_ids=None,
1175 visual_features=None,
1176 image_token_index=None,
1177 attention_mask=None,
1178 position_ids=None,
1179 past_key_values=None,
1180 inputs_embeds=None,
1181 use_cache=None,
1182 output_attentions=None,
1183 output_hidden_states=None,
1184 return_dict=None,
1185 ):
1186 mod = importlib.import_module(type(base).__module__)
1187 output_attentions = output_attentions if output_attentions is not None else base.config.output_attentions
1188 output_hidden_states = (
1189 output_hidden_states if output_hidden_states is not None else base.config.output_hidden_states
1190 )
1191 use_cache = use_cache if use_cache is not None else base.config.use_cache
1192
1193 if input_ids is not None and inputs_embeds is not None:
1194 raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1195 if input_ids is not None:
1196 batch_size, seq_length = input_ids.shape
1197 elif inputs_embeds is not None:
1198 batch_size, seq_length, _ = inputs_embeds.shape
1199 else:
1200 raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1201
1202 past_key_values_length = 0
1203 use_legacy_cache = False
1204 if use_cache:
1205 Cache = getattr(mod, "Cache")
1206 DynamicCache = getattr(mod, "DynamicCache")
1207 use_legacy_cache = not isinstance(past_key_values, Cache)
1208 if use_legacy_cache:
1209 if past_key_values is None:
1210 past_key_values = DynamicCache()
1211 else:
1212 past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1213 past_key_values_length = past_key_values.get_seq_length()
1214
1215 if position_ids is None:
1216 dev = input_ids.device if input_ids is not None else inputs_embeds.device
1217 position_ids = torch.arange(
1218 past_key_values_length,
1219 seq_length + past_key_values_length,
1220 dtype=torch.long,
1221 device=dev,
1222 ).unsqueeze(0).view(-1, seq_length)
1223 else:
1224 position_ids = position_ids.view(-1, seq_length).long()
1225
1226 if inputs_embeds is None:
1227 inputs_embeds = base.image_processing(input_ids, visual_features, image_token_index)
1228
1229 hidden_states = inputs_embeds
1230 all_hidden_states = () if output_hidden_states else None
1231 all_self_attns = () if output_attentions else None
1232 next_decoder_cache = None
1233
1234 for decoder_layer in base.layers:
1235 if output_hidden_states:
1236 all_hidden_states += (hidden_states,)
1237 layer_outputs = decoder_layer(
1238 hidden_states,
1239 attention_mask=attention_mask,
1240 position_ids=position_ids,
1241 past_key_value=past_key_values,
1242 output_attentions=output_attentions,
1243 use_cache=use_cache,
1244 )
1245 hidden_states = layer_outputs[0]
1246 if use_cache:
1247 next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1248 if output_attentions:
1249 all_self_attns += (layer_outputs[1],)
1250
1251 hidden_states = base.norm(hidden_states)
1252 if output_hidden_states:
1253 all_hidden_states += (hidden_states,)
1254 next_cache = None
1255 if use_cache:
1256 next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1257 return SimpleNamespace(
1258 last_hidden_state=hidden_states,
1259 past_key_values=next_cache,
1260 hidden_states=all_hidden_states,
1261 attentions=all_self_attns,
1262 )
1263
1264
1265 def language_model_forward(model, **kwargs):
1266 """Forward through the text LM, bypassing official dense-mask prep for sparse plans."""
1267 lm = model.language_model
1268 return_logits = kwargs.pop("return_logits", True)
1269 logits_slice = kwargs.pop("logits_slice", None)
1270 attention_mask = kwargs.get("attention_mask")
1271 use_direct_sparse = (
1272 getattr(model, "_la_flash_requested_attn", ATTN_MODE) in {"magi", "la_flash"}
1273 and _is_magi_plan(attention_mask)
1274 )
1275 if not use_direct_sparse:
1276 return lm(**kwargs)
1277
1278 labels = kwargs.pop("labels", None)
1279 if labels is not None:
1280 raise NotImplementedError("labels are not supported in the direct sparse-plan decode forward")
1281 output_attentions = kwargs.get("output_attentions", None)
1282 output_hidden_states = kwargs.get("output_hidden_states", None)
1283 base_out = _direct_base_forward(lm.model, **kwargs)
1284 logits = None
1285 if return_logits:
1286 hidden_states = base_out.last_hidden_state
1287 if logits_slice is not None:
1288 hidden_states = hidden_states[:, logits_slice, :]
1289 logits = lm.lm_head(hidden_states).float()
1290 return SimpleNamespace(
1291 logits=logits,
1292 past_key_values=base_out.past_key_values,
1293 hidden_states=base_out.hidden_states if output_hidden_states else None,
1294 attentions=base_out.attentions if output_attentions else None,
1295 )
1296
1297
1298 _EagerCls = _SdpaCls = _LaFlashCls = _MagiCls = None
1299 def _attn_classes(mode=None):
1300 """Attention classes from the dynamic Qwen2 remote module.
1301
1302 The official Qwen2Model mask dispatcher only implements ``sdpa`` and
1303 single-row ``magi``. Eager, LA Flash, and batched Magi inference
1304 therefore swap the layer class while keeping the model's mask dispatcher
1305 pinned to ``sdpa``.
1306 """
1307 global _EagerCls, _SdpaCls, _LaFlashCls, _MagiCls
1308 mode = _normalize_attn_mode(mode) if mode is not None else None
1309 if _SdpaCls is None:
1310 mod = importlib.import_module(type(_model.language_model.model).__module__)
1311 _EagerCls = mod.Qwen2Attention
1312 _SdpaCls = mod.Qwen2SdpaAttention
1313 else:
1314 mod = importlib.import_module(type(_model.language_model.model).__module__)
1315 if (mode is None or mode == "la_flash") and _LaFlashCls is None:
1316 _LaFlashCls = build_la_flash_attention_class(mod)
1317 if (mode is None or mode == "magi") and _MagiCls is None:
1318 _MagiCls = build_batched_magi_attention_class(mod) if getattr(mod, "_MAGI_AVAILABLE", False) else None
1319 return _EagerCls, _SdpaCls, _LaFlashCls, _MagiCls
1320
1321 def _set_llm_mode(model, mode):
1322 """Swap every Qwen2 decoder layer's attention class.
1323
1324 Release backends keep ``Qwen2Model._attn_implementation='sdpa'`` so the
1325 official Qwen2 mask dispatcher stays available for dense-mask modes. The
1326 local ``la_flash`` and batched ``magi`` wrappers can also consume scheduler-built
1327 sparse plans directly, avoiding repeated per-layer dense mask conversion.
1328 """
1329 mode = _normalize_attn_mode(mode)
1330 eager, sdpa, la_flash, magi = _attn_classes(mode)
1331 impl = "sdpa"
1332 if mode == "sdpa":
1333 cls = sdpa
1334 elif mode == "eager":
1335 cls = eager
1336 elif mode == "la_flash":
1337 cls = la_flash
1338 elif mode == "magi":
1339 if magi is None:
1340 raise RuntimeError("MagiAttention is unavailable in the current Python environment.")
1341 cls = magi
1342 else:
1343 raise ValueError(f"unknown LLM attention mode: {mode}")
1344 llm = model.language_model.model
1345 for lyr in llm.layers:
1346 lyr.self_attn.__class__ = cls
1347 if mode == "magi":
1348 lyr.self_attn.softmax_scale = lyr.self_attn.head_dim ** -0.5
1349 llm._attn_implementation = impl
1350 llm.config._attn_implementation = llm._attn_implementation
1351 if hasattr(model.config, "text_config"):
1352 model.config.text_config._attn_implementation = llm._attn_implementation
1353 model.config._attn_implementation = llm._attn_implementation
1354 model._la_flash_requested_attn = mode
1355
1356 _st = _hp = None
1357 def _helpers():
1358 """The model's own sample_tokens / handle_pattern (the exact box decoders)."""
1359 global _st, _hp
1360 if _st is None:
1361 m = importlib.import_module(type(load()[2]).__module__)
1362 _st, _hp = m.sample_tokens, m.handle_pattern
1363 return _st, _hp
1364
1365
1366 _gu = None
1367 def _gen_utils():
1368 """The model's generate_utils module (apply_repetition_penalty / top_p_logits / top_k_logits /
1369 decode_bbox_avg / decode_ref / dists) -- the pieces sample_tokens_batched reuses verbatim."""
1370 global _gu
1371 if _gu is None:
1372 m = importlib.import_module(type(load()[2]).__module__)
1373 _gu = importlib.import_module(m.sample_tokens.__module__)
1374 return _gu
1375
1376
1377 def _env_float(name, default):
1378 val = os.environ.get(name)
1379 if val is None or val.strip() == "":
1380 return float(default)
1381 return float(val)
1382
1383
1384 def _coord_fallback_mode():
1385 mode = os.environ.get("LA_FLASH_COORD_FALLBACK_MODE", "legacy").strip().lower().replace("-", "_")
1386 aliases = {
1387 "": "legacy",
1388 "official": "legacy",
1389 "range": "legacy",
1390 "spread": "legacy",
1391 "none": "off",
1392 "disable": "off",
1393 "disabled": "off",
1394 "entropy_variance": "uncertainty",
1395 "entropy_var": "uncertainty",
1396 "ent_var": "uncertainty",
1397 "entropy_std": "uncertainty",
1398 }
1399 mode = aliases.get(mode, mode)
1400 if mode not in {"legacy", "uncertainty", "off"}:
1401 raise ValueError(
1402 "LA_FLASH_COORD_FALLBACK_MODE must be one of legacy, uncertainty, off"
1403 )
1404 return mode
1405
1406
1407 def _coord_uncertainty_threshold(coord_start_token_id, coord_end_token_id):
1408 """Return the coord uncertainty threshold in raw coord-token units.
1409
1410 Backward-compatible behavior:
1411 - LA_FLASH_COORD_UNCERTAINTY_THRESH > 1 is treated as raw coord-token RMSE.
1412 - LA_FLASH_COORD_UNCERTAINTY_THRESH <= 1 is treated as normalized by coord span.
1413 - LA_FLASH_COORD_UNCERTAINTY_NORM_THRESH is an explicit normalized override.
1414 """
1415 coord_span = max(float(coord_end_token_id - coord_start_token_id + 1), 1.0)
1416 norm_val = os.environ.get("LA_FLASH_COORD_UNCERTAINTY_NORM_THRESH")
1417 if norm_val is not None and norm_val.strip() != "":
1418 return float(norm_val) * coord_span
1419
1420 val = os.environ.get("LA_FLASH_COORD_UNCERTAINTY_THRESH")
1421 if val is None or val.strip() == "":
1422 return 20.0
1423 threshold = float(val)
1424 if 0.0 < threshold <= 1.0:
1425 return threshold * coord_span
1426 return threshold
1427
1428
1429 def _decode_bbox_with_uncertainty(logits, probs, token_ids, keep_k=4, generation_mode="hybrid"):
1430 """Decode an MTP box with configurable coord uncertainty fallback.
1431
1432 The default mode is the official LocateAnything rule. ``uncertainty`` keeps
1433 the same frame checks and top-k coord selection, but uses one scalar
1434 criterion per coordinate: the posterior RMSE of committing to the current
1435 MAP coordinate among valid coord candidates. This is the Bayes risk under
1436 squared coordinate error, so probabilities and token distances are folded
1437 into one threshold in coordinate-token units.
1438 """
1439 gu = _gen_utils()
1440 mode = _coord_fallback_mode()
1441 if mode == "legacy" or generation_mode != "hybrid":
1442 return gu.decode_bbox_avg(logits, probs, token_ids, keep_k=keep_k, generation_mode=generation_mode)
1443
1444 coord_start_token_id = token_ids["coord_start_token_id"]
1445 coord_end_token_id = token_ids["coord_end_token_id"]
1446 box_start_token_id = token_ids["box_start_token_id"]
1447 box_end_token_id = token_ids["box_end_token_id"]
1448 none_token_id = token_ids["none_token_id"]
1449 null_token_id = token_ids["null_token_id"]
1450 device = logits.device
1451
1452 box_type = gu.is_valid_box_frame(
1453 probs,
1454 token_ids,
1455 start_thresh=_env_float("LA_FLASH_COORD_BOX_START_THRESH", 0.7),
1456 end_thresh=_env_float("LA_FLASH_COORD_BOX_END_THRESH", 0.2),
1457 topk=keep_k,
1458 )
1459 if box_type == "empty_box":
1460 return torch.tensor([
1461 box_start_token_id,
1462 none_token_id,
1463 box_end_token_id,
1464 null_token_id,
1465 null_token_id,
1466 null_token_id,
1467 ], dtype=torch.long, device=device)
1468 if box_type == "illegal_box":
1469 return None
1470
1471 pos_probs, pos_ids = torch.topk(probs[1:5], k=keep_k, dim=-1)
1472 valid = (pos_ids >= coord_start_token_id) & (pos_ids <= coord_end_token_id)
1473 has_valid = valid.any(dim=-1)
1474 if not has_valid.all():
1475 return None
1476
1477 first_valid_idx = valid.long().argmax(dim=-1, keepdim=True)
1478 first_valid_ids = pos_ids.gather(-1, first_valid_idx).squeeze(-1)
1479 if mode == "off":
1480 final_coords = first_valid_ids
1481 else:
1482 valid_counts = valid.sum(dim=-1)
1483 valid_probs = torch.where(valid, pos_probs, torch.zeros_like(pos_probs))
1484 valid_mass = valid_probs.sum(dim=-1).clamp_min(1e-12)
1485 weights = valid_probs / valid_mass.unsqueeze(-1)
1486 coord_values = (pos_ids - coord_start_token_id).to(dtype=torch.float32)
1487 map_coord = (first_valid_ids - coord_start_token_id).to(dtype=torch.float32)
1488 uncertainty = (weights * (coord_values - map_coord.unsqueeze(-1)).pow(2)).sum(dim=-1).sqrt()
1489 is_abnormal = (
1490 (valid_counts > 1)
1491 & (uncertainty > _coord_uncertainty_threshold(coord_start_token_id, coord_end_token_id))
1492 )
1493 final_coords = torch.where(is_abnormal, torch.tensor(0, device=device), first_valid_ids)
1494
1495 start_t = torch.tensor([box_start_token_id], dtype=final_coords.dtype, device=device)
1496 end_t = torch.tensor([box_end_token_id], dtype=final_coords.dtype, device=device)
1497 return torch.cat([start_t, final_coords, end_t])
1498
1499
1500 def _apply_repetition_penalty_lowmem(logits, generated, repetition_penalty):
1501 """Apply the stock repetition penalty without allocating a [B, S, V] mask."""
1502 if repetition_penalty == 1.0:
1503 return logits
1504 _, _, vocab_size = logits.shape
1505 for row in range(logits.shape[0]):
1506 valid_tokens = generated[row].unique()
1507 valid_tokens = valid_tokens[(valid_tokens >= 0) & (valid_tokens < vocab_size)]
1508 if valid_tokens.numel() == 0:
1509 continue
1510 row_logits = logits[row, :, valid_tokens]
1511 logits[row, :, valid_tokens] = torch.where(
1512 row_logits > 0,
1513 row_logits / repetition_penalty,
1514 row_logits * repetition_penalty,
1515 )
1516 return logits
1517
1518
1519 def _finite_logit_bounds(dtype):
1520 finfo = torch.finfo(dtype)
1521 return finfo.min, finfo.max
1522
1523
1524 def _finite_logits(logits):
1525 if not logits.dtype.is_floating_point:
1526 logits = logits.float()
1527 min_val, max_val = _finite_logit_bounds(logits.dtype)
1528 return torch.nan_to_num(logits, nan=min_val, posinf=max_val, neginf=min_val)
1529
1530
1531 def _finite_logits_(logits):
1532 if not logits.dtype.is_floating_point:
1533 return logits.float()
1534 min_val, max_val = _finite_logit_bounds(logits.dtype)
1535 return logits.nan_to_num_(nan=min_val, posinf=max_val, neginf=min_val)
1536
1537
1538 def _top_p_logits_slice_(logits, top_p):
1539 sorted_logits, sorted_indices = torch.sort(logits, descending=True)
1540 cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
1541 sorted_indices_to_remove = cumulative_probs > top_p
1542 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
1543 sorted_indices_to_remove[..., 0] = False
1544
1545 remove = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
1546 remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
1547 logits.masked_fill_(remove, torch.finfo(logits.dtype).min)
1548 return logits
1549
1550
1551 def _top_p_logits_(logits, top_p):
1552 """In-place nucleus filtering with bounded sort workspace.
1553
1554 The MTP sampler uses logits shaped ``[B, 6, V]``. Top-p is independent for
1555 each row and each future position, so filtering one position at a time keeps
1556 the expensive sorted-index workspace at ``[B, V]`` instead of ``[B, 6, V]``.
1557 """
1558 if logits.dim() == 3 and logits.shape[1] > 1:
1559 for pos in range(logits.shape[1]):
1560 _top_p_logits_slice_(logits[:, pos, :], top_p)
1561 return logits
1562 return _top_p_logits_slice_(logits, top_p)
1563
1564
1565 def _top_k_logits_(logits, top_k):
1566 """In-place top-k filtering mirroring generate_utils.top_k_logits."""
1567 top_k = min(int(top_k), logits.size(-1))
1568 threshold = torch.topk(logits, top_k)[0][..., -1, None]
1569 logits.masked_fill_(logits < threshold, torch.finfo(logits.dtype).min)
1570 return logits
1571
1572
1573 def _safe_probs(filtered_logits):
1574 """Softmax with CUDA-multinomial-safe cleanup and row-wise argmax fallback."""
1575 filtered_logits = _finite_logits(filtered_logits)
1576 probs = torch.softmax(filtered_logits, dim=-1, dtype=torch.float32)
1577 probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0).clamp_min_(0.0)
1578 row_sum = probs.sum(dim=-1, keepdim=True)
1579 bad = (~torch.isfinite(row_sum)) | (row_sum <= 0)
1580 if bool(bad.any().item()):
1581 fallback = torch.zeros_like(probs)
1582 fallback.scatter_(-1, filtered_logits.argmax(dim=-1, keepdim=True), 1.0)
1583 probs = torch.where(bad, fallback, probs)
1584 row_sum = probs.sum(dim=-1, keepdim=True)
1585 return probs / row_sum.clamp_min(1.0e-20)
1586
1587
1588 def _sample_top_p_sorted_tokens(logits, top_p):
1589 """Sample from top-p filtered logits without scattering back to vocab order."""
1590 sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
1591 cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
1592 remove = cumulative_probs > top_p
1593 remove[..., 1:] = remove[..., :-1].clone()
1594 remove[..., 0] = False
1595 sorted_logits.masked_fill_(remove, torch.finfo(sorted_logits.dtype).min)
1596 sorted_probs = _safe_probs(sorted_logits)
1597 sample_idx = sorted_probs.argmax(dim=-1)
1598 try:
1599 sample_idx = torch.distributions.Categorical(probs=sorted_probs).sample()
1600 except Exception:
1601 pass
1602 return sorted_indices.gather(-1, sample_idx.unsqueeze(-1)).squeeze(-1)
1603
1604
1605 @torch.no_grad()
1606 def sample_tokens_batched(logits, generated, token_ids, per_row_temp,
1607 repetition_penalty=1.0, top_p=None, top_k=None,
1608 keep_k_avg=4, generation_mode='fast'):
1609 """Batched fork of generate_utils.sample_tokens for the MTP window [B,6,V]. The logits pipeline
1610 (rep-penalty / per-row temperature / top_p / top_k / softmax / sample) is ROW-INDEPENDENT, so run
1611 it ONCE over the whole batch instead of B times on [1,6,V] (the per-row san defeats batching by
1612 slicing wlogits[b:b+1]). Only the variable-length box ASSEMBLY (decode_bbox_avg -> ragged shapes,
1613 where sample_tokens' final torch.stack throws) stays per-row, returned as a LIST.
1614
1615 Equivalence to per-row san: every pipeline op reduces on dim=-1 only (never crosses the row dim),
1616 so row b's processed logits/probs are bit-identical to slicing first -> greedy (per_row_temp==0,
1617 argmax branch, no RNG) is BIT-EXACT. Under sampling, one batched Categorical changes the global
1618 RNG consumption order vs B per-row draws -> box-size jitter (blessed; greedy is the exact gate).
1619 apply_repetition_penalty already loops per-row internally, so passing the full [B,M] `generated`
1620 is row-correct. keep_k_avg/generation_mode mirror sample_tokens' decode_bbox_avg call EXACTLY
1621 (note: the per-row san passes keep_k=5 but decode_bbox_avg reads keep_k_avg, default 4 -- so 5 is
1622 a no-op there; we replicate keep_k_avg=4). Returns (x0[B,6], boxes: list of B 1-D LongTensors)."""
1623 gu = _gen_utils()
1624 B, S, V = logits.shape # S = N_FUTURE = 6
1625 if repetition_penalty != 1.0:
1626 logits = _apply_repetition_penalty_lowmem(logits, generated, repetition_penalty)
1627 t = per_row_temp.to(dtype=logits.dtype).view(B, 1, 1)
1628 sample_rows = per_row_temp > 0
1629 if bool(sample_rows.all().item()):
1630 logits.div_(t.clamp(min=1e-8))
1631 elif bool(sample_rows.any().item()):
1632 idx = sample_rows.nonzero(as_tuple=True)[0]
1633 logits[idx].div_(t[idx].clamp(min=1e-8))
1634 logits = _finite_logits_(logits)
1635 if top_p is not None and top_p < 1:
1636 logits = _top_p_logits_(logits, top_p)
1637 if top_k is not None and top_k > 0:
1638 logits = _top_k_logits_(logits, top_k)
1639 probs = _safe_probs(logits)
1640 x0 = probs.argmax(dim=-1) # [B,6]; greedy rows are final here
1641 samp = per_row_temp > 0
1642 if bool(samp.any()): # sampling rows: ONE batched Categorical draw
1643 idx = samp.nonzero(as_tuple=True)[0]
1644 try:
1645 x0[idx] = gu.dists.Categorical(probs=probs[idx]).sample()
1646 except Exception:
1647 pass # keep argmax (matches san's except: probs.max)
1648 boxes = []
1649 fallback = torch.zeros(1, dtype=x0.dtype, device=x0.device)
1650 for b in range(B): # variable-length box assembly (per-row, exact)
1651 db = _decode_bbox_with_uncertainty(
1652 logits[b], probs[b], token_ids,
1653 keep_k=keep_k_avg, generation_mode=generation_mode)
1654 if db is not None:
1655 boxes.append(db)
1656 else:
1657 ref = gu.decode_ref(logits[b], probs[b], token_ids)
1658 if ref is None:
1659 boxes.append(fallback)
1660 elif torch.is_tensor(ref):
1661 boxes.append(ref.to(dtype=x0.dtype, device=x0.device))
1662 else:
1663 boxes.append(torch.tensor(ref, dtype=x0.dtype, device=x0.device))
1664 return x0, boxes
1665
1666
1667 @torch.no_grad()
1668 def sample_next_tokens_batched(logits, generated, per_row_temp,
1669 repetition_penalty=1.0, top_p=None, top_k=None):
1670 """Batched one-token sampler for AR repair rows.
1671
1672 This mirrors the row-independent part of ``sample_tokens`` for logits shaped
1673 ``[B,1,V]``. It intentionally does not run bbox/ref assembly because AR mode
1674 only needs the next token before the state machine classifies it.
1675 """
1676 gu = _gen_utils()
1677 if logits.dim() != 3 or logits.shape[1] != 1:
1678 raise ValueError(f"AR batched sampler expects logits [B,1,V], got {tuple(logits.shape)}")
1679 B = int(logits.shape[0])
1680 if repetition_penalty != 1.0:
1681 logits = _apply_repetition_penalty_lowmem(logits, generated, repetition_penalty)
1682 t = per_row_temp.to(dtype=logits.dtype).view(B, 1, 1)
1683 sample_rows = per_row_temp > 0
1684 if bool(sample_rows.all().item()):
1685 logits.div_(t.clamp(min=1e-8))
1686 elif bool(sample_rows.any().item()):
1687 idx = sample_rows.nonzero(as_tuple=True)[0]
1688 logits[idx].div_(t[idx].clamp(min=1e-8))
1689 logits = _finite_logits_(logits)
1690 sorted_top_p = os.environ.get("AR_SORTED_TOPP", "0") == "1"
1691 default_top_p = sorted_top_p and top_p is not None and top_p < 1 and (top_k is None or top_k <= 0)
1692 if default_top_p and bool(sample_rows.all().item()):
1693 return _sample_top_p_sorted_tokens(logits, top_p)
1694 if top_p is not None and top_p < 1:
1695 logits = _top_p_logits_(logits, top_p)
1696 if top_k is not None and top_k > 0:
1697 logits = _top_k_logits_(logits, top_k)
1698 probs = _safe_probs(logits)
1699 x0 = probs.argmax(dim=-1)
1700 if bool(sample_rows.any().item()):
1701 # Keep row-ordered sampling as the release default. A single batched
1702 # Categorical is faster, but it consumes RNG differently from stock AR
1703 # repair and can alter default-temperature termination behavior.
1704 for row in sample_rows.nonzero(as_tuple=True)[0].tolist():
1705 try:
1706 x0[row : row + 1] = gu.dists.Categorical(probs=probs[row : row + 1]).sample()
1707 except Exception:
1708 pass
1709 return x0
1710
1711
1712 def load_pil(p):
1713 from PIL import Image
1714 im = Image.open(p).convert("RGB"); w, h = im.size
1715 if max(w, h) > MAX_DIM:
1716 s = MAX_DIM / max(w, h); im = im.resize((max(1, round(w*s)), max(1, round(h*s))), Image.LANCZOS)
1717 return im
1718
1719 def _preproc_one(im):
1720 """CPU-side processor for one image -> (pixel_values[bf16], grid[int32]). Split out of
1721 _encode_image so _encode_images can batch the GPU encode while preprocessing stays per-image."""
1722 tok, proc, model = load()
1723 msg = [{"role": "user", "content": [{"type": "image", "image": im}, {"type": "text", "text": "x"}]}]
1724 text = proc.py_apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
1725 imgs, vids = proc.process_vision_info(msg)
1726 inp = proc(text=[text], images=imgs, videos=vids, return_tensors="pt").to(DEV)
1727 grid = inp.get("image_grid_hws")
1728 if isinstance(grid, np.ndarray): grid = torch.from_numpy(grid).to(DEV, dtype=torch.int32)
1729 return inp["pixel_values"].to(DT), grid
1730
1731
1732 def _vision_is_flash():
1733 """True iff MoonViT will actually run flash_attn_varlen (so cross-image packing is
1734 block-diagonal = exact AND a win). If the vision blocks are on sdpa/eager, OR the flash
1735 wheel is absent (multihead_attention falls back to the dense-mask sdpa path), packing is
1736 O(S^2) N^2 -> caller must stay per-image."""
1737 vm = load()[2].vision_model
1738 mod = importlib.import_module(type(vm).__module__)
1739 if getattr(mod, "flash_attn_varlen_func", None) is None:
1740 return False
1741 try:
1742 return vm.encoder.blocks[0].attn_implementation == "flash_attention_2"
1743 except Exception:
1744 return False
1745
1746
1747 @torch.no_grad()
1748 def _encode_images(ims):
1749 """N images -> list of [n_img_tokens, C] mlp1-projected visual_features, one per image
1750 (row-order). Drop-in for [_encode_image(im) for im in ims].
1751
1752 With flash present (_vision_is_flash) and N>1, packs images into
1753 extract_feature micro-batches: MoonViT's varlen cu_seqlens path is
1754 block-diagonal by image. Without flash, the dense SDPA fallback would scale
1755 with the packed total sequence length, so this function falls back to
1756 per-image encode. MTP_BATCH_VISION=0 also forces per-image encode."""
1757 tok, proc, model = load()
1758 pvs, grids = [], []
1759 for im in ims:
1760 pv, g = _preproc_one(im)
1761 pvs.append(pv); grids.append(g)
1762 if BATCH_VISION and len(ims) > 1 and _vision_is_flash():
1763 if VISION_ENCODE_BATCH_SIZE <= 0 or VISION_ENCODE_BATCH_SIZE >= len(ims):
1764 vit_list = model.extract_feature(torch.cat(pvs, dim=0), torch.cat(grids, dim=0))
1765 else:
1766 vit_list = []
1767 for start in range(0, len(ims), VISION_ENCODE_BATCH_SIZE):
1768 end = min(start + VISION_ENCODE_BATCH_SIZE, len(ims))
1769 vit_list.extend(
1770 model.extract_feature(
1771 torch.cat(pvs[start:end], dim=0),
1772 torch.cat(grids[start:end], dim=0),
1773 )
1774 )
1775 return [model.mlp1(v) for v in vit_list] # one [P_i, C] per image (patch_merger split)
1776 return [model.mlp1(torch.cat(model.extract_feature(pv, g), dim=0))
1777 for pv, g in zip(pvs, grids)] # per-image (flash absent / N==1 / forced off)
1778
1779
1780 @torch.no_grad()
1781 def _encode_image(im):
1782 """Single-image convenience wrapper (single-image callers); = _encode_images([im])[0]
1783 (takes the per-image path inside _encode_images, so bit-identical to the original)."""
1784 return _encode_images([im])[0]
1785
1786 @torch.no_grad()
1787 def _tokenize(im, query):
1788 """1-D prompt token ids for (image, query). Uses the model's own chat template."""
1789 tok, proc, model = load()
1790 msg = [{"role": "user", "content": [{"type": "image", "image": im},
1791 {"type": "text", "text": _PROMPT + query + "."}]}]
1792 text = proc.py_apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
1793 imgs, vids = proc.process_vision_info(msg)
1794 return proc(text=[text], images=imgs, videos=vids, return_tensors="pt").to(DEV)["input_ids"][0]
1795
1796
1797 @torch.no_grad()
1798 def _tokenize_cached_image(query, image_token_count, im=None):
1799 """Tokenize a prompt when the image token count is already known.
1800
1801 This keeps the processor's chat template, but directly expands ``<image-1>``
1802 from the cached visual feature length. It avoids re-running the CPU image
1803 processor for every category prompt that shares the same image.
1804 """
1805 tok, proc, model = load()
1806 msg = [{"role": "user", "content": [{"type": "image", "image": im},
1807 {"type": "text", "text": _PROMPT + query + "."}]}]
1808 text = proc.py_apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
1809 placeholder = f"<{getattr(proc, 'image_placeholder', 'image')}-1>"
1810 image_token = getattr(proc, "image_token", "<IMG_CONTEXT>")
1811 image_start = getattr(proc, "image_start_token", "<img>")
1812 image_end = getattr(proc, "image_end_token", "</img>")
1813 replacement = f"<image 1>{image_start}{image_token * int(image_token_count)}{image_end}"
1814 if placeholder not in text:
1815 raise ValueError(f"cached image placeholder {placeholder!r} was not found in chat template")
1816 text = text.replace(placeholder, replacement, 1)
1817 return tok([text], return_tensors="pt").to(DEV)["input_ids"][0]
1818
1819
1820 def _proc_full(im, query):
1821 """Full processor dict (input_ids, attention_mask, pixel_values, image_grid_hws) —
1822 used by the bench to drive the STOCK generate for the equivalence check."""
1823 tok, proc, model = load()
1824 msg = [{"role": "user", "content": [{"type": "image", "image": im},
1825 {"type": "text", "text": _PROMPT + query + "."}]}]
1826 text = proc.py_apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
1827 imgs, vids = proc.process_vision_info(msg)
1828 inp = proc(text=[text], images=imgs, videos=vids, return_tensors="pt").to(DEV)
1829 grid = inp.get("image_grid_hws")
1830 if isinstance(grid, np.ndarray): grid = torch.from_numpy(grid).to(DEV, dtype=torch.int32)
1831 inp["image_grid_hws"] = grid
1832 return inp
1833
1834 def _pad_generated(prompt_ids, gen_ids, img_tok, dev):
1835 """Per-row [prompt + accepted] left-padded with the image token (already in every
1836 prompt -> .unique() unchanged -> repetition penalty identical to single-run)."""
1837 rows = [list(prompt_ids[b].tolist()) + gen_ids[b] for b in range(len(prompt_ids))]
1838 M = max(len(r) for r in rows)
1839 out = torch.full((len(rows), M), img_tok, dtype=torch.long, device=dev)
1840 for b, r in enumerate(rows):
1841 out[b, M - len(r):] = torch.tensor(r, dtype=torch.long, device=dev)
1842 return out
1843