inference/model.py
37.7 KB · 923 lines · python Raw
1 import math
2 from dataclasses import dataclass
3 from typing import Tuple, Optional, Literal
4
5 import torch
6 from torch import nn
7 import torch.nn.functional as F
8 import torch.distributed as dist
9
10 from kernel import act_quant, fp8_gemm, fp8_index
11
12
13 world_size = 1
14 rank = 0
15 block_size = 128
16
17 @dataclass
18 class ModelArgs:
19 """
20 Data class for defining model arguments and hyperparameters.
21
22 Attributes:
23 max_batch_size (int): Maximum batch size.
24 max_seq_len (int): Maximum sequence length.
25 dtype (Literal["bf16", "fp8"]): Data type for computations.
26 scale_fmt (Optional[str]): Format for quantization scale.
27 vocab_size (int): Vocabulary size.
28 dim (int): Model dimension.
29 inter_dim (int): Intermediate dimension for MLP layers.
30 moe_inter_dim (int): Intermediate dimension for MoE layers.
31 n_layers (int): Number of transformer layers.
32 n_dense_layers (int): Number of dense layers in the model.
33 n_heads (int): Number of attention heads.
34 n_routed_experts (int): Number of routed experts for MoE layers.
35 n_shared_experts (int): Number of shared experts for MoE layers.
36 n_activated_experts (int): Number of activated experts in MoE layers.
37 n_expert_groups (int): Number of expert groups.
38 n_limited_groups (int): Number of limited groups for MoE routing.
39 score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
40 route_scale (float): Scaling factor for routing scores.
41 q_lora_rank (int): LoRA rank for query projections.
42 kv_lora_rank (int): LoRA rank for key-value projections.
43 qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
44 qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
45 v_head_dim (int): Dimension for value projections.
46 original_seq_len (int): Original sequence length.
47 rope_theta (float): Base for rotary positional encoding.
48 rope_factor (float): Scaling factor for extended sequence lengths.
49 beta_fast (int): Fast beta correction factor.
50 beta_slow (int): Slow beta correction factor.
51 mscale (float): Scaling factor for extended attention.
52 index_head_dim (int): Dimension for index head.
53 index_topk (int): Top-k for index head.
54 """
55 max_batch_size: int = 8
56 max_seq_len: int = 4096 * 4
57 dtype: Literal["bf16", "fp8"] = "bf16"
58 scale_fmt: Optional[str] = None
59 vocab_size: int = 102400
60 dim: int = 2048
61 inter_dim: int = 10944
62 moe_inter_dim: int = 1408
63 n_layers: int = 27
64 n_dense_layers: int = 1
65 n_heads: int = 16
66 # moe
67 n_routed_experts: int = 64
68 n_shared_experts: int = 2
69 n_activated_experts: int = 6
70 n_expert_groups: int = 1
71 n_limited_groups: int = 1
72 score_func: Literal["softmax", "sigmoid"] = "softmax"
73 route_scale: float = 1.
74 # mla
75 q_lora_rank: int = 0
76 kv_lora_rank: int = 512
77 qk_nope_head_dim: int = 128
78 qk_rope_head_dim: int = 64
79 v_head_dim: int = 128
80 # yarn
81 original_seq_len: int = 4096
82 rope_theta: float = 10000.0
83 rope_factor: float = 40
84 beta_fast: int = 32
85 beta_slow: int = 1
86 mscale: float = 1.
87 # index
88 index_n_heads: int = 64
89 index_head_dim: int = 128
90 index_topk: int = 2048
91
92 class ParallelEmbedding(nn.Module):
93 """
94 Embedding layer with parallelism support across distributed processes.
95
96 Args:
97 vocab_size (int): Vocabulary size.
98 dim (int): Embedding dimension.
99 """
100 def __init__(self, vocab_size: int, dim: int):
101 super().__init__()
102 self.vocab_size = vocab_size
103 self.dim = dim
104 assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
105 self.part_vocab_size = (vocab_size // world_size)
106 self.vocab_start_idx = rank * self.part_vocab_size
107 self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
108 self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
109
110 def forward(self, x: torch.Tensor) -> torch.Tensor:
111 """
112 Forward pass for parallel embedding layer.
113
114 Args:
115 x (torch.Tensor): Input tensor containing token indices.
116
117 Returns:
118 torch.Tensor: Embedded representations.
119
120 Raises:
121 ValueError: If `world_size` is not defined.
122 """
123 if world_size > 1:
124 mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
125 x = x - self.vocab_start_idx
126 x[mask] = 0
127 y = F.embedding(x, self.weight)
128 if world_size > 1:
129 y[mask] = 0
130 dist.all_reduce(y)
131 return y
132
133
134 def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None,
135 scale_fmt: Optional[str] = None) -> torch.Tensor:
136 """
137 Applies a linear transformation to the incoming data: y = xA^T + b.
138 This function supports specialized implementations based on quantization
139 and tensor formats.
140
141 Args:
142 x (torch.Tensor): The input tensor.
143 weight (torch.Tensor): The weight tensor. It may be quantized and
144 requires dequantization for certain cases.
145 bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
146 scale_fmt (Optional[str]): The format of scaling factors.
147
148 Returns:
149 torch.Tensor: The result of the linear transformation, which may involve
150 quantization-aware computations depending on the input parameters.
151
152 Notes:
153 - If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
154 is used for computation.
155 - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
156 """
157 assert bias is None
158
159 if weight.dtype != torch.float8_e4m3fn:
160 return F.linear(x, weight)
161 else:
162 x, scale = act_quant(x, block_size, scale_fmt)
163 return fp8_gemm(x, scale, weight, weight.scale)
164
165
166 class Linear(nn.Module):
167 """
168 Custom linear layer with support for quantized weights and optional bias.
169
170 Args:
171 in_features (int): Number of input features.
172 out_features (int): Number of output features.
173 bias (bool): Whether to include a bias term. Defaults to False.
174 dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
175 """
176 dtype = torch.bfloat16
177 scale_fmt: Optional[str] = None
178
179 def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
180 super().__init__()
181 self.in_features = in_features
182 self.out_features = out_features
183 self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
184 if self.weight.element_size() == 1:
185 scale_out_features = (out_features + block_size - 1) // block_size
186 scale_in_features = (in_features + block_size - 1) // block_size
187 self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
188 else:
189 self.register_parameter("scale", None)
190 if bias:
191 self.bias = nn.Parameter(torch.empty(out_features))
192 else:
193 self.register_parameter("bias", None)
194
195 def forward(self, x: torch.Tensor) -> torch.Tensor:
196 """
197 Forward pass for the custom linear layer.
198
199 Args:
200 x (torch.Tensor): Input tensor.
201
202 Returns:
203 torch.Tensor: Transformed tensor after linear computation.
204 """
205 return linear(x, self.weight, self.bias, self.scale_fmt)
206
207
208 class ColumnParallelLinear(Linear):
209 """
210 Linear layer with column parallelism, splitting output features across distributed processes.
211
212 Args:
213 in_features (int): Number of input features.
214 out_features (int): Total number of output features.
215 bias (bool): Whether to include a bias term. Defaults to False.
216 dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
217 """
218 def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
219 assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
220 self.part_out_features = out_features // world_size
221 super().__init__(in_features, self.part_out_features, bias, dtype)
222
223 def forward(self, x: torch.Tensor) -> torch.Tensor:
224 """
225 Forward pass for column parallel linear layer.
226
227 Args:
228 x (torch.Tensor): Input tensor.
229
230 Returns:
231 torch.Tensor: Transformed tensor with column-parallel computation.
232 """
233 y = linear(x, self.weight, self.bias, self.scale_fmt)
234 return y
235
236
237 class RowParallelLinear(Linear):
238 """
239 Linear layer with row parallelism, splitting input features across distributed processes.
240
241 Args:
242 in_features (int): Total number of input features.
243 out_features (int): Number of output features.
244 bias (bool): Whether to include a bias term. Defaults to False.
245 dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
246 """
247 def __init__(self, in_features: int, out_features: int, bias: bool = False, reduce_output = True, dtype = None):
248 assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
249 self.part_in_features = in_features // world_size
250 self.reduce_output = reduce_output
251 super().__init__(self.part_in_features, out_features, bias, dtype)
252
253 def forward(self, x: torch.Tensor) -> torch.Tensor:
254 """
255 Forward pass for row parallel linear layer.
256
257 Args:
258 x (torch.Tensor): Input tensor.
259
260 Returns:
261 torch.Tensor: Transformed tensor with row-parallel computation.
262 """
263 y = linear(x, self.weight, None, self.scale_fmt)
264 if self.reduce_output and world_size > 1:
265 y = y.float()
266 dist.all_reduce(y)
267 if self.bias is not None:
268 y += self.bias
269 return y.type_as(x)
270
271
272 class RMSNorm(nn.Module):
273 """
274 Root Mean Square Layer Normalization (RMSNorm).
275
276 Args:
277 dim (int): Dimension of the input tensor.
278 eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
279 """
280 def __init__(self, dim: int, eps: float = 1e-6):
281 super().__init__()
282 self.dim = dim
283 self.eps = eps
284 self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
285
286 def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
287 """
288 Forward pass for RMSNorm.
289
290 Args:
291 x (torch.Tensor): Input tensor.
292
293 Returns:
294 torch.Tensor: Normalized tensor with the same shape as input.
295 """
296 dtype = x.dtype
297 if residual is None:
298 x = x.float()
299 var = x.pow(2).mean(-1, keepdim=True)
300 x = x * torch.rsqrt(var + self.eps)
301 return (self.weight * x).to(dtype)
302 else:
303 x = residual = x.float() + residual.float()
304 var = x.pow(2).mean(-1, keepdim=True)
305 x = x * torch.rsqrt(var + self.eps)
306 return (self.weight * x).to(dtype), residual.to(dtype)
307
308
309 class LayerNorm(nn.Module):
310 """
311 Layer Normalization.
312 """
313 def __init__(self, dim: int, eps: float = 1e-6):
314 super().__init__()
315 self.dim = dim
316 self.eps = eps
317 self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
318 self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
319
320 def forward(self, x: torch.Tensor):
321 return F.layer_norm(x.float(), (self.dim,), self.weight, self.bias, self.eps).type_as(x)
322
323
324 def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
325 """
326 Precomputes frequency-based complex exponential values for rotary positional embeddings.
327
328 Args:
329 args (ModelArgs): Model arguments containing positional embedding parameters.
330
331 Returns:
332 torch.Tensor: Precomputed complex exponential values for positional embeddings.
333 """
334 dim = args.qk_rope_head_dim
335 seqlen = args.max_seq_len
336 beta_fast = args.beta_fast
337 beta_slow = args.beta_slow
338 base = args.rope_theta
339 factor = args.rope_factor
340
341 def find_correction_dim(num_rotations, dim, base, max_seq_len):
342 """
343 Computes the correction dimension for a given number of rotations in the rotary positional embedding.
344
345 Args:
346 num_rotations (float): Number of rotations to compute the correction for.
347 dim (int): Dimensionality of the embedding space.
348 base (float): Base value for the exponential computation.
349 max_seq_len (int): Maximum sequence length.
350
351 Returns:
352 float: The correction dimension based on the input parameters.
353 """
354 return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
355
356 def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
357 """
358 Computes the range of correction dimensions for rotary positional embeddings.
359
360 Args:
361 low_rot (float): Lower bound for the number of rotations.
362 high_rot (float): Upper bound for the number of rotations.
363 dim (int): Dimensionality of the embedding space.
364 base (float): Base value for the exponential computation.
365 max_seq_len (int): Maximum sequence length.
366
367 Returns:
368 Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
369 """
370 low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
371 high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
372 return max(low, 0), min(high, dim-1)
373
374 def linear_ramp_factor(min, max, dim):
375 """
376 Computes a linear ramp function used to smooth values between a minimum and maximum range.
377
378 Args:
379 min (float): Minimum value for the ramp function.
380 max (float): Maximum value for the ramp function.
381 dim (int): Dimensionality of the ramp tensor.
382
383 Returns:
384 torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
385 clamped to the range [0, 1].
386 """
387 if min == max:
388 max += 0.001
389 linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
390 ramp_func = torch.clamp(linear_func, 0, 1)
391 return ramp_func
392
393 freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
394 if seqlen > args.original_seq_len:
395 low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
396 smooth = 1 - linear_ramp_factor(low, high, dim // 2)
397 freqs = freqs / factor * (1 - smooth) + freqs * smooth
398
399 t = torch.arange(seqlen)
400 freqs = torch.outer(t, freqs)
401 freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
402 return freqs_cis
403
404
405 def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, interleaved: bool = True) -> torch.Tensor:
406 """
407 Applies rotary positional embeddings to the input tensor.
408
409 Args:
410 x (torch.Tensor): Input tensor with positional embeddings to be applied.
411 freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
412
413 Returns:
414 torch.Tensor: Tensor with rotary embeddings applied.
415 """
416 dtype = x.dtype
417 shape = x.shape
418 if not interleaved:
419 x = x.view(*shape[:-1], 2, -1).transpose(-1, -2).contiguous()
420 x = torch.view_as_complex(x.float().view(*shape[:-1], -1, 2))
421 freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
422 y = torch.view_as_real(x * freqs_cis).flatten(3)
423 if not interleaved:
424 y = torch.cat([y[..., 0::2], y[..., 1::2]], dim=-1)
425 return y.to(dtype)
426
427
428 def rotate_activation(x: torch.Tensor) -> torch.Tensor:
429 assert x.dtype == torch.bfloat16
430 from fast_hadamard_transform import hadamard_transform
431 hidden_size = x.size(-1)
432 return hadamard_transform(x, scale=hidden_size ** -0.5)
433
434
435 class Indexer(torch.nn.Module):
436 def __init__(self, args: ModelArgs):
437 super().__init__()
438 self.dim: int = args.dim
439 self.n_heads: int = args.index_n_heads
440 self.n_local_heads = args.index_n_heads // world_size
441 self.head_dim: int = args.index_head_dim
442 self.rope_head_dim: int = args.qk_rope_head_dim
443 self.index_topk: int = args.index_topk
444 self.q_lora_rank: int = args.q_lora_rank
445 self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
446 self.wk = Linear(self.dim, self.head_dim)
447 self.k_norm = LayerNorm(self.head_dim)
448 # weights_proj in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenient.
449 self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.float32)
450 self.softmax_scale = self.head_dim ** -0.5
451 self.scale_fmt = args.scale_fmt
452
453 self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.head_dim, dtype=torch.float8_e4m3fn), persistent=False)
454 self.register_buffer("k_scale_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.head_dim // block_size, dtype=torch.float32), persistent=False)
455
456
457 def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
458 bsz, seqlen, _ = x.size()
459 end_pos = start_pos + seqlen
460 q = self.wq_b(qr)
461 q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
462 q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
463 # rope in indexer is not interleaved
464 q_pe = apply_rotary_emb(q_pe, freqs_cis, False)
465 q = torch.cat([q_pe, q_nope], dim=-1)
466 k = self.wk(x)
467 k = self.k_norm(k)
468 k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
469 # rope in indexer is not interleaved
470 k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis, False).squeeze(2)
471 k = torch.cat([k_pe, k_nope], dim=-1)
472 q = rotate_activation(q)
473 k = rotate_activation(k)
474 q_fp8, q_scale = act_quant(q, block_size, self.scale_fmt)
475 k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
476 self.k_cache[:bsz, start_pos:end_pos] = k_fp8
477 self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
478 weights = self.weights_proj(x.float()) * self.n_heads ** -0.5
479 weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
480 index_score = fp8_index(q_fp8.contiguous(), weights, self.k_cache[:bsz, :end_pos].contiguous(), self.k_scale_cache[:bsz, :end_pos].contiguous())
481 if mask is not None:
482 index_score += mask
483 topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1]
484 topk_indices_ = topk_indices.clone()
485 dist.broadcast(topk_indices_, src=0)
486 assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}"
487 return topk_indices
488
489
490 def weight_dequant(weight, scale):
491 shape = weight.shape
492 assert weight.dim() == 2
493 weight = weight.view(shape[0] // block_size, block_size, shape[1] // block_size, block_size).transpose(1, 2).contiguous().view(-1, block_size * block_size)
494 weight = (weight.float() * scale.view(-1, 1).float()).to(torch.get_default_dtype()).view(shape[0] // block_size, shape[1] // block_size, block_size, block_size).transpose(1, 2).contiguous().view(shape)
495 return weight
496
497
498 class MLA(nn.Module):
499 """
500 Multi-Head Latent Attention (MLA) Layer.
501
502 Attributes:
503 dim (int): Dimensionality of the input features.
504 n_heads (int): Number of attention heads.
505 n_local_heads (int): Number of local attention heads for distributed systems.
506 q_lora_rank (int): Rank for low-rank query projection.
507 kv_lora_rank (int): Rank for low-rank key/value projection.
508 qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
509 qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
510 qk_head_dim (int): Total dimensionality of query/key projections.
511 v_head_dim (int): Dimensionality of value projections.
512 softmax_scale (float): Scaling factor for softmax in attention computation.
513 """
514 def __init__(self, args: ModelArgs):
515 super().__init__()
516 self.dim = args.dim
517 self.n_heads = args.n_heads
518 self.n_local_heads = args.n_heads // world_size
519 self.q_lora_rank = args.q_lora_rank
520 self.kv_lora_rank = args.kv_lora_rank
521 self.qk_nope_head_dim = args.qk_nope_head_dim
522 self.qk_rope_head_dim = args.qk_rope_head_dim
523 self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
524 self.v_head_dim = args.v_head_dim
525
526 self.wq_a = Linear(self.dim, self.q_lora_rank)
527 self.q_norm = RMSNorm(self.q_lora_rank)
528 self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
529 self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
530 self.kv_norm = RMSNorm(self.kv_lora_rank)
531 self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
532 self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
533 self.softmax_scale = self.qk_head_dim ** -0.5
534 self.scale_fmt = args.scale_fmt
535 if args.max_seq_len > args.original_seq_len:
536 mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
537 self.softmax_scale = self.softmax_scale * mscale * mscale
538
539 self.indexer = Indexer(args)
540
541 self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
542 self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
543 self.dequant_wkv_b = None
544
545 def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
546 """
547 Forward pass for the Multi-Head Latent Attention (MLA) Layer.
548
549 Args:
550 x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
551 start_pos (int): Starting position in the sequence for caching.
552 freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
553 mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
554
555 Returns:
556 torch.Tensor: Output tensor with the same shape as the input.
557 """
558 bsz, seqlen, _ = x.size()
559 end_pos = start_pos + seqlen
560 qr = self.q_norm(self.wq_a(x))
561 q = self.wq_b(qr)
562 q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
563 q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
564 q_pe = apply_rotary_emb(q_pe, freqs_cis)
565 kv = self.wkv_a(x)
566 kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
567 kv = self.kv_norm(kv)
568 k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
569 # we use fp8 kv cache in actual deployment, so here we simulate the precision by casting kv to fp8 and then back to bf16.
570 kv_fp8, kv_scale = act_quant(kv, block_size, self.scale_fmt)
571 kv = (kv_fp8.view(-1, block_size).float() * kv_scale.view(-1, 1)).to(kv.dtype).view_as(kv)
572 self.kv_cache[:bsz, start_pos:end_pos] = kv
573 self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
574 if mask is not None: # MHA prefill
575 q = torch.cat([q_nope, q_pe], dim=-1)
576 kv = self.wkv_b(kv)
577 kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
578 k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
579 k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
580 scores = torch.einsum("bshd,bthd->bsht", q, k).mul_(self.softmax_scale)
581
582 # indexer
583 topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
584 index_mask = torch.full((bsz, seqlen, seqlen), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
585 index_mask += mask
586 scores += index_mask.unsqueeze(2)
587
588 scores = scores.softmax(dim=-1)
589 x = torch.einsum("bsht,bthd->bshd", scores, v)
590 else: # MQA decode
591 if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
592 self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
593 wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
594 wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
595 q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
596 scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
597 torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
598
599 # indexer
600 topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
601 index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
602 scores += index_mask.unsqueeze(2)
603
604 scores = scores.softmax(dim=-1)
605 x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
606 x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
607 x = self.wo(x.flatten(2))
608 return x
609
610
611 class MLP(nn.Module):
612 """
613 Multi-Layer Perceptron (MLP) used as a feed-forward layer.
614
615 Attributes:
616 w1 (nn.Module): Linear layer for input-to-hidden transformation.
617 w2 (nn.Module): Linear layer for hidden-to-output transformation.
618 w3 (nn.Module): Additional linear layer for feature transformation.
619 """
620 def __init__(self, dim: int, inter_dim: int, reduce_output: bool = True):
621 """
622 Initializes the MLP layer.
623
624 Args:
625 dim (int): Input and output dimensionality.
626 inter_dim (int): Hidden layer dimensionality.
627 """
628 super().__init__()
629 self.w1 = ColumnParallelLinear(dim, inter_dim)
630 self.w2 = RowParallelLinear(inter_dim, dim, reduce_output=reduce_output)
631 self.w3 = ColumnParallelLinear(dim, inter_dim)
632
633 def forward(self, x: torch.Tensor) -> torch.Tensor:
634 """
635 Forward pass for the MLP layer.
636
637 Args:
638 x (torch.Tensor): Input tensor.
639
640 Returns:
641 torch.Tensor: Output tensor after MLP computation.
642 """
643 return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))
644
645
646 class Gate(nn.Module):
647 """
648 Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
649
650 Attributes:
651 dim (int): Dimensionality of input features.
652 topk (int): Number of top experts activated for each input.
653 n_groups (int): Number of groups for routing.
654 topk_groups (int): Number of groups to route inputs to.
655 score_func (str): Scoring function ('softmax' or 'sigmoid').
656 route_scale (float): Scaling factor for routing weights.
657 weight (torch.nn.Parameter): Learnable weights for the gate.
658 bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
659 """
660 def __init__(self, args: ModelArgs):
661 """
662 Initializes the Gate module.
663
664 Args:
665 args (ModelArgs): Model arguments containing gating parameters.
666 """
667 super().__init__()
668 self.dim = args.dim
669 self.topk = args.n_activated_experts
670 self.n_groups = args.n_expert_groups
671 self.topk_groups = args.n_limited_groups
672 self.score_func = args.score_func
673 self.route_scale = args.route_scale
674 self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
675 self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None
676
677 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
678 """
679 Forward pass for the gating mechanism.
680
681 Args:
682 x (torch.Tensor): Input tensor.
683
684 Returns:
685 Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
686 """
687 scores = linear(x.float(), self.weight.float())
688 if self.score_func == "softmax":
689 scores = scores.softmax(dim=-1)
690 else:
691 scores = scores.sigmoid()
692 original_scores = scores
693 if self.bias is not None:
694 scores = scores + self.bias
695 if self.n_groups > 1:
696 scores = scores.view(x.size(0), self.n_groups, -1)
697 if self.bias is None:
698 group_scores = scores.amax(dim=-1)
699 else:
700 group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
701 indices = group_scores.topk(self.topk_groups, dim=-1)[1]
702 mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
703 scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
704 indices = scores.topk(self.topk, dim=-1)[1]
705 weights = original_scores.gather(1, indices)
706 if self.score_func == "sigmoid":
707 weights /= weights.sum(dim=-1, keepdim=True)
708 weights *= self.route_scale
709 return weights, indices
710
711
712 class Expert(nn.Module):
713 """
714 Expert layer for Mixture-of-Experts (MoE) models.
715
716 Attributes:
717 w1 (nn.Module): Linear layer for input-to-hidden transformation.
718 w2 (nn.Module): Linear layer for hidden-to-output transformation.
719 w3 (nn.Module): Additional linear layer for feature transformation.
720 """
721 def __init__(self, dim: int, inter_dim: int):
722 """
723 Initializes the Expert layer.
724
725 Args:
726 dim (int): Input and output dimensionality.
727 inter_dim (int): Hidden layer dimensionality.
728 """
729 super().__init__()
730 self.w1 = Linear(dim, inter_dim)
731 self.w2 = Linear(inter_dim, dim)
732 self.w3 = Linear(dim, inter_dim)
733
734 def forward(self, x: torch.Tensor) -> torch.Tensor:
735 """
736 Forward pass for the Expert layer.
737
738 Args:
739 x (torch.Tensor): Input tensor.
740
741 Returns:
742 torch.Tensor: Output tensor after expert computation.
743 """
744 return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))
745
746
747 class MoE(nn.Module):
748 """
749 Mixture-of-Experts (MoE) module.
750
751 Attributes:
752 dim (int): Dimensionality of input features.
753 n_routed_experts (int): Total number of experts in the model.
754 n_local_experts (int): Number of experts handled locally in distributed systems.
755 n_activated_experts (int): Number of experts activated for each input.
756 gate (nn.Module): Gating mechanism to route inputs to experts.
757 experts (nn.ModuleList): List of expert modules.
758 shared_experts (nn.Module): Shared experts applied to all inputs.
759 """
760 def __init__(self, args: ModelArgs):
761 """
762 Initializes the MoE module.
763
764 Args:
765 args (ModelArgs): Model arguments containing MoE parameters.
766 """
767 super().__init__()
768 self.dim = args.dim
769 assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
770 self.n_routed_experts = args.n_routed_experts
771 self.n_local_experts = args.n_routed_experts // world_size
772 self.n_activated_experts = args.n_activated_experts
773 self.experts_start_idx = rank * self.n_local_experts
774 self.experts_end_idx = self.experts_start_idx + self.n_local_experts
775 self.gate = Gate(args)
776 self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
777 for i in range(self.n_routed_experts)])
778 self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim, reduce_output=False)
779
780 def forward(self, x: torch.Tensor) -> torch.Tensor:
781 """
782 Forward pass for the MoE module.
783
784 Args:
785 x (torch.Tensor): Input tensor.
786
787 Returns:
788 torch.Tensor: Output tensor after expert routing and computation.
789 """
790 shape = x.size()
791 x = x.view(-1, self.dim)
792 weights, indices = self.gate(x)
793 y = torch.zeros_like(x, dtype=torch.float32)
794 counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
795 for i in range(self.experts_start_idx, self.experts_end_idx):
796 if counts[i] == 0:
797 continue
798 expert = self.experts[i]
799 idx, top = torch.where(indices == i)
800 y[idx] += expert(x[idx]) * weights[idx, top, None]
801 y += self.shared_experts(x)
802 if world_size > 1:
803 dist.all_reduce(y)
804 return y.type_as(x).view(shape)
805
806
807 class Block(nn.Module):
808 """
809 Transformer block combining attention and feed-forward layers.
810
811 Attributes:
812 attn (nn.Module): Attention layer (MLA).
813 ffn (nn.Module): Feed-forward network (MLP or MoE).
814 attn_norm (nn.Module): Layer normalization for attention.
815 ffn_norm (nn.Module): Layer normalization for feed-forward network.
816 """
817 def __init__(self, layer_id: int, args: ModelArgs):
818 """
819 Initializes the Transformer block.
820
821 Args:
822 layer_id (int): Layer index in the transformer.
823 args (ModelArgs): Model arguments containing block parameters.
824 """
825 super().__init__()
826 self.attn = MLA(args)
827 self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
828 self.attn_norm = RMSNorm(args.dim)
829 self.ffn_norm = RMSNorm(args.dim)
830
831 def forward(self, x: torch.Tensor, residual: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
832 """
833 Forward pass for the Transformer block.
834
835 Args:
836 x (torch.Tensor): Input tensor.
837 start_pos (int): Starting position in the sequence.
838 freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
839 mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
840
841 Returns:
842 torch.Tensor: Output tensor after block computation.
843 """
844 if residual is None:
845 x, residual = self.attn_norm(x), x
846 else:
847 x, residual = self.attn_norm(x, residual)
848 x = self.attn(x, start_pos, freqs_cis, mask)
849 x, residual = self.ffn_norm(x, residual)
850 x = self.ffn(x)
851 return x, residual
852
853
854 class Transformer(nn.Module):
855 """
856 Transformer model with positional embeddings, multiple layers, and output projection.
857
858 Attributes:
859 max_seq_len (int): Maximum sequence length for the transformer.
860 embed (nn.Module): Embedding layer for input tokens.
861 layers (torch.nn.ModuleList): List of transformer blocks.
862 norm (nn.Module): Layer normalization applied after all blocks.
863 head (nn.Module): Output projection layer mapping to vocabulary size.
864 freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
865 """
866 def __init__(self, args: ModelArgs):
867 """
868 Initializes the Transformer model.
869
870 Args:
871 args (ModelArgs): Model arguments containing transformer parameters.
872 """
873 global world_size, rank
874 world_size = dist.get_world_size() if dist.is_initialized() else 1
875 rank = dist.get_rank() if dist.is_initialized() else 0
876 Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
877 Linear.scale_fmt = args.scale_fmt
878 super().__init__()
879 self.max_seq_len = args.max_seq_len
880 self.embed = ParallelEmbedding(args.vocab_size, args.dim)
881 self.layers = torch.nn.ModuleList()
882 for layer_id in range(args.n_layers):
883 self.layers.append(Block(layer_id, args))
884 self.norm = RMSNorm(args.dim)
885 # lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later.
886 self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.float32)
887 self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
888
889 @torch.inference_mode()
890 def forward(self, tokens: torch.Tensor, start_pos: int = 0):
891 """
892 Forward pass for the Transformer model.
893
894 Args:
895 tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
896 start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0.
897
898 Returns:
899 torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
900 """
901 seqlen = tokens.size(1)
902 freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
903 mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) if seqlen > 1 else None
904 h, residual = self.embed(tokens), None
905 for layer in self.layers:
906 h, residual = layer(h, residual, start_pos, freqs_cis, mask)
907 h, _ = self.norm(h, residual)
908 logits = self.head(h[:, -1].float())
909 if world_size > 1:
910 all_logits = [torch.empty_like(logits) for _ in range(world_size)]
911 dist.all_gather(all_logits, logits)
912 logits = torch.cat(all_logits, dim=-1)
913 return logits
914
915
916 if __name__ == "__main__":
917 torch.set_default_dtype(torch.bfloat16)
918 torch.set_default_device("cuda")
919 torch.manual_seed(0)
920 args = ModelArgs()
921 x = torch.randint(0, args.vocab_size, (2, 128))
922 model = Transformer(args)
923 print(model(x).size())