batch_utils/engine_hybrid.py
| 1 | """Batched hybrid-mode generation for LocateAnything-3B. |
| 2 | |
| 3 | This module keeps the stock hybrid state machine: |
| 4 | |
| 5 | MTP -> error_box -> AR |
| 6 | AR -> box_end_ar -> MTP |
| 7 | |
| 8 | Rows in a batch may be in different modes. The decode loop therefore stores |
| 9 | per-row KV caches, packs rows with the same mode for one forward call, then |
| 10 | unpacks the clean KV back per row. |
| 11 | """ |
| 12 | import copy |
| 13 | import importlib |
| 14 | import os |
| 15 | |
| 16 | import torch |
| 17 | |
| 18 | |
| 19 | from .hybrid_runtime import ( |
| 20 | ATTN_MODE, |
| 21 | AR_BATCH_SAN, |
| 22 | BATCH_SAN, |
| 23 | DEV, |
| 24 | N_FUTURE, |
| 25 | _encode_images, |
| 26 | _helpers, |
| 27 | _pad_generated, |
| 28 | _set_llm_mode, |
| 29 | _tokenize, |
| 30 | _tokenize_cached_image, |
| 31 | build_magi_scheduler_ranges, |
| 32 | language_model_forward, |
| 33 | load, |
| 34 | sample_next_tokens_batched, |
| 35 | sample_tokens_batched, |
| 36 | ) |
| 37 | |
| 38 | |
| 39 | README_MAX_NEW_TOKENS = 2048 |
| 40 | README_TEMPERATURE = 0.7 |
| 41 | README_TOP_P = 0.9 |
| 42 | README_REPETITION_PENALTY = 1.1 |
| 43 | |
| 44 | _LAST_HYBRID_STATS = None |
| 45 | |
| 46 | |
| 47 | def _row_len(kv): |
| 48 | return kv[0][0].shape[2] |
| 49 | |
| 50 | |
| 51 | def _pack_stock_kv_rows(kv_rows, rows, dev): |
| 52 | """Left-pad per-row real-token KV caches for stock-style decoding.""" |
| 53 | lengths = [0 if kv_rows[r] is None else _row_len(kv_rows[r]) for r in rows] |
| 54 | kmax = max(lengths) if lengths else 0 |
| 55 | if kmax == 0: |
| 56 | return None, torch.zeros((len(rows), 0), dtype=torch.long, device=dev), lengths, 0 |
| 57 | |
| 58 | ref = next(kv_rows[r] for r in rows if kv_rows[r] is not None) |
| 59 | packed = [] |
| 60 | for layer in range(len(ref)): |
| 61 | ref_k, ref_v = ref[layer] |
| 62 | ks, vs = [], [] |
| 63 | for r, length in zip(rows, lengths): |
| 64 | if length == 0: |
| 65 | k = ref_k.new_zeros((1, ref_k.shape[1], kmax, ref_k.shape[3])) |
| 66 | v = ref_v.new_zeros((1, ref_v.shape[1], kmax, ref_v.shape[3])) |
| 67 | else: |
| 68 | k, v = kv_rows[r][layer] |
| 69 | if length < kmax: |
| 70 | pad_shape = (1, k.shape[1], kmax - length, k.shape[3]) |
| 71 | k = torch.cat([k.new_zeros(pad_shape), k], dim=2) |
| 72 | v = torch.cat([v.new_zeros(pad_shape), v], dim=2) |
| 73 | ks.append(k) |
| 74 | vs.append(v) |
| 75 | packed.append((torch.cat(ks, dim=0), torch.cat(vs, dim=0))) |
| 76 | |
| 77 | kvalid = torch.zeros((len(rows), kmax), dtype=torch.long, device=dev) |
| 78 | for i, length in enumerate(lengths): |
| 79 | if length: |
| 80 | kvalid[i, kmax - length :] = 1 |
| 81 | return tuple(packed), kvalid, lengths, kmax |
| 82 | |
| 83 | |
| 84 | def _unpack_stock_after_forward(out_kv, local_row, old_len, uncached_len, kmax, umax): |
| 85 | """Keep old real KV plus the right-aligned uncached real tokens; drop pads/window.""" |
| 86 | out = [] |
| 87 | u0 = kmax + (umax - uncached_len) |
| 88 | u1 = kmax + umax |
| 89 | for k, v in out_kv: |
| 90 | parts_k, parts_v = [], [] |
| 91 | if old_len: |
| 92 | parts_k.append(k[local_row : local_row + 1, :, kmax - old_len : kmax, :]) |
| 93 | parts_v.append(v[local_row : local_row + 1, :, kmax - old_len : kmax, :]) |
| 94 | if uncached_len: |
| 95 | parts_k.append(k[local_row : local_row + 1, :, u0:u1, :]) |
| 96 | parts_v.append(v[local_row : local_row + 1, :, u0:u1, :]) |
| 97 | out.append((torch.cat(parts_k, dim=2).contiguous(), |
| 98 | torch.cat(parts_v, dim=2).contiguous())) |
| 99 | return tuple(out) |
| 100 | |
| 101 | |
| 102 | def _mk_generate_kwargs(temperature, top_p, top_k, repetition_penalty, row_temp=None): |
| 103 | t = temperature if row_temp is None else row_temp |
| 104 | gk = {"repetition_penalty": repetition_penalty, "generation_mode": "hybrid"} |
| 105 | if t and t > 0: |
| 106 | gk["temperature"] = t |
| 107 | if top_p is not None: |
| 108 | gk["top_p"] = top_p |
| 109 | if top_k is not None: |
| 110 | gk["top_k"] = top_k |
| 111 | return gk |
| 112 | |
| 113 | |
| 114 | def _classify_ar_token(token_val, tids): |
| 115 | if token_val == tids["box_end_token_id"]: |
| 116 | return "box_end_ar" |
| 117 | if tids["coord_start_token_id"] <= token_val <= tids["coord_end_token_id"]: |
| 118 | return "coord_ar" |
| 119 | if token_val == tids["none_token_id"]: |
| 120 | return "coord_ar" |
| 121 | return "im_end" |
| 122 | |
| 123 | |
| 124 | def _env_flag(name, default=False): |
| 125 | val = os.environ.get(name) |
| 126 | if val is None: |
| 127 | return default |
| 128 | return val.lower() not in {"0", "false", "no", "off", ""} |
| 129 | |
| 130 | |
| 131 | def _env_int(name, default): |
| 132 | val = os.environ.get(name) |
| 133 | if val is None or val == "": |
| 134 | return default |
| 135 | return int(val) |
| 136 | |
| 137 | |
| 138 | def _kv_pack_token_budget(): |
| 139 | return max(0, _env_int("LA_FLASH_KV_PACK_TOKEN_BUDGET", 0)) |
| 140 | |
| 141 | |
| 142 | def _debug_enabled(debug): |
| 143 | return _env_flag("LA_FLASH_DEBUG", False) if debug is None else bool(debug) |
| 144 | |
| 145 | |
| 146 | def _new_hybrid_stats(total_rows, scheduler, group_size, hold_max_steps, adaptive_hold_mtp_max=0): |
| 147 | return { |
| 148 | "scheduler": scheduler, |
| 149 | "requested_group_size": int(group_size or 0), |
| 150 | "hold_max_steps": int(hold_max_steps), |
| 151 | "adaptive_hold_mtp_max": int(adaptive_hold_mtp_max), |
| 152 | "input_batches": 1, |
| 153 | "input_rows": int(total_rows), |
| 154 | "groups": 0, |
| 155 | "group_sizes": [], |
| 156 | "decode_loops": 0, |
| 157 | "mixed_mode_cycles": 0, |
| 158 | "eager_mtp_then_ar_cycles": 0, |
| 159 | "ar_first_cycles": 0, |
| 160 | "pipeline_ar_after_mtp_cycles": 0, |
| 161 | "adaptive_hold_cycles": 0, |
| 162 | "adaptive_ar_first_cycles": 0, |
| 163 | "hold_ar_steps": 0, |
| 164 | "hold_ar_held_mtp_rows": 0, |
| 165 | "hold_ar_limit_mtp_forwards": 0, |
| 166 | "mtp_forwards": 0, |
| 167 | "ar_forwards": 0, |
| 168 | "mtp_forward_rows": 0, |
| 169 | "ar_forward_rows": 0, |
| 170 | "mtp_forward_query_tokens": 0, |
| 171 | "ar_forward_query_tokens": 0, |
| 172 | "max_mtp_forward_rows": 0, |
| 173 | "max_ar_forward_rows": 0, |
| 174 | "mtp_max_uncached_len": 0, |
| 175 | "ar_max_uncached_len": 0, |
| 176 | "mtp_forward_row_hist": {}, |
| 177 | "ar_forward_row_hist": {}, |
| 178 | "prompt_prefill_mode": _hybrid_prefill_mode(), |
| 179 | "prompt_prefill_forwards": 0, |
| 180 | "prompt_prefill_forward_rows": 0, |
| 181 | "prompt_prefill_forward_query_tokens": 0, |
| 182 | "prompt_prefill_real_tokens": 0, |
| 183 | "prompt_prefill_shared_groups": 0, |
| 184 | "prompt_prefill_shared_rows": 0, |
| 185 | "prompt_prefill_shared_saved_tokens": 0, |
| 186 | "kv_bucket_splits": 0, |
| 187 | "kv_bucket_groups": 0, |
| 188 | "kv_bucket_max_packed_tokens": 0, |
| 189 | } |
| 190 | |
| 191 | |
| 192 | def _set_last_hybrid_stats(stats): |
| 193 | global _LAST_HYBRID_STATS |
| 194 | _LAST_HYBRID_STATS = copy.deepcopy(stats) if stats is not None else None |
| 195 | |
| 196 | |
| 197 | def get_last_hybrid_stats(): |
| 198 | """Return scheduler/forward statistics from the most recent hybrid batch.""" |
| 199 | return copy.deepcopy(_LAST_HYBRID_STATS) |
| 200 | |
| 201 | |
| 202 | def _record_group_stats(stats, bsz): |
| 203 | if stats is None: |
| 204 | return |
| 205 | stats["groups"] += 1 |
| 206 | stats["group_sizes"].append(int(bsz)) |
| 207 | |
| 208 | |
| 209 | def _bump_hist(hist, val): |
| 210 | key = str(int(val)) |
| 211 | hist[key] = int(hist.get(key, 0)) + 1 |
| 212 | |
| 213 | |
| 214 | def _record_forward_stats(stats, kind, rows, q_len, uncached_lens): |
| 215 | if stats is None: |
| 216 | return |
| 217 | prefix = "mtp" if kind == "mtp" else "ar" |
| 218 | nrows = int(len(rows)) |
| 219 | q_len = int(q_len) |
| 220 | stats[f"{prefix}_forwards"] += 1 |
| 221 | stats[f"{prefix}_forward_rows"] += nrows |
| 222 | stats[f"{prefix}_forward_query_tokens"] += nrows * q_len |
| 223 | stats[f"max_{prefix}_forward_rows"] = max(stats[f"max_{prefix}_forward_rows"], nrows) |
| 224 | stats[f"{prefix}_max_uncached_len"] = max( |
| 225 | stats[f"{prefix}_max_uncached_len"], |
| 226 | max((int(x) for x in uncached_lens), default=0), |
| 227 | ) |
| 228 | _bump_hist(stats[f"{prefix}_forward_row_hist"], nrows) |
| 229 | |
| 230 | |
| 231 | def _record_prefill_stats(stats, rows, q_len, real_tokens, shared_groups=0, shared_rows=0, saved_tokens=0): |
| 232 | if stats is None: |
| 233 | return |
| 234 | nrows = int(rows) |
| 235 | stats["prompt_prefill_forwards"] += 1 |
| 236 | stats["prompt_prefill_forward_rows"] += nrows |
| 237 | stats["prompt_prefill_forward_query_tokens"] += nrows * int(q_len) |
| 238 | stats["prompt_prefill_real_tokens"] += int(real_tokens) |
| 239 | stats["prompt_prefill_shared_groups"] += int(shared_groups) |
| 240 | stats["prompt_prefill_shared_rows"] += int(shared_rows) |
| 241 | stats["prompt_prefill_shared_saved_tokens"] += int(saved_tokens) |
| 242 | |
| 243 | |
| 244 | def _split_rows_by_kv_budget(rows, kv_rows): |
| 245 | """Keep dense left-padded KV packs bounded when a few rows become long tails.""" |
| 246 | budget = _kv_pack_token_budget() |
| 247 | if budget <= 0 or len(rows) <= 1: |
| 248 | return [rows] |
| 249 | lengths = [0 if kv_rows[r] is None else _row_len(kv_rows[r]) for r in rows] |
| 250 | if not lengths or max(lengths) * len(rows) <= budget: |
| 251 | return [rows] |
| 252 | |
| 253 | groups = [] |
| 254 | current = [] |
| 255 | current_max = 0 |
| 256 | for row, length in sorted(zip(rows, lengths), key=lambda item: item[1]): |
| 257 | next_max = max(current_max, int(length)) |
| 258 | if current and next_max * (len(current) + 1) > budget: |
| 259 | groups.append(current) |
| 260 | current = [row] |
| 261 | current_max = int(length) |
| 262 | else: |
| 263 | current.append(row) |
| 264 | current_max = next_max |
| 265 | if current: |
| 266 | groups.append(current) |
| 267 | return groups or [rows] |
| 268 | |
| 269 | |
| 270 | def _record_kv_bucket_stats(stats, groups, kv_rows): |
| 271 | if stats is None: |
| 272 | return |
| 273 | max_packed = 0 |
| 274 | for group in groups: |
| 275 | if not group: |
| 276 | continue |
| 277 | kmax = max((0 if kv_rows[r] is None else _row_len(kv_rows[r])) for r in group) |
| 278 | max_packed = max(max_packed, int(kmax) * len(group)) |
| 279 | stats["kv_bucket_max_packed_tokens"] = max(stats["kv_bucket_max_packed_tokens"], max_packed) |
| 280 | if len(groups) > 1: |
| 281 | stats["kv_bucket_splits"] += 1 |
| 282 | stats["kv_bucket_groups"] += len(groups) |
| 283 | |
| 284 | |
| 285 | def _hybrid_scheduler(scheduler): |
| 286 | val = os.environ.get("LA_FLASH_HYBRID_SCHEDULER", "eager") if scheduler is None else scheduler |
| 287 | val = str(val).strip().lower() |
| 288 | aliases = { |
| 289 | "": "eager", |
| 290 | "default": "eager", |
| 291 | "normal": "eager", |
| 292 | "hold": "hold_ar", |
| 293 | "hold-ar": "hold_ar", |
| 294 | "hold_mtp": "hold_ar", |
| 295 | "hold-mtp": "hold_ar", |
| 296 | "repair_first": "ar_first", |
| 297 | "repair-first": "ar_first", |
| 298 | "ar-first": "ar_first", |
| 299 | } |
| 300 | val = aliases.get(val, val) |
| 301 | if val not in {"eager", "hold_ar", "ar_first", "pipeline", "adaptive"}: |
| 302 | raise ValueError("scheduler must be one of: eager, hold_ar, ar_first, pipeline, adaptive") |
| 303 | return val |
| 304 | |
| 305 | |
| 306 | def _hybrid_group_size(group_size): |
| 307 | if group_size is None: |
| 308 | return max(0, _env_int("LA_FLASH_HYBRID_GROUP_SIZE", 0)) |
| 309 | return max(0, int(group_size)) |
| 310 | |
| 311 | |
| 312 | def _hybrid_prefill_mode(): |
| 313 | val = os.environ.get("LA_FLASH_HYBRID_PREFILL", "shared").strip().lower() |
| 314 | aliases = { |
| 315 | "0": "none", |
| 316 | "false": "none", |
| 317 | "off": "none", |
| 318 | "legacy": "none", |
| 319 | "1": "per_row", |
| 320 | "true": "per_row", |
| 321 | "on": "per_row", |
| 322 | "single": "per_row", |
| 323 | "row": "per_row", |
| 324 | "rows": "per_row", |
| 325 | "batched": "batch", |
| 326 | "prefix": "shared", |
| 327 | "shared_prefix": "shared", |
| 328 | "shared-image": "shared", |
| 329 | "shared_image": "shared", |
| 330 | "vision": "shared", |
| 331 | } |
| 332 | val = aliases.get(val, val) |
| 333 | if val not in {"none", "per_row", "batch", "shared"}: |
| 334 | raise ValueError("LA_FLASH_HYBRID_PREFILL must be one of none, per_row, batch, shared") |
| 335 | return val |
| 336 | |
| 337 | |
| 338 | def _tolist(t): |
| 339 | return t.detach().cpu().tolist() |
| 340 | |
| 341 | |
| 342 | def _safe_decode_rows(tok, input_ids): |
| 343 | rows = [] |
| 344 | for row in _tolist(input_ids): |
| 345 | try: |
| 346 | rows.append(tok.decode(torch.tensor(row), skip_special_tokens=False)) |
| 347 | except Exception: |
| 348 | rows.append("<decode failed>") |
| 349 | return rows |
| 350 | |
| 351 | |
| 352 | def _safe_decode_row(tok, row): |
| 353 | try: |
| 354 | return tok.decode(torch.tensor(row), skip_special_tokens=False) |
| 355 | except Exception: |
| 356 | return "<decode failed>" |
| 357 | |
| 358 | |
| 359 | def _effective_allowed_mask(mask2d, q_len, past_len, mtp_window=False): |
| 360 | """Readable 1/0 q-by-k mask derived from the 2D key-valid mask. |
| 361 | |
| 362 | This mirrors the model path at a high level: |
| 363 | causal + padding columns, then the MTP window update |
| 364 | attn[-block:, -block:] = visible and attn[-block:, -block-1] = masked. |
| 365 | """ |
| 366 | rows = [] |
| 367 | key_valid = mask2d.detach().cpu().bool() |
| 368 | total_len = int(key_valid.numel()) |
| 369 | for qi in range(q_len): |
| 370 | q_abs = past_len + qi |
| 371 | row = [] |
| 372 | for ki in range(total_len): |
| 373 | row.append(1 if bool(key_valid[ki]) and ki <= q_abs else 0) |
| 374 | rows.append(row) |
| 375 | |
| 376 | if mtp_window and q_len >= N_FUTURE and total_len >= N_FUTURE: |
| 377 | q0 = q_len - N_FUTURE |
| 378 | k0 = total_len - N_FUTURE |
| 379 | for qi in range(q0, q_len): |
| 380 | for ki in range(k0, total_len): |
| 381 | rows[qi][ki] = 1 |
| 382 | if k0 - 1 >= 0: |
| 383 | rows[qi][k0 - 1] = 0 |
| 384 | return rows |
| 385 | |
| 386 | |
| 387 | def _tail_matrix(mat, rows=None, cols=None): |
| 388 | if rows is not None: |
| 389 | mat = mat[-rows:] |
| 390 | if cols is not None: |
| 391 | mat = [row[-cols:] for row in mat] |
| 392 | return mat |
| 393 | |
| 394 | |
| 395 | def _format_01_matrix(mat): |
| 396 | return "\n".join(" " + " ".join(str(int(v)) for v in row) for row in mat) |
| 397 | |
| 398 | |
| 399 | def _safe_sdpa_mask_enabled(): |
| 400 | return _env_flag("LA_FLASH_SDPA_SAFE_4D_MASK", True) |
| 401 | |
| 402 | |
| 403 | def _build_safe_sdpa_visible_mask(attention_mask_2d, input_ids, past_len, mtp_window=False): |
| 404 | """Build a 4D 1/0 visible mask, with harmless visibility for all-masked pad queries. |
| 405 | |
| 406 | The remote Qwen2 SDPA path uses a 2D key-valid mask and can create fully |
| 407 | masked query rows for left-padded, no-cache prefill. Those rows can produce |
| 408 | NaNs inside SDPA and later contaminate real tokens through masked K columns. |
| 409 | This 4D mask keeps real-token visibility identical, and only gives otherwise |
| 410 | all-masked query rows one valid fallback key so their activations stay finite. |
| 411 | """ |
| 412 | bsz, q_len = int(input_ids.shape[0]), int(input_ids.shape[1]) |
| 413 | key_len = int(attention_mask_2d.shape[1]) |
| 414 | dev = input_ids.device |
| 415 | key_valid = attention_mask_2d.to(dtype=torch.bool, device=dev) |
| 416 | key_idx = torch.arange(key_len, device=dev).view(1, 1, key_len) |
| 417 | q_abs = (past_len + torch.arange(q_len, device=dev)).view(1, q_len, 1) |
| 418 | visible = key_valid[:, None, :] & (key_idx <= q_abs) |
| 419 | |
| 420 | if mtp_window and q_len >= N_FUTURE and key_len >= N_FUTURE: |
| 421 | k0 = key_len - N_FUTURE |
| 422 | visible[:, -N_FUTURE:, k0:key_len] = key_valid[:, None, k0:key_len] |
| 423 | blocked_k = k0 - 1 |
| 424 | if blocked_k >= 0: |
| 425 | visible[:, -N_FUTURE:, blocked_k] = False |
| 426 | |
| 427 | row_has_key = visible.any(dim=-1) |
| 428 | fallback_rows = int((~row_has_key).sum().item()) |
| 429 | if fallback_rows: |
| 430 | for b in range(bsz): |
| 431 | valid = torch.nonzero(key_valid[b], as_tuple=False).flatten() |
| 432 | fallback = int(valid[0].item()) if valid.numel() else 0 |
| 433 | missing = torch.nonzero(~row_has_key[b], as_tuple=False).flatten() |
| 434 | if missing.numel(): |
| 435 | visible[b, missing, fallback] = True |
| 436 | |
| 437 | mask = visible[:, None, :, :].to(dtype=torch.bfloat16) |
| 438 | try: |
| 439 | mask._la_flash_visible_mask = True |
| 440 | except Exception: |
| 441 | pass |
| 442 | return mask, fallback_rows |
| 443 | |
| 444 | |
| 445 | def _mask_desc(mask): |
| 446 | if mask is None: |
| 447 | return "none" |
| 448 | if isinstance(mask, dict): |
| 449 | return "magi_ranges" |
| 450 | if hasattr(mask, "dim"): |
| 451 | return "4d_safe_sdpa" if mask.dim() == 4 else "2d_key_valid" |
| 452 | return type(mask).__name__ |
| 453 | |
| 454 | |
| 455 | def _forward_attention_mask(model, input_ids, attention_mask_2d, past_len, mtp_window=False, range_plan=False): |
| 456 | llm = model.language_model.model |
| 457 | if getattr(model, "_la_flash_requested_attn", ATTN_MODE) in {"magi", "la_flash"}: |
| 458 | range_plan = build_magi_scheduler_ranges( |
| 459 | model, attention_mask_2d, input_ids, past_len, mtp_window=mtp_window) |
| 460 | if range_plan is not None: |
| 461 | return range_plan, 0 |
| 462 | needs_safe_pad = ( |
| 463 | past_len == 0 |
| 464 | and attention_mask_2d is not None |
| 465 | and attention_mask_2d.dim() == 2 |
| 466 | and input_ids.shape[0] > 1 |
| 467 | ) |
| 468 | if ( |
| 469 | getattr(llm, "_attn_implementation", None) == "sdpa" |
| 470 | and _safe_sdpa_mask_enabled() |
| 471 | and needs_safe_pad |
| 472 | and attention_mask_2d is not None |
| 473 | and attention_mask_2d.dim() == 2 |
| 474 | ): |
| 475 | return _build_safe_sdpa_visible_mask(attention_mask_2d, input_ids, past_len, mtp_window) |
| 476 | return attention_mask_2d, 0 |
| 477 | |
| 478 | |
| 479 | def _actual_sdpa_allowed_masks(model, input_ids, attention_mask, past_len): |
| 480 | """Recreate the remote Qwen2 SDPA 4D additive mask and return a 0/1 view.""" |
| 481 | llm = model.language_model.model |
| 482 | mod = importlib.import_module(type(llm).__module__) |
| 483 | bsz, q_len = int(input_ids.shape[0]), int(input_ids.shape[1]) |
| 484 | dummy = torch.empty( |
| 485 | (bsz, q_len, 1), |
| 486 | dtype=torch.bfloat16, |
| 487 | device=input_ids.device, |
| 488 | ) |
| 489 | mask4 = mod._prepare_4d_causal_attention_mask( |
| 490 | attention_mask, |
| 491 | (bsz, q_len), |
| 492 | dummy, |
| 493 | past_len, |
| 494 | sliding_window=getattr(llm.config, "sliding_window", None), |
| 495 | ) |
| 496 | remote_ar_decode = q_len == 1 or ( |
| 497 | input_ids is not None and int(input_ids[0, -1].item()) != int(llm.text_mask_token_id) |
| 498 | ) |
| 499 | if not remote_ar_decode and mask4 is not None and mask4.dim() == 4: |
| 500 | rows = [] |
| 501 | for b in range(bsz): |
| 502 | rows.append( |
| 503 | mod.update_causal_mask_for_one_gen_window_2d( |
| 504 | input_ids[b], |
| 505 | mask4[b][0].clone(), |
| 506 | block_size=int(llm.block_size), |
| 507 | use_cache=True, |
| 508 | causal_attn=bool(getattr(llm, "causal_attn", False)), |
| 509 | ).unsqueeze(0) |
| 510 | ) |
| 511 | mask4 = torch.stack(rows, dim=0) |
| 512 | allowed = (mask4[:, 0] >= 0).to(torch.int8).detach().cpu().tolist() |
| 513 | return allowed, tuple(mask4.shape), remote_ar_decode |
| 514 | |
| 515 | |
| 516 | def _debug_magi_ranges(q_len, past_len, mtp_window=False): |
| 517 | kv_len = past_len + q_len |
| 518 | ar_decode = not mtp_window |
| 519 | if ar_decode: |
| 520 | return { |
| 521 | "q_ranges": [[0, q_len]], |
| 522 | "k_ranges": [[0, kv_len]], |
| 523 | "attn_type_map": ["CAUSAL"], |
| 524 | } |
| 525 | |
| 526 | block = N_FUTURE |
| 527 | if not (0 < block <= q_len <= kv_len): |
| 528 | return {"error": f"invalid magi MTP shape: block={block}, q_len={q_len}, kv_len={kv_len}"} |
| 529 | |
| 530 | prefix_len = kv_len - block |
| 531 | blocked_k = prefix_len - 1 |
| 532 | q_ranges, k_ranges, attn_types = [], [], [] |
| 533 | if q_len == kv_len: |
| 534 | if prefix_len > 0: |
| 535 | q_ranges.append([0, prefix_len]) |
| 536 | k_ranges.append([0, prefix_len]) |
| 537 | attn_types.append("CAUSAL") |
| 538 | if prefix_len > 0 and blocked_k > 0: |
| 539 | q_ranges.append([prefix_len, kv_len]) |
| 540 | k_ranges.append([0, blocked_k]) |
| 541 | attn_types.append("FULL") |
| 542 | q_ranges.append([prefix_len, kv_len]) |
| 543 | k_ranges.append([prefix_len, kv_len]) |
| 544 | attn_types.append("FULL") |
| 545 | else: |
| 546 | recompute = q_len - block |
| 547 | q_global_start = kv_len - q_len |
| 548 | for i in range(recompute): |
| 549 | g = q_global_start + i |
| 550 | q_ranges.append([i, i + 1]) |
| 551 | k_ranges.append([0, g + 1]) |
| 552 | attn_types.append("FULL") |
| 553 | q_win = [recompute, q_len] |
| 554 | if blocked_k > 0: |
| 555 | q_ranges.append(q_win) |
| 556 | k_ranges.append([0, blocked_k]) |
| 557 | attn_types.append("FULL") |
| 558 | q_ranges.append(q_win) |
| 559 | k_ranges.append([prefix_len, kv_len]) |
| 560 | attn_types.append("FULL") |
| 561 | |
| 562 | return {"q_ranges": q_ranges, "k_ranges": k_ranges, "attn_type_map": attn_types} |
| 563 | |
| 564 | |
| 565 | def _print_debug_forward(label, model, tok, input_ids, attention_mask, position_ids, |
| 566 | past_len, mtp_window=False, extra=None, attention_impl="sdpa"): |
| 567 | print(f"\n========== LA Flash DEBUG {label} ==========", flush=True) |
| 568 | if extra: |
| 569 | for k, v in extra.items(): |
| 570 | print(f"{k}: {v}", flush=True) |
| 571 | tail = int(os.environ.get("LA_FLASH_DEBUG_TAIL", "15")) |
| 572 | bsz, q_len = int(input_ids.shape[0]), int(input_ids.shape[1]) |
| 573 | key_len = int(attention_mask.shape[1]) |
| 574 | q_tail, k_tail = min(tail, q_len), min(tail, key_len) |
| 575 | print( |
| 576 | "shapes: " |
| 577 | f"input_ids={tuple(input_ids.shape)} " |
| 578 | f"position_ids={tuple(position_ids.shape)} " |
| 579 | f"attention_mask_key_valid={tuple(attention_mask.shape)} " |
| 580 | f"mask_2d_q_by_k=({bsz}, {q_len}, {key_len}) " |
| 581 | f"mask_2d_tail=({bsz}, {q_tail}, {k_tail}) " |
| 582 | f"past_len={past_len} q_len={q_len} " |
| 583 | f"mtp_window={mtp_window} ar_decode={not mtp_window}", |
| 584 | flush=True, |
| 585 | ) |
| 586 | print(f"dtypes/devices: input_ids={input_ids.dtype}@{input_ids.device} position_ids={position_ids.dtype}@{position_ids.device} attention_mask={attention_mask.dtype}@{attention_mask.device}", flush=True) |
| 587 | print(f"attention_impl={attention_impl}", flush=True) |
| 588 | input_rows = _tolist(input_ids) |
| 589 | pos_rows = _tolist(position_ids) |
| 590 | print(f"tail_window_last={tail}", flush=True) |
| 591 | print(f"input_ids_tail.shape=({bsz}, {q_tail})", flush=True) |
| 592 | print(f"position_ids_tail.shape=({bsz}, {q_tail})", flush=True) |
| 593 | actual_sdpa = None |
| 594 | if attention_impl in {"sdpa", "eager", "la_flash"}: |
| 595 | try: |
| 596 | actual_sdpa = _actual_sdpa_allowed_masks(model, input_ids, attention_mask, past_len) |
| 597 | print( |
| 598 | f"actual_sdpa_4d_mask_shape={actual_sdpa[1]} " |
| 599 | f"remote_ar_decode={actual_sdpa[2]}", |
| 600 | flush=True, |
| 601 | ) |
| 602 | except Exception as e: |
| 603 | print(f"actual_sdpa_4d_mask_debug_failed={type(e).__name__}: {e}", flush=True) |
| 604 | |
| 605 | for b in range(input_ids.shape[0]): |
| 606 | ids_tail = input_rows[b][-tail:] |
| 607 | pos_tail = pos_rows[b][-tail:] |
| 608 | allowed = _effective_allowed_mask(attention_mask[b], input_ids.shape[1], past_len, mtp_window) |
| 609 | q_tail = min(tail, len(allowed)) |
| 610 | k_tail = min(tail, len(allowed[0]) if allowed else 0) |
| 611 | allowed_tail = _tail_matrix(allowed, rows=q_tail, cols=k_tail) |
| 612 | print(f"batch_row={b} ar_decode={not mtp_window}", flush=True) |
| 613 | print(f"input_ids_tail[-{tail}:]: {ids_tail}", flush=True) |
| 614 | print(f"decoded_tail[-{tail}:]: {_safe_decode_row(tok, ids_tail)}", flush=True) |
| 615 | print(f"position_ids_tail[-{tail}:]: {pos_tail}", flush=True) |
| 616 | print(f"expected_mask_2d_tail[-{q_tail}:,-{k_tail}:].shape=({q_tail}, {k_tail})", flush=True) |
| 617 | print(_format_01_matrix(allowed_tail), flush=True) |
| 618 | if actual_sdpa is not None: |
| 619 | actual = actual_sdpa[0][b] |
| 620 | actual_tail = _tail_matrix(actual, rows=q_tail, cols=k_tail) |
| 621 | mismatch = sum( |
| 622 | int(allowed[qi][ki] != actual[qi][ki]) |
| 623 | for qi in range(len(allowed)) |
| 624 | for ki in range(len(allowed[qi])) |
| 625 | ) |
| 626 | print( |
| 627 | f"actual_sdpa_mask_2d_tail[-{q_tail}:,-{k_tail}:].shape=({q_tail}, {k_tail})", |
| 628 | flush=True, |
| 629 | ) |
| 630 | print(_format_01_matrix(actual_tail), flush=True) |
| 631 | print(f"expected_vs_actual_sdpa_mismatch_count={mismatch}", flush=True) |
| 632 | |
| 633 | if _env_flag("LA_FLASH_DEBUG_FULL_MASK", False): |
| 634 | masks = [ |
| 635 | _effective_allowed_mask(attention_mask[b], input_ids.shape[1], past_len, mtp_window) |
| 636 | for b in range(input_ids.shape[0]) |
| 637 | ] |
| 638 | print("effective_allowed_mask_q_by_k_FULL:", masks, flush=True) |
| 639 | if attention_impl == "magi": |
| 640 | if bsz == 1: |
| 641 | print( |
| 642 | "magi_ranges:", |
| 643 | _debug_magi_ranges(input_ids.shape[1], past_len, mtp_window), |
| 644 | flush=True, |
| 645 | ) |
| 646 | else: |
| 647 | print( |
| 648 | "magi_ranges: built once per forward from the batched scheduler mask", |
| 649 | flush=True, |
| 650 | ) |
| 651 | print( |
| 652 | "magi_ranges_single_row_template:", |
| 653 | _debug_magi_ranges(input_ids.shape[1], past_len, mtp_window), |
| 654 | flush=True, |
| 655 | ) |
| 656 | |
| 657 | |
| 658 | def _common_prefix_len(prompt_ids, rows): |
| 659 | if not rows: |
| 660 | return 0 |
| 661 | first = prompt_ids[rows[0]] |
| 662 | max_len = min(int(prompt_ids[r].numel()) for r in rows) |
| 663 | prefix_len = 0 |
| 664 | for idx in range(max_len): |
| 665 | val = int(first[idx].item()) |
| 666 | if all(int(prompt_ids[r][idx].item()) == val for r in rows[1:]): |
| 667 | prefix_len += 1 |
| 668 | else: |
| 669 | break |
| 670 | return prefix_len |
| 671 | |
| 672 | |
| 673 | def _prefill_shared_prefix_kv_rows(model, prompt_ids, vit_list, img_tok, pad, dev, stats=None, debug=False): |
| 674 | """Cache one common prompt prefix per image-feature group. |
| 675 | |
| 676 | Multi-category split repeats the same image feature tensor for each |
| 677 | category prompt. Token ids are identical through the image tokens and the |
| 678 | fixed prompt prefix, so we prefill that shared prefix once and let each |
| 679 | category row forward only its text suffix. |
| 680 | """ |
| 681 | bsz = len(prompt_ids) |
| 682 | kv_rows = [None] * bsz |
| 683 | cached_lens = [0] * bsz |
| 684 | groups = {} |
| 685 | for row, vit in enumerate(vit_list): |
| 686 | groups.setdefault(id(vit), []).append(row) |
| 687 | |
| 688 | items = [] |
| 689 | min_prefix_len = max(1, _env_int("LA_FLASH_SHARED_PREFILL_MIN_PREFIX", 64)) |
| 690 | for rows in groups.values(): |
| 691 | if len(rows) < 2: |
| 692 | continue |
| 693 | prefix_len = _common_prefix_len(prompt_ids, rows) |
| 694 | if prefix_len < min_prefix_len: |
| 695 | continue |
| 696 | prefix_ids = prompt_ids[rows[0]][:prefix_len] |
| 697 | image_token_count = int((prefix_ids == img_tok).sum().item()) |
| 698 | if image_token_count != int(vit_list[rows[0]].shape[0]): |
| 699 | if debug: |
| 700 | print( |
| 701 | "LA Flash shared prefill skip group: " |
| 702 | f"rows={rows} prefix_len={prefix_len} " |
| 703 | f"image_tokens={image_token_count} visual_rows={int(vit_list[rows[0]].shape[0])}", |
| 704 | flush=True, |
| 705 | ) |
| 706 | continue |
| 707 | items.append((rows, prefix_ids, vit_list[rows[0]])) |
| 708 | |
| 709 | if not items: |
| 710 | return kv_rows, cached_lens |
| 711 | |
| 712 | lengths = [int(ids.numel()) for _rows, ids, _vit in items] |
| 713 | pmax = max(lengths) |
| 714 | input_ids = torch.full((len(items), pmax), pad, dtype=torch.long, device=dev) |
| 715 | amask = torch.zeros((len(items), pmax), dtype=torch.long, device=dev) |
| 716 | pos = torch.ones((len(items), pmax), dtype=torch.long, device=dev) |
| 717 | for item_idx, (_rows, ids, _vit) in enumerate(items): |
| 718 | length = lengths[item_idx] |
| 719 | left = pmax - length |
| 720 | input_ids[item_idx, left:] = ids.to(dev) |
| 721 | amask[item_idx, left:] = 1 |
| 722 | pos[item_idx, left:] = torch.arange(length, dtype=torch.long, device=dev) |
| 723 | |
| 724 | visual_features = torch.cat([vit for _rows, _ids, vit in items], dim=0) |
| 725 | assert int((input_ids == img_tok).sum().item()) == visual_features.shape[0], \ |
| 726 | "shared-prefix image-token count != supplied visual_features rows" |
| 727 | |
| 728 | if debug: |
| 729 | group_sizes = [len(rows) for rows, _ids, _vit in items] |
| 730 | print( |
| 731 | "LA Flash hybrid shared prompt prefill " |
| 732 | f"groups={len(items)} group_sizes={group_sizes} prefix_lens={lengths}", |
| 733 | flush=True, |
| 734 | ) |
| 735 | |
| 736 | forward_mask, fallback_rows = _forward_attention_mask( |
| 737 | model, input_ids, amask, 0, mtp_window=False) |
| 738 | if debug and fallback_rows: |
| 739 | print( |
| 740 | "LA Flash hybrid shared prefill safe SDPA fallback " |
| 741 | f"query_rows={fallback_rows}", |
| 742 | flush=True, |
| 743 | ) |
| 744 | forward_kwargs = dict( |
| 745 | input_ids=input_ids, |
| 746 | visual_features=visual_features, |
| 747 | image_token_index=img_tok, |
| 748 | attention_mask=forward_mask, |
| 749 | position_ids=pos, |
| 750 | past_key_values=None, |
| 751 | use_cache=True, |
| 752 | ) |
| 753 | if isinstance(forward_mask, dict): |
| 754 | out = language_model_forward(model, **forward_kwargs, return_logits=False) |
| 755 | else: |
| 756 | out = model.language_model.model(**forward_kwargs) |
| 757 | |
| 758 | real_tokens = sum(lengths) |
| 759 | shared_rows = sum(len(rows) for rows, _ids, _vit in items) |
| 760 | saved_tokens = sum((len(rows) - 1) * length for (rows, _ids, _vit), length in zip(items, lengths)) |
| 761 | _record_prefill_stats( |
| 762 | stats, |
| 763 | rows=len(items), |
| 764 | q_len=pmax, |
| 765 | real_tokens=real_tokens, |
| 766 | shared_groups=len(items), |
| 767 | shared_rows=shared_rows, |
| 768 | saved_tokens=saved_tokens, |
| 769 | ) |
| 770 | |
| 771 | for item_idx, (rows, _ids, _vit) in enumerate(items): |
| 772 | prefix_len = lengths[item_idx] |
| 773 | prefix_kv = _unpack_stock_after_forward(out.past_key_values, item_idx, 0, prefix_len, 0, pmax) |
| 774 | for row in rows: |
| 775 | kv_rows[row] = prefix_kv |
| 776 | cached_lens[row] = prefix_len |
| 777 | |
| 778 | return kv_rows, cached_lens |
| 779 | |
| 780 | |
| 781 | @torch.no_grad() |
| 782 | def _prefill_prompt_kv_rows(model, prompt_ids, vit_list, img_tok, pad, dev, mode, debug=False, stats=None): |
| 783 | """Return per-row prompt KV caches and cached lengths. |
| 784 | |
| 785 | ``mode='none'`` preserves the legacy stock-like first MTP forward where the |
| 786 | whole prompt and the 6-token MTP window are forwarded together. The split |
| 787 | prefill modes keep prompt KV clean before the scheduler batches only short |
| 788 | suffix/window forwards, which avoids ragged prompt+window masking in the |
| 789 | first decode step. |
| 790 | """ |
| 791 | bsz = len(prompt_ids) |
| 792 | lengths = [int(p.numel()) for p in prompt_ids] |
| 793 | if mode == "none": |
| 794 | return [None] * bsz, [0] * bsz |
| 795 | |
| 796 | base = model.language_model.model |
| 797 | if debug: |
| 798 | print(f"LA Flash hybrid prompt prefill mode={mode} rows={bsz} lengths={lengths}", flush=True) |
| 799 | |
| 800 | if mode == "shared": |
| 801 | return _prefill_shared_prefix_kv_rows( |
| 802 | model, prompt_ids, vit_list, img_tok, pad, dev, stats=stats, debug=debug) |
| 803 | |
| 804 | if mode == "per_row": |
| 805 | kv_rows = [] |
| 806 | for b, ids in enumerate(prompt_ids): |
| 807 | ids = ids.to(dev).unsqueeze(0) |
| 808 | pos = torch.arange(ids.shape[1], dtype=torch.long, device=dev).unsqueeze(0) |
| 809 | out = base( |
| 810 | input_ids=ids, |
| 811 | visual_features=vit_list[b], |
| 812 | image_token_index=img_tok, |
| 813 | attention_mask=None, |
| 814 | position_ids=pos, |
| 815 | past_key_values=None, |
| 816 | use_cache=True, |
| 817 | ) |
| 818 | kv_rows.append(out.past_key_values) |
| 819 | _record_prefill_stats(stats, rows=1, q_len=ids.shape[1], real_tokens=ids.shape[1]) |
| 820 | return kv_rows, lengths |
| 821 | |
| 822 | pmax = max(lengths) |
| 823 | input_ids = torch.full((bsz, pmax), pad, dtype=torch.long, device=dev) |
| 824 | amask = torch.zeros((bsz, pmax), dtype=torch.long, device=dev) |
| 825 | pos = torch.ones((bsz, pmax), dtype=torch.long, device=dev) |
| 826 | for b, ids in enumerate(prompt_ids): |
| 827 | left = pmax - lengths[b] |
| 828 | input_ids[b, left:] = ids.to(dev) |
| 829 | amask[b, left:] = 1 |
| 830 | pos[b, left:] = torch.arange(lengths[b], dtype=torch.long, device=dev) |
| 831 | |
| 832 | visual_features = torch.cat(vit_list, dim=0) |
| 833 | assert int((input_ids == img_tok).sum().item()) == visual_features.shape[0], \ |
| 834 | "image-token count != supplied visual_features rows" |
| 835 | forward_mask, fallback_rows = _forward_attention_mask( |
| 836 | model, input_ids, amask, 0, mtp_window=False) |
| 837 | if debug and fallback_rows: |
| 838 | print( |
| 839 | "LA Flash hybrid batch prefill safe SDPA fallback " |
| 840 | f"query_rows={fallback_rows}", |
| 841 | flush=True, |
| 842 | ) |
| 843 | forward_kwargs = dict( |
| 844 | input_ids=input_ids, |
| 845 | visual_features=visual_features, |
| 846 | image_token_index=img_tok, |
| 847 | attention_mask=forward_mask, |
| 848 | position_ids=pos, |
| 849 | past_key_values=None, |
| 850 | use_cache=True, |
| 851 | ) |
| 852 | if isinstance(forward_mask, dict): |
| 853 | out = language_model_forward(model, **forward_kwargs, return_logits=False) |
| 854 | else: |
| 855 | out = base(**forward_kwargs) |
| 856 | _record_prefill_stats(stats, rows=bsz, q_len=pmax, real_tokens=sum(lengths)) |
| 857 | kv_rows = [ |
| 858 | _unpack_stock_after_forward(out.past_key_values, b, 0, lengths[b], 0, pmax) |
| 859 | for b in range(bsz) |
| 860 | ] |
| 861 | return kv_rows, lengths |
| 862 | |
| 863 | |
| 864 | @torch.no_grad() |
| 865 | def generate_batch_hybrid(pairs, temperature=README_TEMPERATURE, top_p=README_TOP_P, top_k=None, |
| 866 | repetition_penalty=README_REPETITION_PENALTY, |
| 867 | max_new_tokens=README_MAX_NEW_TOKENS, temps=None, |
| 868 | debug=None, scheduler=None, group_size=None, |
| 869 | vision_features=None, _stats=None): |
| 870 | """Batched stock-style LocateAnything-3B hybrid generation. |
| 871 | |
| 872 | This mirrors ``model.generate(..., generation_mode='hybrid')``: each row |
| 873 | owns a full ``generated`` token stream plus a KV cache truncated to real |
| 874 | generated tokens before sampling. MTP forwards |
| 875 | ``generated[cached_len:] + duplicate-last + mask*5``; AR forwards |
| 876 | ``generated[cached_len:]``. |
| 877 | """ |
| 878 | tok, _, model = load() |
| 879 | san, hpat = _helpers() |
| 880 | tids = model.token_ids |
| 881 | img_tok = model.config.image_token_index |
| 882 | mask_tok = tids["default_mask_token_id"] |
| 883 | im_end = tids["im_end_token_id"] |
| 884 | pad = tok.pad_token_id if tok.pad_token_id is not None else im_end |
| 885 | dev = DEV |
| 886 | |
| 887 | if not pairs: |
| 888 | return [] |
| 889 | if temps is not None and len(temps) != len(pairs): |
| 890 | raise ValueError("temps must have the same length as pairs") |
| 891 | if vision_features is not None and len(vision_features) != len(pairs): |
| 892 | raise ValueError("vision_features must have the same length as pairs") |
| 893 | debug = _debug_enabled(debug) |
| 894 | scheduler = _hybrid_scheduler(scheduler) |
| 895 | group_size = _hybrid_group_size(group_size) |
| 896 | requested_attn = getattr(model, "_la_flash_requested_attn", ATTN_MODE) |
| 897 | use_magi = requested_attn == "magi" |
| 898 | prefill_mode = _hybrid_prefill_mode() |
| 899 | hold_max_steps = max(0, _env_int("LA_FLASH_HYBRID_HOLD_MAX_STEPS", 5)) |
| 900 | adaptive_hold_mtp_max = max(0, _env_int("LA_FLASH_HYBRID_ADAPTIVE_HOLD_MTP_MAX", 3)) |
| 901 | top_level_stats = _stats is None |
| 902 | if top_level_stats: |
| 903 | _stats = _new_hybrid_stats( |
| 904 | len(pairs), scheduler, group_size, hold_max_steps, adaptive_hold_mtp_max) |
| 905 | if os.environ.get("LA_FLASH_PLAN_STATS", "0") == "1": |
| 906 | model._la_flash_sparse_plan_stats = None |
| 907 | if group_size and len(pairs) > group_size: |
| 908 | outs = [] |
| 909 | if debug: |
| 910 | print( |
| 911 | f"LA Flash hybrid grouped scheduling: total_rows={len(pairs)} " |
| 912 | f"group_size={group_size} scheduler={scheduler} hold_max_steps={hold_max_steps} " |
| 913 | f"adaptive_hold_mtp_max={adaptive_hold_mtp_max}", |
| 914 | flush=True, |
| 915 | ) |
| 916 | for start in range(0, len(pairs), group_size): |
| 917 | end = min(start + group_size, len(pairs)) |
| 918 | chunk_temps = temps[start:end] if temps is not None else None |
| 919 | chunk_vision_features = ( |
| 920 | vision_features[start:end] if vision_features is not None else None |
| 921 | ) |
| 922 | if debug: |
| 923 | print(f"LA Flash hybrid group rows=[{start}:{end}]", flush=True) |
| 924 | outs.extend(generate_batch_hybrid( |
| 925 | pairs[start:end], |
| 926 | temperature=temperature, |
| 927 | top_p=top_p, |
| 928 | top_k=top_k, |
| 929 | repetition_penalty=repetition_penalty, |
| 930 | max_new_tokens=max_new_tokens, |
| 931 | temps=chunk_temps, |
| 932 | debug=debug, |
| 933 | scheduler=scheduler, |
| 934 | group_size=0, |
| 935 | vision_features=chunk_vision_features, |
| 936 | _stats=_stats, |
| 937 | )) |
| 938 | if top_level_stats: |
| 939 | _set_last_hybrid_stats(_stats) |
| 940 | return outs |
| 941 | |
| 942 | use_cached_tokenize = ( |
| 943 | vision_features is not None |
| 944 | and os.environ.get("LA_FLASH_CACHE_TOKENIZE", "1") != "0" |
| 945 | ) |
| 946 | if use_cached_tokenize: |
| 947 | try: |
| 948 | prompt_ids = [ |
| 949 | _tokenize_cached_image(q, int(v.shape[0]), im=im) |
| 950 | for (im, q), v in zip(pairs, vision_features) |
| 951 | ] |
| 952 | except Exception as exc: |
| 953 | if os.environ.get("LA_FLASH_CACHE_TOKENIZE_STRICT", "0") == "1": |
| 954 | raise |
| 955 | if debug: |
| 956 | print(f"LA Flash cached tokenize fallback: {exc}", flush=True) |
| 957 | prompt_ids = [_tokenize(im, q) for im, q in pairs] |
| 958 | else: |
| 959 | prompt_ids = [_tokenize(im, q) for im, q in pairs] |
| 960 | vit_list = ( |
| 961 | list(vision_features) |
| 962 | if vision_features is not None |
| 963 | else _encode_images([im for im, _ in pairs]) |
| 964 | ) |
| 965 | lengths = [int(p.numel()) for p in prompt_ids] |
| 966 | bsz = len(pairs) |
| 967 | _record_group_stats(_stats, bsz) |
| 968 | |
| 969 | _set_llm_mode(model, requested_attn) |
| 970 | |
| 971 | modes = ["mtp"] * bsz |
| 972 | finished = [False] * bsz |
| 973 | gen_ids = [[] for _ in range(bsz)] |
| 974 | full_ids = [list(ids.detach().cpu().tolist()) for ids in prompt_ids] |
| 975 | kv_rows, cached_lens = _prefill_prompt_kv_rows( |
| 976 | model, prompt_ids, vit_list, img_tok, pad, dev, prefill_mode, debug=debug, stats=_stats) |
| 977 | total_limits = [lengths[b] + max_new_tokens for b in range(bsz)] |
| 978 | |
| 979 | row_temps = [float(temperature or 0.0)] * bsz if temps is None else [float(t or 0.0) for t in temps] |
| 980 | |
| 981 | def run_ar(ar_rows, step_idx): |
| 982 | row_groups = _split_rows_by_kv_budget(ar_rows, kv_rows) |
| 983 | _record_kv_bucket_stats(_stats, row_groups, kv_rows) |
| 984 | for row_group in row_groups: |
| 985 | _step_stock_ar_rows( |
| 986 | model, san, tids, prompt_ids, kv_rows, row_group, |
| 987 | cached_lens, full_ids, gen_ids, modes, finished, total_limits, |
| 988 | pad, img_tok, row_temps, temperature, top_p, top_k, |
| 989 | repetition_penalty, dev, tok, debug, step_idx, use_magi, _stats, |
| 990 | ) |
| 991 | |
| 992 | def run_mtp(mtp_rows, step_idx): |
| 993 | if any(cached_lens[r] == 0 for r in mtp_rows) and any(cached_lens[r] > 0 for r in mtp_rows): |
| 994 | first_rows = [r for r in mtp_rows if cached_lens[r] == 0] |
| 995 | cached_rows = [r for r in mtp_rows if cached_lens[r] > 0] |
| 996 | if first_rows: |
| 997 | run_mtp(first_rows, step_idx) |
| 998 | if cached_rows: |
| 999 | run_mtp(cached_rows, step_idx) |
| 1000 | return |
| 1001 | row_groups = _split_rows_by_kv_budget(mtp_rows, kv_rows) |
| 1002 | _record_kv_bucket_stats(_stats, row_groups, kv_rows) |
| 1003 | if len(row_groups) > 1: |
| 1004 | for row_group in row_groups: |
| 1005 | run_mtp(row_group, step_idx) |
| 1006 | return |
| 1007 | _step_stock_mtp_rows( |
| 1008 | model, san, hpat, tids, prompt_ids, kv_rows, mtp_rows, |
| 1009 | cached_lens, full_ids, gen_ids, modes, finished, total_limits, |
| 1010 | vit_list, pad, mask_tok, img_tok, row_temps, top_p, top_k, |
| 1011 | repetition_penalty, dev, tok, debug, step_idx, use_magi, _stats, |
| 1012 | ) |
| 1013 | |
| 1014 | def live_rows(mode): |
| 1015 | return [b for b in range(bsz) if not finished[b] and modes[b] == mode] |
| 1016 | |
| 1017 | step = 0 |
| 1018 | hold_steps = 0 |
| 1019 | while not all(finished) and step <= max_new_tokens: |
| 1020 | step += 1 |
| 1021 | if _stats is not None: |
| 1022 | _stats["decode_loops"] += 1 |
| 1023 | if scheduler == "hold_ar" and hold_max_steps > 0: |
| 1024 | ar_rows = live_rows("ar") |
| 1025 | mtp_rows = live_rows("mtp") |
| 1026 | if ar_rows and mtp_rows and _stats is not None: |
| 1027 | _stats["mixed_mode_cycles"] += 1 |
| 1028 | if ar_rows and (hold_steps < hold_max_steps or not mtp_rows): |
| 1029 | if mtp_rows and _stats is not None: |
| 1030 | _stats["hold_ar_steps"] += 1 |
| 1031 | _stats["hold_ar_held_mtp_rows"] += len(mtp_rows) |
| 1032 | run_ar(ar_rows, step) |
| 1033 | hold_steps += 1 |
| 1034 | continue |
| 1035 | if mtp_rows: |
| 1036 | if ar_rows and _stats is not None: |
| 1037 | _stats["hold_ar_limit_mtp_forwards"] += 1 |
| 1038 | run_mtp(mtp_rows, step) |
| 1039 | hold_steps = 0 |
| 1040 | continue |
| 1041 | |
| 1042 | if scheduler in {"ar_first", "pipeline", "adaptive"}: |
| 1043 | ar_rows_at_loop_start = live_rows("ar") |
| 1044 | mtp_rows_at_loop_start = live_rows("mtp") |
| 1045 | mixed = bool(ar_rows_at_loop_start and mtp_rows_at_loop_start) |
| 1046 | if mixed and _stats is not None: |
| 1047 | _stats["mixed_mode_cycles"] += 1 |
| 1048 | |
| 1049 | if scheduler == "adaptive" and mixed and hold_max_steps > 0: |
| 1050 | should_hold = len(mtp_rows_at_loop_start) <= adaptive_hold_mtp_max |
| 1051 | if should_hold and hold_steps < hold_max_steps: |
| 1052 | if _stats is not None: |
| 1053 | _stats["adaptive_hold_cycles"] += 1 |
| 1054 | _stats["hold_ar_steps"] += 1 |
| 1055 | _stats["hold_ar_held_mtp_rows"] += len(mtp_rows_at_loop_start) |
| 1056 | run_ar(ar_rows_at_loop_start, step) |
| 1057 | hold_steps += 1 |
| 1058 | continue |
| 1059 | |
| 1060 | if ar_rows_at_loop_start: |
| 1061 | if mixed and _stats is not None: |
| 1062 | if scheduler == "adaptive": |
| 1063 | _stats["adaptive_ar_first_cycles"] += 1 |
| 1064 | else: |
| 1065 | _stats["ar_first_cycles"] += 1 |
| 1066 | run_ar(ar_rows_at_loop_start, step) |
| 1067 | |
| 1068 | mtp_rows = live_rows("mtp") |
| 1069 | if mtp_rows: |
| 1070 | run_mtp(mtp_rows, step) |
| 1071 | hold_steps = 0 |
| 1072 | |
| 1073 | if scheduler == "pipeline" and mtp_rows: |
| 1074 | old_ar = set(ar_rows_at_loop_start) |
| 1075 | new_ar_rows = [b for b in live_rows("ar") if b not in old_ar] |
| 1076 | if new_ar_rows: |
| 1077 | if _stats is not None: |
| 1078 | _stats["pipeline_ar_after_mtp_cycles"] += 1 |
| 1079 | run_ar(new_ar_rows, step) |
| 1080 | continue |
| 1081 | |
| 1082 | mtp_rows = live_rows("mtp") |
| 1083 | ar_rows_at_loop_start = live_rows("ar") |
| 1084 | if mtp_rows and ar_rows_at_loop_start and _stats is not None: |
| 1085 | _stats["mixed_mode_cycles"] += 1 |
| 1086 | if mtp_rows: |
| 1087 | run_mtp(mtp_rows, step) |
| 1088 | |
| 1089 | ar_rows = [b for b in range(bsz) if not finished[b] and modes[b] == "ar"] |
| 1090 | if mtp_rows and ar_rows and _stats is not None: |
| 1091 | _stats["eager_mtp_then_ar_cycles"] += 1 |
| 1092 | if ar_rows: |
| 1093 | run_ar(ar_rows, step) |
| 1094 | |
| 1095 | outs = [ |
| 1096 | tok.decode(torch.tensor(gen_ids[b], dtype=torch.long, device=dev), |
| 1097 | skip_special_tokens=False) if gen_ids[b] else "" |
| 1098 | for b in range(bsz) |
| 1099 | ] |
| 1100 | if top_level_stats: |
| 1101 | if os.environ.get("LA_FLASH_PLAN_STATS", "0") == "1": |
| 1102 | _stats["sparse_plan_stats"] = copy.deepcopy( |
| 1103 | getattr(model, "_la_flash_sparse_plan_stats", None) or {} |
| 1104 | ) |
| 1105 | _set_last_hybrid_stats(_stats) |
| 1106 | return outs |
| 1107 | |
| 1108 | |
| 1109 | @torch.no_grad() |
| 1110 | def _step_stock_mtp_rows(model, san, hpat, tids, prompt_ids, kv_rows, rows, |
| 1111 | cached_lens, full_ids, gen_ids, modes, finished, total_limits, |
| 1112 | vit_list, pad, mask_tok, img_tok, row_temps, top_p, top_k, |
| 1113 | repetition_penalty, dev, tok, debug, step_idx, use_magi, stats=None): |
| 1114 | kv, kvalid, old_lens, kmax = _pack_stock_kv_rows(kv_rows, rows, dev) |
| 1115 | uncached_lens = [len(full_ids[r]) - cached_lens[r] for r in rows] |
| 1116 | umax = max(uncached_lens) |
| 1117 | seq_len = umax + N_FUTURE |
| 1118 | _record_forward_stats(stats, "mtp", rows, seq_len, uncached_lens) |
| 1119 | |
| 1120 | suf_ids = torch.full((len(rows), seq_len), pad, dtype=torch.long, device=dev) |
| 1121 | suf_pos = torch.ones((len(rows), seq_len), dtype=torch.long, device=dev) |
| 1122 | q_valid = torch.zeros((len(rows), seq_len), dtype=torch.long, device=dev) |
| 1123 | |
| 1124 | for i, r in enumerate(rows): |
| 1125 | uncached = full_ids[r][cached_lens[r] :] |
| 1126 | left = umax - len(uncached) |
| 1127 | if uncached: |
| 1128 | suf_ids[i, left : left + len(uncached)] = torch.tensor(uncached, dtype=torch.long, device=dev) |
| 1129 | suf_pos[i, left : left + len(uncached)] = torch.arange( |
| 1130 | cached_lens[r], len(full_ids[r]), dtype=torch.long, device=dev) |
| 1131 | q_valid[i, left : left + len(uncached)] = 1 |
| 1132 | |
| 1133 | rep = full_ids[r][-1] |
| 1134 | cur_len = len(full_ids[r]) |
| 1135 | suf_ids[i, umax] = rep |
| 1136 | suf_pos[i, umax] = cur_len - 1 |
| 1137 | q_valid[i, umax] = 1 |
| 1138 | for j in range(1, N_FUTURE): |
| 1139 | suf_ids[i, umax + j] = mask_tok |
| 1140 | suf_pos[i, umax + j] = cur_len + (j - 1) |
| 1141 | q_valid[i, umax + j] = 1 |
| 1142 | |
| 1143 | full_mask = torch.cat([kvalid, q_valid], dim=1) |
| 1144 | |
| 1145 | if debug: |
| 1146 | forward_mask, fallback_rows = _forward_attention_mask( |
| 1147 | model, suf_ids, full_mask, kmax, mtp_window=True, range_plan=True) |
| 1148 | _print_debug_forward( |
| 1149 | f"MTP step={step_idx}", |
| 1150 | model, |
| 1151 | tok, |
| 1152 | suf_ids, |
| 1153 | full_mask, |
| 1154 | suf_pos, |
| 1155 | past_len=kmax, |
| 1156 | mtp_window=True, |
| 1157 | extra={ |
| 1158 | "global_rows": rows, |
| 1159 | "old_kv_lens": old_lens, |
| 1160 | "cached_lens": [cached_lens[r] for r in rows], |
| 1161 | "full_lens": [len(full_ids[r]) for r in rows], |
| 1162 | "uncached_lens": uncached_lens, |
| 1163 | "forward_attention_mask": _mask_desc(forward_mask), |
| 1164 | "safe_sdpa_fallback_query_rows": fallback_rows, |
| 1165 | }, |
| 1166 | attention_impl="magi" if use_magi else ATTN_MODE, |
| 1167 | ) |
| 1168 | else: |
| 1169 | forward_mask, _ = _forward_attention_mask( |
| 1170 | model, suf_ids, full_mask, kmax, mtp_window=True, range_plan=True) |
| 1171 | first_rows = [r for r in rows if cached_lens[r] == 0] |
| 1172 | visual_features = None |
| 1173 | if first_rows: |
| 1174 | if first_rows != rows: |
| 1175 | raise RuntimeError("mixed first/non-first MTP rows are not supported") |
| 1176 | visual_features = torch.cat([vit_list[r] for r in rows], dim=0) |
| 1177 | assert int((suf_ids == img_tok).sum().item()) == visual_features.shape[0], \ |
| 1178 | "image-token count != supplied visual_features rows" |
| 1179 | out = language_model_forward( |
| 1180 | model, input_ids=suf_ids, attention_mask=forward_mask, |
| 1181 | position_ids=suf_pos, past_key_values=kv, use_cache=True, |
| 1182 | visual_features=visual_features, |
| 1183 | image_token_index=img_tok if visual_features is not None else None, |
| 1184 | logits_slice=slice(-N_FUTURE, None)) |
| 1185 | |
| 1186 | for i, r in enumerate(rows): |
| 1187 | kv_rows[r] = _unpack_stock_after_forward( |
| 1188 | out.past_key_values, i, old_lens[i], uncached_lens[i], kmax, umax) |
| 1189 | cached_lens[r] = len(full_ids[r]) |
| 1190 | |
| 1191 | wlogits = out.logits[:, -N_FUTURE:, :] |
| 1192 | local_prompts = [prompt_ids[r] for r in rows] |
| 1193 | local_gen = [gen_ids[r] for r in rows] |
| 1194 | gen_pad = _pad_generated(local_prompts, local_gen, img_tok, dev) |
| 1195 | per_row_temp = torch.tensor([row_temps[r] for r in rows], dtype=torch.float32, device=dev) |
| 1196 | |
| 1197 | if BATCH_SAN: |
| 1198 | x0_all, boxes_all = sample_tokens_batched( |
| 1199 | wlogits, gen_pad, tids, per_row_temp, |
| 1200 | repetition_penalty=repetition_penalty, top_p=top_p, top_k=top_k, |
| 1201 | keep_k_avg=4, generation_mode="hybrid") |
| 1202 | |
| 1203 | for i, r in enumerate(rows): |
| 1204 | if finished[r]: |
| 1205 | continue |
| 1206 | if BATCH_SAN: |
| 1207 | x0b, boxb = x0_all[i], boxes_all[i] |
| 1208 | else: |
| 1209 | gk = _mk_generate_kwargs(row_temps[r], top_p, top_k, repetition_penalty) |
| 1210 | _, _, x0, box_avg = san(wlogits[i : i + 1], gen_pad[i : i + 1], tids, keep_k=5, **gk) |
| 1211 | x0b, boxb = x0[0], box_avg[0] |
| 1212 | nt = x0b if bool((boxb == 0).all()) else boxb |
| 1213 | op = hpat(nt, tids, "hybrid") |
| 1214 | |
| 1215 | toks = [int(t) for t in op["tokens"]] |
| 1216 | for t in toks: |
| 1217 | gen_ids[r].append(t) |
| 1218 | full_ids[r].append(t) |
| 1219 | |
| 1220 | if op["type"] == "im_end": |
| 1221 | finished[r] = True |
| 1222 | elif op["type"] == "error_box": |
| 1223 | modes[r] = "ar" |
| 1224 | if len(full_ids[r]) >= total_limits[r]: |
| 1225 | finished[r] = True |
| 1226 | |
| 1227 | |
| 1228 | @torch.no_grad() |
| 1229 | def _step_stock_ar_rows(model, san, tids, prompt_ids, kv_rows, rows, |
| 1230 | cached_lens, full_ids, gen_ids, modes, finished, total_limits, |
| 1231 | pad, img_tok, row_temps, temperature, top_p, top_k, |
| 1232 | repetition_penalty, dev, tok, debug, step_idx, use_magi, stats=None): |
| 1233 | kv, kvalid, old_lens, kmax = _pack_stock_kv_rows(kv_rows, rows, dev) |
| 1234 | uncached_lens = [len(full_ids[r]) - cached_lens[r] for r in rows] |
| 1235 | if any(n <= 0 for n in uncached_lens): |
| 1236 | raise RuntimeError(f"AR rows have no uncached tokens: {rows}") |
| 1237 | umax = max(uncached_lens) |
| 1238 | _record_forward_stats(stats, "ar", rows, umax, uncached_lens) |
| 1239 | |
| 1240 | suf_ids = torch.full((len(rows), umax), pad, dtype=torch.long, device=dev) |
| 1241 | suf_pos = torch.ones((len(rows), umax), dtype=torch.long, device=dev) |
| 1242 | q_valid = torch.zeros((len(rows), umax), dtype=torch.long, device=dev) |
| 1243 | |
| 1244 | for i, r in enumerate(rows): |
| 1245 | uncached = full_ids[r][cached_lens[r] :] |
| 1246 | left = umax - len(uncached) |
| 1247 | suf_ids[i, left:] = torch.tensor(uncached, dtype=torch.long, device=dev) |
| 1248 | suf_pos[i, left:] = torch.arange(cached_lens[r], len(full_ids[r]), dtype=torch.long, device=dev) |
| 1249 | q_valid[i, left:] = 1 |
| 1250 | |
| 1251 | full_mask = torch.cat([kvalid, q_valid], dim=1) |
| 1252 | if debug: |
| 1253 | forward_mask, fallback_rows = _forward_attention_mask( |
| 1254 | model, suf_ids, full_mask, kmax, mtp_window=False, range_plan=True) |
| 1255 | _print_debug_forward( |
| 1256 | f"AR step={step_idx}", |
| 1257 | model, |
| 1258 | tok, |
| 1259 | suf_ids, |
| 1260 | full_mask, |
| 1261 | suf_pos, |
| 1262 | past_len=kmax, |
| 1263 | mtp_window=False, |
| 1264 | extra={ |
| 1265 | "global_rows": rows, |
| 1266 | "old_kv_lens": old_lens, |
| 1267 | "cached_lens": [cached_lens[r] for r in rows], |
| 1268 | "full_lens": [len(full_ids[r]) for r in rows], |
| 1269 | "uncached_lens": uncached_lens, |
| 1270 | "forward_attention_mask": _mask_desc(forward_mask), |
| 1271 | "safe_sdpa_fallback_query_rows": fallback_rows, |
| 1272 | }, |
| 1273 | attention_impl="magi" if use_magi else ATTN_MODE, |
| 1274 | ) |
| 1275 | else: |
| 1276 | forward_mask, _ = _forward_attention_mask( |
| 1277 | model, suf_ids, full_mask, kmax, mtp_window=False, range_plan=True) |
| 1278 | |
| 1279 | out = language_model_forward( |
| 1280 | model, input_ids=suf_ids, attention_mask=forward_mask, |
| 1281 | position_ids=suf_pos, past_key_values=kv, use_cache=True, |
| 1282 | logits_slice=slice(-1, None)) |
| 1283 | |
| 1284 | for i, r in enumerate(rows): |
| 1285 | kv_rows[r] = _unpack_stock_after_forward( |
| 1286 | out.past_key_values, i, old_lens[i], uncached_lens[i], kmax, umax) |
| 1287 | cached_lens[r] = len(full_ids[r]) |
| 1288 | |
| 1289 | if AR_BATCH_SAN: |
| 1290 | local_prompts = [prompt_ids[r] for r in rows] |
| 1291 | local_gen = [gen_ids[r] for r in rows] |
| 1292 | gen_pad = _pad_generated(local_prompts, local_gen, img_tok, dev) |
| 1293 | per_row_temp = torch.tensor([row_temps[r] for r in rows], dtype=torch.float32, device=dev) |
| 1294 | x0_all = sample_next_tokens_batched( |
| 1295 | out.logits[:, -1:, :], |
| 1296 | gen_pad, |
| 1297 | per_row_temp, |
| 1298 | repetition_penalty=repetition_penalty, |
| 1299 | top_p=top_p, |
| 1300 | top_k=top_k, |
| 1301 | ) |
| 1302 | |
| 1303 | for i, r in enumerate(rows): |
| 1304 | if AR_BATCH_SAN: |
| 1305 | token_val = int(x0_all[i, 0].item()) |
| 1306 | else: |
| 1307 | logits = out.logits[i : i + 1, -1:, :] |
| 1308 | gen_pad = _pad_generated([prompt_ids[r]], [gen_ids[r]], img_tok, dev) |
| 1309 | gk = _mk_generate_kwargs(temperature, top_p, top_k, repetition_penalty, row_temp=row_temps[r]) |
| 1310 | _, _, x0, _ = san(logits, gen_pad, tids, **gk) |
| 1311 | token_val = int(x0[0, 0].item()) |
| 1312 | out_type = _classify_ar_token(token_val, tids) |
| 1313 | |
| 1314 | gen_ids[r].append(token_val) |
| 1315 | full_ids[r].append(token_val) |
| 1316 | |
| 1317 | if out_type == "im_end": |
| 1318 | finished[r] = True |
| 1319 | elif out_type == "box_end_ar": |
| 1320 | modes[r] = "mtp" |
| 1321 | |
| 1322 | if len(full_ids[r]) >= total_limits[r]: |
| 1323 | finished[r] = True |
| 1324 | |
| 1325 | |
| 1326 | def generate_batch_grouped_hybrid(groups, temperature=README_TEMPERATURE, top_p=README_TOP_P, |
| 1327 | top_k=None, repetition_penalty=README_REPETITION_PENALTY, |
| 1328 | max_new_tokens=README_MAX_NEW_TOKENS, temps=None, |
| 1329 | debug=None, scheduler=None, group_size=None, |
| 1330 | vision_features=None): |
| 1331 | """Hybrid grouped API shape. |
| 1332 | |
| 1333 | This preserves grouped return shape, but intentionally uses the generic |
| 1334 | hybrid decoder rather than the fast engine's shared-prefix optimization. |
| 1335 | """ |
| 1336 | flat = [] |
| 1337 | flat_vision_features = [] if vision_features is not None else None |
| 1338 | counts = [] |
| 1339 | for group_idx, (im, queries) in enumerate(groups): |
| 1340 | counts.append(len(queries)) |
| 1341 | flat.extend((im, q) for q in queries) |
| 1342 | if flat_vision_features is not None: |
| 1343 | flat_vision_features.extend([vision_features[group_idx]] * len(queries)) |
| 1344 | |
| 1345 | outs = generate_batch_hybrid( |
| 1346 | flat, temperature=temperature, top_p=top_p, top_k=top_k, |
| 1347 | repetition_penalty=repetition_penalty, max_new_tokens=max_new_tokens, |
| 1348 | temps=temps, debug=debug, scheduler=scheduler, group_size=group_size, |
| 1349 | vision_features=flat_vision_features) |
| 1350 | res, offset = [], 0 |
| 1351 | for n in counts: |
| 1352 | res.append(outs[offset : offset + n]) |
| 1353 | offset += n |
| 1354 | return res |
| 1355 | |
| 1356 | |
| 1357 | __all__ = ["generate_batch_hybrid", "generate_batch_grouped_hybrid", "get_last_hybrid_stats"] |
| 1358 | |