kernel_utils/range_attention.py
14.1 KB · 395 lines · python Raw
1 """Sparse LocateAnything attention implemented with FlashAttention varlen.
2
3 The public API accepts flattened query/key/value tensors:
4
5 q: [total_q, num_q_heads, head_dim]
6 k: [total_k, num_kv_heads, head_dim]
7 v: [total_k, num_kv_heads, head_dim]
8
9 and a Magi-style range plan:
10
11 q_ranges: [num_ranges, 2]
12 k_ranges: [num_key_segments, 2]
13 segment_offsets: [num_query_groups + 1]
14 attn_type_map:
15 0 = full attention over the listed key segment(s)
16 1 = bottom-right causal attention
17
18 For LocateAnything hybrid MTP decode, batch_utils represents the window as a
19 causal prefix plus full-attention sparse window segments. This module packs
20 those visible KV segments and calls FlashAttention varlen, avoiding dense masks.
21 """
22 from __future__ import annotations
23
24 import os
25 from typing import Optional
26
27 import torch
28
29
30 _FLASH_ATTN_VARLEN = None
31 _FLASH_ATTN_ERROR: Optional[BaseException] = None
32
33
34 def _env_enabled(name: str, default: str = "auto") -> bool:
35 value = os.environ.get(name, default).strip().lower()
36 return value in {"", "auto", "1", "on", "true", "yes", "force"}
37
38
39 def is_available() -> bool:
40 try:
41 _load_flash_attn_varlen()
42 return True
43 except Exception:
44 return False
45
46
47 def _flash_fastpath_enabled() -> bool:
48 return _env_enabled("LA_FLASH_FASTPATH", "auto")
49
50
51 def _flash_segment_fastpath_enabled() -> bool:
52 return _env_enabled("LA_FLASH_SEGMENT_FASTPATH", "auto")
53
54
55 def _load_flash_attn_varlen():
56 global _FLASH_ATTN_VARLEN, _FLASH_ATTN_ERROR
57 if _FLASH_ATTN_VARLEN is not None:
58 return _FLASH_ATTN_VARLEN
59 if _FLASH_ATTN_ERROR is not None:
60 raise _FLASH_ATTN_ERROR
61 try:
62 from flash_attn import flash_attn_varlen_func
63
64 _FLASH_ATTN_VARLEN = flash_attn_varlen_func
65 return _FLASH_ATTN_VARLEN
66 except BaseException as exc:
67 _FLASH_ATTN_ERROR = exc
68 raise
69
70
71 def _coalesce_query_groups(q_ranges, k_ranges, attn_type_map):
72 """Group consecutive entries that share the same query span and mask type."""
73 if q_ranges.numel() == 0:
74 segment_offsets = torch.zeros((1,), dtype=torch.int32, device=q_ranges.device)
75 return q_ranges, k_ranges, segment_offsets, attn_type_map, 0, 0
76
77 q_cpu = q_ranges.detach().to(device="cpu", dtype=torch.int32).contiguous()
78 t_cpu = attn_type_map.detach().to(device="cpu", dtype=torch.int32).contiguous()
79 grouped_q = []
80 grouped_t = []
81 offsets = [0]
82 max_q_len = 0
83 last_q = None
84 last_t = None
85 for idx, (qr, attn_type) in enumerate(zip(q_cpu.tolist(), t_cpu.tolist())):
86 key = (int(qr[0]), int(qr[1]))
87 attn_type = int(attn_type)
88 if attn_type not in (0, 1):
89 raise RuntimeError(
90 "LA Flash path only supports FlashAttention-compatible attn_type 0/1. "
91 f"Got attn_type={attn_type}; regenerate a type 0/1 range plan."
92 )
93 if last_q is None:
94 grouped_q.append([key[0], key[1]])
95 grouped_t.append(attn_type)
96 max_q_len = max(max_q_len, key[1] - key[0])
97 last_q = key
98 last_t = attn_type
99 continue
100 if key == last_q and attn_type == last_t:
101 continue
102 offsets.append(idx)
103 grouped_q.append([key[0], key[1]])
104 grouped_t.append(attn_type)
105 max_q_len = max(max_q_len, key[1] - key[0])
106 last_q = key
107 last_t = attn_type
108 offsets.append(int(q_ranges.shape[0]))
109
110 k_cpu = k_ranges.detach().to(device="cpu", dtype=torch.int32).contiguous()
111 max_k_len = max((int(end) - int(start) for start, end in k_cpu.tolist()), default=0)
112
113 return (
114 torch.tensor(grouped_q, dtype=torch.int32, device=q_ranges.device).contiguous(),
115 k_ranges,
116 torch.tensor(offsets, dtype=torch.int32, device=q_ranges.device).contiguous(),
117 torch.tensor(grouped_t, dtype=torch.int32, device=q_ranges.device).contiguous(),
118 int(max_q_len),
119 int(max_k_len),
120 )
121
122
123 def _flash_lse_to_tq_h(lse, total_q, q_lengths=None):
124 if lse is None:
125 return None
126 if lse.dim() != 2:
127 if lse.dim() == 3 and q_lengths is not None and lse.shape[0] == len(q_lengths):
128 chunks = []
129 for idx, q_len in enumerate(q_lengths):
130 q_len = int(q_len)
131 if lse.shape[1] == 0 or q_len > lse.shape[2]:
132 return None
133 chunks.append(lse[idx, :, :q_len].transpose(0, 1).contiguous())
134 merged = torch.cat(chunks, dim=0).float()
135 return merged if merged.shape[0] == total_q else None
136 return None
137 if lse.shape[0] == total_q:
138 return lse.float()
139 if lse.shape[1] == total_q:
140 return lse.transpose(0, 1).contiguous().float()
141 return None
142
143
144 def _make_cu_seqlens(lengths, device):
145 return torch.tensor([0] + list(torch.tensor(lengths).cumsum(0).tolist()), device=device, dtype=torch.int32)
146
147
148 def _try_flash_segment_merge(
149 q,
150 k,
151 v,
152 k_ranges,
153 segment_offsets,
154 group_q_ranges,
155 group_attn_type_map,
156 softmax_scale,
157 ):
158 if not _flash_segment_fastpath_enabled():
159 return None
160 if q.dtype not in (torch.float16, torch.bfloat16) or k.dtype != q.dtype or v.dtype != q.dtype:
161 return None
162 if group_q_ranges is None or segment_offsets is None or group_attn_type_map is None:
163 return None
164
165 flash_attn_varlen = _load_flash_attn_varlen()
166 gq_cpu = group_q_ranges.detach().to(device="cpu", dtype=torch.int32).contiguous()
167 kr_cpu = k_ranges.detach().to(device="cpu", dtype=torch.int32).contiguous()
168 seg_cpu = segment_offsets.detach().to(device="cpu", dtype=torch.int32).contiguous()
169 type_cpu = group_attn_type_map.detach().to(device="cpu", dtype=torch.int32).contiguous()
170
171 groups = []
172 max_segments = 0
173 for group_idx, (q_start, q_end) in enumerate(gq_cpu.tolist()):
174 attn_type = int(type_cpu[group_idx].item())
175 if attn_type not in (0, 1):
176 return None
177 seg_start = int(seg_cpu[group_idx].item())
178 seg_end = int(seg_cpu[group_idx + 1].item())
179 if seg_end <= seg_start or q_end <= q_start:
180 return None
181 segments = kr_cpu[seg_start:seg_end].tolist()
182 max_segments = max(max_segments, len(segments))
183 groups.append((int(q_start), int(q_end), attn_type, [(int(a), int(b)) for a, b in segments]))
184
185 if not groups or max_segments == 0:
186 return None
187
188 can_pack_full_groups = all(attn_type == 0 or len(segments) == 1 for _, _, attn_type, segments in groups)
189 if can_pack_full_groups:
190 merged = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=q.dtype)
191 covered = torch.zeros((q.shape[0],), device=q.device, dtype=torch.bool)
192 for attn_type in (0, 1):
193 q_slices = []
194 k_slices = []
195 v_slices = []
196 q_lengths = []
197 k_lengths = []
198 targets = []
199 for q_start, q_end, group_type, segments in groups:
200 if group_type != attn_type:
201 continue
202 q_slices.append(q[q_start:q_end])
203 if attn_type == 0 and len(segments) > 1:
204 k_slices.append(torch.cat([k[start:end] for start, end in segments], dim=0))
205 v_slices.append(torch.cat([v[start:end] for start, end in segments], dim=0))
206 k_lengths.append(sum(end - start for start, end in segments))
207 else:
208 k_start, k_end = segments[0]
209 k_slices.append(k[k_start:k_end])
210 v_slices.append(v[k_start:k_end])
211 k_lengths.append(k_end - k_start)
212 q_lengths.append(q_end - q_start)
213 targets.append((q_start, q_end))
214 if not q_slices:
215 continue
216
217 out_pass = flash_attn_varlen(
218 torch.cat(q_slices, dim=0).contiguous(),
219 torch.cat(k_slices, dim=0).contiguous(),
220 torch.cat(v_slices, dim=0).contiguous(),
221 _make_cu_seqlens(q_lengths, q.device),
222 _make_cu_seqlens(k_lengths, q.device),
223 int(max(q_lengths)),
224 int(max(k_lengths)),
225 dropout_p=0.0,
226 softmax_scale=float(softmax_scale),
227 causal=bool(attn_type == 1),
228 )
229 if isinstance(out_pass, tuple):
230 out_pass = out_pass[0]
231
232 cursor = 0
233 for q_start, q_end in targets:
234 q_len = q_end - q_start
235 merged[q_start:q_end] = out_pass[cursor:cursor + q_len]
236 covered[q_start:q_end] = True
237 cursor += q_len
238
239 if bool(covered.all().item()):
240 return merged
241
242 merged = torch.zeros((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
243 merged_lse = torch.full((q.shape[0], q.shape[1]), -float("inf"), device=q.device, dtype=torch.float32)
244 covered = torch.zeros((q.shape[0],), device=q.device, dtype=torch.bool)
245
246 for segment_idx in range(max_segments):
247 for attn_type in (0, 1):
248 q_slices = []
249 k_slices = []
250 v_slices = []
251 q_lengths = []
252 k_lengths = []
253 targets = []
254 for q_start, q_end, group_type, segments in groups:
255 if group_type != attn_type or segment_idx >= len(segments):
256 continue
257 k_start, k_end = segments[segment_idx]
258 if k_end <= k_start:
259 continue
260 q_slices.append(q[q_start:q_end])
261 k_slices.append(k[k_start:k_end])
262 v_slices.append(v[k_start:k_end])
263 q_lengths.append(q_end - q_start)
264 k_lengths.append(k_end - k_start)
265 targets.append((q_start, q_end))
266 if not q_slices:
267 continue
268
269 result = flash_attn_varlen(
270 torch.cat(q_slices, dim=0).contiguous(),
271 torch.cat(k_slices, dim=0).contiguous(),
272 torch.cat(v_slices, dim=0).contiguous(),
273 _make_cu_seqlens(q_lengths, q.device),
274 _make_cu_seqlens(k_lengths, q.device),
275 int(max(q_lengths)),
276 int(max(k_lengths)),
277 dropout_p=0.0,
278 softmax_scale=float(softmax_scale),
279 causal=bool(attn_type == 1),
280 return_attn_probs=True,
281 )
282 if not isinstance(result, tuple) or len(result) < 2:
283 return None
284 out_pass = result[0]
285 lse_pass = _flash_lse_to_tq_h(result[1], out_pass.shape[0], q_lengths)
286 if lse_pass is None:
287 return None
288
289 cursor = 0
290 for q_start, q_end in targets:
291 q_len = q_end - q_start
292 out_seg = out_pass[cursor:cursor + q_len].float()
293 lse_seg = lse_pass[cursor:cursor + q_len]
294 old_lse = merged_lse[q_start:q_end]
295 new_lse = torch.maximum(old_lse, lse_seg)
296 old_w = torch.exp(old_lse - new_lse)
297 seg_w = torch.exp(lse_seg - new_lse)
298 denom = (old_w + seg_w).clamp_min(1e-20)
299 merged[q_start:q_end] = (
300 merged[q_start:q_end] * old_w.unsqueeze(-1)
301 + out_seg * seg_w.unsqueeze(-1)
302 ) / denom.unsqueeze(-1)
303 merged_lse[q_start:q_end] = new_lse + torch.log(denom)
304 covered[q_start:q_end] = True
305 cursor += q_len
306
307 if not bool(covered.all().item()):
308 return None
309 return merged.to(dtype=q.dtype)
310
311
312 def range_attention(
313 q,
314 k,
315 v,
316 q_ranges,
317 k_ranges,
318 attn_type_map,
319 softmax_scale: float,
320 *,
321 segment_offsets=None,
322 group_q_ranges=None,
323 group_attn_type_map=None,
324 max_q_len=None,
325 max_k_len=None,
326 flash_cu_seqlens_q=None,
327 flash_cu_seqlens_k=None,
328 flash_causal=None,
329 disjoint_q_ranges=None,
330 ):
331 """Run sparse range attention through FlashAttention varlen."""
332 del disjoint_q_ranges
333 if not q.is_cuda:
334 raise RuntimeError("LA Flash range_attention requires CUDA tensors")
335 if segment_offsets is None or group_q_ranges is None or group_attn_type_map is None:
336 (
337 group_q_ranges,
338 k_ranges,
339 segment_offsets,
340 group_attn_type_map,
341 computed_max_q_len,
342 computed_max_k_len,
343 ) = _coalesce_query_groups(q_ranges, k_ranges, attn_type_map)
344 if max_q_len is None:
345 max_q_len = computed_max_q_len
346 if max_k_len is None:
347 max_k_len = computed_max_k_len
348 elif max_q_len is None:
349 lengths = (group_q_ranges[:, 1] - group_q_ranges[:, 0]).detach().to(device="cpu")
350 max_q_len = int(lengths.max().item()) if lengths.numel() else 0
351 if max_k_len is None:
352 k_lengths = (k_ranges[:, 1] - k_ranges[:, 0]).detach().to(device="cpu")
353 max_k_len = int(k_lengths.max().item()) if k_lengths.numel() else 0
354
355 if (
356 flash_cu_seqlens_q is not None
357 and flash_cu_seqlens_k is not None
358 and flash_causal is not None
359 and _flash_fastpath_enabled()
360 and q.dtype in (torch.float16, torch.bfloat16)
361 and k.dtype == q.dtype
362 and v.dtype == q.dtype
363 ):
364 flash_attn_varlen = _load_flash_attn_varlen()
365 return flash_attn_varlen(
366 q.contiguous(),
367 k.contiguous(),
368 v.contiguous(),
369 flash_cu_seqlens_q.contiguous().to(device=q.device, dtype=torch.int32),
370 flash_cu_seqlens_k.contiguous().to(device=q.device, dtype=torch.int32),
371 int(max_q_len),
372 int(max_k_len),
373 dropout_p=0.0,
374 softmax_scale=float(softmax_scale),
375 causal=bool(flash_causal),
376 )
377
378 segment_out = _try_flash_segment_merge(
379 q,
380 k,
381 v,
382 k_ranges,
383 segment_offsets,
384 group_q_ranges,
385 group_attn_type_map,
386 softmax_scale,
387 )
388 if segment_out is not None:
389 return segment_out
390
391 raise RuntimeError(
392 "LA Flash could not express this range plan with FlashAttention varlen. "
393 "Only attn_type 0/1 range plans are supported in the release path."
394 )
395