kernel_utils/range_attention.py
| 1 | """Sparse LocateAnything attention implemented with FlashAttention varlen. |
| 2 | |
| 3 | The public API accepts flattened query/key/value tensors: |
| 4 | |
| 5 | q: [total_q, num_q_heads, head_dim] |
| 6 | k: [total_k, num_kv_heads, head_dim] |
| 7 | v: [total_k, num_kv_heads, head_dim] |
| 8 | |
| 9 | and a Magi-style range plan: |
| 10 | |
| 11 | q_ranges: [num_ranges, 2] |
| 12 | k_ranges: [num_key_segments, 2] |
| 13 | segment_offsets: [num_query_groups + 1] |
| 14 | attn_type_map: |
| 15 | 0 = full attention over the listed key segment(s) |
| 16 | 1 = bottom-right causal attention |
| 17 | |
| 18 | For LocateAnything hybrid MTP decode, batch_utils represents the window as a |
| 19 | causal prefix plus full-attention sparse window segments. This module packs |
| 20 | those visible KV segments and calls FlashAttention varlen, avoiding dense masks. |
| 21 | """ |
| 22 | from __future__ import annotations |
| 23 | |
| 24 | import os |
| 25 | from typing import Optional |
| 26 | |
| 27 | import torch |
| 28 | |
| 29 | |
| 30 | _FLASH_ATTN_VARLEN = None |
| 31 | _FLASH_ATTN_ERROR: Optional[BaseException] = None |
| 32 | |
| 33 | |
| 34 | def _env_enabled(name: str, default: str = "auto") -> bool: |
| 35 | value = os.environ.get(name, default).strip().lower() |
| 36 | return value in {"", "auto", "1", "on", "true", "yes", "force"} |
| 37 | |
| 38 | |
| 39 | def is_available() -> bool: |
| 40 | try: |
| 41 | _load_flash_attn_varlen() |
| 42 | return True |
| 43 | except Exception: |
| 44 | return False |
| 45 | |
| 46 | |
| 47 | def _flash_fastpath_enabled() -> bool: |
| 48 | return _env_enabled("LA_FLASH_FASTPATH", "auto") |
| 49 | |
| 50 | |
| 51 | def _flash_segment_fastpath_enabled() -> bool: |
| 52 | return _env_enabled("LA_FLASH_SEGMENT_FASTPATH", "auto") |
| 53 | |
| 54 | |
| 55 | def _load_flash_attn_varlen(): |
| 56 | global _FLASH_ATTN_VARLEN, _FLASH_ATTN_ERROR |
| 57 | if _FLASH_ATTN_VARLEN is not None: |
| 58 | return _FLASH_ATTN_VARLEN |
| 59 | if _FLASH_ATTN_ERROR is not None: |
| 60 | raise _FLASH_ATTN_ERROR |
| 61 | try: |
| 62 | from flash_attn import flash_attn_varlen_func |
| 63 | |
| 64 | _FLASH_ATTN_VARLEN = flash_attn_varlen_func |
| 65 | return _FLASH_ATTN_VARLEN |
| 66 | except BaseException as exc: |
| 67 | _FLASH_ATTN_ERROR = exc |
| 68 | raise |
| 69 | |
| 70 | |
| 71 | def _coalesce_query_groups(q_ranges, k_ranges, attn_type_map): |
| 72 | """Group consecutive entries that share the same query span and mask type.""" |
| 73 | if q_ranges.numel() == 0: |
| 74 | segment_offsets = torch.zeros((1,), dtype=torch.int32, device=q_ranges.device) |
| 75 | return q_ranges, k_ranges, segment_offsets, attn_type_map, 0, 0 |
| 76 | |
| 77 | q_cpu = q_ranges.detach().to(device="cpu", dtype=torch.int32).contiguous() |
| 78 | t_cpu = attn_type_map.detach().to(device="cpu", dtype=torch.int32).contiguous() |
| 79 | grouped_q = [] |
| 80 | grouped_t = [] |
| 81 | offsets = [0] |
| 82 | max_q_len = 0 |
| 83 | last_q = None |
| 84 | last_t = None |
| 85 | for idx, (qr, attn_type) in enumerate(zip(q_cpu.tolist(), t_cpu.tolist())): |
| 86 | key = (int(qr[0]), int(qr[1])) |
| 87 | attn_type = int(attn_type) |
| 88 | if attn_type not in (0, 1): |
| 89 | raise RuntimeError( |
| 90 | "LA Flash path only supports FlashAttention-compatible attn_type 0/1. " |
| 91 | f"Got attn_type={attn_type}; regenerate a type 0/1 range plan." |
| 92 | ) |
| 93 | if last_q is None: |
| 94 | grouped_q.append([key[0], key[1]]) |
| 95 | grouped_t.append(attn_type) |
| 96 | max_q_len = max(max_q_len, key[1] - key[0]) |
| 97 | last_q = key |
| 98 | last_t = attn_type |
| 99 | continue |
| 100 | if key == last_q and attn_type == last_t: |
| 101 | continue |
| 102 | offsets.append(idx) |
| 103 | grouped_q.append([key[0], key[1]]) |
| 104 | grouped_t.append(attn_type) |
| 105 | max_q_len = max(max_q_len, key[1] - key[0]) |
| 106 | last_q = key |
| 107 | last_t = attn_type |
| 108 | offsets.append(int(q_ranges.shape[0])) |
| 109 | |
| 110 | k_cpu = k_ranges.detach().to(device="cpu", dtype=torch.int32).contiguous() |
| 111 | max_k_len = max((int(end) - int(start) for start, end in k_cpu.tolist()), default=0) |
| 112 | |
| 113 | return ( |
| 114 | torch.tensor(grouped_q, dtype=torch.int32, device=q_ranges.device).contiguous(), |
| 115 | k_ranges, |
| 116 | torch.tensor(offsets, dtype=torch.int32, device=q_ranges.device).contiguous(), |
| 117 | torch.tensor(grouped_t, dtype=torch.int32, device=q_ranges.device).contiguous(), |
| 118 | int(max_q_len), |
| 119 | int(max_k_len), |
| 120 | ) |
| 121 | |
| 122 | |
| 123 | def _flash_lse_to_tq_h(lse, total_q, q_lengths=None): |
| 124 | if lse is None: |
| 125 | return None |
| 126 | if lse.dim() != 2: |
| 127 | if lse.dim() == 3 and q_lengths is not None and lse.shape[0] == len(q_lengths): |
| 128 | chunks = [] |
| 129 | for idx, q_len in enumerate(q_lengths): |
| 130 | q_len = int(q_len) |
| 131 | if lse.shape[1] == 0 or q_len > lse.shape[2]: |
| 132 | return None |
| 133 | chunks.append(lse[idx, :, :q_len].transpose(0, 1).contiguous()) |
| 134 | merged = torch.cat(chunks, dim=0).float() |
| 135 | return merged if merged.shape[0] == total_q else None |
| 136 | return None |
| 137 | if lse.shape[0] == total_q: |
| 138 | return lse.float() |
| 139 | if lse.shape[1] == total_q: |
| 140 | return lse.transpose(0, 1).contiguous().float() |
| 141 | return None |
| 142 | |
| 143 | |
| 144 | def _make_cu_seqlens(lengths, device): |
| 145 | return torch.tensor([0] + list(torch.tensor(lengths).cumsum(0).tolist()), device=device, dtype=torch.int32) |
| 146 | |
| 147 | |
| 148 | def _try_flash_segment_merge( |
| 149 | q, |
| 150 | k, |
| 151 | v, |
| 152 | k_ranges, |
| 153 | segment_offsets, |
| 154 | group_q_ranges, |
| 155 | group_attn_type_map, |
| 156 | softmax_scale, |
| 157 | ): |
| 158 | if not _flash_segment_fastpath_enabled(): |
| 159 | return None |
| 160 | if q.dtype not in (torch.float16, torch.bfloat16) or k.dtype != q.dtype or v.dtype != q.dtype: |
| 161 | return None |
| 162 | if group_q_ranges is None or segment_offsets is None or group_attn_type_map is None: |
| 163 | return None |
| 164 | |
| 165 | flash_attn_varlen = _load_flash_attn_varlen() |
| 166 | gq_cpu = group_q_ranges.detach().to(device="cpu", dtype=torch.int32).contiguous() |
| 167 | kr_cpu = k_ranges.detach().to(device="cpu", dtype=torch.int32).contiguous() |
| 168 | seg_cpu = segment_offsets.detach().to(device="cpu", dtype=torch.int32).contiguous() |
| 169 | type_cpu = group_attn_type_map.detach().to(device="cpu", dtype=torch.int32).contiguous() |
| 170 | |
| 171 | groups = [] |
| 172 | max_segments = 0 |
| 173 | for group_idx, (q_start, q_end) in enumerate(gq_cpu.tolist()): |
| 174 | attn_type = int(type_cpu[group_idx].item()) |
| 175 | if attn_type not in (0, 1): |
| 176 | return None |
| 177 | seg_start = int(seg_cpu[group_idx].item()) |
| 178 | seg_end = int(seg_cpu[group_idx + 1].item()) |
| 179 | if seg_end <= seg_start or q_end <= q_start: |
| 180 | return None |
| 181 | segments = kr_cpu[seg_start:seg_end].tolist() |
| 182 | max_segments = max(max_segments, len(segments)) |
| 183 | groups.append((int(q_start), int(q_end), attn_type, [(int(a), int(b)) for a, b in segments])) |
| 184 | |
| 185 | if not groups or max_segments == 0: |
| 186 | return None |
| 187 | |
| 188 | can_pack_full_groups = all(attn_type == 0 or len(segments) == 1 for _, _, attn_type, segments in groups) |
| 189 | if can_pack_full_groups: |
| 190 | merged = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=q.dtype) |
| 191 | covered = torch.zeros((q.shape[0],), device=q.device, dtype=torch.bool) |
| 192 | for attn_type in (0, 1): |
| 193 | q_slices = [] |
| 194 | k_slices = [] |
| 195 | v_slices = [] |
| 196 | q_lengths = [] |
| 197 | k_lengths = [] |
| 198 | targets = [] |
| 199 | for q_start, q_end, group_type, segments in groups: |
| 200 | if group_type != attn_type: |
| 201 | continue |
| 202 | q_slices.append(q[q_start:q_end]) |
| 203 | if attn_type == 0 and len(segments) > 1: |
| 204 | k_slices.append(torch.cat([k[start:end] for start, end in segments], dim=0)) |
| 205 | v_slices.append(torch.cat([v[start:end] for start, end in segments], dim=0)) |
| 206 | k_lengths.append(sum(end - start for start, end in segments)) |
| 207 | else: |
| 208 | k_start, k_end = segments[0] |
| 209 | k_slices.append(k[k_start:k_end]) |
| 210 | v_slices.append(v[k_start:k_end]) |
| 211 | k_lengths.append(k_end - k_start) |
| 212 | q_lengths.append(q_end - q_start) |
| 213 | targets.append((q_start, q_end)) |
| 214 | if not q_slices: |
| 215 | continue |
| 216 | |
| 217 | out_pass = flash_attn_varlen( |
| 218 | torch.cat(q_slices, dim=0).contiguous(), |
| 219 | torch.cat(k_slices, dim=0).contiguous(), |
| 220 | torch.cat(v_slices, dim=0).contiguous(), |
| 221 | _make_cu_seqlens(q_lengths, q.device), |
| 222 | _make_cu_seqlens(k_lengths, q.device), |
| 223 | int(max(q_lengths)), |
| 224 | int(max(k_lengths)), |
| 225 | dropout_p=0.0, |
| 226 | softmax_scale=float(softmax_scale), |
| 227 | causal=bool(attn_type == 1), |
| 228 | ) |
| 229 | if isinstance(out_pass, tuple): |
| 230 | out_pass = out_pass[0] |
| 231 | |
| 232 | cursor = 0 |
| 233 | for q_start, q_end in targets: |
| 234 | q_len = q_end - q_start |
| 235 | merged[q_start:q_end] = out_pass[cursor:cursor + q_len] |
| 236 | covered[q_start:q_end] = True |
| 237 | cursor += q_len |
| 238 | |
| 239 | if bool(covered.all().item()): |
| 240 | return merged |
| 241 | |
| 242 | merged = torch.zeros((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) |
| 243 | merged_lse = torch.full((q.shape[0], q.shape[1]), -float("inf"), device=q.device, dtype=torch.float32) |
| 244 | covered = torch.zeros((q.shape[0],), device=q.device, dtype=torch.bool) |
| 245 | |
| 246 | for segment_idx in range(max_segments): |
| 247 | for attn_type in (0, 1): |
| 248 | q_slices = [] |
| 249 | k_slices = [] |
| 250 | v_slices = [] |
| 251 | q_lengths = [] |
| 252 | k_lengths = [] |
| 253 | targets = [] |
| 254 | for q_start, q_end, group_type, segments in groups: |
| 255 | if group_type != attn_type or segment_idx >= len(segments): |
| 256 | continue |
| 257 | k_start, k_end = segments[segment_idx] |
| 258 | if k_end <= k_start: |
| 259 | continue |
| 260 | q_slices.append(q[q_start:q_end]) |
| 261 | k_slices.append(k[k_start:k_end]) |
| 262 | v_slices.append(v[k_start:k_end]) |
| 263 | q_lengths.append(q_end - q_start) |
| 264 | k_lengths.append(k_end - k_start) |
| 265 | targets.append((q_start, q_end)) |
| 266 | if not q_slices: |
| 267 | continue |
| 268 | |
| 269 | result = flash_attn_varlen( |
| 270 | torch.cat(q_slices, dim=0).contiguous(), |
| 271 | torch.cat(k_slices, dim=0).contiguous(), |
| 272 | torch.cat(v_slices, dim=0).contiguous(), |
| 273 | _make_cu_seqlens(q_lengths, q.device), |
| 274 | _make_cu_seqlens(k_lengths, q.device), |
| 275 | int(max(q_lengths)), |
| 276 | int(max(k_lengths)), |
| 277 | dropout_p=0.0, |
| 278 | softmax_scale=float(softmax_scale), |
| 279 | causal=bool(attn_type == 1), |
| 280 | return_attn_probs=True, |
| 281 | ) |
| 282 | if not isinstance(result, tuple) or len(result) < 2: |
| 283 | return None |
| 284 | out_pass = result[0] |
| 285 | lse_pass = _flash_lse_to_tq_h(result[1], out_pass.shape[0], q_lengths) |
| 286 | if lse_pass is None: |
| 287 | return None |
| 288 | |
| 289 | cursor = 0 |
| 290 | for q_start, q_end in targets: |
| 291 | q_len = q_end - q_start |
| 292 | out_seg = out_pass[cursor:cursor + q_len].float() |
| 293 | lse_seg = lse_pass[cursor:cursor + q_len] |
| 294 | old_lse = merged_lse[q_start:q_end] |
| 295 | new_lse = torch.maximum(old_lse, lse_seg) |
| 296 | old_w = torch.exp(old_lse - new_lse) |
| 297 | seg_w = torch.exp(lse_seg - new_lse) |
| 298 | denom = (old_w + seg_w).clamp_min(1e-20) |
| 299 | merged[q_start:q_end] = ( |
| 300 | merged[q_start:q_end] * old_w.unsqueeze(-1) |
| 301 | + out_seg * seg_w.unsqueeze(-1) |
| 302 | ) / denom.unsqueeze(-1) |
| 303 | merged_lse[q_start:q_end] = new_lse + torch.log(denom) |
| 304 | covered[q_start:q_end] = True |
| 305 | cursor += q_len |
| 306 | |
| 307 | if not bool(covered.all().item()): |
| 308 | return None |
| 309 | return merged.to(dtype=q.dtype) |
| 310 | |
| 311 | |
| 312 | def range_attention( |
| 313 | q, |
| 314 | k, |
| 315 | v, |
| 316 | q_ranges, |
| 317 | k_ranges, |
| 318 | attn_type_map, |
| 319 | softmax_scale: float, |
| 320 | *, |
| 321 | segment_offsets=None, |
| 322 | group_q_ranges=None, |
| 323 | group_attn_type_map=None, |
| 324 | max_q_len=None, |
| 325 | max_k_len=None, |
| 326 | flash_cu_seqlens_q=None, |
| 327 | flash_cu_seqlens_k=None, |
| 328 | flash_causal=None, |
| 329 | disjoint_q_ranges=None, |
| 330 | ): |
| 331 | """Run sparse range attention through FlashAttention varlen.""" |
| 332 | del disjoint_q_ranges |
| 333 | if not q.is_cuda: |
| 334 | raise RuntimeError("LA Flash range_attention requires CUDA tensors") |
| 335 | if segment_offsets is None or group_q_ranges is None or group_attn_type_map is None: |
| 336 | ( |
| 337 | group_q_ranges, |
| 338 | k_ranges, |
| 339 | segment_offsets, |
| 340 | group_attn_type_map, |
| 341 | computed_max_q_len, |
| 342 | computed_max_k_len, |
| 343 | ) = _coalesce_query_groups(q_ranges, k_ranges, attn_type_map) |
| 344 | if max_q_len is None: |
| 345 | max_q_len = computed_max_q_len |
| 346 | if max_k_len is None: |
| 347 | max_k_len = computed_max_k_len |
| 348 | elif max_q_len is None: |
| 349 | lengths = (group_q_ranges[:, 1] - group_q_ranges[:, 0]).detach().to(device="cpu") |
| 350 | max_q_len = int(lengths.max().item()) if lengths.numel() else 0 |
| 351 | if max_k_len is None: |
| 352 | k_lengths = (k_ranges[:, 1] - k_ranges[:, 0]).detach().to(device="cpu") |
| 353 | max_k_len = int(k_lengths.max().item()) if k_lengths.numel() else 0 |
| 354 | |
| 355 | if ( |
| 356 | flash_cu_seqlens_q is not None |
| 357 | and flash_cu_seqlens_k is not None |
| 358 | and flash_causal is not None |
| 359 | and _flash_fastpath_enabled() |
| 360 | and q.dtype in (torch.float16, torch.bfloat16) |
| 361 | and k.dtype == q.dtype |
| 362 | and v.dtype == q.dtype |
| 363 | ): |
| 364 | flash_attn_varlen = _load_flash_attn_varlen() |
| 365 | return flash_attn_varlen( |
| 366 | q.contiguous(), |
| 367 | k.contiguous(), |
| 368 | v.contiguous(), |
| 369 | flash_cu_seqlens_q.contiguous().to(device=q.device, dtype=torch.int32), |
| 370 | flash_cu_seqlens_k.contiguous().to(device=q.device, dtype=torch.int32), |
| 371 | int(max_q_len), |
| 372 | int(max_k_len), |
| 373 | dropout_p=0.0, |
| 374 | softmax_scale=float(softmax_scale), |
| 375 | causal=bool(flash_causal), |
| 376 | ) |
| 377 | |
| 378 | segment_out = _try_flash_segment_merge( |
| 379 | q, |
| 380 | k, |
| 381 | v, |
| 382 | k_ranges, |
| 383 | segment_offsets, |
| 384 | group_q_ranges, |
| 385 | group_attn_type_map, |
| 386 | softmax_scale, |
| 387 | ) |
| 388 | if segment_out is not None: |
| 389 | return segment_out |
| 390 | |
| 391 | raise RuntimeError( |
| 392 | "LA Flash could not express this range plan with FlashAttention varlen. " |
| 393 | "Only attn_type 0/1 range plans are supported in the release path." |
| 394 | ) |
| 395 | |