batch_utils/hybrid_runtime.py
| 1 | """Internal runtime support for the LocateAnything-3B hybrid batch decoder. |
| 2 | |
| 3 | This file keeps only the model-loading, tokenization, image-encoding, stock |
| 4 | processor, and sample-token helpers that ``engine_hybrid.py`` needs. |
| 5 | |
| 6 | Important env knobs: |
| 7 | LA_FLASH_MODEL HF repo id / local path of the model (default nvidia/LocateAnything-3B) |
| 8 | HF_HUB_OFFLINE=1 read the local HF cache only (no network); unset -> download on first use |
| 9 | LA_FLASH_ATTN sdpa, eager, magi, or la_flash; la_flash uses FlashAttention sparse ranges |
| 10 | LA_FLASH_STRICT_ATTN 1 -> fail if the requested backend is unavailable; |
| 11 | default 0 falls back to sdpa |
| 12 | LA_FLASH_VISION_ATTN auto, flash_attention_2, sdpa, or eager (default auto) |
| 13 | LA_FLASH_HYBRID_PREFILL shared, none, per_row, or batch prompt KV prefill (default shared) |
| 14 | MTP_BATCH_VISION 0 -> per-image vision encode (default 1: batched when flash is present) |
| 15 | LA_FLASH_VISION_ENCODE_BATCH_SIZE |
| 16 | max images per MoonViT encode micro-batch (default 8; <=0 disables limit) |
| 17 | MTP_BATCH_SAN 0 -> per-row logits/sample pipeline (default 1: batched over [B,6,V]) |
| 18 | AR_BATCH_SAN 0 -> per-row AR sample pipeline (default 1: batched over [B,1,V]) |
| 19 | """ |
| 20 | import inspect |
| 21 | import os, warnings, importlib, torch |
| 22 | from types import SimpleNamespace |
| 23 | import numpy as np |
| 24 | from transformers import AutoModel, AutoTokenizer, AutoProcessor |
| 25 | |
| 26 | |
| 27 | # By default let transformers fetch the model on first use; set HF_HUB_OFFLINE=1 yourself |
| 28 | # to read the local HF cache only (e.g. air-gapped / already-downloaded runs). |
| 29 | MODEL = os.environ.get("LA_FLASH_MODEL", "nvidia/LocateAnything-3B") |
| 30 | |
| 31 | |
| 32 | LLM_ATTN_MODES = ("sdpa", "eager", "magi", "la_flash") |
| 33 | VISION_ATTN_MODES = ("auto", "flash_attention_2", "sdpa", "eager") |
| 34 | |
| 35 | |
| 36 | def _normalize_attn_mode(value): |
| 37 | mode = (value or "sdpa").strip().lower().replace("-", "_") |
| 38 | aliases = { |
| 39 | "": "sdpa", |
| 40 | "manual": "eager", |
| 41 | "torch": "eager", |
| 42 | "torch_eager": "eager", |
| 43 | "torch_sdpa": "sdpa", |
| 44 | "scaled_dot_product_attention": "sdpa", |
| 45 | "flash": "la_flash", |
| 46 | "la_flash": "la_flash", |
| 47 | "kernel": "la_flash", |
| 48 | "cuda": "la_flash", |
| 49 | "range": "la_flash", |
| 50 | "range_attention": "la_flash", |
| 51 | "flex_flash": "magi", |
| 52 | "flex_flash_attention": "magi", |
| 53 | "flex_flash_attn": "magi", |
| 54 | } |
| 55 | mode = aliases.get(mode, mode) |
| 56 | if mode not in LLM_ATTN_MODES: |
| 57 | raise ValueError( |
| 58 | f"LA_FLASH_ATTN must be one of {', '.join(LLM_ATTN_MODES)}; got {value!r}" |
| 59 | ) |
| 60 | return mode |
| 61 | |
| 62 | |
| 63 | def _normalize_vision_attn_mode(value): |
| 64 | mode = (value or "auto").strip().lower().replace("-", "_") |
| 65 | aliases = { |
| 66 | "": "auto", |
| 67 | "flash": "flash_attention_2", |
| 68 | "flash_attention2": "flash_attention_2", |
| 69 | "fa2": "flash_attention_2", |
| 70 | "manual": "eager", |
| 71 | } |
| 72 | mode = aliases.get(mode, mode) |
| 73 | if mode not in VISION_ATTN_MODES: |
| 74 | raise ValueError( |
| 75 | f"LA_FLASH_VISION_ATTN must be one of {', '.join(VISION_ATTN_MODES)}; got {value!r}" |
| 76 | ) |
| 77 | return mode |
| 78 | |
| 79 | |
| 80 | ATTN_MODE = _normalize_attn_mode(os.environ.get("LA_FLASH_ATTN", "sdpa")) |
| 81 | REMOTE_ATTN_MODE = "sdpa" if ATTN_MODE in {"la_flash", "magi"} else ATTN_MODE |
| 82 | VISION_ATTN_MODE = _normalize_vision_attn_mode(os.environ.get("LA_FLASH_VISION_ATTN", "auto")) |
| 83 | MAX_DIM = 1024 |
| 84 | DEV, DT = "cuda", torch.bfloat16 |
| 85 | N_FUTURE = 6 # = config.block_size (MTP window) |
| 86 | _PROMPT = "Locate all the instances that matches the following description: " |
| 87 | |
| 88 | |
| 89 | def _env_flag(name, default=False): |
| 90 | val = os.environ.get(name) |
| 91 | if val is None: |
| 92 | return default |
| 93 | return val.strip().lower() not in {"0", "false", "no", "off"} |
| 94 | |
| 95 | |
| 96 | def _env_int(name): |
| 97 | val = os.environ.get(name) |
| 98 | if val is None or val.strip() == "": |
| 99 | return None |
| 100 | return int(val) |
| 101 | |
| 102 | |
| 103 | def _strict_attn(): |
| 104 | return _env_flag("LA_FLASH_STRICT_ATTN", False) |
| 105 | |
| 106 | |
| 107 | def _fallback_to_sdpa(model, requested, reason): |
| 108 | if requested == "sdpa": |
| 109 | raise RuntimeError(f"LA_FLASH_ATTN=sdpa failed: {reason}") from reason |
| 110 | message = f"LA_FLASH_ATTN={requested} is unavailable; falling back to sdpa. Reason: {reason}" |
| 111 | if _strict_attn(): |
| 112 | raise RuntimeError(message) from reason |
| 113 | warnings.warn(message) |
| 114 | _set_llm_mode(model, "sdpa") |
| 115 | model._la_flash_requested_attn_original = requested |
| 116 | model._la_flash_attn_fallback_reason = str(reason) |
| 117 | return "sdpa" |
| 118 | |
| 119 | |
| 120 | # Optional compile for the shared Qwen2 core. This is off by default because the |
| 121 | # hybrid scheduler already varies query/cache shapes and first-call compile cost is high. |
| 122 | MTP_COMPILE = os.environ.get("MTP_COMPILE", "0") == "1" |
| 123 | |
| 124 | # Batch the MoonViT vision encode across a micro-batch's images: pack N images into ONE |
| 125 | # extract_feature. With flash present, MoonViT's varlen cu_seqlens path is block-diagonal per |
| 126 | # image and equivalent to per-image encode. |
| 127 | # Without flash, sdpa builds a dense [1,S,S] mask -> O(S^2) N^2 -> per-image fallback (auto, see |
| 128 | # _vision_is_flash). Default ON; set MTP_BATCH_VISION=0 to force per-image. |
| 129 | BATCH_VISION = os.environ.get("MTP_BATCH_VISION", "1") == "1" |
| 130 | _vision_encode_batch_size = _env_int("LA_FLASH_VISION_ENCODE_BATCH_SIZE") |
| 131 | VISION_ENCODE_BATCH_SIZE = 8 if _vision_encode_batch_size is None else max(0, _vision_encode_batch_size) |
| 132 | |
| 133 | # Batch the per-row box-decode (sample_tokens): run the row-independent logits pipeline |
| 134 | # (rep-penalty / per-row temperature / top_p / top_k / softmax / sample) ONCE over the whole |
| 135 | # [B,6,V] step instead of B times on [1,6,V]; only the variable-length box assembly stays per-row. |
| 136 | # Greedy is BIT-IDENTICAL to the per-row san (argmax, no RNG). Default ON; MTP_BATCH_SAN=0 -> per-row. |
| 137 | BATCH_SAN = os.environ.get("MTP_BATCH_SAN", "1") == "1" |
| 138 | |
| 139 | # Batch the AR repair sampler over [B,1,V]. This shares the exact filtering |
| 140 | # helpers with MTP batching but skips box/ref decoding, so it only replaces the |
| 141 | # repeated stock one-token sample calls. Sampling itself stays row-ordered by |
| 142 | # default to preserve the stock RNG consumption pattern for AR repair. |
| 143 | AR_BATCH_SAN = os.environ.get("AR_BATCH_SAN", "1") == "1" |
| 144 | |
| 145 | _tok = _proc = _model = None |
| 146 | |
| 147 | def _magi_diag(): |
| 148 | lines = [] |
| 149 | try: |
| 150 | import magi_attention |
| 151 | lines.append(f"magi_attention: OK file={getattr(magi_attention, '__file__', None)}") |
| 152 | lines.append(f"magi_attention.__version__={getattr(magi_attention, '__version__', '<missing>')}") |
| 153 | except Exception as e: |
| 154 | lines.append(f"magi_attention: FAIL {type(e).__name__}: {e}") |
| 155 | return "\n".join(lines) |
| 156 | try: |
| 157 | from magi_attention.functional.flex_flash_attn import flex_flash_attn_func |
| 158 | lines.append(f"magi_attention.functional.flex_flash_attn: OK func={flex_flash_attn_func}") |
| 159 | except Exception as e: |
| 160 | lines.append(f"magi_attention.functional.flex_flash_attn: FAIL {type(e).__name__}: {e}") |
| 161 | return "\n".join(lines) |
| 162 | |
| 163 | def _remote_magi_diag(model=None): |
| 164 | lines = [] |
| 165 | try: |
| 166 | if model is not None: |
| 167 | mod = importlib.import_module(type(model.language_model.model).__module__) |
| 168 | else: |
| 169 | # Best effort: if the dynamic module is not imported yet this may fail; |
| 170 | # the post-load diagnostic below will still work. |
| 171 | mod = importlib.import_module("transformers_modules.LocateAnything-3B.modeling_qwen2") |
| 172 | lines.append(f"remote_qwen2_module={getattr(mod, '__file__', None)}") |
| 173 | lines.append(f"remote_qwen2._MAGI_AVAILABLE={getattr(mod, '_MAGI_AVAILABLE', '<missing>')!r}") |
| 174 | lines.append(f"remote_qwen2.flex_flash_attn_func={getattr(mod, 'flex_flash_attn_func', '<missing>')}") |
| 175 | except Exception as e: |
| 176 | lines.append(f"remote_qwen2: diagnostic failed {type(e).__name__}: {e}") |
| 177 | return "\n".join(lines) |
| 178 | |
| 179 | def _attn_class_diag(model): |
| 180 | try: |
| 181 | llm = model.language_model.model |
| 182 | classes = [type(layer.self_attn).__name__ for layer in llm.layers[:4]] |
| 183 | return ( |
| 184 | f"llm._attn_implementation={getattr(llm, '_attn_implementation', None)!r}\n" |
| 185 | f"config._attn_implementation={getattr(llm.config, '_attn_implementation', None)!r}\n" |
| 186 | f"first_attn_classes={classes}" |
| 187 | ) |
| 188 | except Exception as e: |
| 189 | return f"attention class diagnostic failed {type(e).__name__}: {e}" |
| 190 | |
| 191 | |
| 192 | def _set_vision_attention_mode(model): |
| 193 | """Match HF's MoonViT policy: prefer flash_attention_2, then sdpa, then eager.""" |
| 194 | vm = getattr(model, "vision_model", None) |
| 195 | if vm is None: |
| 196 | return None |
| 197 | mod = importlib.import_module(type(vm).__module__) |
| 198 | funcs = getattr(mod, "VL_VISION_ATTENTION_FUNCTIONS", {}) |
| 199 | has_flash = getattr(mod, "flash_attn_varlen_func", None) is not None |
| 200 | requested = VISION_ATTN_MODE |
| 201 | |
| 202 | if requested == "auto": |
| 203 | candidates = ("flash_attention_2", "sdpa", "eager") |
| 204 | else: |
| 205 | candidates = (requested, "flash_attention_2", "sdpa", "eager") |
| 206 | |
| 207 | chosen = None |
| 208 | for candidate in candidates: |
| 209 | if candidate == "flash_attention_2" and not has_flash: |
| 210 | continue |
| 211 | if candidate in funcs: |
| 212 | chosen = candidate |
| 213 | break |
| 214 | if chosen is None: |
| 215 | raise RuntimeError("MoonViT has no supported attention implementation.") |
| 216 | |
| 217 | if requested == "flash_attention_2" and chosen != "flash_attention_2": |
| 218 | warnings.warn("LA_FLASH_VISION_ATTN=flash_attention_2 requested but flash-attn is unavailable; " |
| 219 | f"using {chosen}.") |
| 220 | elif requested not in {"auto", chosen}: |
| 221 | warnings.warn(f"LA_FLASH_VISION_ATTN={requested} is unavailable; using {chosen}.") |
| 222 | |
| 223 | if hasattr(model.config, "vision_config"): |
| 224 | model.config.vision_config._attn_implementation = chosen |
| 225 | try: |
| 226 | vm.config._attn_implementation = chosen |
| 227 | except Exception: |
| 228 | pass |
| 229 | try: |
| 230 | for block in vm.encoder.blocks: |
| 231 | block.attn_implementation = chosen |
| 232 | except Exception as exc: |
| 233 | raise RuntimeError("Failed to configure MoonViT attention implementation.") from exc |
| 234 | model._la_flash_vision_attn = chosen |
| 235 | return chosen |
| 236 | |
| 237 | |
| 238 | def load(): |
| 239 | """Lazy model load with HF remote-code semantics plus release backends. |
| 240 | |
| 241 | The text decoder is pinned to one of sdpa/eager/magi/la_flash. MoonViT is |
| 242 | configured independently and follows the HF policy: flash_attention_2 when |
| 243 | flash-attn is importable, otherwise sdpa, otherwise eager. |
| 244 | """ |
| 245 | global _tok, _proc, _model |
| 246 | if _model is None: |
| 247 | _tok = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True) |
| 248 | _proc = AutoProcessor.from_pretrained(MODEL, trust_remote_code=True) |
| 249 | attn_impl = REMOTE_ATTN_MODE |
| 250 | if ATTN_MODE == "magi" and os.environ.get("LA_FLASH_DEBUG", "0") != "0": |
| 251 | print("LA Flash magi pre-load diagnostic:", flush=True) |
| 252 | print(_magi_diag(), flush=True) |
| 253 | _model = AutoModel.from_pretrained(MODEL, torch_dtype=DT, trust_remote_code=True, |
| 254 | attn_implementation=attn_impl).to(DEV).eval() |
| 255 | _set_vision_attention_mode(_model) |
| 256 | actual_attn = getattr(_model.language_model.model, "_attn_implementation", None) |
| 257 | if ATTN_MODE == "magi" and os.environ.get("LA_FLASH_DEBUG", "0") != "0": |
| 258 | print("LA Flash magi post-load diagnostic:", flush=True) |
| 259 | print(_remote_magi_diag(_model), flush=True) |
| 260 | print(_attn_class_diag(_model), flush=True) |
| 261 | if ATTN_MODE == "magi": |
| 262 | try: |
| 263 | qwen2_mod = importlib.import_module(type(_model.language_model.model).__module__) |
| 264 | if not getattr(qwen2_mod, "_MAGI_AVAILABLE", False): |
| 265 | raise RuntimeError( |
| 266 | "remote module reports _MAGI_AVAILABLE=False.\n" |
| 267 | f"{_remote_magi_diag(_model)}\n{_magi_diag()}" |
| 268 | ) |
| 269 | first_attn = type(_model.language_model.model.layers[0].self_attn).__name__ |
| 270 | if actual_attn != "sdpa" or first_attn != "_BatchedMagiAttention": |
| 271 | _set_llm_mode(_model, "magi") |
| 272 | actual_attn = getattr(_model.language_model.model, "_attn_implementation", None) |
| 273 | first_attn = type(_model.language_model.model.layers[0].self_attn).__name__ |
| 274 | if os.environ.get("LA_FLASH_DEBUG", "0") != "0": |
| 275 | print("LA Flash magi post-swap diagnostic:", flush=True) |
| 276 | print(_attn_class_diag(_model), flush=True) |
| 277 | if actual_attn != "sdpa" or first_attn != "_BatchedMagiAttention": |
| 278 | raise RuntimeError( |
| 279 | "batched magi attention did not activate. " |
| 280 | f"actual_attn={actual_attn!r}; first_attn={first_attn!r}; " |
| 281 | f"{_remote_magi_diag(_model)}; {_attn_class_diag(_model)}" |
| 282 | ) |
| 283 | _model._la_flash_requested_attn = "magi" |
| 284 | except Exception as exc: |
| 285 | _fallback_to_sdpa(_model, "magi", exc) |
| 286 | else: |
| 287 | try: |
| 288 | _set_llm_mode(_model, ATTN_MODE) # decode-safe mask plumbing for sdpa/eager/la_flash |
| 289 | except Exception as exc: |
| 290 | _fallback_to_sdpa(_model, ATTN_MODE, exc) |
| 291 | if MTP_COMPILE: |
| 292 | _maybe_compile(_model) |
| 293 | return _tok, _proc, _model |
| 294 | |
| 295 | |
| 296 | def _maybe_compile(model): |
| 297 | """Compile the shared Qwen2Model core (base.forward). It backs BOTH prefill (called directly) |
| 298 | and decode (language_model.forward -> self.model). lm_head + MoonViT left eager. dynamic=True |
| 299 | so the varying decode S/kvlen don't trigger a recompile storm. No-op + warning if triton is |
| 300 | missing (inductor needs it on GPU). First call pays the compile cost (~42s warm / ~187s cold).""" |
| 301 | try: |
| 302 | import triton # noqa: F401 |
| 303 | except Exception: |
| 304 | warnings.warn("MTP_COMPILE set but triton is unavailable; running without torch.compile.") |
| 305 | return |
| 306 | import torch._dynamo as _dyn |
| 307 | _dyn.config.cache_size_limit = max(_dyn.config.cache_size_limit, 64) |
| 308 | base = model.language_model.model |
| 309 | if not getattr(base, "_mtp_compiled", False): |
| 310 | base.forward = torch.compile(base.forward, dynamic=True) |
| 311 | base._mtp_compiled = True |
| 312 | |
| 313 | |
| 314 | def build_batched_magi_attention_class(mod): |
| 315 | """Build a Qwen2 attention subclass backed by Magi's flex_flash_attn. |
| 316 | |
| 317 | The official LocateAnything ``Qwen2MagiAttention`` asserts ``bsz == 1`` and |
| 318 | relies on ``Qwen2Model._attn_implementation == "magi"`` to build a single |
| 319 | sample range plan. For release batch inference the hybrid scheduler passes |
| 320 | a batched Magi range plan directly to this layer; a 4D-mask conversion path |
| 321 | remains as a compatibility fallback. |
| 322 | """ |
| 323 | flex_flash_attn_func = getattr(mod, "flex_flash_attn_func", None) |
| 324 | if flex_flash_attn_func is None: |
| 325 | try: |
| 326 | from magi_attention.functional.flex_flash_attn import flex_flash_attn_func |
| 327 | except Exception as exc: |
| 328 | raise RuntimeError( |
| 329 | "LA_FLASH_ATTN=magi requires " |
| 330 | "magi_attention.functional.flex_flash_attn.flex_flash_attn_func." |
| 331 | ) from exc |
| 332 | |
| 333 | FULL, CAUSAL = 0, 1 |
| 334 | causal_plan_cache = {} |
| 335 | try: |
| 336 | magi_params = set(inspect.signature(flex_flash_attn_func).parameters) |
| 337 | except (TypeError, ValueError): |
| 338 | magi_params = set() |
| 339 | supports_disable_fwd_atomic = "disable_fwd_atomic_reduction" in magi_params |
| 340 | |
| 341 | def _disjoint_q_ranges(q_ranges): |
| 342 | seen = set() |
| 343 | for start, end in q_ranges: |
| 344 | key = (int(start), int(end)) |
| 345 | if key in seen: |
| 346 | return False |
| 347 | seen.add(key) |
| 348 | return True |
| 349 | |
| 350 | def _plan_disjoint_q_ranges(plan): |
| 351 | cached = plan.get("_la_flash_disjoint_q_ranges") |
| 352 | if cached is not None: |
| 353 | return bool(cached) |
| 354 | q_ranges = plan["q_ranges"].detach().to(device="cpu", dtype=torch.int32).tolist() |
| 355 | disjoint = _disjoint_q_ranges(q_ranges) |
| 356 | try: |
| 357 | plan["_la_flash_disjoint_q_ranges"] = disjoint |
| 358 | except Exception: |
| 359 | pass |
| 360 | return disjoint |
| 361 | |
| 362 | def _tensor_plan(q_ranges, k_ranges, types, device): |
| 363 | return { |
| 364 | "q_ranges": torch.tensor(q_ranges, dtype=torch.int32, device=device).contiguous(), |
| 365 | "k_ranges": torch.tensor(k_ranges, dtype=torch.int32, device=device).contiguous(), |
| 366 | "attn_type_map": torch.tensor(types, dtype=torch.int32, device=device).contiguous(), |
| 367 | "_la_flash_disjoint_q_ranges": _disjoint_q_ranges(q_ranges), |
| 368 | } |
| 369 | |
| 370 | def _offset_plan(plan, q_offset, k_offset): |
| 371 | return ( |
| 372 | (plan["q_ranges"] + int(q_offset)).tolist(), |
| 373 | (plan["k_ranges"] + int(k_offset)).tolist(), |
| 374 | plan["attn_type_map"].tolist(), |
| 375 | ) |
| 376 | |
| 377 | def _causal_plan(bsz, q_len, kv_seq_len, device): |
| 378 | key = (int(bsz), int(q_len), int(kv_seq_len), device.type, device.index) |
| 379 | cached = causal_plan_cache.get(key) |
| 380 | if cached is not None: |
| 381 | return cached |
| 382 | q_ranges, k_ranges, types = [], [], [] |
| 383 | for b in range(int(bsz)): |
| 384 | q_base = b * int(q_len) |
| 385 | k_base = b * int(kv_seq_len) |
| 386 | q_ranges.append([q_base, q_base + int(q_len)]) |
| 387 | k_ranges.append([k_base, k_base + int(kv_seq_len)]) |
| 388 | types.append(CAUSAL) |
| 389 | plan = _tensor_plan(q_ranges, k_ranges, types, device) |
| 390 | plan.update( |
| 391 | { |
| 392 | "flash_cu_seqlens_q": torch.arange( |
| 393 | 0, |
| 394 | (int(bsz) + 1) * int(q_len), |
| 395 | int(q_len), |
| 396 | dtype=torch.int32, |
| 397 | device=device, |
| 398 | ), |
| 399 | "flash_cu_seqlens_k": torch.arange( |
| 400 | 0, |
| 401 | (int(bsz) + 1) * int(kv_seq_len), |
| 402 | int(kv_seq_len), |
| 403 | dtype=torch.int32, |
| 404 | device=device, |
| 405 | ), |
| 406 | "flash_causal": True, |
| 407 | } |
| 408 | ) |
| 409 | causal_plan_cache[key] = plan |
| 410 | return plan |
| 411 | |
| 412 | def _row_segments(row): |
| 413 | idx = np.flatnonzero(row) |
| 414 | if idx.size == 0: |
| 415 | return ((0, 1),) |
| 416 | split = np.flatnonzero(np.diff(idx) > 1) + 1 |
| 417 | starts = np.concatenate((idx[:1], idx[split])) |
| 418 | ends = np.concatenate((idx[split - 1], idx[-1:])) + 1 |
| 419 | return tuple((int(s), int(e)) for s, e in zip(starts, ends)) |
| 420 | |
| 421 | def _visible_from_4d_mask(attention_mask, kv_seq_len): |
| 422 | mask = attention_mask[:, :, :, :kv_seq_len] |
| 423 | if mask.dtype == torch.bool: |
| 424 | return mask[:, 0].detach().to(device="cpu", dtype=torch.bool).contiguous() |
| 425 | mask_cpu = mask[:, 0].detach().to(device="cpu").contiguous() |
| 426 | if getattr(attention_mask, "_la_flash_visible_mask", False): |
| 427 | return (mask_cpu > 0).to(dtype=torch.bool) |
| 428 | |
| 429 | max_value = float(mask_cpu.max().item()) if mask_cpu.numel() else 0.0 |
| 430 | min_value = float(mask_cpu.min().item()) if mask_cpu.numel() else 0.0 |
| 431 | if max_value > 0.0 and min_value >= 0.0: |
| 432 | return (mask_cpu > 0).to(dtype=torch.bool) |
| 433 | return (mask_cpu >= 0).to(dtype=torch.bool) |
| 434 | |
| 435 | def _plan_from_visible_mask(attention_mask, bsz, q_len, kv_seq_len, device): |
| 436 | cache_key = (int(bsz), int(q_len), int(kv_seq_len), device.type, device.index) |
| 437 | cached = getattr(attention_mask, "_la_flash_magi_plan", None) |
| 438 | if cached is not None and cached[0] == cache_key: |
| 439 | return cached[1] |
| 440 | |
| 441 | visible = _visible_from_4d_mask(attention_mask, int(kv_seq_len)).numpy() |
| 442 | q_ranges, k_ranges, types = [], [], [] |
| 443 | for b in range(int(bsz)): |
| 444 | q_base = b * int(q_len) |
| 445 | k_base = b * int(kv_seq_len) |
| 446 | run_start = 0 |
| 447 | run_segments = _row_segments(visible[b, 0]) |
| 448 | for q in range(1, int(q_len)): |
| 449 | segments = _row_segments(visible[b, q]) |
| 450 | if segments == run_segments: |
| 451 | continue |
| 452 | for start, end in run_segments: |
| 453 | q_ranges.append([q_base + run_start, q_base + q]) |
| 454 | k_ranges.append([k_base + start, k_base + end]) |
| 455 | types.append(FULL) |
| 456 | run_start = q |
| 457 | run_segments = segments |
| 458 | for start, end in run_segments: |
| 459 | q_ranges.append([q_base + run_start, q_base + int(q_len)]) |
| 460 | k_ranges.append([k_base + start, k_base + end]) |
| 461 | types.append(FULL) |
| 462 | |
| 463 | plan = _tensor_plan(q_ranges, k_ranges, types, device) |
| 464 | try: |
| 465 | attention_mask._la_flash_magi_plan = (cache_key, plan) |
| 466 | except Exception: |
| 467 | pass |
| 468 | return plan |
| 469 | |
| 470 | def _plan_from_magi_dict(attention_mask, bsz, q_len, kv_seq_len, device): |
| 471 | if int(bsz) == 1: |
| 472 | return attention_mask |
| 473 | q_ranges, k_ranges, types = [], [], [] |
| 474 | for b in range(int(bsz)): |
| 475 | qs, ks, ts = _offset_plan( |
| 476 | attention_mask, |
| 477 | q_offset=b * int(q_len), |
| 478 | k_offset=b * int(kv_seq_len), |
| 479 | ) |
| 480 | q_ranges.extend(qs) |
| 481 | k_ranges.extend(ks) |
| 482 | types.extend(ts) |
| 483 | return _tensor_plan(q_ranges, k_ranges, types, device) |
| 484 | |
| 485 | def _magi_plan(attention_mask, bsz, q_len, kv_seq_len, device): |
| 486 | if isinstance(attention_mask, dict): |
| 487 | if attention_mask.get("_la_flash_batched", False): |
| 488 | return attention_mask |
| 489 | return _plan_from_magi_dict(attention_mask, bsz, q_len, kv_seq_len, device) |
| 490 | if attention_mask is None: |
| 491 | return _causal_plan(bsz, q_len, kv_seq_len, device) |
| 492 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
| 493 | raise ValueError( |
| 494 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " |
| 495 | f"but is {attention_mask.size()}" |
| 496 | ) |
| 497 | return _plan_from_visible_mask(attention_mask, bsz, q_len, kv_seq_len, device) |
| 498 | |
| 499 | class _BatchedMagiAttention(mod.Qwen2Attention): |
| 500 | """MagiAttention path with true batch inference via packed token ranges.""" |
| 501 | |
| 502 | def forward( |
| 503 | self, |
| 504 | hidden_states: torch.Tensor, |
| 505 | attention_mask=None, |
| 506 | position_ids=None, |
| 507 | past_key_value=None, |
| 508 | output_attentions=False, |
| 509 | use_cache=False, |
| 510 | **kwargs, |
| 511 | ): |
| 512 | if output_attentions: |
| 513 | raise NotImplementedError("MagiAttention does not support output_attentions=True") |
| 514 | |
| 515 | bsz, q_len, _ = hidden_states.size() |
| 516 | query_states = self.q_proj(hidden_states) |
| 517 | key_states = self.k_proj(hidden_states) |
| 518 | value_states = self.v_proj(hidden_states) |
| 519 | |
| 520 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 521 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| 522 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| 523 | |
| 524 | kv_seq_len = key_states.shape[-2] |
| 525 | if past_key_value is not None: |
| 526 | if self.layer_idx is None: |
| 527 | raise ValueError( |
| 528 | f"The cache structure has changed since version v4.36. If you are using " |
| 529 | f"{self.__class__.__name__} for auto-regressive decoding with k/v caching, " |
| 530 | "please initialize the attention class with a layer index." |
| 531 | ) |
| 532 | kv_seq_len += past_key_value.get_seq_length(self.layer_idx) |
| 533 | |
| 534 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| 535 | query_states, key_states = mod.apply_rotary_pos_emb( |
| 536 | query_states, key_states, cos, sin, position_ids) |
| 537 | |
| 538 | if past_key_value is not None: |
| 539 | cache_kwargs = {"sin": sin, "cos": cos} |
| 540 | key_states, value_states = past_key_value.update( |
| 541 | key_states, value_states, self.layer_idx, cache_kwargs) |
| 542 | |
| 543 | kv_seq_len = key_states.shape[-2] |
| 544 | plan = _magi_plan(attention_mask, bsz, q_len, kv_seq_len, query_states.device) |
| 545 | magi_extra_kwargs = {} |
| 546 | if supports_disable_fwd_atomic: |
| 547 | magi_extra_kwargs["disable_fwd_atomic_reduction"] = ( |
| 548 | (not self.training) and _plan_disjoint_q_ranges(plan) |
| 549 | ) |
| 550 | |
| 551 | query_states = query_states.transpose(1, 2).reshape( |
| 552 | bsz * q_len, self.num_heads, self.head_dim).contiguous() |
| 553 | key_states = key_states.transpose(1, 2).reshape( |
| 554 | bsz * kv_seq_len, self.num_key_value_heads, self.head_dim).contiguous() |
| 555 | value_states = value_states.transpose(1, 2).reshape( |
| 556 | bsz * kv_seq_len, self.num_key_value_heads, self.head_dim).contiguous() |
| 557 | |
| 558 | attn_output, _ = flex_flash_attn_func( |
| 559 | query_states, |
| 560 | key_states, |
| 561 | value_states, |
| 562 | q_ranges=plan["q_ranges"], |
| 563 | k_ranges=plan["k_ranges"], |
| 564 | attn_type_map=plan["attn_type_map"], |
| 565 | softmax_scale=getattr(self, "softmax_scale", self.head_dim ** -0.5), |
| 566 | softcap=0.0, |
| 567 | deterministic=False, |
| 568 | **magi_extra_kwargs, |
| 569 | ) |
| 570 | attn_output = attn_output.view(bsz, q_len, self.hidden_size) |
| 571 | attn_output = self.o_proj(attn_output) |
| 572 | return attn_output, None, past_key_value |
| 573 | |
| 574 | return _BatchedMagiAttention |
| 575 | |
| 576 | |
| 577 | def build_la_flash_attention_class(mod): |
| 578 | """Build a Qwen2 attention subclass backed by LA Flash sparse ranges.""" |
| 579 | try: |
| 580 | from kernel_utils import is_available, range_attention |
| 581 | except Exception as exc: |
| 582 | raise RuntimeError( |
| 583 | "LA_FLASH_ATTN=la_flash requires kernel_utils and FlashAttention." |
| 584 | ) from exc |
| 585 | if not is_available(): |
| 586 | raise RuntimeError( |
| 587 | "LA_FLASH_ATTN=la_flash requires flash_attn.flash_attn_varlen_func." |
| 588 | ) |
| 589 | |
| 590 | FULL, CAUSAL = 0, 1 |
| 591 | causal_plan_cache = {} |
| 592 | |
| 593 | def _tensor_plan(q_ranges, k_ranges, types, device): |
| 594 | max_q_len = max((int(end) - int(start) for start, end in q_ranges), default=0) |
| 595 | max_k_len = max((int(end) - int(start) for start, end in k_ranges), default=0) |
| 596 | plan = { |
| 597 | "q_ranges": torch.tensor(q_ranges, dtype=torch.int32, device=device).contiguous(), |
| 598 | "k_ranges": torch.tensor(k_ranges, dtype=torch.int32, device=device).contiguous(), |
| 599 | "attn_type_map": torch.tensor(types, dtype=torch.int32, device=device).contiguous(), |
| 600 | "max_q_len": max_q_len, |
| 601 | "max_k_len": max_k_len, |
| 602 | } |
| 603 | plan.update(_la_flash_group_plan_tensors(q_ranges, types, device)) |
| 604 | return plan |
| 605 | |
| 606 | def _offset_plan(plan, q_offset, k_offset): |
| 607 | return ( |
| 608 | (plan["q_ranges"] + int(q_offset)).tolist(), |
| 609 | (plan["k_ranges"] + int(k_offset)).tolist(), |
| 610 | plan["attn_type_map"].tolist(), |
| 611 | ) |
| 612 | |
| 613 | def _causal_plan(bsz, q_len, kv_seq_len, device): |
| 614 | key = (int(bsz), int(q_len), int(kv_seq_len), device.type, device.index) |
| 615 | cached = causal_plan_cache.get(key) |
| 616 | if cached is not None: |
| 617 | return cached |
| 618 | q_ranges, k_ranges, types = [], [], [] |
| 619 | for b in range(int(bsz)): |
| 620 | q_base = b * int(q_len) |
| 621 | k_base = b * int(kv_seq_len) |
| 622 | q_ranges.append([q_base, q_base + int(q_len)]) |
| 623 | k_ranges.append([k_base, k_base + int(kv_seq_len)]) |
| 624 | types.append(CAUSAL) |
| 625 | plan = _tensor_plan(q_ranges, k_ranges, types, device) |
| 626 | plan.update( |
| 627 | { |
| 628 | "flash_cu_seqlens_q": torch.arange( |
| 629 | 0, |
| 630 | (int(bsz) + 1) * int(q_len), |
| 631 | int(q_len), |
| 632 | dtype=torch.int32, |
| 633 | device=device, |
| 634 | ), |
| 635 | "flash_cu_seqlens_k": torch.arange( |
| 636 | 0, |
| 637 | (int(bsz) + 1) * int(kv_seq_len), |
| 638 | int(kv_seq_len), |
| 639 | dtype=torch.int32, |
| 640 | device=device, |
| 641 | ), |
| 642 | "flash_causal": True, |
| 643 | } |
| 644 | ) |
| 645 | causal_plan_cache[key] = plan |
| 646 | return plan |
| 647 | |
| 648 | def _row_segments(row): |
| 649 | idx = np.flatnonzero(row) |
| 650 | if idx.size == 0: |
| 651 | return ((0, 1),) |
| 652 | split = np.flatnonzero(np.diff(idx) > 1) + 1 |
| 653 | starts = np.concatenate((idx[:1], idx[split])) |
| 654 | ends = np.concatenate((idx[split - 1], idx[-1:])) + 1 |
| 655 | return tuple((int(s), int(e)) for s, e in zip(starts, ends)) |
| 656 | |
| 657 | def _visible_from_4d_mask(attention_mask, kv_seq_len): |
| 658 | mask = attention_mask[:, :, :, :kv_seq_len] |
| 659 | if mask.dtype == torch.bool: |
| 660 | return mask[:, 0].detach().to(device="cpu", dtype=torch.bool).contiguous() |
| 661 | mask_cpu = mask[:, 0].detach().to(device="cpu").contiguous() |
| 662 | if getattr(attention_mask, "_la_flash_visible_mask", False): |
| 663 | return (mask_cpu > 0).to(dtype=torch.bool) |
| 664 | |
| 665 | max_value = float(mask_cpu.max().item()) if mask_cpu.numel() else 0.0 |
| 666 | min_value = float(mask_cpu.min().item()) if mask_cpu.numel() else 0.0 |
| 667 | if max_value > 0.0 and min_value >= 0.0: |
| 668 | return (mask_cpu > 0).to(dtype=torch.bool) |
| 669 | return (mask_cpu >= 0).to(dtype=torch.bool) |
| 670 | |
| 671 | def _prefix_len(row): |
| 672 | idx = np.flatnonzero(row) |
| 673 | if idx.size == 0: |
| 674 | return None |
| 675 | end = int(idx[-1]) + 1 |
| 676 | if not bool(row[:end].all()) or bool(row[end:].any()): |
| 677 | return None |
| 678 | return end |
| 679 | |
| 680 | def _causal_plan_from_visible(visible, bsz, q_len, kv_seq_len, device): |
| 681 | q_ranges, k_ranges, types = [], [], [] |
| 682 | packed_flash = True |
| 683 | for b in range(int(bsz)): |
| 684 | first_len = _prefix_len(visible[b, 0]) |
| 685 | if first_len is None: |
| 686 | return None |
| 687 | valid_len = int(first_len) + int(q_len) - 1 |
| 688 | if valid_len < int(q_len) or valid_len > int(kv_seq_len): |
| 689 | return None |
| 690 | for q in range(int(q_len)): |
| 691 | row_len = _prefix_len(visible[b, q]) |
| 692 | expected = valid_len - int(q_len) + q + 1 |
| 693 | if row_len != expected: |
| 694 | return None |
| 695 | q_base = b * int(q_len) |
| 696 | k_base = b * int(kv_seq_len) |
| 697 | q_ranges.append([q_base, q_base + int(q_len)]) |
| 698 | k_ranges.append([k_base, k_base + valid_len]) |
| 699 | types.append(CAUSAL) |
| 700 | packed_flash = packed_flash and valid_len == int(kv_seq_len) |
| 701 | |
| 702 | plan = _tensor_plan(q_ranges, k_ranges, types, device) |
| 703 | plan["_la_flash_disjoint_q_ranges"] = True |
| 704 | if packed_flash: |
| 705 | plan.update( |
| 706 | { |
| 707 | "flash_cu_seqlens_q": torch.arange( |
| 708 | 0, |
| 709 | (int(bsz) + 1) * int(q_len), |
| 710 | int(q_len), |
| 711 | dtype=torch.int32, |
| 712 | device=device, |
| 713 | ), |
| 714 | "flash_cu_seqlens_k": torch.arange( |
| 715 | 0, |
| 716 | (int(bsz) + 1) * int(kv_seq_len), |
| 717 | int(kv_seq_len), |
| 718 | dtype=torch.int32, |
| 719 | device=device, |
| 720 | ), |
| 721 | "flash_causal": True, |
| 722 | } |
| 723 | ) |
| 724 | return plan |
| 725 | |
| 726 | def _plan_from_visible_mask(attention_mask, bsz, q_len, kv_seq_len, device): |
| 727 | cache_key = (int(bsz), int(q_len), int(kv_seq_len), device.type, device.index, "la_flash") |
| 728 | cached = getattr(attention_mask, "_la_flash_range_plan", None) |
| 729 | if cached is not None and cached[0] == cache_key: |
| 730 | return cached[1] |
| 731 | |
| 732 | visible = _visible_from_4d_mask(attention_mask, int(kv_seq_len)).numpy() |
| 733 | plan = _causal_plan_from_visible(visible, bsz, q_len, kv_seq_len, device) |
| 734 | if plan is not None: |
| 735 | try: |
| 736 | attention_mask._la_flash_range_plan = (cache_key, plan) |
| 737 | except Exception: |
| 738 | pass |
| 739 | return plan |
| 740 | |
| 741 | q_ranges, k_ranges, types = [], [], [] |
| 742 | for b in range(int(bsz)): |
| 743 | q_base = b * int(q_len) |
| 744 | k_base = b * int(kv_seq_len) |
| 745 | run_start = 0 |
| 746 | run_segments = _row_segments(visible[b, 0]) |
| 747 | for q in range(1, int(q_len)): |
| 748 | segments = _row_segments(visible[b, q]) |
| 749 | if segments == run_segments: |
| 750 | continue |
| 751 | for start, end in run_segments: |
| 752 | q_ranges.append([q_base + run_start, q_base + q]) |
| 753 | k_ranges.append([k_base + start, k_base + end]) |
| 754 | types.append(FULL) |
| 755 | run_start = q |
| 756 | run_segments = segments |
| 757 | for start, end in run_segments: |
| 758 | q_ranges.append([q_base + run_start, q_base + int(q_len)]) |
| 759 | k_ranges.append([k_base + start, k_base + end]) |
| 760 | types.append(FULL) |
| 761 | |
| 762 | plan = _tensor_plan(q_ranges, k_ranges, types, device) |
| 763 | try: |
| 764 | attention_mask._la_flash_range_plan = (cache_key, plan) |
| 765 | except Exception: |
| 766 | pass |
| 767 | return plan |
| 768 | |
| 769 | def _plan_from_magi_dict(attention_mask, bsz, q_len, kv_seq_len, device): |
| 770 | if int(bsz) == 1: |
| 771 | return attention_mask |
| 772 | q_ranges, k_ranges, types = [], [], [] |
| 773 | for b in range(int(bsz)): |
| 774 | qs, ks, ts = _offset_plan( |
| 775 | attention_mask, |
| 776 | q_offset=b * int(q_len), |
| 777 | k_offset=b * int(kv_seq_len), |
| 778 | ) |
| 779 | q_ranges.extend(qs) |
| 780 | k_ranges.extend(ks) |
| 781 | types.extend(ts) |
| 782 | return _tensor_plan(q_ranges, k_ranges, types, device) |
| 783 | |
| 784 | def _range_plan(attention_mask, bsz, q_len, kv_seq_len, device): |
| 785 | if isinstance(attention_mask, dict): |
| 786 | if attention_mask.get("_la_flash_batched", False): |
| 787 | return attention_mask |
| 788 | return _plan_from_magi_dict(attention_mask, bsz, q_len, kv_seq_len, device) |
| 789 | if attention_mask is None: |
| 790 | return _causal_plan(bsz, q_len, kv_seq_len, device) |
| 791 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
| 792 | raise ValueError( |
| 793 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " |
| 794 | f"but is {attention_mask.size()}" |
| 795 | ) |
| 796 | return _plan_from_visible_mask(attention_mask, bsz, q_len, kv_seq_len, device) |
| 797 | |
| 798 | class _LaFlashAttention(mod.Qwen2Attention): |
| 799 | """Range-plan attention path backed by FlashAttention sparse ranges.""" |
| 800 | |
| 801 | def forward( |
| 802 | self, |
| 803 | hidden_states: torch.Tensor, |
| 804 | attention_mask=None, |
| 805 | position_ids=None, |
| 806 | past_key_value=None, |
| 807 | output_attentions=False, |
| 808 | use_cache=False, |
| 809 | **kwargs, |
| 810 | ): |
| 811 | if output_attentions: |
| 812 | raise NotImplementedError("LA Flash attention does not support output_attentions=True") |
| 813 | |
| 814 | bsz, q_len, _ = hidden_states.size() |
| 815 | query_states = self.q_proj(hidden_states) |
| 816 | key_states = self.k_proj(hidden_states) |
| 817 | value_states = self.v_proj(hidden_states) |
| 818 | |
| 819 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 820 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| 821 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| 822 | |
| 823 | kv_seq_len = key_states.shape[-2] |
| 824 | if past_key_value is not None: |
| 825 | if self.layer_idx is None: |
| 826 | raise ValueError( |
| 827 | f"The cache structure has changed since version v4.36. If you are using " |
| 828 | f"{self.__class__.__name__} for auto-regressive decoding with k/v caching, " |
| 829 | "please initialize the attention class with a layer index." |
| 830 | ) |
| 831 | kv_seq_len += past_key_value.get_seq_length(self.layer_idx) |
| 832 | |
| 833 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| 834 | query_states, key_states = mod.apply_rotary_pos_emb( |
| 835 | query_states, key_states, cos, sin, position_ids) |
| 836 | |
| 837 | if past_key_value is not None: |
| 838 | cache_kwargs = {"sin": sin, "cos": cos} |
| 839 | key_states, value_states = past_key_value.update( |
| 840 | key_states, value_states, self.layer_idx, cache_kwargs) |
| 841 | |
| 842 | kv_seq_len = key_states.shape[-2] |
| 843 | dense_backend = os.environ.get("LA_FLASH_DENSE_BACKEND", "sdpa").strip().lower() |
| 844 | if dense_backend == "sdpa" and not isinstance(attention_mask, dict): |
| 845 | dense_key_states = mod.repeat_kv(key_states, self.num_key_value_groups) |
| 846 | dense_value_states = mod.repeat_kv(value_states, self.num_key_value_groups) |
| 847 | if attention_mask is not None: |
| 848 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
| 849 | raise ValueError( |
| 850 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " |
| 851 | f"but is {attention_mask.size()}" |
| 852 | ) |
| 853 | query_for_sdpa = query_states.contiguous() |
| 854 | key_for_sdpa = dense_key_states.contiguous() |
| 855 | value_for_sdpa = dense_value_states.contiguous() |
| 856 | is_causal = False |
| 857 | elif past_key_value is None: |
| 858 | query_for_sdpa = query_states |
| 859 | key_for_sdpa = dense_key_states |
| 860 | value_for_sdpa = dense_value_states |
| 861 | is_causal = bool(self.is_causal and q_len > 1) |
| 862 | else: |
| 863 | query_for_sdpa = key_for_sdpa = value_for_sdpa = None |
| 864 | is_causal = False |
| 865 | if query_for_sdpa is not None: |
| 866 | attn_output = torch.nn.functional.scaled_dot_product_attention( |
| 867 | query_for_sdpa, |
| 868 | key_for_sdpa, |
| 869 | value_for_sdpa, |
| 870 | attn_mask=attention_mask, |
| 871 | dropout_p=self.attention_dropout if self.training else 0.0, |
| 872 | is_causal=is_causal, |
| 873 | ) |
| 874 | attn_output = attn_output.transpose(1, 2).contiguous() |
| 875 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
| 876 | attn_output = self.o_proj(attn_output) |
| 877 | return attn_output, None, past_key_value |
| 878 | |
| 879 | plan = _range_plan(attention_mask, bsz, q_len, kv_seq_len, query_states.device) |
| 880 | |
| 881 | query_states = query_states.transpose(1, 2).reshape( |
| 882 | bsz * q_len, self.num_heads, self.head_dim).contiguous() |
| 883 | key_states = key_states.transpose(1, 2).reshape( |
| 884 | bsz * kv_seq_len, self.num_key_value_heads, self.head_dim).contiguous() |
| 885 | value_states = value_states.transpose(1, 2).reshape( |
| 886 | bsz * kv_seq_len, self.num_key_value_heads, self.head_dim).contiguous() |
| 887 | |
| 888 | attn_output = range_attention( |
| 889 | query_states, |
| 890 | key_states, |
| 891 | value_states, |
| 892 | plan["q_ranges"], |
| 893 | plan["k_ranges"], |
| 894 | plan["attn_type_map"], |
| 895 | getattr(self, "softmax_scale", self.head_dim ** -0.5), |
| 896 | segment_offsets=plan.get("segment_offsets"), |
| 897 | group_q_ranges=plan.get("group_q_ranges"), |
| 898 | group_attn_type_map=plan.get("group_attn_type_map"), |
| 899 | max_q_len=plan.get("max_q_len"), |
| 900 | max_k_len=plan.get("max_k_len"), |
| 901 | flash_cu_seqlens_q=plan.get("flash_cu_seqlens_q"), |
| 902 | flash_cu_seqlens_k=plan.get("flash_cu_seqlens_k"), |
| 903 | flash_causal=plan.get("flash_causal"), |
| 904 | disjoint_q_ranges=plan.get("_la_flash_disjoint_q_ranges"), |
| 905 | ) |
| 906 | attn_output = attn_output.view(bsz, q_len, self.hidden_size) |
| 907 | attn_output = self.o_proj(attn_output) |
| 908 | return attn_output, None, past_key_value |
| 909 | |
| 910 | return _LaFlashAttention |
| 911 | |
| 912 | |
| 913 | def _is_magi_plan(obj): |
| 914 | return isinstance(obj, dict) and { |
| 915 | "q_ranges", |
| 916 | "k_ranges", |
| 917 | "attn_type_map", |
| 918 | }.issubset(obj.keys()) |
| 919 | |
| 920 | |
| 921 | def _la_flash_group_plan_tensors(q_ranges, types, device): |
| 922 | """Group consecutive Magi range entries that share the same query span. |
| 923 | |
| 924 | Magi-style plans may represent one query span with multiple disjoint key |
| 925 | spans. LA Flash consumes those as one FlashAttention-backed softmax group. |
| 926 | """ |
| 927 | if not q_ranges: |
| 928 | return { |
| 929 | "group_q_ranges": torch.empty((0, 2), dtype=torch.int32, device=device), |
| 930 | "segment_offsets": torch.zeros((1,), dtype=torch.int32, device=device), |
| 931 | "group_attn_type_map": torch.empty((0,), dtype=torch.int32, device=device), |
| 932 | } |
| 933 | |
| 934 | grouped_q, grouped_types, offsets = [], [], [0] |
| 935 | last_q = None |
| 936 | last_type = None |
| 937 | for idx, (q_range, attn_type) in enumerate(zip(q_ranges, types)): |
| 938 | key = (int(q_range[0]), int(q_range[1])) |
| 939 | attn_type = int(attn_type) |
| 940 | if last_q is None: |
| 941 | grouped_q.append([key[0], key[1]]) |
| 942 | grouped_types.append(attn_type) |
| 943 | last_q = key |
| 944 | last_type = attn_type |
| 945 | continue |
| 946 | if key == last_q and attn_type == last_type: |
| 947 | continue |
| 948 | offsets.append(idx) |
| 949 | grouped_q.append([key[0], key[1]]) |
| 950 | grouped_types.append(attn_type) |
| 951 | last_q = key |
| 952 | last_type = attn_type |
| 953 | offsets.append(len(q_ranges)) |
| 954 | |
| 955 | return { |
| 956 | "group_q_ranges": torch.tensor(grouped_q, dtype=torch.int32, device=device).contiguous(), |
| 957 | "segment_offsets": torch.tensor(offsets, dtype=torch.int32, device=device).contiguous(), |
| 958 | "group_attn_type_map": torch.tensor(grouped_types, dtype=torch.int32, device=device).contiguous(), |
| 959 | "max_q_len": max((end - start for start, end in grouped_q), default=0), |
| 960 | } |
| 961 | |
| 962 | |
| 963 | def _record_sparse_plan_stats(model, q_ranges, k_ranges, types): |
| 964 | if os.environ.get("LA_FLASH_PLAN_STATS", "0") != "1": |
| 965 | return |
| 966 | stats = getattr(model, "_la_flash_sparse_plan_stats", None) |
| 967 | if stats is None: |
| 968 | stats = { |
| 969 | "calls": 0, |
| 970 | "ranges": 0, |
| 971 | "q_tokens": 0, |
| 972 | "k_tokens": 0, |
| 973 | "max_q_len": 0, |
| 974 | "max_k_len": 0, |
| 975 | "full_ranges": 0, |
| 976 | "causal_ranges": 0, |
| 977 | "other_ranges": 0, |
| 978 | } |
| 979 | model._la_flash_sparse_plan_stats = stats |
| 980 | stats["calls"] += 1 |
| 981 | stats["ranges"] += len(q_ranges) |
| 982 | for (q_start, q_end), (k_start, k_end), attn_type in zip(q_ranges, k_ranges, types): |
| 983 | q_len = int(q_end) - int(q_start) |
| 984 | k_len = int(k_end) - int(k_start) |
| 985 | stats["q_tokens"] += q_len |
| 986 | stats["k_tokens"] += k_len |
| 987 | stats["max_q_len"] = max(stats["max_q_len"], q_len) |
| 988 | stats["max_k_len"] = max(stats["max_k_len"], k_len) |
| 989 | attn_type = int(attn_type) |
| 990 | if attn_type == 0: |
| 991 | stats["full_ranges"] += 1 |
| 992 | elif attn_type == 1: |
| 993 | stats["causal_ranges"] += 1 |
| 994 | else: |
| 995 | stats["other_ranges"] += 1 |
| 996 | |
| 997 | |
| 998 | def build_magi_scheduler_ranges(model, attention_mask_2d, input_ids, past_len, mtp_window=False): |
| 999 | """Build batched Magi ranges directly from the hybrid scheduler mask. |
| 1000 | |
| 1001 | The official Qwen2 SDPA dispatcher may optimize an all-valid 2D mask to |
| 1002 | ``None`` before decoder layers see it. That is correct for plain causal |
| 1003 | attention but loses LocateAnything's MTP generation-window rule. Building |
| 1004 | ranges here keeps Magi batch inference exact and avoids per-layer dense |
| 1005 | mask conversion. |
| 1006 | """ |
| 1007 | requested_attn = getattr(model, "_la_flash_requested_attn", ATTN_MODE) |
| 1008 | if requested_attn not in {"magi", "la_flash"}: |
| 1009 | return None |
| 1010 | if attention_mask_2d is None or not hasattr(attention_mask_2d, "dim") or attention_mask_2d.dim() != 2: |
| 1011 | return None |
| 1012 | |
| 1013 | bsz, q_len = int(input_ids.shape[0]), int(input_ids.shape[1]) |
| 1014 | key_len = int(attention_mask_2d.shape[1]) |
| 1015 | dev = input_ids.device |
| 1016 | llm = model.language_model.model |
| 1017 | block = int(getattr(llm, "block_size", N_FUTURE)) |
| 1018 | causal_attn = bool(getattr(llm, "causal_attn", False)) |
| 1019 | use_mtp_window = bool(mtp_window and q_len >= block and key_len >= block) |
| 1020 | q0 = max(0, q_len - block) |
| 1021 | k0 = max(0, key_len - block) |
| 1022 | blocked_k = k0 - 1 |
| 1023 | past_len = int(past_len) |
| 1024 | |
| 1025 | key_valid = attention_mask_2d.detach().to(device="cpu", dtype=torch.bool).contiguous().numpy() |
| 1026 | key_idx = np.arange(key_len) |
| 1027 | q_ranges, k_ranges, types = [], [], [] |
| 1028 | if not use_mtp_window: |
| 1029 | causal_q_ranges, causal_k_ranges, causal_types = [], [], [] |
| 1030 | causal_fast_path = True |
| 1031 | packed_flash = True |
| 1032 | for b in range(bsz): |
| 1033 | valid = np.flatnonzero(key_valid[b]) |
| 1034 | if valid.size == 0: |
| 1035 | causal_fast_path = False |
| 1036 | break |
| 1037 | valid_len = int(valid[-1]) + 1 |
| 1038 | if valid_len < q_len or not bool(key_valid[b, :valid_len].all()) or bool(key_valid[b, valid_len:].any()): |
| 1039 | causal_fast_path = False |
| 1040 | break |
| 1041 | packed_flash = packed_flash and valid_len == key_len |
| 1042 | q_base = b * q_len |
| 1043 | k_base = b * key_len |
| 1044 | causal_q_ranges.append([q_base, q_base + q_len]) |
| 1045 | causal_k_ranges.append([k_base, k_base + valid_len]) |
| 1046 | causal_types.append(1) |
| 1047 | if causal_fast_path: |
| 1048 | plan = { |
| 1049 | "q_ranges": torch.tensor(causal_q_ranges, dtype=torch.int32, device=dev).contiguous(), |
| 1050 | "k_ranges": torch.tensor(causal_k_ranges, dtype=torch.int32, device=dev).contiguous(), |
| 1051 | "attn_type_map": torch.tensor(causal_types, dtype=torch.int32, device=dev).contiguous(), |
| 1052 | "max_q_len": q_len, |
| 1053 | "max_k_len": max((end - start for start, end in causal_k_ranges), default=0), |
| 1054 | "_la_flash_batched": True, |
| 1055 | "_la_flash_disjoint_q_ranges": True, |
| 1056 | } |
| 1057 | if packed_flash: |
| 1058 | plan.update( |
| 1059 | { |
| 1060 | "flash_cu_seqlens_q": torch.arange( |
| 1061 | 0, |
| 1062 | (bsz + 1) * q_len, |
| 1063 | q_len, |
| 1064 | dtype=torch.int32, |
| 1065 | device=dev, |
| 1066 | ), |
| 1067 | "flash_cu_seqlens_k": torch.arange( |
| 1068 | 0, |
| 1069 | (bsz + 1) * key_len, |
| 1070 | key_len, |
| 1071 | dtype=torch.int32, |
| 1072 | device=dev, |
| 1073 | ), |
| 1074 | "flash_causal": True, |
| 1075 | } |
| 1076 | ) |
| 1077 | plan.update(_la_flash_group_plan_tensors(causal_q_ranges, causal_types, dev)) |
| 1078 | _record_sparse_plan_stats(model, causal_q_ranges, causal_k_ranges, causal_types) |
| 1079 | return plan |
| 1080 | |
| 1081 | def row_segments(row): |
| 1082 | idx = np.flatnonzero(row) |
| 1083 | if idx.size == 0: |
| 1084 | return ((0, 1),) |
| 1085 | split = np.flatnonzero(np.diff(idx) > 1) + 1 |
| 1086 | starts = np.concatenate((idx[:1], idx[split])) |
| 1087 | ends = np.concatenate((idx[split - 1], idx[-1:])) + 1 |
| 1088 | return tuple((int(s), int(e)) for s, e in zip(starts, ends)) |
| 1089 | |
| 1090 | for b in range(bsz): |
| 1091 | q_base = b * q_len |
| 1092 | k_base = b * key_len |
| 1093 | run_start = 0 |
| 1094 | run_segments = None |
| 1095 | if use_mtp_window and not causal_attn: |
| 1096 | prefix_q_len = q0 |
| 1097 | prefix_k_end = past_len + prefix_q_len |
| 1098 | prefix_ok = ( |
| 1099 | prefix_q_len > 0 |
| 1100 | and prefix_k_end <= key_len |
| 1101 | and bool(key_valid[b, :prefix_k_end].all()) |
| 1102 | ) |
| 1103 | window_prefix_ok = blocked_k <= 0 or bool(key_valid[b, :blocked_k].all()) |
| 1104 | window_ok = bool(key_valid[b, k0:key_len].all()) |
| 1105 | if prefix_ok: |
| 1106 | q_ranges.append([q_base, q_base + prefix_q_len]) |
| 1107 | k_ranges.append([k_base, k_base + prefix_k_end]) |
| 1108 | types.append(1) |
| 1109 | run_start = prefix_q_len |
| 1110 | if run_start == prefix_q_len and prefix_q_len < q_len and window_prefix_ok and window_ok: |
| 1111 | if blocked_k > 0: |
| 1112 | q_ranges.append([q_base + prefix_q_len, q_base + q_len]) |
| 1113 | k_ranges.append([k_base, k_base + blocked_k]) |
| 1114 | types.append(0) |
| 1115 | q_ranges.append([q_base + prefix_q_len, q_base + q_len]) |
| 1116 | k_ranges.append([k_base + k0, k_base + key_len]) |
| 1117 | types.append(0) |
| 1118 | continue |
| 1119 | |
| 1120 | for q in range(run_start, q_len): |
| 1121 | visible = key_valid[b] & (key_idx <= q + past_len) |
| 1122 | if use_mtp_window and q >= q0: |
| 1123 | if not causal_attn: |
| 1124 | visible = visible.copy() |
| 1125 | visible[k0:key_len] = key_valid[b, k0:key_len] |
| 1126 | if blocked_k >= 0: |
| 1127 | if visible.base is None: |
| 1128 | visible[blocked_k] = False |
| 1129 | else: |
| 1130 | visible = visible.copy() |
| 1131 | visible[blocked_k] = False |
| 1132 | segments = row_segments(visible) |
| 1133 | if run_segments is None: |
| 1134 | run_segments = segments |
| 1135 | continue |
| 1136 | if segments == run_segments: |
| 1137 | continue |
| 1138 | for start, end in run_segments: |
| 1139 | q_ranges.append([q_base + run_start, q_base + q]) |
| 1140 | k_ranges.append([k_base + start, k_base + end]) |
| 1141 | types.append(0) |
| 1142 | run_start = q |
| 1143 | run_segments = segments |
| 1144 | for start, end in run_segments: |
| 1145 | q_ranges.append([q_base + run_start, q_base + q_len]) |
| 1146 | k_ranges.append([k_base + start, k_base + end]) |
| 1147 | types.append(0) |
| 1148 | |
| 1149 | seen_q_ranges = set() |
| 1150 | disjoint_q_ranges = True |
| 1151 | for start, end in q_ranges: |
| 1152 | key = (int(start), int(end)) |
| 1153 | if key in seen_q_ranges: |
| 1154 | disjoint_q_ranges = False |
| 1155 | break |
| 1156 | seen_q_ranges.add(key) |
| 1157 | |
| 1158 | plan = { |
| 1159 | "q_ranges": torch.tensor(q_ranges, dtype=torch.int32, device=dev).contiguous(), |
| 1160 | "k_ranges": torch.tensor(k_ranges, dtype=torch.int32, device=dev).contiguous(), |
| 1161 | "attn_type_map": torch.tensor(types, dtype=torch.int32, device=dev).contiguous(), |
| 1162 | "max_q_len": max((end - start for start, end in q_ranges), default=0), |
| 1163 | "max_k_len": max((end - start for start, end in k_ranges), default=0), |
| 1164 | "_la_flash_batched": True, |
| 1165 | "_la_flash_disjoint_q_ranges": disjoint_q_ranges, |
| 1166 | } |
| 1167 | plan.update(_la_flash_group_plan_tensors(q_ranges, types, dev)) |
| 1168 | _record_sparse_plan_stats(model, q_ranges, k_ranges, types) |
| 1169 | return plan |
| 1170 | |
| 1171 | |
| 1172 | def _direct_base_forward( |
| 1173 | base, |
| 1174 | input_ids=None, |
| 1175 | visual_features=None, |
| 1176 | image_token_index=None, |
| 1177 | attention_mask=None, |
| 1178 | position_ids=None, |
| 1179 | past_key_values=None, |
| 1180 | inputs_embeds=None, |
| 1181 | use_cache=None, |
| 1182 | output_attentions=None, |
| 1183 | output_hidden_states=None, |
| 1184 | return_dict=None, |
| 1185 | ): |
| 1186 | mod = importlib.import_module(type(base).__module__) |
| 1187 | output_attentions = output_attentions if output_attentions is not None else base.config.output_attentions |
| 1188 | output_hidden_states = ( |
| 1189 | output_hidden_states if output_hidden_states is not None else base.config.output_hidden_states |
| 1190 | ) |
| 1191 | use_cache = use_cache if use_cache is not None else base.config.use_cache |
| 1192 | |
| 1193 | if input_ids is not None and inputs_embeds is not None: |
| 1194 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") |
| 1195 | if input_ids is not None: |
| 1196 | batch_size, seq_length = input_ids.shape |
| 1197 | elif inputs_embeds is not None: |
| 1198 | batch_size, seq_length, _ = inputs_embeds.shape |
| 1199 | else: |
| 1200 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") |
| 1201 | |
| 1202 | past_key_values_length = 0 |
| 1203 | use_legacy_cache = False |
| 1204 | if use_cache: |
| 1205 | Cache = getattr(mod, "Cache") |
| 1206 | DynamicCache = getattr(mod, "DynamicCache") |
| 1207 | use_legacy_cache = not isinstance(past_key_values, Cache) |
| 1208 | if use_legacy_cache: |
| 1209 | if past_key_values is None: |
| 1210 | past_key_values = DynamicCache() |
| 1211 | else: |
| 1212 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
| 1213 | past_key_values_length = past_key_values.get_seq_length() |
| 1214 | |
| 1215 | if position_ids is None: |
| 1216 | dev = input_ids.device if input_ids is not None else inputs_embeds.device |
| 1217 | position_ids = torch.arange( |
| 1218 | past_key_values_length, |
| 1219 | seq_length + past_key_values_length, |
| 1220 | dtype=torch.long, |
| 1221 | device=dev, |
| 1222 | ).unsqueeze(0).view(-1, seq_length) |
| 1223 | else: |
| 1224 | position_ids = position_ids.view(-1, seq_length).long() |
| 1225 | |
| 1226 | if inputs_embeds is None: |
| 1227 | inputs_embeds = base.image_processing(input_ids, visual_features, image_token_index) |
| 1228 | |
| 1229 | hidden_states = inputs_embeds |
| 1230 | all_hidden_states = () if output_hidden_states else None |
| 1231 | all_self_attns = () if output_attentions else None |
| 1232 | next_decoder_cache = None |
| 1233 | |
| 1234 | for decoder_layer in base.layers: |
| 1235 | if output_hidden_states: |
| 1236 | all_hidden_states += (hidden_states,) |
| 1237 | layer_outputs = decoder_layer( |
| 1238 | hidden_states, |
| 1239 | attention_mask=attention_mask, |
| 1240 | position_ids=position_ids, |
| 1241 | past_key_value=past_key_values, |
| 1242 | output_attentions=output_attentions, |
| 1243 | use_cache=use_cache, |
| 1244 | ) |
| 1245 | hidden_states = layer_outputs[0] |
| 1246 | if use_cache: |
| 1247 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
| 1248 | if output_attentions: |
| 1249 | all_self_attns += (layer_outputs[1],) |
| 1250 | |
| 1251 | hidden_states = base.norm(hidden_states) |
| 1252 | if output_hidden_states: |
| 1253 | all_hidden_states += (hidden_states,) |
| 1254 | next_cache = None |
| 1255 | if use_cache: |
| 1256 | next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache |
| 1257 | return SimpleNamespace( |
| 1258 | last_hidden_state=hidden_states, |
| 1259 | past_key_values=next_cache, |
| 1260 | hidden_states=all_hidden_states, |
| 1261 | attentions=all_self_attns, |
| 1262 | ) |
| 1263 | |
| 1264 | |
| 1265 | def language_model_forward(model, **kwargs): |
| 1266 | """Forward through the text LM, bypassing official dense-mask prep for sparse plans.""" |
| 1267 | lm = model.language_model |
| 1268 | return_logits = kwargs.pop("return_logits", True) |
| 1269 | logits_slice = kwargs.pop("logits_slice", None) |
| 1270 | attention_mask = kwargs.get("attention_mask") |
| 1271 | use_direct_sparse = ( |
| 1272 | getattr(model, "_la_flash_requested_attn", ATTN_MODE) in {"magi", "la_flash"} |
| 1273 | and _is_magi_plan(attention_mask) |
| 1274 | ) |
| 1275 | if not use_direct_sparse: |
| 1276 | return lm(**kwargs) |
| 1277 | |
| 1278 | labels = kwargs.pop("labels", None) |
| 1279 | if labels is not None: |
| 1280 | raise NotImplementedError("labels are not supported in the direct sparse-plan decode forward") |
| 1281 | output_attentions = kwargs.get("output_attentions", None) |
| 1282 | output_hidden_states = kwargs.get("output_hidden_states", None) |
| 1283 | base_out = _direct_base_forward(lm.model, **kwargs) |
| 1284 | logits = None |
| 1285 | if return_logits: |
| 1286 | hidden_states = base_out.last_hidden_state |
| 1287 | if logits_slice is not None: |
| 1288 | hidden_states = hidden_states[:, logits_slice, :] |
| 1289 | logits = lm.lm_head(hidden_states).float() |
| 1290 | return SimpleNamespace( |
| 1291 | logits=logits, |
| 1292 | past_key_values=base_out.past_key_values, |
| 1293 | hidden_states=base_out.hidden_states if output_hidden_states else None, |
| 1294 | attentions=base_out.attentions if output_attentions else None, |
| 1295 | ) |
| 1296 | |
| 1297 | |
| 1298 | _EagerCls = _SdpaCls = _LaFlashCls = _MagiCls = None |
| 1299 | def _attn_classes(mode=None): |
| 1300 | """Attention classes from the dynamic Qwen2 remote module. |
| 1301 | |
| 1302 | The official Qwen2Model mask dispatcher only implements ``sdpa`` and |
| 1303 | single-row ``magi``. Eager, LA Flash, and batched Magi inference |
| 1304 | therefore swap the layer class while keeping the model's mask dispatcher |
| 1305 | pinned to ``sdpa``. |
| 1306 | """ |
| 1307 | global _EagerCls, _SdpaCls, _LaFlashCls, _MagiCls |
| 1308 | mode = _normalize_attn_mode(mode) if mode is not None else None |
| 1309 | if _SdpaCls is None: |
| 1310 | mod = importlib.import_module(type(_model.language_model.model).__module__) |
| 1311 | _EagerCls = mod.Qwen2Attention |
| 1312 | _SdpaCls = mod.Qwen2SdpaAttention |
| 1313 | else: |
| 1314 | mod = importlib.import_module(type(_model.language_model.model).__module__) |
| 1315 | if (mode is None or mode == "la_flash") and _LaFlashCls is None: |
| 1316 | _LaFlashCls = build_la_flash_attention_class(mod) |
| 1317 | if (mode is None or mode == "magi") and _MagiCls is None: |
| 1318 | _MagiCls = build_batched_magi_attention_class(mod) if getattr(mod, "_MAGI_AVAILABLE", False) else None |
| 1319 | return _EagerCls, _SdpaCls, _LaFlashCls, _MagiCls |
| 1320 | |
| 1321 | def _set_llm_mode(model, mode): |
| 1322 | """Swap every Qwen2 decoder layer's attention class. |
| 1323 | |
| 1324 | Release backends keep ``Qwen2Model._attn_implementation='sdpa'`` so the |
| 1325 | official Qwen2 mask dispatcher stays available for dense-mask modes. The |
| 1326 | local ``la_flash`` and batched ``magi`` wrappers can also consume scheduler-built |
| 1327 | sparse plans directly, avoiding repeated per-layer dense mask conversion. |
| 1328 | """ |
| 1329 | mode = _normalize_attn_mode(mode) |
| 1330 | eager, sdpa, la_flash, magi = _attn_classes(mode) |
| 1331 | impl = "sdpa" |
| 1332 | if mode == "sdpa": |
| 1333 | cls = sdpa |
| 1334 | elif mode == "eager": |
| 1335 | cls = eager |
| 1336 | elif mode == "la_flash": |
| 1337 | cls = la_flash |
| 1338 | elif mode == "magi": |
| 1339 | if magi is None: |
| 1340 | raise RuntimeError("MagiAttention is unavailable in the current Python environment.") |
| 1341 | cls = magi |
| 1342 | else: |
| 1343 | raise ValueError(f"unknown LLM attention mode: {mode}") |
| 1344 | llm = model.language_model.model |
| 1345 | for lyr in llm.layers: |
| 1346 | lyr.self_attn.__class__ = cls |
| 1347 | if mode == "magi": |
| 1348 | lyr.self_attn.softmax_scale = lyr.self_attn.head_dim ** -0.5 |
| 1349 | llm._attn_implementation = impl |
| 1350 | llm.config._attn_implementation = llm._attn_implementation |
| 1351 | if hasattr(model.config, "text_config"): |
| 1352 | model.config.text_config._attn_implementation = llm._attn_implementation |
| 1353 | model.config._attn_implementation = llm._attn_implementation |
| 1354 | model._la_flash_requested_attn = mode |
| 1355 | |
| 1356 | _st = _hp = None |
| 1357 | def _helpers(): |
| 1358 | """The model's own sample_tokens / handle_pattern (the exact box decoders).""" |
| 1359 | global _st, _hp |
| 1360 | if _st is None: |
| 1361 | m = importlib.import_module(type(load()[2]).__module__) |
| 1362 | _st, _hp = m.sample_tokens, m.handle_pattern |
| 1363 | return _st, _hp |
| 1364 | |
| 1365 | |
| 1366 | _gu = None |
| 1367 | def _gen_utils(): |
| 1368 | """The model's generate_utils module (apply_repetition_penalty / top_p_logits / top_k_logits / |
| 1369 | decode_bbox_avg / decode_ref / dists) -- the pieces sample_tokens_batched reuses verbatim.""" |
| 1370 | global _gu |
| 1371 | if _gu is None: |
| 1372 | m = importlib.import_module(type(load()[2]).__module__) |
| 1373 | _gu = importlib.import_module(m.sample_tokens.__module__) |
| 1374 | return _gu |
| 1375 | |
| 1376 | |
| 1377 | def _env_float(name, default): |
| 1378 | val = os.environ.get(name) |
| 1379 | if val is None or val.strip() == "": |
| 1380 | return float(default) |
| 1381 | return float(val) |
| 1382 | |
| 1383 | |
| 1384 | def _coord_fallback_mode(): |
| 1385 | mode = os.environ.get("LA_FLASH_COORD_FALLBACK_MODE", "legacy").strip().lower().replace("-", "_") |
| 1386 | aliases = { |
| 1387 | "": "legacy", |
| 1388 | "official": "legacy", |
| 1389 | "range": "legacy", |
| 1390 | "spread": "legacy", |
| 1391 | "none": "off", |
| 1392 | "disable": "off", |
| 1393 | "disabled": "off", |
| 1394 | "entropy_variance": "uncertainty", |
| 1395 | "entropy_var": "uncertainty", |
| 1396 | "ent_var": "uncertainty", |
| 1397 | "entropy_std": "uncertainty", |
| 1398 | } |
| 1399 | mode = aliases.get(mode, mode) |
| 1400 | if mode not in {"legacy", "uncertainty", "off"}: |
| 1401 | raise ValueError( |
| 1402 | "LA_FLASH_COORD_FALLBACK_MODE must be one of legacy, uncertainty, off" |
| 1403 | ) |
| 1404 | return mode |
| 1405 | |
| 1406 | |
| 1407 | def _coord_uncertainty_threshold(coord_start_token_id, coord_end_token_id): |
| 1408 | """Return the coord uncertainty threshold in raw coord-token units. |
| 1409 | |
| 1410 | Backward-compatible behavior: |
| 1411 | - LA_FLASH_COORD_UNCERTAINTY_THRESH > 1 is treated as raw coord-token RMSE. |
| 1412 | - LA_FLASH_COORD_UNCERTAINTY_THRESH <= 1 is treated as normalized by coord span. |
| 1413 | - LA_FLASH_COORD_UNCERTAINTY_NORM_THRESH is an explicit normalized override. |
| 1414 | """ |
| 1415 | coord_span = max(float(coord_end_token_id - coord_start_token_id + 1), 1.0) |
| 1416 | norm_val = os.environ.get("LA_FLASH_COORD_UNCERTAINTY_NORM_THRESH") |
| 1417 | if norm_val is not None and norm_val.strip() != "": |
| 1418 | return float(norm_val) * coord_span |
| 1419 | |
| 1420 | val = os.environ.get("LA_FLASH_COORD_UNCERTAINTY_THRESH") |
| 1421 | if val is None or val.strip() == "": |
| 1422 | return 20.0 |
| 1423 | threshold = float(val) |
| 1424 | if 0.0 < threshold <= 1.0: |
| 1425 | return threshold * coord_span |
| 1426 | return threshold |
| 1427 | |
| 1428 | |
| 1429 | def _decode_bbox_with_uncertainty(logits, probs, token_ids, keep_k=4, generation_mode="hybrid"): |
| 1430 | """Decode an MTP box with configurable coord uncertainty fallback. |
| 1431 | |
| 1432 | The default mode is the official LocateAnything rule. ``uncertainty`` keeps |
| 1433 | the same frame checks and top-k coord selection, but uses one scalar |
| 1434 | criterion per coordinate: the posterior RMSE of committing to the current |
| 1435 | MAP coordinate among valid coord candidates. This is the Bayes risk under |
| 1436 | squared coordinate error, so probabilities and token distances are folded |
| 1437 | into one threshold in coordinate-token units. |
| 1438 | """ |
| 1439 | gu = _gen_utils() |
| 1440 | mode = _coord_fallback_mode() |
| 1441 | if mode == "legacy" or generation_mode != "hybrid": |
| 1442 | return gu.decode_bbox_avg(logits, probs, token_ids, keep_k=keep_k, generation_mode=generation_mode) |
| 1443 | |
| 1444 | coord_start_token_id = token_ids["coord_start_token_id"] |
| 1445 | coord_end_token_id = token_ids["coord_end_token_id"] |
| 1446 | box_start_token_id = token_ids["box_start_token_id"] |
| 1447 | box_end_token_id = token_ids["box_end_token_id"] |
| 1448 | none_token_id = token_ids["none_token_id"] |
| 1449 | null_token_id = token_ids["null_token_id"] |
| 1450 | device = logits.device |
| 1451 | |
| 1452 | box_type = gu.is_valid_box_frame( |
| 1453 | probs, |
| 1454 | token_ids, |
| 1455 | start_thresh=_env_float("LA_FLASH_COORD_BOX_START_THRESH", 0.7), |
| 1456 | end_thresh=_env_float("LA_FLASH_COORD_BOX_END_THRESH", 0.2), |
| 1457 | topk=keep_k, |
| 1458 | ) |
| 1459 | if box_type == "empty_box": |
| 1460 | return torch.tensor([ |
| 1461 | box_start_token_id, |
| 1462 | none_token_id, |
| 1463 | box_end_token_id, |
| 1464 | null_token_id, |
| 1465 | null_token_id, |
| 1466 | null_token_id, |
| 1467 | ], dtype=torch.long, device=device) |
| 1468 | if box_type == "illegal_box": |
| 1469 | return None |
| 1470 | |
| 1471 | pos_probs, pos_ids = torch.topk(probs[1:5], k=keep_k, dim=-1) |
| 1472 | valid = (pos_ids >= coord_start_token_id) & (pos_ids <= coord_end_token_id) |
| 1473 | has_valid = valid.any(dim=-1) |
| 1474 | if not has_valid.all(): |
| 1475 | return None |
| 1476 | |
| 1477 | first_valid_idx = valid.long().argmax(dim=-1, keepdim=True) |
| 1478 | first_valid_ids = pos_ids.gather(-1, first_valid_idx).squeeze(-1) |
| 1479 | if mode == "off": |
| 1480 | final_coords = first_valid_ids |
| 1481 | else: |
| 1482 | valid_counts = valid.sum(dim=-1) |
| 1483 | valid_probs = torch.where(valid, pos_probs, torch.zeros_like(pos_probs)) |
| 1484 | valid_mass = valid_probs.sum(dim=-1).clamp_min(1e-12) |
| 1485 | weights = valid_probs / valid_mass.unsqueeze(-1) |
| 1486 | coord_values = (pos_ids - coord_start_token_id).to(dtype=torch.float32) |
| 1487 | map_coord = (first_valid_ids - coord_start_token_id).to(dtype=torch.float32) |
| 1488 | uncertainty = (weights * (coord_values - map_coord.unsqueeze(-1)).pow(2)).sum(dim=-1).sqrt() |
| 1489 | is_abnormal = ( |
| 1490 | (valid_counts > 1) |
| 1491 | & (uncertainty > _coord_uncertainty_threshold(coord_start_token_id, coord_end_token_id)) |
| 1492 | ) |
| 1493 | final_coords = torch.where(is_abnormal, torch.tensor(0, device=device), first_valid_ids) |
| 1494 | |
| 1495 | start_t = torch.tensor([box_start_token_id], dtype=final_coords.dtype, device=device) |
| 1496 | end_t = torch.tensor([box_end_token_id], dtype=final_coords.dtype, device=device) |
| 1497 | return torch.cat([start_t, final_coords, end_t]) |
| 1498 | |
| 1499 | |
| 1500 | def _apply_repetition_penalty_lowmem(logits, generated, repetition_penalty): |
| 1501 | """Apply the stock repetition penalty without allocating a [B, S, V] mask.""" |
| 1502 | if repetition_penalty == 1.0: |
| 1503 | return logits |
| 1504 | _, _, vocab_size = logits.shape |
| 1505 | for row in range(logits.shape[0]): |
| 1506 | valid_tokens = generated[row].unique() |
| 1507 | valid_tokens = valid_tokens[(valid_tokens >= 0) & (valid_tokens < vocab_size)] |
| 1508 | if valid_tokens.numel() == 0: |
| 1509 | continue |
| 1510 | row_logits = logits[row, :, valid_tokens] |
| 1511 | logits[row, :, valid_tokens] = torch.where( |
| 1512 | row_logits > 0, |
| 1513 | row_logits / repetition_penalty, |
| 1514 | row_logits * repetition_penalty, |
| 1515 | ) |
| 1516 | return logits |
| 1517 | |
| 1518 | |
| 1519 | def _finite_logit_bounds(dtype): |
| 1520 | finfo = torch.finfo(dtype) |
| 1521 | return finfo.min, finfo.max |
| 1522 | |
| 1523 | |
| 1524 | def _finite_logits(logits): |
| 1525 | if not logits.dtype.is_floating_point: |
| 1526 | logits = logits.float() |
| 1527 | min_val, max_val = _finite_logit_bounds(logits.dtype) |
| 1528 | return torch.nan_to_num(logits, nan=min_val, posinf=max_val, neginf=min_val) |
| 1529 | |
| 1530 | |
| 1531 | def _finite_logits_(logits): |
| 1532 | if not logits.dtype.is_floating_point: |
| 1533 | return logits.float() |
| 1534 | min_val, max_val = _finite_logit_bounds(logits.dtype) |
| 1535 | return logits.nan_to_num_(nan=min_val, posinf=max_val, neginf=min_val) |
| 1536 | |
| 1537 | |
| 1538 | def _top_p_logits_slice_(logits, top_p): |
| 1539 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| 1540 | cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
| 1541 | sorted_indices_to_remove = cumulative_probs > top_p |
| 1542 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| 1543 | sorted_indices_to_remove[..., 0] = False |
| 1544 | |
| 1545 | remove = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) |
| 1546 | remove.scatter_(-1, sorted_indices, sorted_indices_to_remove) |
| 1547 | logits.masked_fill_(remove, torch.finfo(logits.dtype).min) |
| 1548 | return logits |
| 1549 | |
| 1550 | |
| 1551 | def _top_p_logits_(logits, top_p): |
| 1552 | """In-place nucleus filtering with bounded sort workspace. |
| 1553 | |
| 1554 | The MTP sampler uses logits shaped ``[B, 6, V]``. Top-p is independent for |
| 1555 | each row and each future position, so filtering one position at a time keeps |
| 1556 | the expensive sorted-index workspace at ``[B, V]`` instead of ``[B, 6, V]``. |
| 1557 | """ |
| 1558 | if logits.dim() == 3 and logits.shape[1] > 1: |
| 1559 | for pos in range(logits.shape[1]): |
| 1560 | _top_p_logits_slice_(logits[:, pos, :], top_p) |
| 1561 | return logits |
| 1562 | return _top_p_logits_slice_(logits, top_p) |
| 1563 | |
| 1564 | |
| 1565 | def _top_k_logits_(logits, top_k): |
| 1566 | """In-place top-k filtering mirroring generate_utils.top_k_logits.""" |
| 1567 | top_k = min(int(top_k), logits.size(-1)) |
| 1568 | threshold = torch.topk(logits, top_k)[0][..., -1, None] |
| 1569 | logits.masked_fill_(logits < threshold, torch.finfo(logits.dtype).min) |
| 1570 | return logits |
| 1571 | |
| 1572 | |
| 1573 | def _safe_probs(filtered_logits): |
| 1574 | """Softmax with CUDA-multinomial-safe cleanup and row-wise argmax fallback.""" |
| 1575 | filtered_logits = _finite_logits(filtered_logits) |
| 1576 | probs = torch.softmax(filtered_logits, dim=-1, dtype=torch.float32) |
| 1577 | probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0).clamp_min_(0.0) |
| 1578 | row_sum = probs.sum(dim=-1, keepdim=True) |
| 1579 | bad = (~torch.isfinite(row_sum)) | (row_sum <= 0) |
| 1580 | if bool(bad.any().item()): |
| 1581 | fallback = torch.zeros_like(probs) |
| 1582 | fallback.scatter_(-1, filtered_logits.argmax(dim=-1, keepdim=True), 1.0) |
| 1583 | probs = torch.where(bad, fallback, probs) |
| 1584 | row_sum = probs.sum(dim=-1, keepdim=True) |
| 1585 | return probs / row_sum.clamp_min(1.0e-20) |
| 1586 | |
| 1587 | |
| 1588 | def _sample_top_p_sorted_tokens(logits, top_p): |
| 1589 | """Sample from top-p filtered logits without scattering back to vocab order.""" |
| 1590 | sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) |
| 1591 | cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
| 1592 | remove = cumulative_probs > top_p |
| 1593 | remove[..., 1:] = remove[..., :-1].clone() |
| 1594 | remove[..., 0] = False |
| 1595 | sorted_logits.masked_fill_(remove, torch.finfo(sorted_logits.dtype).min) |
| 1596 | sorted_probs = _safe_probs(sorted_logits) |
| 1597 | sample_idx = sorted_probs.argmax(dim=-1) |
| 1598 | try: |
| 1599 | sample_idx = torch.distributions.Categorical(probs=sorted_probs).sample() |
| 1600 | except Exception: |
| 1601 | pass |
| 1602 | return sorted_indices.gather(-1, sample_idx.unsqueeze(-1)).squeeze(-1) |
| 1603 | |
| 1604 | |
| 1605 | @torch.no_grad() |
| 1606 | def sample_tokens_batched(logits, generated, token_ids, per_row_temp, |
| 1607 | repetition_penalty=1.0, top_p=None, top_k=None, |
| 1608 | keep_k_avg=4, generation_mode='fast'): |
| 1609 | """Batched fork of generate_utils.sample_tokens for the MTP window [B,6,V]. The logits pipeline |
| 1610 | (rep-penalty / per-row temperature / top_p / top_k / softmax / sample) is ROW-INDEPENDENT, so run |
| 1611 | it ONCE over the whole batch instead of B times on [1,6,V] (the per-row san defeats batching by |
| 1612 | slicing wlogits[b:b+1]). Only the variable-length box ASSEMBLY (decode_bbox_avg -> ragged shapes, |
| 1613 | where sample_tokens' final torch.stack throws) stays per-row, returned as a LIST. |
| 1614 | |
| 1615 | Equivalence to per-row san: every pipeline op reduces on dim=-1 only (never crosses the row dim), |
| 1616 | so row b's processed logits/probs are bit-identical to slicing first -> greedy (per_row_temp==0, |
| 1617 | argmax branch, no RNG) is BIT-EXACT. Under sampling, one batched Categorical changes the global |
| 1618 | RNG consumption order vs B per-row draws -> box-size jitter (blessed; greedy is the exact gate). |
| 1619 | apply_repetition_penalty already loops per-row internally, so passing the full [B,M] `generated` |
| 1620 | is row-correct. keep_k_avg/generation_mode mirror sample_tokens' decode_bbox_avg call EXACTLY |
| 1621 | (note: the per-row san passes keep_k=5 but decode_bbox_avg reads keep_k_avg, default 4 -- so 5 is |
| 1622 | a no-op there; we replicate keep_k_avg=4). Returns (x0[B,6], boxes: list of B 1-D LongTensors).""" |
| 1623 | gu = _gen_utils() |
| 1624 | B, S, V = logits.shape # S = N_FUTURE = 6 |
| 1625 | if repetition_penalty != 1.0: |
| 1626 | logits = _apply_repetition_penalty_lowmem(logits, generated, repetition_penalty) |
| 1627 | t = per_row_temp.to(dtype=logits.dtype).view(B, 1, 1) |
| 1628 | sample_rows = per_row_temp > 0 |
| 1629 | if bool(sample_rows.all().item()): |
| 1630 | logits.div_(t.clamp(min=1e-8)) |
| 1631 | elif bool(sample_rows.any().item()): |
| 1632 | idx = sample_rows.nonzero(as_tuple=True)[0] |
| 1633 | logits[idx].div_(t[idx].clamp(min=1e-8)) |
| 1634 | logits = _finite_logits_(logits) |
| 1635 | if top_p is not None and top_p < 1: |
| 1636 | logits = _top_p_logits_(logits, top_p) |
| 1637 | if top_k is not None and top_k > 0: |
| 1638 | logits = _top_k_logits_(logits, top_k) |
| 1639 | probs = _safe_probs(logits) |
| 1640 | x0 = probs.argmax(dim=-1) # [B,6]; greedy rows are final here |
| 1641 | samp = per_row_temp > 0 |
| 1642 | if bool(samp.any()): # sampling rows: ONE batched Categorical draw |
| 1643 | idx = samp.nonzero(as_tuple=True)[0] |
| 1644 | try: |
| 1645 | x0[idx] = gu.dists.Categorical(probs=probs[idx]).sample() |
| 1646 | except Exception: |
| 1647 | pass # keep argmax (matches san's except: probs.max) |
| 1648 | boxes = [] |
| 1649 | fallback = torch.zeros(1, dtype=x0.dtype, device=x0.device) |
| 1650 | for b in range(B): # variable-length box assembly (per-row, exact) |
| 1651 | db = _decode_bbox_with_uncertainty( |
| 1652 | logits[b], probs[b], token_ids, |
| 1653 | keep_k=keep_k_avg, generation_mode=generation_mode) |
| 1654 | if db is not None: |
| 1655 | boxes.append(db) |
| 1656 | else: |
| 1657 | ref = gu.decode_ref(logits[b], probs[b], token_ids) |
| 1658 | if ref is None: |
| 1659 | boxes.append(fallback) |
| 1660 | elif torch.is_tensor(ref): |
| 1661 | boxes.append(ref.to(dtype=x0.dtype, device=x0.device)) |
| 1662 | else: |
| 1663 | boxes.append(torch.tensor(ref, dtype=x0.dtype, device=x0.device)) |
| 1664 | return x0, boxes |
| 1665 | |
| 1666 | |
| 1667 | @torch.no_grad() |
| 1668 | def sample_next_tokens_batched(logits, generated, per_row_temp, |
| 1669 | repetition_penalty=1.0, top_p=None, top_k=None): |
| 1670 | """Batched one-token sampler for AR repair rows. |
| 1671 | |
| 1672 | This mirrors the row-independent part of ``sample_tokens`` for logits shaped |
| 1673 | ``[B,1,V]``. It intentionally does not run bbox/ref assembly because AR mode |
| 1674 | only needs the next token before the state machine classifies it. |
| 1675 | """ |
| 1676 | gu = _gen_utils() |
| 1677 | if logits.dim() != 3 or logits.shape[1] != 1: |
| 1678 | raise ValueError(f"AR batched sampler expects logits [B,1,V], got {tuple(logits.shape)}") |
| 1679 | B = int(logits.shape[0]) |
| 1680 | if repetition_penalty != 1.0: |
| 1681 | logits = _apply_repetition_penalty_lowmem(logits, generated, repetition_penalty) |
| 1682 | t = per_row_temp.to(dtype=logits.dtype).view(B, 1, 1) |
| 1683 | sample_rows = per_row_temp > 0 |
| 1684 | if bool(sample_rows.all().item()): |
| 1685 | logits.div_(t.clamp(min=1e-8)) |
| 1686 | elif bool(sample_rows.any().item()): |
| 1687 | idx = sample_rows.nonzero(as_tuple=True)[0] |
| 1688 | logits[idx].div_(t[idx].clamp(min=1e-8)) |
| 1689 | logits = _finite_logits_(logits) |
| 1690 | sorted_top_p = os.environ.get("AR_SORTED_TOPP", "0") == "1" |
| 1691 | default_top_p = sorted_top_p and top_p is not None and top_p < 1 and (top_k is None or top_k <= 0) |
| 1692 | if default_top_p and bool(sample_rows.all().item()): |
| 1693 | return _sample_top_p_sorted_tokens(logits, top_p) |
| 1694 | if top_p is not None and top_p < 1: |
| 1695 | logits = _top_p_logits_(logits, top_p) |
| 1696 | if top_k is not None and top_k > 0: |
| 1697 | logits = _top_k_logits_(logits, top_k) |
| 1698 | probs = _safe_probs(logits) |
| 1699 | x0 = probs.argmax(dim=-1) |
| 1700 | if bool(sample_rows.any().item()): |
| 1701 | # Keep row-ordered sampling as the release default. A single batched |
| 1702 | # Categorical is faster, but it consumes RNG differently from stock AR |
| 1703 | # repair and can alter default-temperature termination behavior. |
| 1704 | for row in sample_rows.nonzero(as_tuple=True)[0].tolist(): |
| 1705 | try: |
| 1706 | x0[row : row + 1] = gu.dists.Categorical(probs=probs[row : row + 1]).sample() |
| 1707 | except Exception: |
| 1708 | pass |
| 1709 | return x0 |
| 1710 | |
| 1711 | |
| 1712 | def load_pil(p): |
| 1713 | from PIL import Image |
| 1714 | im = Image.open(p).convert("RGB"); w, h = im.size |
| 1715 | if max(w, h) > MAX_DIM: |
| 1716 | s = MAX_DIM / max(w, h); im = im.resize((max(1, round(w*s)), max(1, round(h*s))), Image.LANCZOS) |
| 1717 | return im |
| 1718 | |
| 1719 | def _preproc_one(im): |
| 1720 | """CPU-side processor for one image -> (pixel_values[bf16], grid[int32]). Split out of |
| 1721 | _encode_image so _encode_images can batch the GPU encode while preprocessing stays per-image.""" |
| 1722 | tok, proc, model = load() |
| 1723 | msg = [{"role": "user", "content": [{"type": "image", "image": im}, {"type": "text", "text": "x"}]}] |
| 1724 | text = proc.py_apply_chat_template(msg, tokenize=False, add_generation_prompt=True) |
| 1725 | imgs, vids = proc.process_vision_info(msg) |
| 1726 | inp = proc(text=[text], images=imgs, videos=vids, return_tensors="pt").to(DEV) |
| 1727 | grid = inp.get("image_grid_hws") |
| 1728 | if isinstance(grid, np.ndarray): grid = torch.from_numpy(grid).to(DEV, dtype=torch.int32) |
| 1729 | return inp["pixel_values"].to(DT), grid |
| 1730 | |
| 1731 | |
| 1732 | def _vision_is_flash(): |
| 1733 | """True iff MoonViT will actually run flash_attn_varlen (so cross-image packing is |
| 1734 | block-diagonal = exact AND a win). If the vision blocks are on sdpa/eager, OR the flash |
| 1735 | wheel is absent (multihead_attention falls back to the dense-mask sdpa path), packing is |
| 1736 | O(S^2) N^2 -> caller must stay per-image.""" |
| 1737 | vm = load()[2].vision_model |
| 1738 | mod = importlib.import_module(type(vm).__module__) |
| 1739 | if getattr(mod, "flash_attn_varlen_func", None) is None: |
| 1740 | return False |
| 1741 | try: |
| 1742 | return vm.encoder.blocks[0].attn_implementation == "flash_attention_2" |
| 1743 | except Exception: |
| 1744 | return False |
| 1745 | |
| 1746 | |
| 1747 | @torch.no_grad() |
| 1748 | def _encode_images(ims): |
| 1749 | """N images -> list of [n_img_tokens, C] mlp1-projected visual_features, one per image |
| 1750 | (row-order). Drop-in for [_encode_image(im) for im in ims]. |
| 1751 | |
| 1752 | With flash present (_vision_is_flash) and N>1, packs images into |
| 1753 | extract_feature micro-batches: MoonViT's varlen cu_seqlens path is |
| 1754 | block-diagonal by image. Without flash, the dense SDPA fallback would scale |
| 1755 | with the packed total sequence length, so this function falls back to |
| 1756 | per-image encode. MTP_BATCH_VISION=0 also forces per-image encode.""" |
| 1757 | tok, proc, model = load() |
| 1758 | pvs, grids = [], [] |
| 1759 | for im in ims: |
| 1760 | pv, g = _preproc_one(im) |
| 1761 | pvs.append(pv); grids.append(g) |
| 1762 | if BATCH_VISION and len(ims) > 1 and _vision_is_flash(): |
| 1763 | if VISION_ENCODE_BATCH_SIZE <= 0 or VISION_ENCODE_BATCH_SIZE >= len(ims): |
| 1764 | vit_list = model.extract_feature(torch.cat(pvs, dim=0), torch.cat(grids, dim=0)) |
| 1765 | else: |
| 1766 | vit_list = [] |
| 1767 | for start in range(0, len(ims), VISION_ENCODE_BATCH_SIZE): |
| 1768 | end = min(start + VISION_ENCODE_BATCH_SIZE, len(ims)) |
| 1769 | vit_list.extend( |
| 1770 | model.extract_feature( |
| 1771 | torch.cat(pvs[start:end], dim=0), |
| 1772 | torch.cat(grids[start:end], dim=0), |
| 1773 | ) |
| 1774 | ) |
| 1775 | return [model.mlp1(v) for v in vit_list] # one [P_i, C] per image (patch_merger split) |
| 1776 | return [model.mlp1(torch.cat(model.extract_feature(pv, g), dim=0)) |
| 1777 | for pv, g in zip(pvs, grids)] # per-image (flash absent / N==1 / forced off) |
| 1778 | |
| 1779 | |
| 1780 | @torch.no_grad() |
| 1781 | def _encode_image(im): |
| 1782 | """Single-image convenience wrapper (single-image callers); = _encode_images([im])[0] |
| 1783 | (takes the per-image path inside _encode_images, so bit-identical to the original).""" |
| 1784 | return _encode_images([im])[0] |
| 1785 | |
| 1786 | @torch.no_grad() |
| 1787 | def _tokenize(im, query): |
| 1788 | """1-D prompt token ids for (image, query). Uses the model's own chat template.""" |
| 1789 | tok, proc, model = load() |
| 1790 | msg = [{"role": "user", "content": [{"type": "image", "image": im}, |
| 1791 | {"type": "text", "text": _PROMPT + query + "."}]}] |
| 1792 | text = proc.py_apply_chat_template(msg, tokenize=False, add_generation_prompt=True) |
| 1793 | imgs, vids = proc.process_vision_info(msg) |
| 1794 | return proc(text=[text], images=imgs, videos=vids, return_tensors="pt").to(DEV)["input_ids"][0] |
| 1795 | |
| 1796 | |
| 1797 | @torch.no_grad() |
| 1798 | def _tokenize_cached_image(query, image_token_count, im=None): |
| 1799 | """Tokenize a prompt when the image token count is already known. |
| 1800 | |
| 1801 | This keeps the processor's chat template, but directly expands ``<image-1>`` |
| 1802 | from the cached visual feature length. It avoids re-running the CPU image |
| 1803 | processor for every category prompt that shares the same image. |
| 1804 | """ |
| 1805 | tok, proc, model = load() |
| 1806 | msg = [{"role": "user", "content": [{"type": "image", "image": im}, |
| 1807 | {"type": "text", "text": _PROMPT + query + "."}]}] |
| 1808 | text = proc.py_apply_chat_template(msg, tokenize=False, add_generation_prompt=True) |
| 1809 | placeholder = f"<{getattr(proc, 'image_placeholder', 'image')}-1>" |
| 1810 | image_token = getattr(proc, "image_token", "<IMG_CONTEXT>") |
| 1811 | image_start = getattr(proc, "image_start_token", "<img>") |
| 1812 | image_end = getattr(proc, "image_end_token", "</img>") |
| 1813 | replacement = f"<image 1>{image_start}{image_token * int(image_token_count)}{image_end}" |
| 1814 | if placeholder not in text: |
| 1815 | raise ValueError(f"cached image placeholder {placeholder!r} was not found in chat template") |
| 1816 | text = text.replace(placeholder, replacement, 1) |
| 1817 | return tok([text], return_tensors="pt").to(DEV)["input_ids"][0] |
| 1818 | |
| 1819 | |
| 1820 | def _proc_full(im, query): |
| 1821 | """Full processor dict (input_ids, attention_mask, pixel_values, image_grid_hws) — |
| 1822 | used by the bench to drive the STOCK generate for the equivalence check.""" |
| 1823 | tok, proc, model = load() |
| 1824 | msg = [{"role": "user", "content": [{"type": "image", "image": im}, |
| 1825 | {"type": "text", "text": _PROMPT + query + "."}]}] |
| 1826 | text = proc.py_apply_chat_template(msg, tokenize=False, add_generation_prompt=True) |
| 1827 | imgs, vids = proc.process_vision_info(msg) |
| 1828 | inp = proc(text=[text], images=imgs, videos=vids, return_tensors="pt").to(DEV) |
| 1829 | grid = inp.get("image_grid_hws") |
| 1830 | if isinstance(grid, np.ndarray): grid = torch.from_numpy(grid).to(DEV, dtype=torch.int32) |
| 1831 | inp["image_grid_hws"] = grid |
| 1832 | return inp |
| 1833 | |
| 1834 | def _pad_generated(prompt_ids, gen_ids, img_tok, dev): |
| 1835 | """Per-row [prompt + accepted] left-padded with the image token (already in every |
| 1836 | prompt -> .unique() unchanged -> repetition penalty identical to single-run).""" |
| 1837 | rows = [list(prompt_ids[b].tolist()) + gen_ids[b] for b in range(len(prompt_ids))] |
| 1838 | M = max(len(r) for r in rows) |
| 1839 | out = torch.full((len(rows), M), img_tok, dtype=torch.long, device=dev) |
| 1840 | for b, r in enumerate(rows): |
| 1841 | out[b, M - len(r):] = torch.tensor(r, dtype=torch.long, device=dev) |
| 1842 | return out |
| 1843 | |