mask_magi_utils.py
3.8 KB · 101 lines · python Raw
1 # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2 #
3 # NVIDIA CORPORATION and its licensors retain all intellectual property
4 # and proprietary rights in and to this software, related documentation
5 # and any modifications thereto. Any use, reproduction, disclosure or
6 # distribution of this software and related documentation without an express
7 # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
9 import torch
10
11 # MagiAttention attn_type_map convention
12 FULL, CAUSAL = 0, 1
13
14 def build_magi_ranges(kv_len: int, q_len: int, block_size: int, ar_decode: bool=False, device: str = "cpu"):
15 """
16 Fixed strategy:
17 - use_cache=True: Mask blocked_k = (kv_len - block_size - 1) column
18 - causal_attn=False: Window interior is FULL (bidirectional)
19 - If q_len==kv_len: Use coarse prefix version (fewer ranges)
20 - Otherwise: General decode version (recompute rows expanding visible region row by row)
21
22 Conventions:
23 - K/V global length kv_len: [0, kv_len)
24 - Current Q is "last q_len tokens"
25 - First r=q_len-block_size rows are recomputed; last block_size rows are window
26 """
27 assert 0 < q_len <= kv_len
28
29 if ar_decode:
30 return {
31 "q_ranges": torch.tensor([[0, q_len]], dtype=torch.int32, device=device).contiguous(),
32 "k_ranges": torch.tensor([[0, kv_len]], dtype=torch.int32, device=device).contiguous(),
33 "attn_type_map": torch.tensor([CAUSAL], dtype=torch.int32, device=device).contiguous(),
34 }
35
36
37 assert 0 < block_size <= q_len <= kv_len
38 B = block_size
39 r = q_len - B
40 q_global_start = kv_len - q_len
41
42 window_start_k = kv_len - B
43 blocked_k = window_start_k - 1 # The column that is blocked
44
45 q_ranges, k_ranges, types = [], [], []
46
47 # -------- prefix (q_len == kv_len) coarse-grained --------
48 if q_len == kv_len:
49 prefix_len = window_start_k # kv_len - B
50
51 # prefix->prefix: causal
52 if prefix_len > 0:
53 q_ranges += [[0, prefix_len]]
54 k_ranges += [[0, prefix_len]]
55 types += [CAUSAL]
56
57 # window->prefix: full, but exclude blocked_k => keys [0, blocked_k)
58 if prefix_len > 0 and blocked_k > 0:
59 q_ranges += [[prefix_len, kv_len]]
60 k_ranges += [[0, blocked_k]]
61 types += [FULL]
62
63 # window->window: full
64 q_ranges += [[prefix_len, kv_len]]
65 k_ranges += [[prefix_len, kv_len]]
66 types += [FULL]
67
68 return {
69 "q_ranges": torch.tensor(q_ranges, dtype=torch.int32, device=device).contiguous(),
70 "k_ranges": torch.tensor(k_ranges, dtype=torch.int32, device=device).contiguous(),
71 "attn_type_map": torch.tensor(types, dtype=torch.int32, device=device).contiguous(),
72 }
73
74 # -------- decode / general (q_len < kv_len) --------
75
76 # A) Recomputed rows: expand visible key cutoff row by row (use FULL + single-row q_range for precise shape)
77 for i in range(r):
78 g = q_global_start + i
79 q_ranges.append([i, i + 1])
80 k_ranges.append([0, g + 1]) # Allow keys [0, g]
81 types.append(FULL)
82
83 # B) Window rows: allow prefix but block blocked_k; window interior is full
84 q_win = [r, q_len]
85
86 # prefix keys [0, blocked_k)
87 if blocked_k > 0:
88 q_ranges.append(q_win)
89 k_ranges.append([0, blocked_k])
90 types.append(FULL)
91
92 # window keys [window_start_k, kv_len)
93 q_ranges.append(q_win)
94 k_ranges.append([window_start_k, kv_len])
95 types.append(FULL)
96
97 return {
98 "q_ranges": torch.tensor(q_ranges, dtype=torch.int32, device=device).contiguous(),
99 "k_ranges": torch.tensor(k_ranges, dtype=torch.int32, device=device).contiguous(),
100 "attn_type_map": torch.tensor(types, dtype=torch.int32, device=device).contiguous(),
101 }