mask_magi_utils.py
| 1 | # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. |
| 2 | # |
| 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property |
| 4 | # and proprietary rights in and to this software, related documentation |
| 5 | # and any modifications thereto. Any use, reproduction, disclosure or |
| 6 | # distribution of this software and related documentation without an express |
| 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. |
| 8 | |
| 9 | import torch |
| 10 | |
| 11 | # MagiAttention attn_type_map convention |
| 12 | FULL, CAUSAL = 0, 1 |
| 13 | |
| 14 | def build_magi_ranges(kv_len: int, q_len: int, block_size: int, ar_decode: bool=False, device: str = "cpu"): |
| 15 | """ |
| 16 | Fixed strategy: |
| 17 | - use_cache=True: Mask blocked_k = (kv_len - block_size - 1) column |
| 18 | - causal_attn=False: Window interior is FULL (bidirectional) |
| 19 | - If q_len==kv_len: Use coarse prefix version (fewer ranges) |
| 20 | - Otherwise: General decode version (recompute rows expanding visible region row by row) |
| 21 | |
| 22 | Conventions: |
| 23 | - K/V global length kv_len: [0, kv_len) |
| 24 | - Current Q is "last q_len tokens" |
| 25 | - First r=q_len-block_size rows are recomputed; last block_size rows are window |
| 26 | """ |
| 27 | assert 0 < q_len <= kv_len |
| 28 | |
| 29 | if ar_decode: |
| 30 | return { |
| 31 | "q_ranges": torch.tensor([[0, q_len]], dtype=torch.int32, device=device).contiguous(), |
| 32 | "k_ranges": torch.tensor([[0, kv_len]], dtype=torch.int32, device=device).contiguous(), |
| 33 | "attn_type_map": torch.tensor([CAUSAL], dtype=torch.int32, device=device).contiguous(), |
| 34 | } |
| 35 | |
| 36 | |
| 37 | assert 0 < block_size <= q_len <= kv_len |
| 38 | B = block_size |
| 39 | r = q_len - B |
| 40 | q_global_start = kv_len - q_len |
| 41 | |
| 42 | window_start_k = kv_len - B |
| 43 | blocked_k = window_start_k - 1 # The column that is blocked |
| 44 | |
| 45 | q_ranges, k_ranges, types = [], [], [] |
| 46 | |
| 47 | # -------- prefix (q_len == kv_len) coarse-grained -------- |
| 48 | if q_len == kv_len: |
| 49 | prefix_len = window_start_k # kv_len - B |
| 50 | |
| 51 | # prefix->prefix: causal |
| 52 | if prefix_len > 0: |
| 53 | q_ranges += [[0, prefix_len]] |
| 54 | k_ranges += [[0, prefix_len]] |
| 55 | types += [CAUSAL] |
| 56 | |
| 57 | # window->prefix: full, but exclude blocked_k => keys [0, blocked_k) |
| 58 | if prefix_len > 0 and blocked_k > 0: |
| 59 | q_ranges += [[prefix_len, kv_len]] |
| 60 | k_ranges += [[0, blocked_k]] |
| 61 | types += [FULL] |
| 62 | |
| 63 | # window->window: full |
| 64 | q_ranges += [[prefix_len, kv_len]] |
| 65 | k_ranges += [[prefix_len, kv_len]] |
| 66 | types += [FULL] |
| 67 | |
| 68 | return { |
| 69 | "q_ranges": torch.tensor(q_ranges, dtype=torch.int32, device=device).contiguous(), |
| 70 | "k_ranges": torch.tensor(k_ranges, dtype=torch.int32, device=device).contiguous(), |
| 71 | "attn_type_map": torch.tensor(types, dtype=torch.int32, device=device).contiguous(), |
| 72 | } |
| 73 | |
| 74 | # -------- decode / general (q_len < kv_len) -------- |
| 75 | |
| 76 | # A) Recomputed rows: expand visible key cutoff row by row (use FULL + single-row q_range for precise shape) |
| 77 | for i in range(r): |
| 78 | g = q_global_start + i |
| 79 | q_ranges.append([i, i + 1]) |
| 80 | k_ranges.append([0, g + 1]) # Allow keys [0, g] |
| 81 | types.append(FULL) |
| 82 | |
| 83 | # B) Window rows: allow prefix but block blocked_k; window interior is full |
| 84 | q_win = [r, q_len] |
| 85 | |
| 86 | # prefix keys [0, blocked_k) |
| 87 | if blocked_k > 0: |
| 88 | q_ranges.append(q_win) |
| 89 | k_ranges.append([0, blocked_k]) |
| 90 | types.append(FULL) |
| 91 | |
| 92 | # window keys [window_start_k, kv_len) |
| 93 | q_ranges.append(q_win) |
| 94 | k_ranges.append([window_start_k, kv_len]) |
| 95 | types.append(FULL) |
| 96 | |
| 97 | return { |
| 98 | "q_ranges": torch.tensor(q_ranges, dtype=torch.int32, device=device).contiguous(), |
| 99 | "k_ranges": torch.tensor(k_ranges, dtype=torch.int32, device=device).contiguous(), |
| 100 | "attn_type_map": torch.tensor(types, dtype=torch.int32, device=device).contiguous(), |
| 101 | } |