inference/model.py
| 1 | import math |
| 2 | from dataclasses import dataclass |
| 3 | from typing import Tuple, Optional, Literal |
| 4 | |
| 5 | import torch |
| 6 | from torch import nn |
| 7 | import torch.nn.functional as F |
| 8 | import torch.distributed as dist |
| 9 | |
| 10 | from kernel import act_quant, fp8_gemm, fp8_index |
| 11 | |
| 12 | |
| 13 | world_size = 1 |
| 14 | rank = 0 |
| 15 | block_size = 128 |
| 16 | |
| 17 | @dataclass |
| 18 | class ModelArgs: |
| 19 | """ |
| 20 | Data class for defining model arguments and hyperparameters. |
| 21 | |
| 22 | Attributes: |
| 23 | max_batch_size (int): Maximum batch size. |
| 24 | max_seq_len (int): Maximum sequence length. |
| 25 | dtype (Literal["bf16", "fp8"]): Data type for computations. |
| 26 | scale_fmt (Optional[str]): Format for quantization scale. |
| 27 | vocab_size (int): Vocabulary size. |
| 28 | dim (int): Model dimension. |
| 29 | inter_dim (int): Intermediate dimension for MLP layers. |
| 30 | moe_inter_dim (int): Intermediate dimension for MoE layers. |
| 31 | n_layers (int): Number of transformer layers. |
| 32 | n_dense_layers (int): Number of dense layers in the model. |
| 33 | n_heads (int): Number of attention heads. |
| 34 | n_routed_experts (int): Number of routed experts for MoE layers. |
| 35 | n_shared_experts (int): Number of shared experts for MoE layers. |
| 36 | n_activated_experts (int): Number of activated experts in MoE layers. |
| 37 | n_expert_groups (int): Number of expert groups. |
| 38 | n_limited_groups (int): Number of limited groups for MoE routing. |
| 39 | score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing. |
| 40 | route_scale (float): Scaling factor for routing scores. |
| 41 | q_lora_rank (int): LoRA rank for query projections. |
| 42 | kv_lora_rank (int): LoRA rank for key-value projections. |
| 43 | qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. |
| 44 | qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. |
| 45 | v_head_dim (int): Dimension for value projections. |
| 46 | original_seq_len (int): Original sequence length. |
| 47 | rope_theta (float): Base for rotary positional encoding. |
| 48 | rope_factor (float): Scaling factor for extended sequence lengths. |
| 49 | beta_fast (int): Fast beta correction factor. |
| 50 | beta_slow (int): Slow beta correction factor. |
| 51 | mscale (float): Scaling factor for extended attention. |
| 52 | index_head_dim (int): Dimension for index head. |
| 53 | index_topk (int): Top-k for index head. |
| 54 | """ |
| 55 | max_batch_size: int = 8 |
| 56 | max_seq_len: int = 4096 * 4 |
| 57 | dtype: Literal["bf16", "fp8"] = "bf16" |
| 58 | scale_fmt: Optional[str] = None |
| 59 | vocab_size: int = 102400 |
| 60 | dim: int = 2048 |
| 61 | inter_dim: int = 10944 |
| 62 | moe_inter_dim: int = 1408 |
| 63 | n_layers: int = 27 |
| 64 | n_dense_layers: int = 1 |
| 65 | n_heads: int = 16 |
| 66 | # moe |
| 67 | n_routed_experts: int = 64 |
| 68 | n_shared_experts: int = 2 |
| 69 | n_activated_experts: int = 6 |
| 70 | n_expert_groups: int = 1 |
| 71 | n_limited_groups: int = 1 |
| 72 | score_func: Literal["softmax", "sigmoid"] = "softmax" |
| 73 | route_scale: float = 1. |
| 74 | # mla |
| 75 | q_lora_rank: int = 0 |
| 76 | kv_lora_rank: int = 512 |
| 77 | qk_nope_head_dim: int = 128 |
| 78 | qk_rope_head_dim: int = 64 |
| 79 | v_head_dim: int = 128 |
| 80 | # yarn |
| 81 | original_seq_len: int = 4096 |
| 82 | rope_theta: float = 10000.0 |
| 83 | rope_factor: float = 40 |
| 84 | beta_fast: int = 32 |
| 85 | beta_slow: int = 1 |
| 86 | mscale: float = 1. |
| 87 | # index |
| 88 | index_n_heads: int = 64 |
| 89 | index_head_dim: int = 128 |
| 90 | index_topk: int = 2048 |
| 91 | |
| 92 | class ParallelEmbedding(nn.Module): |
| 93 | """ |
| 94 | Embedding layer with parallelism support across distributed processes. |
| 95 | |
| 96 | Args: |
| 97 | vocab_size (int): Vocabulary size. |
| 98 | dim (int): Embedding dimension. |
| 99 | """ |
| 100 | def __init__(self, vocab_size: int, dim: int): |
| 101 | super().__init__() |
| 102 | self.vocab_size = vocab_size |
| 103 | self.dim = dim |
| 104 | assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})" |
| 105 | self.part_vocab_size = (vocab_size // world_size) |
| 106 | self.vocab_start_idx = rank * self.part_vocab_size |
| 107 | self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size |
| 108 | self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) |
| 109 | |
| 110 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 111 | """ |
| 112 | Forward pass for parallel embedding layer. |
| 113 | |
| 114 | Args: |
| 115 | x (torch.Tensor): Input tensor containing token indices. |
| 116 | |
| 117 | Returns: |
| 118 | torch.Tensor: Embedded representations. |
| 119 | |
| 120 | Raises: |
| 121 | ValueError: If `world_size` is not defined. |
| 122 | """ |
| 123 | if world_size > 1: |
| 124 | mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) |
| 125 | x = x - self.vocab_start_idx |
| 126 | x[mask] = 0 |
| 127 | y = F.embedding(x, self.weight) |
| 128 | if world_size > 1: |
| 129 | y[mask] = 0 |
| 130 | dist.all_reduce(y) |
| 131 | return y |
| 132 | |
| 133 | |
| 134 | def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, |
| 135 | scale_fmt: Optional[str] = None) -> torch.Tensor: |
| 136 | """ |
| 137 | Applies a linear transformation to the incoming data: y = xA^T + b. |
| 138 | This function supports specialized implementations based on quantization |
| 139 | and tensor formats. |
| 140 | |
| 141 | Args: |
| 142 | x (torch.Tensor): The input tensor. |
| 143 | weight (torch.Tensor): The weight tensor. It may be quantized and |
| 144 | requires dequantization for certain cases. |
| 145 | bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None. |
| 146 | scale_fmt (Optional[str]): The format of scaling factors. |
| 147 | |
| 148 | Returns: |
| 149 | torch.Tensor: The result of the linear transformation, which may involve |
| 150 | quantization-aware computations depending on the input parameters. |
| 151 | |
| 152 | Notes: |
| 153 | - If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version |
| 154 | is used for computation. |
| 155 | - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation. |
| 156 | """ |
| 157 | assert bias is None |
| 158 | |
| 159 | if weight.dtype != torch.float8_e4m3fn: |
| 160 | return F.linear(x, weight) |
| 161 | else: |
| 162 | x, scale = act_quant(x, block_size, scale_fmt) |
| 163 | return fp8_gemm(x, scale, weight, weight.scale) |
| 164 | |
| 165 | |
| 166 | class Linear(nn.Module): |
| 167 | """ |
| 168 | Custom linear layer with support for quantized weights and optional bias. |
| 169 | |
| 170 | Args: |
| 171 | in_features (int): Number of input features. |
| 172 | out_features (int): Number of output features. |
| 173 | bias (bool): Whether to include a bias term. Defaults to False. |
| 174 | dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. |
| 175 | """ |
| 176 | dtype = torch.bfloat16 |
| 177 | scale_fmt: Optional[str] = None |
| 178 | |
| 179 | def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): |
| 180 | super().__init__() |
| 181 | self.in_features = in_features |
| 182 | self.out_features = out_features |
| 183 | self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)) |
| 184 | if self.weight.element_size() == 1: |
| 185 | scale_out_features = (out_features + block_size - 1) // block_size |
| 186 | scale_in_features = (in_features + block_size - 1) // block_size |
| 187 | self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) |
| 188 | else: |
| 189 | self.register_parameter("scale", None) |
| 190 | if bias: |
| 191 | self.bias = nn.Parameter(torch.empty(out_features)) |
| 192 | else: |
| 193 | self.register_parameter("bias", None) |
| 194 | |
| 195 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 196 | """ |
| 197 | Forward pass for the custom linear layer. |
| 198 | |
| 199 | Args: |
| 200 | x (torch.Tensor): Input tensor. |
| 201 | |
| 202 | Returns: |
| 203 | torch.Tensor: Transformed tensor after linear computation. |
| 204 | """ |
| 205 | return linear(x, self.weight, self.bias, self.scale_fmt) |
| 206 | |
| 207 | |
| 208 | class ColumnParallelLinear(Linear): |
| 209 | """ |
| 210 | Linear layer with column parallelism, splitting output features across distributed processes. |
| 211 | |
| 212 | Args: |
| 213 | in_features (int): Number of input features. |
| 214 | out_features (int): Total number of output features. |
| 215 | bias (bool): Whether to include a bias term. Defaults to False. |
| 216 | dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. |
| 217 | """ |
| 218 | def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): |
| 219 | assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" |
| 220 | self.part_out_features = out_features // world_size |
| 221 | super().__init__(in_features, self.part_out_features, bias, dtype) |
| 222 | |
| 223 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 224 | """ |
| 225 | Forward pass for column parallel linear layer. |
| 226 | |
| 227 | Args: |
| 228 | x (torch.Tensor): Input tensor. |
| 229 | |
| 230 | Returns: |
| 231 | torch.Tensor: Transformed tensor with column-parallel computation. |
| 232 | """ |
| 233 | y = linear(x, self.weight, self.bias, self.scale_fmt) |
| 234 | return y |
| 235 | |
| 236 | |
| 237 | class RowParallelLinear(Linear): |
| 238 | """ |
| 239 | Linear layer with row parallelism, splitting input features across distributed processes. |
| 240 | |
| 241 | Args: |
| 242 | in_features (int): Total number of input features. |
| 243 | out_features (int): Number of output features. |
| 244 | bias (bool): Whether to include a bias term. Defaults to False. |
| 245 | dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. |
| 246 | """ |
| 247 | def __init__(self, in_features: int, out_features: int, bias: bool = False, reduce_output = True, dtype = None): |
| 248 | assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" |
| 249 | self.part_in_features = in_features // world_size |
| 250 | self.reduce_output = reduce_output |
| 251 | super().__init__(self.part_in_features, out_features, bias, dtype) |
| 252 | |
| 253 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 254 | """ |
| 255 | Forward pass for row parallel linear layer. |
| 256 | |
| 257 | Args: |
| 258 | x (torch.Tensor): Input tensor. |
| 259 | |
| 260 | Returns: |
| 261 | torch.Tensor: Transformed tensor with row-parallel computation. |
| 262 | """ |
| 263 | y = linear(x, self.weight, None, self.scale_fmt) |
| 264 | if self.reduce_output and world_size > 1: |
| 265 | y = y.float() |
| 266 | dist.all_reduce(y) |
| 267 | if self.bias is not None: |
| 268 | y += self.bias |
| 269 | return y.type_as(x) |
| 270 | |
| 271 | |
| 272 | class RMSNorm(nn.Module): |
| 273 | """ |
| 274 | Root Mean Square Layer Normalization (RMSNorm). |
| 275 | |
| 276 | Args: |
| 277 | dim (int): Dimension of the input tensor. |
| 278 | eps (float): Epsilon value for numerical stability. Defaults to 1e-6. |
| 279 | """ |
| 280 | def __init__(self, dim: int, eps: float = 1e-6): |
| 281 | super().__init__() |
| 282 | self.dim = dim |
| 283 | self.eps = eps |
| 284 | self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) |
| 285 | |
| 286 | def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None): |
| 287 | """ |
| 288 | Forward pass for RMSNorm. |
| 289 | |
| 290 | Args: |
| 291 | x (torch.Tensor): Input tensor. |
| 292 | |
| 293 | Returns: |
| 294 | torch.Tensor: Normalized tensor with the same shape as input. |
| 295 | """ |
| 296 | dtype = x.dtype |
| 297 | if residual is None: |
| 298 | x = x.float() |
| 299 | var = x.pow(2).mean(-1, keepdim=True) |
| 300 | x = x * torch.rsqrt(var + self.eps) |
| 301 | return (self.weight * x).to(dtype) |
| 302 | else: |
| 303 | x = residual = x.float() + residual.float() |
| 304 | var = x.pow(2).mean(-1, keepdim=True) |
| 305 | x = x * torch.rsqrt(var + self.eps) |
| 306 | return (self.weight * x).to(dtype), residual.to(dtype) |
| 307 | |
| 308 | |
| 309 | class LayerNorm(nn.Module): |
| 310 | """ |
| 311 | Layer Normalization. |
| 312 | """ |
| 313 | def __init__(self, dim: int, eps: float = 1e-6): |
| 314 | super().__init__() |
| 315 | self.dim = dim |
| 316 | self.eps = eps |
| 317 | self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) |
| 318 | self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) |
| 319 | |
| 320 | def forward(self, x: torch.Tensor): |
| 321 | return F.layer_norm(x.float(), (self.dim,), self.weight, self.bias, self.eps).type_as(x) |
| 322 | |
| 323 | |
| 324 | def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: |
| 325 | """ |
| 326 | Precomputes frequency-based complex exponential values for rotary positional embeddings. |
| 327 | |
| 328 | Args: |
| 329 | args (ModelArgs): Model arguments containing positional embedding parameters. |
| 330 | |
| 331 | Returns: |
| 332 | torch.Tensor: Precomputed complex exponential values for positional embeddings. |
| 333 | """ |
| 334 | dim = args.qk_rope_head_dim |
| 335 | seqlen = args.max_seq_len |
| 336 | beta_fast = args.beta_fast |
| 337 | beta_slow = args.beta_slow |
| 338 | base = args.rope_theta |
| 339 | factor = args.rope_factor |
| 340 | |
| 341 | def find_correction_dim(num_rotations, dim, base, max_seq_len): |
| 342 | """ |
| 343 | Computes the correction dimension for a given number of rotations in the rotary positional embedding. |
| 344 | |
| 345 | Args: |
| 346 | num_rotations (float): Number of rotations to compute the correction for. |
| 347 | dim (int): Dimensionality of the embedding space. |
| 348 | base (float): Base value for the exponential computation. |
| 349 | max_seq_len (int): Maximum sequence length. |
| 350 | |
| 351 | Returns: |
| 352 | float: The correction dimension based on the input parameters. |
| 353 | """ |
| 354 | return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) |
| 355 | |
| 356 | def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): |
| 357 | """ |
| 358 | Computes the range of correction dimensions for rotary positional embeddings. |
| 359 | |
| 360 | Args: |
| 361 | low_rot (float): Lower bound for the number of rotations. |
| 362 | high_rot (float): Upper bound for the number of rotations. |
| 363 | dim (int): Dimensionality of the embedding space. |
| 364 | base (float): Base value for the exponential computation. |
| 365 | max_seq_len (int): Maximum sequence length. |
| 366 | |
| 367 | Returns: |
| 368 | Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. |
| 369 | """ |
| 370 | low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) |
| 371 | high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) |
| 372 | return max(low, 0), min(high, dim-1) |
| 373 | |
| 374 | def linear_ramp_factor(min, max, dim): |
| 375 | """ |
| 376 | Computes a linear ramp function used to smooth values between a minimum and maximum range. |
| 377 | |
| 378 | Args: |
| 379 | min (float): Minimum value for the ramp function. |
| 380 | max (float): Maximum value for the ramp function. |
| 381 | dim (int): Dimensionality of the ramp tensor. |
| 382 | |
| 383 | Returns: |
| 384 | torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, |
| 385 | clamped to the range [0, 1]. |
| 386 | """ |
| 387 | if min == max: |
| 388 | max += 0.001 |
| 389 | linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) |
| 390 | ramp_func = torch.clamp(linear_func, 0, 1) |
| 391 | return ramp_func |
| 392 | |
| 393 | freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
| 394 | if seqlen > args.original_seq_len: |
| 395 | low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len) |
| 396 | smooth = 1 - linear_ramp_factor(low, high, dim // 2) |
| 397 | freqs = freqs / factor * (1 - smooth) + freqs * smooth |
| 398 | |
| 399 | t = torch.arange(seqlen) |
| 400 | freqs = torch.outer(t, freqs) |
| 401 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| 402 | return freqs_cis |
| 403 | |
| 404 | |
| 405 | def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, interleaved: bool = True) -> torch.Tensor: |
| 406 | """ |
| 407 | Applies rotary positional embeddings to the input tensor. |
| 408 | |
| 409 | Args: |
| 410 | x (torch.Tensor): Input tensor with positional embeddings to be applied. |
| 411 | freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. |
| 412 | |
| 413 | Returns: |
| 414 | torch.Tensor: Tensor with rotary embeddings applied. |
| 415 | """ |
| 416 | dtype = x.dtype |
| 417 | shape = x.shape |
| 418 | if not interleaved: |
| 419 | x = x.view(*shape[:-1], 2, -1).transpose(-1, -2).contiguous() |
| 420 | x = torch.view_as_complex(x.float().view(*shape[:-1], -1, 2)) |
| 421 | freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) |
| 422 | y = torch.view_as_real(x * freqs_cis).flatten(3) |
| 423 | if not interleaved: |
| 424 | y = torch.cat([y[..., 0::2], y[..., 1::2]], dim=-1) |
| 425 | return y.to(dtype) |
| 426 | |
| 427 | |
| 428 | def rotate_activation(x: torch.Tensor) -> torch.Tensor: |
| 429 | assert x.dtype == torch.bfloat16 |
| 430 | from fast_hadamard_transform import hadamard_transform |
| 431 | hidden_size = x.size(-1) |
| 432 | return hadamard_transform(x, scale=hidden_size ** -0.5) |
| 433 | |
| 434 | |
| 435 | class Indexer(torch.nn.Module): |
| 436 | def __init__(self, args: ModelArgs): |
| 437 | super().__init__() |
| 438 | self.dim: int = args.dim |
| 439 | self.n_heads: int = args.index_n_heads |
| 440 | self.n_local_heads = args.index_n_heads // world_size |
| 441 | self.head_dim: int = args.index_head_dim |
| 442 | self.rope_head_dim: int = args.qk_rope_head_dim |
| 443 | self.index_topk: int = args.index_topk |
| 444 | self.q_lora_rank: int = args.q_lora_rank |
| 445 | self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim) |
| 446 | self.wk = Linear(self.dim, self.head_dim) |
| 447 | self.k_norm = LayerNorm(self.head_dim) |
| 448 | # weights_proj in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenient. |
| 449 | self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.float32) |
| 450 | self.softmax_scale = self.head_dim ** -0.5 |
| 451 | self.scale_fmt = args.scale_fmt |
| 452 | |
| 453 | self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.head_dim, dtype=torch.float8_e4m3fn), persistent=False) |
| 454 | self.register_buffer("k_scale_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.head_dim // block_size, dtype=torch.float32), persistent=False) |
| 455 | |
| 456 | |
| 457 | def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): |
| 458 | bsz, seqlen, _ = x.size() |
| 459 | end_pos = start_pos + seqlen |
| 460 | q = self.wq_b(qr) |
| 461 | q = q.view(bsz, seqlen, self.n_heads, self.head_dim) |
| 462 | q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1) |
| 463 | # rope in indexer is not interleaved |
| 464 | q_pe = apply_rotary_emb(q_pe, freqs_cis, False) |
| 465 | q = torch.cat([q_pe, q_nope], dim=-1) |
| 466 | k = self.wk(x) |
| 467 | k = self.k_norm(k) |
| 468 | k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1) |
| 469 | # rope in indexer is not interleaved |
| 470 | k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis, False).squeeze(2) |
| 471 | k = torch.cat([k_pe, k_nope], dim=-1) |
| 472 | q = rotate_activation(q) |
| 473 | k = rotate_activation(k) |
| 474 | q_fp8, q_scale = act_quant(q, block_size, self.scale_fmt) |
| 475 | k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt) |
| 476 | self.k_cache[:bsz, start_pos:end_pos] = k_fp8 |
| 477 | self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale |
| 478 | weights = self.weights_proj(x.float()) * self.n_heads ** -0.5 |
| 479 | weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale |
| 480 | index_score = fp8_index(q_fp8.contiguous(), weights, self.k_cache[:bsz, :end_pos].contiguous(), self.k_scale_cache[:bsz, :end_pos].contiguous()) |
| 481 | if mask is not None: |
| 482 | index_score += mask |
| 483 | topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1] |
| 484 | topk_indices_ = topk_indices.clone() |
| 485 | dist.broadcast(topk_indices_, src=0) |
| 486 | assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}" |
| 487 | return topk_indices |
| 488 | |
| 489 | |
| 490 | def weight_dequant(weight, scale): |
| 491 | shape = weight.shape |
| 492 | assert weight.dim() == 2 |
| 493 | weight = weight.view(shape[0] // block_size, block_size, shape[1] // block_size, block_size).transpose(1, 2).contiguous().view(-1, block_size * block_size) |
| 494 | weight = (weight.float() * scale.view(-1, 1).float()).to(torch.get_default_dtype()).view(shape[0] // block_size, shape[1] // block_size, block_size, block_size).transpose(1, 2).contiguous().view(shape) |
| 495 | return weight |
| 496 | |
| 497 | |
| 498 | class MLA(nn.Module): |
| 499 | """ |
| 500 | Multi-Head Latent Attention (MLA) Layer. |
| 501 | |
| 502 | Attributes: |
| 503 | dim (int): Dimensionality of the input features. |
| 504 | n_heads (int): Number of attention heads. |
| 505 | n_local_heads (int): Number of local attention heads for distributed systems. |
| 506 | q_lora_rank (int): Rank for low-rank query projection. |
| 507 | kv_lora_rank (int): Rank for low-rank key/value projection. |
| 508 | qk_nope_head_dim (int): Dimensionality of non-positional query/key projections. |
| 509 | qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections. |
| 510 | qk_head_dim (int): Total dimensionality of query/key projections. |
| 511 | v_head_dim (int): Dimensionality of value projections. |
| 512 | softmax_scale (float): Scaling factor for softmax in attention computation. |
| 513 | """ |
| 514 | def __init__(self, args: ModelArgs): |
| 515 | super().__init__() |
| 516 | self.dim = args.dim |
| 517 | self.n_heads = args.n_heads |
| 518 | self.n_local_heads = args.n_heads // world_size |
| 519 | self.q_lora_rank = args.q_lora_rank |
| 520 | self.kv_lora_rank = args.kv_lora_rank |
| 521 | self.qk_nope_head_dim = args.qk_nope_head_dim |
| 522 | self.qk_rope_head_dim = args.qk_rope_head_dim |
| 523 | self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim |
| 524 | self.v_head_dim = args.v_head_dim |
| 525 | |
| 526 | self.wq_a = Linear(self.dim, self.q_lora_rank) |
| 527 | self.q_norm = RMSNorm(self.q_lora_rank) |
| 528 | self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim) |
| 529 | self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) |
| 530 | self.kv_norm = RMSNorm(self.kv_lora_rank) |
| 531 | self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) |
| 532 | self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim) |
| 533 | self.softmax_scale = self.qk_head_dim ** -0.5 |
| 534 | self.scale_fmt = args.scale_fmt |
| 535 | if args.max_seq_len > args.original_seq_len: |
| 536 | mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 |
| 537 | self.softmax_scale = self.softmax_scale * mscale * mscale |
| 538 | |
| 539 | self.indexer = Indexer(args) |
| 540 | |
| 541 | self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False) |
| 542 | self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False) |
| 543 | self.dequant_wkv_b = None |
| 544 | |
| 545 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): |
| 546 | """ |
| 547 | Forward pass for the Multi-Head Latent Attention (MLA) Layer. |
| 548 | |
| 549 | Args: |
| 550 | x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). |
| 551 | start_pos (int): Starting position in the sequence for caching. |
| 552 | freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. |
| 553 | mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention. |
| 554 | |
| 555 | Returns: |
| 556 | torch.Tensor: Output tensor with the same shape as the input. |
| 557 | """ |
| 558 | bsz, seqlen, _ = x.size() |
| 559 | end_pos = start_pos + seqlen |
| 560 | qr = self.q_norm(self.wq_a(x)) |
| 561 | q = self.wq_b(qr) |
| 562 | q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) |
| 563 | q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) |
| 564 | q_pe = apply_rotary_emb(q_pe, freqs_cis) |
| 565 | kv = self.wkv_a(x) |
| 566 | kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) |
| 567 | kv = self.kv_norm(kv) |
| 568 | k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) |
| 569 | # we use fp8 kv cache in actual deployment, so here we simulate the precision by casting kv to fp8 and then back to bf16. |
| 570 | kv_fp8, kv_scale = act_quant(kv, block_size, self.scale_fmt) |
| 571 | kv = (kv_fp8.view(-1, block_size).float() * kv_scale.view(-1, 1)).to(kv.dtype).view_as(kv) |
| 572 | self.kv_cache[:bsz, start_pos:end_pos] = kv |
| 573 | self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) |
| 574 | if mask is not None: # MHA prefill |
| 575 | q = torch.cat([q_nope, q_pe], dim=-1) |
| 576 | kv = self.wkv_b(kv) |
| 577 | kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) |
| 578 | k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) |
| 579 | k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) |
| 580 | scores = torch.einsum("bshd,bthd->bsht", q, k).mul_(self.softmax_scale) |
| 581 | |
| 582 | # indexer |
| 583 | topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask) |
| 584 | index_mask = torch.full((bsz, seqlen, seqlen), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0) |
| 585 | index_mask += mask |
| 586 | scores += index_mask.unsqueeze(2) |
| 587 | |
| 588 | scores = scores.softmax(dim=-1) |
| 589 | x = torch.einsum("bsht,bthd->bshd", scores, v) |
| 590 | else: # MQA decode |
| 591 | if self.dequant_wkv_b is None and self.wkv_b.scale is not None: |
| 592 | self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale) |
| 593 | wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b |
| 594 | wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) |
| 595 | q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) |
| 596 | scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) + |
| 597 | torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale |
| 598 | |
| 599 | # indexer |
| 600 | topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask) |
| 601 | index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0) |
| 602 | scores += index_mask.unsqueeze(2) |
| 603 | |
| 604 | scores = scores.softmax(dim=-1) |
| 605 | x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) |
| 606 | x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) |
| 607 | x = self.wo(x.flatten(2)) |
| 608 | return x |
| 609 | |
| 610 | |
| 611 | class MLP(nn.Module): |
| 612 | """ |
| 613 | Multi-Layer Perceptron (MLP) used as a feed-forward layer. |
| 614 | |
| 615 | Attributes: |
| 616 | w1 (nn.Module): Linear layer for input-to-hidden transformation. |
| 617 | w2 (nn.Module): Linear layer for hidden-to-output transformation. |
| 618 | w3 (nn.Module): Additional linear layer for feature transformation. |
| 619 | """ |
| 620 | def __init__(self, dim: int, inter_dim: int, reduce_output: bool = True): |
| 621 | """ |
| 622 | Initializes the MLP layer. |
| 623 | |
| 624 | Args: |
| 625 | dim (int): Input and output dimensionality. |
| 626 | inter_dim (int): Hidden layer dimensionality. |
| 627 | """ |
| 628 | super().__init__() |
| 629 | self.w1 = ColumnParallelLinear(dim, inter_dim) |
| 630 | self.w2 = RowParallelLinear(inter_dim, dim, reduce_output=reduce_output) |
| 631 | self.w3 = ColumnParallelLinear(dim, inter_dim) |
| 632 | |
| 633 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 634 | """ |
| 635 | Forward pass for the MLP layer. |
| 636 | |
| 637 | Args: |
| 638 | x (torch.Tensor): Input tensor. |
| 639 | |
| 640 | Returns: |
| 641 | torch.Tensor: Output tensor after MLP computation. |
| 642 | """ |
| 643 | return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x)) |
| 644 | |
| 645 | |
| 646 | class Gate(nn.Module): |
| 647 | """ |
| 648 | Gating mechanism for routing inputs in a mixture-of-experts (MoE) model. |
| 649 | |
| 650 | Attributes: |
| 651 | dim (int): Dimensionality of input features. |
| 652 | topk (int): Number of top experts activated for each input. |
| 653 | n_groups (int): Number of groups for routing. |
| 654 | topk_groups (int): Number of groups to route inputs to. |
| 655 | score_func (str): Scoring function ('softmax' or 'sigmoid'). |
| 656 | route_scale (float): Scaling factor for routing weights. |
| 657 | weight (torch.nn.Parameter): Learnable weights for the gate. |
| 658 | bias (Optional[torch.nn.Parameter]): Optional bias term for the gate. |
| 659 | """ |
| 660 | def __init__(self, args: ModelArgs): |
| 661 | """ |
| 662 | Initializes the Gate module. |
| 663 | |
| 664 | Args: |
| 665 | args (ModelArgs): Model arguments containing gating parameters. |
| 666 | """ |
| 667 | super().__init__() |
| 668 | self.dim = args.dim |
| 669 | self.topk = args.n_activated_experts |
| 670 | self.n_groups = args.n_expert_groups |
| 671 | self.topk_groups = args.n_limited_groups |
| 672 | self.score_func = args.score_func |
| 673 | self.route_scale = args.route_scale |
| 674 | self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) |
| 675 | self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None |
| 676 | |
| 677 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| 678 | """ |
| 679 | Forward pass for the gating mechanism. |
| 680 | |
| 681 | Args: |
| 682 | x (torch.Tensor): Input tensor. |
| 683 | |
| 684 | Returns: |
| 685 | Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices. |
| 686 | """ |
| 687 | scores = linear(x.float(), self.weight.float()) |
| 688 | if self.score_func == "softmax": |
| 689 | scores = scores.softmax(dim=-1) |
| 690 | else: |
| 691 | scores = scores.sigmoid() |
| 692 | original_scores = scores |
| 693 | if self.bias is not None: |
| 694 | scores = scores + self.bias |
| 695 | if self.n_groups > 1: |
| 696 | scores = scores.view(x.size(0), self.n_groups, -1) |
| 697 | if self.bias is None: |
| 698 | group_scores = scores.amax(dim=-1) |
| 699 | else: |
| 700 | group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) |
| 701 | indices = group_scores.topk(self.topk_groups, dim=-1)[1] |
| 702 | mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False) |
| 703 | scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1) |
| 704 | indices = scores.topk(self.topk, dim=-1)[1] |
| 705 | weights = original_scores.gather(1, indices) |
| 706 | if self.score_func == "sigmoid": |
| 707 | weights /= weights.sum(dim=-1, keepdim=True) |
| 708 | weights *= self.route_scale |
| 709 | return weights, indices |
| 710 | |
| 711 | |
| 712 | class Expert(nn.Module): |
| 713 | """ |
| 714 | Expert layer for Mixture-of-Experts (MoE) models. |
| 715 | |
| 716 | Attributes: |
| 717 | w1 (nn.Module): Linear layer for input-to-hidden transformation. |
| 718 | w2 (nn.Module): Linear layer for hidden-to-output transformation. |
| 719 | w3 (nn.Module): Additional linear layer for feature transformation. |
| 720 | """ |
| 721 | def __init__(self, dim: int, inter_dim: int): |
| 722 | """ |
| 723 | Initializes the Expert layer. |
| 724 | |
| 725 | Args: |
| 726 | dim (int): Input and output dimensionality. |
| 727 | inter_dim (int): Hidden layer dimensionality. |
| 728 | """ |
| 729 | super().__init__() |
| 730 | self.w1 = Linear(dim, inter_dim) |
| 731 | self.w2 = Linear(inter_dim, dim) |
| 732 | self.w3 = Linear(dim, inter_dim) |
| 733 | |
| 734 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 735 | """ |
| 736 | Forward pass for the Expert layer. |
| 737 | |
| 738 | Args: |
| 739 | x (torch.Tensor): Input tensor. |
| 740 | |
| 741 | Returns: |
| 742 | torch.Tensor: Output tensor after expert computation. |
| 743 | """ |
| 744 | return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x)) |
| 745 | |
| 746 | |
| 747 | class MoE(nn.Module): |
| 748 | """ |
| 749 | Mixture-of-Experts (MoE) module. |
| 750 | |
| 751 | Attributes: |
| 752 | dim (int): Dimensionality of input features. |
| 753 | n_routed_experts (int): Total number of experts in the model. |
| 754 | n_local_experts (int): Number of experts handled locally in distributed systems. |
| 755 | n_activated_experts (int): Number of experts activated for each input. |
| 756 | gate (nn.Module): Gating mechanism to route inputs to experts. |
| 757 | experts (nn.ModuleList): List of expert modules. |
| 758 | shared_experts (nn.Module): Shared experts applied to all inputs. |
| 759 | """ |
| 760 | def __init__(self, args: ModelArgs): |
| 761 | """ |
| 762 | Initializes the MoE module. |
| 763 | |
| 764 | Args: |
| 765 | args (ModelArgs): Model arguments containing MoE parameters. |
| 766 | """ |
| 767 | super().__init__() |
| 768 | self.dim = args.dim |
| 769 | assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})" |
| 770 | self.n_routed_experts = args.n_routed_experts |
| 771 | self.n_local_experts = args.n_routed_experts // world_size |
| 772 | self.n_activated_experts = args.n_activated_experts |
| 773 | self.experts_start_idx = rank * self.n_local_experts |
| 774 | self.experts_end_idx = self.experts_start_idx + self.n_local_experts |
| 775 | self.gate = Gate(args) |
| 776 | self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None |
| 777 | for i in range(self.n_routed_experts)]) |
| 778 | self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim, reduce_output=False) |
| 779 | |
| 780 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 781 | """ |
| 782 | Forward pass for the MoE module. |
| 783 | |
| 784 | Args: |
| 785 | x (torch.Tensor): Input tensor. |
| 786 | |
| 787 | Returns: |
| 788 | torch.Tensor: Output tensor after expert routing and computation. |
| 789 | """ |
| 790 | shape = x.size() |
| 791 | x = x.view(-1, self.dim) |
| 792 | weights, indices = self.gate(x) |
| 793 | y = torch.zeros_like(x, dtype=torch.float32) |
| 794 | counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() |
| 795 | for i in range(self.experts_start_idx, self.experts_end_idx): |
| 796 | if counts[i] == 0: |
| 797 | continue |
| 798 | expert = self.experts[i] |
| 799 | idx, top = torch.where(indices == i) |
| 800 | y[idx] += expert(x[idx]) * weights[idx, top, None] |
| 801 | y += self.shared_experts(x) |
| 802 | if world_size > 1: |
| 803 | dist.all_reduce(y) |
| 804 | return y.type_as(x).view(shape) |
| 805 | |
| 806 | |
| 807 | class Block(nn.Module): |
| 808 | """ |
| 809 | Transformer block combining attention and feed-forward layers. |
| 810 | |
| 811 | Attributes: |
| 812 | attn (nn.Module): Attention layer (MLA). |
| 813 | ffn (nn.Module): Feed-forward network (MLP or MoE). |
| 814 | attn_norm (nn.Module): Layer normalization for attention. |
| 815 | ffn_norm (nn.Module): Layer normalization for feed-forward network. |
| 816 | """ |
| 817 | def __init__(self, layer_id: int, args: ModelArgs): |
| 818 | """ |
| 819 | Initializes the Transformer block. |
| 820 | |
| 821 | Args: |
| 822 | layer_id (int): Layer index in the transformer. |
| 823 | args (ModelArgs): Model arguments containing block parameters. |
| 824 | """ |
| 825 | super().__init__() |
| 826 | self.attn = MLA(args) |
| 827 | self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) |
| 828 | self.attn_norm = RMSNorm(args.dim) |
| 829 | self.ffn_norm = RMSNorm(args.dim) |
| 830 | |
| 831 | def forward(self, x: torch.Tensor, residual: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: |
| 832 | """ |
| 833 | Forward pass for the Transformer block. |
| 834 | |
| 835 | Args: |
| 836 | x (torch.Tensor): Input tensor. |
| 837 | start_pos (int): Starting position in the sequence. |
| 838 | freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. |
| 839 | mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention. |
| 840 | |
| 841 | Returns: |
| 842 | torch.Tensor: Output tensor after block computation. |
| 843 | """ |
| 844 | if residual is None: |
| 845 | x, residual = self.attn_norm(x), x |
| 846 | else: |
| 847 | x, residual = self.attn_norm(x, residual) |
| 848 | x = self.attn(x, start_pos, freqs_cis, mask) |
| 849 | x, residual = self.ffn_norm(x, residual) |
| 850 | x = self.ffn(x) |
| 851 | return x, residual |
| 852 | |
| 853 | |
| 854 | class Transformer(nn.Module): |
| 855 | """ |
| 856 | Transformer model with positional embeddings, multiple layers, and output projection. |
| 857 | |
| 858 | Attributes: |
| 859 | max_seq_len (int): Maximum sequence length for the transformer. |
| 860 | embed (nn.Module): Embedding layer for input tokens. |
| 861 | layers (torch.nn.ModuleList): List of transformer blocks. |
| 862 | norm (nn.Module): Layer normalization applied after all blocks. |
| 863 | head (nn.Module): Output projection layer mapping to vocabulary size. |
| 864 | freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. |
| 865 | """ |
| 866 | def __init__(self, args: ModelArgs): |
| 867 | """ |
| 868 | Initializes the Transformer model. |
| 869 | |
| 870 | Args: |
| 871 | args (ModelArgs): Model arguments containing transformer parameters. |
| 872 | """ |
| 873 | global world_size, rank |
| 874 | world_size = dist.get_world_size() if dist.is_initialized() else 1 |
| 875 | rank = dist.get_rank() if dist.is_initialized() else 0 |
| 876 | Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 |
| 877 | Linear.scale_fmt = args.scale_fmt |
| 878 | super().__init__() |
| 879 | self.max_seq_len = args.max_seq_len |
| 880 | self.embed = ParallelEmbedding(args.vocab_size, args.dim) |
| 881 | self.layers = torch.nn.ModuleList() |
| 882 | for layer_id in range(args.n_layers): |
| 883 | self.layers.append(Block(layer_id, args)) |
| 884 | self.norm = RMSNorm(args.dim) |
| 885 | # lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later. |
| 886 | self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.float32) |
| 887 | self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) |
| 888 | |
| 889 | @torch.inference_mode() |
| 890 | def forward(self, tokens: torch.Tensor, start_pos: int = 0): |
| 891 | """ |
| 892 | Forward pass for the Transformer model. |
| 893 | |
| 894 | Args: |
| 895 | tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len). |
| 896 | start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0. |
| 897 | |
| 898 | Returns: |
| 899 | torch.Tensor: Logits tensor of shape (batch_size, vocab_size). |
| 900 | """ |
| 901 | seqlen = tokens.size(1) |
| 902 | freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] |
| 903 | mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) if seqlen > 1 else None |
| 904 | h, residual = self.embed(tokens), None |
| 905 | for layer in self.layers: |
| 906 | h, residual = layer(h, residual, start_pos, freqs_cis, mask) |
| 907 | h, _ = self.norm(h, residual) |
| 908 | logits = self.head(h[:, -1].float()) |
| 909 | if world_size > 1: |
| 910 | all_logits = [torch.empty_like(logits) for _ in range(world_size)] |
| 911 | dist.all_gather(all_logits, logits) |
| 912 | logits = torch.cat(all_logits, dim=-1) |
| 913 | return logits |
| 914 | |
| 915 | |
| 916 | if __name__ == "__main__": |
| 917 | torch.set_default_dtype(torch.bfloat16) |
| 918 | torch.set_default_device("cuda") |
| 919 | torch.manual_seed(0) |
| 920 | args = ModelArgs() |
| 921 | x = torch.randint(0, args.vocab_size, (2, 128)) |
| 922 | model = Transformer(args) |
| 923 | print(model(x).size()) |