inference/kernel.py
| 1 | import torch |
| 2 | import tilelang |
| 3 | import tilelang.language as T |
| 4 | from typing import Tuple, Optional |
| 5 | |
| 6 | |
| 7 | tilelang.set_log_level("WARNING") |
| 8 | |
| 9 | pass_configs = { |
| 10 | tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, |
| 11 | tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, |
| 12 | tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, |
| 13 | } |
| 14 | |
| 15 | FP8 = "float8_e4m3" |
| 16 | BF16 = "bfloat16" |
| 17 | FP32 = "float32" |
| 18 | |
| 19 | |
| 20 | def fast_log2_ceil(x): |
| 21 | bits_x = T.reinterpret("uint32", x) |
| 22 | exp_x = (bits_x >> 23) & 0xFF |
| 23 | man_bits = bits_x & ((1 << 23) - 1) |
| 24 | return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) |
| 25 | |
| 26 | |
| 27 | def fast_pow2(x): |
| 28 | bits_x = (x + 127) << 23 |
| 29 | return T.reinterpret("float32", bits_x) |
| 30 | |
| 31 | |
| 32 | def fast_round_scale(amax, fp8_max_inv): |
| 33 | return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) |
| 34 | |
| 35 | |
| 36 | @tilelang.jit(pass_configs=pass_configs) |
| 37 | def act_quant_kernel( |
| 38 | N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False |
| 39 | ): |
| 40 | M = T.symbolic("M") |
| 41 | fp8_min = -448.0 |
| 42 | fp8_max = 448.0 |
| 43 | fp8_max_inv = 1 / fp8_max |
| 44 | num_stages = 0 if round_scale else 2 |
| 45 | blk_m = 32 |
| 46 | group_size = 128 |
| 47 | |
| 48 | @T.prim_func |
| 49 | def act_quant_kernel_( |
| 50 | X: T.Tensor[(M, N), in_dtype], |
| 51 | Y: T.Tensor[(M, N), out_dtype], |
| 52 | S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype], |
| 53 | ): |
| 54 | with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as ( |
| 55 | pid_m, |
| 56 | pid_n, |
| 57 | ): |
| 58 | x_shared = T.alloc_shared((blk_m, group_size), in_dtype) |
| 59 | x_local = T.alloc_fragment((blk_m, group_size), in_dtype) |
| 60 | amax_local = T.alloc_fragment((blk_m,), scale_dtype) |
| 61 | s_local = T.alloc_fragment((blk_m,), scale_dtype) |
| 62 | y_local = T.alloc_fragment((blk_m, group_size), out_dtype) |
| 63 | y_shared = T.alloc_shared((blk_m, group_size), out_dtype) |
| 64 | |
| 65 | for _ in T.Pipelined(1, num_stages=num_stages): |
| 66 | T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared) |
| 67 | T.copy(x_shared, x_local) |
| 68 | T.reduce_absmax(x_local, amax_local, dim=1) |
| 69 | for i in T.Parallel(blk_m): |
| 70 | amax_local[i] = T.max(amax_local[i], 1e-4) |
| 71 | if round_scale: |
| 72 | s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv) |
| 73 | else: |
| 74 | s_local[i] = amax_local[i] * fp8_max_inv |
| 75 | for i, j in T.Parallel(blk_m, group_size): |
| 76 | y_local[i, j] = T.clamp( |
| 77 | x_local[i, j] / s_local[i], fp8_min, fp8_max |
| 78 | ) |
| 79 | for i in T.Parallel(blk_m): |
| 80 | S[pid_m * blk_m + i, pid_n] = s_local[i] |
| 81 | T.copy(y_local, y_shared) |
| 82 | T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size]) |
| 83 | |
| 84 | return act_quant_kernel_ |
| 85 | |
| 86 | |
| 87 | def act_quant( |
| 88 | x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None |
| 89 | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 90 | """ |
| 91 | Quantizes the input tensor `x` using block-wise quantization. |
| 92 | |
| 93 | Args: |
| 94 | x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. |
| 95 | block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. |
| 96 | scale_fmt (Optional[str], optional): The format of the scale. Default is None. |
| 97 | Returns: |
| 98 | Tuple[torch.Tensor, torch.Tensor]: A tuple containing: |
| 99 | - The quantized tensor with dtype `torch.float8_e4m3fn`. |
| 100 | - A tensor of scaling factors with dtype `torch.float32`. |
| 101 | """ |
| 102 | assert x.is_contiguous(), "Input tensor must be contiguous" |
| 103 | assert x.size(-1) % block_size == 0, ( |
| 104 | f"Last dimension size must be divisible by block_size (block_size={block_size})" |
| 105 | ) |
| 106 | N = x.size(-1) |
| 107 | y = torch.empty_like(x, dtype=torch.float8_e4m3fn) |
| 108 | s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) |
| 109 | kernel = act_quant_kernel(N, round_scale=scale_fmt is not None) |
| 110 | kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size)) |
| 111 | return y, s |
| 112 | |
| 113 | |
| 114 | @tilelang.jit(pass_configs=pass_configs) |
| 115 | def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"): |
| 116 | assert out_dtype in [BF16, "float32"] |
| 117 | |
| 118 | M = T.symbolic("M") |
| 119 | group_size = 128 |
| 120 | block_M = 32 |
| 121 | block_N = 128 |
| 122 | block_K = 128 |
| 123 | |
| 124 | @T.prim_func |
| 125 | def fp8_gemm_kernel_( |
| 126 | A: T.Tensor[(M, K), FP8], |
| 127 | B: T.Tensor[(N, K), FP8], |
| 128 | C: T.Tensor[(M, N), out_dtype], |
| 129 | scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32], |
| 130 | scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), FP32], |
| 131 | ): |
| 132 | with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( |
| 133 | bx, |
| 134 | by, |
| 135 | ): |
| 136 | A_shared = T.alloc_shared((block_M, block_K), FP8) |
| 137 | B_shared = T.alloc_shared((block_N, block_K), FP8) |
| 138 | C_shared = T.alloc_shared((block_M, block_N), out_dtype) |
| 139 | Scale_C_shared = T.alloc_shared((block_M), FP32) |
| 140 | C_local = T.alloc_fragment((block_M, block_N), accum_dtype) |
| 141 | C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype) |
| 142 | |
| 143 | # Improve L2 Cache |
| 144 | T.use_swizzle(panel_size=10) |
| 145 | |
| 146 | T.clear(C_local) |
| 147 | T.clear(C_local_accum) |
| 148 | K_iters = T.ceildiv(K, block_K) |
| 149 | for k in T.Pipelined(K_iters, num_stages=4): |
| 150 | # Load A into shared memory |
| 151 | T.copy(A[by * block_M, k * block_K], A_shared) |
| 152 | # Load B into shared memory |
| 153 | T.copy(B[bx * block_N, k * block_K], B_shared) |
| 154 | # Load scale into shared memory |
| 155 | Scale_B = scales_b[bx * block_N // group_size, k] |
| 156 | for i in T.Parallel(block_M): |
| 157 | Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B |
| 158 | |
| 159 | T.gemm(A_shared, B_shared, C_local, transpose_B=True) |
| 160 | # Promote to enable 2xAcc |
| 161 | for i, j in T.Parallel(block_M, block_N): |
| 162 | C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] |
| 163 | T.clear(C_local) |
| 164 | # TMA store |
| 165 | T.copy(C_local_accum, C_shared) |
| 166 | T.copy(C_shared, C[by * block_M, bx * block_N]) |
| 167 | |
| 168 | return fp8_gemm_kernel_ |
| 169 | |
| 170 | |
| 171 | def fp8_gemm( |
| 172 | a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor |
| 173 | ) -> torch.Tensor: |
| 174 | """ |
| 175 | Perform a matrix multiplication using FP8 precision. |
| 176 | |
| 177 | Args: |
| 178 | a (torch.Tensor): The first input matrix, must be contiguous. |
| 179 | a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. |
| 180 | b (torch.Tensor): The second input matrix, must be contiguous. |
| 181 | b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. |
| 182 | |
| 183 | Returns: |
| 184 | torch.Tensor: The result of the matrix multiplication. |
| 185 | """ |
| 186 | assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" |
| 187 | assert a_s.is_contiguous() and b_s.is_contiguous(), ( |
| 188 | "Scaling factor tensors must be contiguous" |
| 189 | ) |
| 190 | K = a.size(-1) |
| 191 | M = a.numel() // K |
| 192 | N = b.size(0) |
| 193 | c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) |
| 194 | kernel = fp8_gemm_kernel(N, K) |
| 195 | kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s) |
| 196 | return c |
| 197 | |
| 198 | |
| 199 | @tilelang.jit(out_idx=[4], pass_configs=pass_configs) |
| 200 | def fp8_index_kernel(h: int, d: int): |
| 201 | b = T.symbolic("b") |
| 202 | m = T.symbolic("m") |
| 203 | n = T.symbolic("n") |
| 204 | |
| 205 | blk_n1 = 512 |
| 206 | blk_n2 = 128 |
| 207 | |
| 208 | @T.prim_func |
| 209 | def fp8_index_kernel_( |
| 210 | q: T.Tensor[(b, m, h, d), FP8], |
| 211 | q_s: T.Tensor[(b, m, h), FP32], |
| 212 | k: T.Tensor[(b, n, d), FP8], |
| 213 | k_s: T.Tensor[(b, n), FP32], |
| 214 | o: T.Tensor[(b, m, n), FP32], |
| 215 | ) -> None: |
| 216 | with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n): |
| 217 | q_smem = T.alloc_shared((h, d), FP8) |
| 218 | T.copy(q[i_b, i_m, 0, 0], q_smem) |
| 219 | |
| 220 | q_s_frag = T.alloc_fragment(h, FP32) |
| 221 | T.copy(q_s[i_b, i_m, 0], q_s_frag) |
| 222 | |
| 223 | for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2): |
| 224 | k_smem = T.alloc_shared((blk_n2, d), FP8) |
| 225 | T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem) |
| 226 | |
| 227 | k_s_frag = T.alloc_fragment(blk_n2, FP32) |
| 228 | T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag) |
| 229 | |
| 230 | logits = T.alloc_fragment((blk_n2, h), FP32) |
| 231 | T.gemm( |
| 232 | k_smem, |
| 233 | q_smem, |
| 234 | logits, |
| 235 | transpose_A=False, |
| 236 | transpose_B=True, |
| 237 | clear_accum=True, |
| 238 | ) |
| 239 | |
| 240 | for i_h, i3_n in T.Parallel(h, blk_n2): |
| 241 | logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h] |
| 242 | |
| 243 | logits_sum = T.alloc_fragment(blk_n2, FP32) |
| 244 | T.reduce_sum(logits, logits_sum, dim=1) |
| 245 | |
| 246 | for i3_n in T.Parallel(blk_n2): |
| 247 | logits_sum[i3_n] *= k_s_frag[i3_n] |
| 248 | |
| 249 | T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2]) |
| 250 | |
| 251 | return fp8_index_kernel_ |
| 252 | |
| 253 | |
| 254 | def fp8_index( |
| 255 | q: torch.Tensor, |
| 256 | q_s: torch.Tensor, |
| 257 | k: torch.Tensor, |
| 258 | k_s: torch.Tensor, |
| 259 | ) -> torch.Tensor: |
| 260 | """ |
| 261 | Perform index score using FP8 precision. |
| 262 | |
| 263 | Args: |
| 264 | q (torch.Tensor): The Q tensor, must be contiguous. |
| 265 | q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous. |
| 266 | k (torch.Tensor): The K tensor, must be contiguous. |
| 267 | k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous. |
| 268 | |
| 269 | fp8 q @ fp8 k -> fp32 logits |
| 270 | relu(fp32 logits) * q_s (weights) -> fp32 logits |
| 271 | fp32 logits -> fp32 logits_sum |
| 272 | fp32 logits_sum * k_s (e8m0) -> fp32 index_score |
| 273 | """ |
| 274 | return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s) |
| 275 | |