kernel_utils/README.md
| 1 | # LA Flash Utils |
| 2 | |
| 3 | This folder contains the sparse attention utilities used by |
| 4 | `LA_FLASH_ATTN=la_flash`. The release path is implemented with |
| 5 | FlashAttention varlen over LocateAnything range plans. It does not include or |
| 6 | build a local C++/CUDA extension. |
| 7 | |
| 8 | ## Features |
| 9 | |
| 10 | - Supports batched LocateAnything hybrid MTP inference on A100, RTX 4090, and H100. |
| 11 | - Consumes Magi-style `q_ranges`, `k_ranges`, `segment_offsets`, and |
| 12 | `attn_type_map` plans generated by `batch_utils.hybrid_runtime`. |
| 13 | - Uses FlashAttention varlen for packed causal/full plans. |
| 14 | - Packs LocateAnything MTP full-window key segments before calling |
| 15 | FlashAttention, avoiding dense `[B,H,Q,K]` masks. |
| 16 | - Supports log-sum-exp merging for compatible non-packed multi-segment plans. |
| 17 | |
| 18 | ## Attention Types |
| 19 | |
| 20 | The release path intentionally supports only FlashAttention-compatible plan |
| 21 | types: |
| 22 | |
| 23 | | Value | Meaning | |
| 24 | | --- | --- | |
| 25 | | `0` | Full attention over the listed key segment or packed key segments. | |
| 26 | | `1` | Bottom-right causal attention. | |
| 27 | |
| 28 | ## How It Works |
| 29 | |
| 30 | `batch_utils.hybrid_runtime` builds sparse range plans for the text decoder. |
| 31 | Each plan describes which query token intervals attend to which key/value token |
| 32 | intervals. `kernel_utils.range_attention` executes those plans with |
| 33 | FlashAttention instead of materializing dense SDPA masks. |
| 34 | |
| 35 | The runtime follows three paths: |
| 36 | |
| 37 | - **Packed simple plans:** when each query range maps to one contiguous |
| 38 | key/value range, LA Flash flattens the selected ranges, builds FlashAttention |
| 39 | `cu_seqlens_q` / `cu_seqlens_k`, and calls `flash_attn_varlen_func` directly. |
| 40 | - **Packed MTP full-window plans:** for hybrid MTP decode, multiple full |
| 41 | key/value windows for the same query block are concatenated into one packed |
| 42 | key/value sequence before the FlashAttention call. This keeps the sparse |
| 43 | memory profile without constructing a `[B,H,Q,K]` attention mask. |
| 44 | - **Compatible multi-segment plans:** when a query range attends to multiple |
| 45 | segments that cannot be packed as one sequence, each segment is evaluated with |
| 46 | FlashAttention and the partial outputs are merged with the standard |
| 47 | log-sum-exp softmax composition. |
| 48 | |
| 49 | The output tensor shape and dtype match the decoder attention output expected |
| 50 | by the model. This path is inference-oriented and depends on FlashAttention's |
| 51 | forward kernels; it is not a custom autograd training backend. |
| 52 | |
| 53 | ## Runtime Knobs |
| 54 | |
| 55 | | Variable | Default | Meaning | |
| 56 | | --- | --- | --- | |
| 57 | | `LA_FLASH_ATTN` | `sdpa` | Set to `la_flash` to enable this backend through `batch_utils`. | |
| 58 | | `LA_FLASH_FASTPATH` | `auto` | Use FlashAttention varlen for packed simple plans. | |
| 59 | | `LA_FLASH_SEGMENT_FASTPATH` | `auto` | Use FlashAttention varlen for multi-segment sparse plans. Full segments are packed first; other compatible segments use LSE merging. | |
| 60 | | `LA_FLASH_PLAN_STATS` | `0` | Record sparse plan statistics in inference summaries. | |
| 61 | |
| 62 | ## Notes |
| 63 | |
| 64 | Dense prefill and stock worker-style generation should keep |
| 65 | `LA_FLASH_DENSE_BACKEND=sdpa`; LA Flash is used for sparse range plans |
| 66 | produced by `batch_utils`. |
| 67 | |
| 68 | This package is for inference and evaluation. Training remains on the |
| 69 | MagiAttention backend; the batched sparse-plan decode runtime does not support |
| 70 | the `labels` training path. |
| 71 | |
| 72 | ## Source Layout |
| 73 | |
| 74 | - `range_attention.py`: FlashAttention varlen dispatch, sparse KV packing, LSE |
| 75 | merge fallback, and availability checks. |
| 76 | |