kernel_utils/README.md
3.3 KB · 76 lines · markdown Raw
1 # LA Flash Utils
2
3 This folder contains the sparse attention utilities used by
4 `LA_FLASH_ATTN=la_flash`. The release path is implemented with
5 FlashAttention varlen over LocateAnything range plans. It does not include or
6 build a local C++/CUDA extension.
7
8 ## Features
9
10 - Supports batched LocateAnything hybrid MTP inference on A100, RTX 4090, and H100.
11 - Consumes Magi-style `q_ranges`, `k_ranges`, `segment_offsets`, and
12 `attn_type_map` plans generated by `batch_utils.hybrid_runtime`.
13 - Uses FlashAttention varlen for packed causal/full plans.
14 - Packs LocateAnything MTP full-window key segments before calling
15 FlashAttention, avoiding dense `[B,H,Q,K]` masks.
16 - Supports log-sum-exp merging for compatible non-packed multi-segment plans.
17
18 ## Attention Types
19
20 The release path intentionally supports only FlashAttention-compatible plan
21 types:
22
23 | Value | Meaning |
24 | --- | --- |
25 | `0` | Full attention over the listed key segment or packed key segments. |
26 | `1` | Bottom-right causal attention. |
27
28 ## How It Works
29
30 `batch_utils.hybrid_runtime` builds sparse range plans for the text decoder.
31 Each plan describes which query token intervals attend to which key/value token
32 intervals. `kernel_utils.range_attention` executes those plans with
33 FlashAttention instead of materializing dense SDPA masks.
34
35 The runtime follows three paths:
36
37 - **Packed simple plans:** when each query range maps to one contiguous
38 key/value range, LA Flash flattens the selected ranges, builds FlashAttention
39 `cu_seqlens_q` / `cu_seqlens_k`, and calls `flash_attn_varlen_func` directly.
40 - **Packed MTP full-window plans:** for hybrid MTP decode, multiple full
41 key/value windows for the same query block are concatenated into one packed
42 key/value sequence before the FlashAttention call. This keeps the sparse
43 memory profile without constructing a `[B,H,Q,K]` attention mask.
44 - **Compatible multi-segment plans:** when a query range attends to multiple
45 segments that cannot be packed as one sequence, each segment is evaluated with
46 FlashAttention and the partial outputs are merged with the standard
47 log-sum-exp softmax composition.
48
49 The output tensor shape and dtype match the decoder attention output expected
50 by the model. This path is inference-oriented and depends on FlashAttention's
51 forward kernels; it is not a custom autograd training backend.
52
53 ## Runtime Knobs
54
55 | Variable | Default | Meaning |
56 | --- | --- | --- |
57 | `LA_FLASH_ATTN` | `sdpa` | Set to `la_flash` to enable this backend through `batch_utils`. |
58 | `LA_FLASH_FASTPATH` | `auto` | Use FlashAttention varlen for packed simple plans. |
59 | `LA_FLASH_SEGMENT_FASTPATH` | `auto` | Use FlashAttention varlen for multi-segment sparse plans. Full segments are packed first; other compatible segments use LSE merging. |
60 | `LA_FLASH_PLAN_STATS` | `0` | Record sparse plan statistics in inference summaries. |
61
62 ## Notes
63
64 Dense prefill and stock worker-style generation should keep
65 `LA_FLASH_DENSE_BACKEND=sdpa`; LA Flash is used for sparse range plans
66 produced by `batch_utils`.
67
68 This package is for inference and evaluation. Training remains on the
69 MagiAttention backend; the batched sparse-plan decode runtime does not support
70 the `labels` training path.
71
72 ## Source Layout
73
74 - `range_attention.py`: FlashAttention varlen dispatch, sparse KV packing, LSE
75 merge fallback, and availability checks.
76