modeling_deepseek.py
74.0 KB · 1849 lines · python Raw
1 # coding=utf-8
2 # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
3 #
4 # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 # and OPT implementations in this library. It has been modified from its
6 # original forms to accommodate minor architectural differences compared
7 # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 #
9 # Licensed under the Apache License, Version 2.0 (the "License");
10 # you may not use this file except in compliance with the License.
11 # You may obtain a copy of the License at
12 #
13 # http://www.apache.org/licenses/LICENSE-2.0
14 #
15 # Unless required by applicable law or agreed to in writing, software
16 # distributed under the License is distributed on an "AS IS" BASIS,
17 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 # See the License for the specific language governing permissions and
19 # limitations under the License.
20 """ PyTorch DeepSeek model."""
21 import math
22 import warnings
23 from typing import List, Optional, Tuple, Union
24
25 import torch
26 import torch.nn.functional as F
27 import torch.utils.checkpoint
28 from torch import nn
29 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
31 from transformers.activations import ACT2FN
32 from transformers.cache_utils import Cache, DynamicCache
33 from transformers.modeling_attn_mask_utils import (
34 AttentionMaskConverter,
35 _prepare_4d_attention_mask,
36 _prepare_4d_causal_attention_mask,
37 )
38 from transformers.modeling_outputs import (
39 BaseModelOutputWithPast,
40 CausalLMOutputWithPast,
41 SequenceClassifierOutputWithPast,
42 )
43 from transformers.modeling_utils import PreTrainedModel
44 from transformers.pytorch_utils import (
45 ALL_LAYERNORM_LAYERS,
46 is_torch_greater_or_equal_than_1_13,
47 )
48 from transformers.utils import (
49 add_start_docstrings,
50 add_start_docstrings_to_model_forward,
51 is_flash_attn_2_available,
52 is_flash_attn_greater_or_equal_2_10,
53 logging,
54 replace_return_docstrings,
55 )
56 from transformers.utils.import_utils import is_torch_fx_available
57 from .configuration_deepseek import DeepseekV3Config
58 import torch.distributed as dist
59 import numpy as np
60
61 if is_flash_attn_2_available():
62 from flash_attn import flash_attn_func, flash_attn_varlen_func
63 from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
64
65
66 # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
67 # It means that the function will not be traced through and simply appear as a node in the graph.
68 if is_torch_fx_available():
69 if not is_torch_greater_or_equal_than_1_13:
70 import torch.fx
71
72 _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
73
74
75 logger = logging.get_logger(__name__)
76
77 _CONFIG_FOR_DOC = "DeepseekV3Config"
78
79
80 def _get_unpad_data(attention_mask):
81 seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
82 indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
83 max_seqlen_in_batch = seqlens_in_batch.max().item()
84 cu_seqlens = F.pad(
85 torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
86 )
87 return (
88 indices,
89 cu_seqlens,
90 max_seqlen_in_batch,
91 )
92
93
94 class DeepseekV3RMSNorm(nn.Module):
95 def __init__(self, hidden_size, eps=1e-6):
96 """
97 DeepseekV3RMSNorm is equivalent to T5LayerNorm
98 """
99 super().__init__()
100 self.weight = nn.Parameter(torch.ones(hidden_size))
101 self.variance_epsilon = eps
102
103 def forward(self, hidden_states):
104 input_dtype = hidden_states.dtype
105 hidden_states = hidden_states.to(torch.float32)
106 variance = hidden_states.pow(2).mean(-1, keepdim=True)
107 hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
108 return self.weight * hidden_states.to(input_dtype)
109
110
111 ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm)
112
113
114 class DeepseekV3RotaryEmbedding(nn.Module):
115 def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
116 super().__init__()
117
118 self.dim = dim
119 self.max_position_embeddings = max_position_embeddings
120 self.base = base
121 inv_freq = 1.0 / (
122 self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
123 )
124 self.register_buffer("inv_freq", inv_freq, persistent=False)
125
126 # Build here to make `torch.jit.trace` work.
127 self._set_cos_sin_cache(
128 seq_len=max_position_embeddings,
129 device=self.inv_freq.device,
130 dtype=torch.get_default_dtype(),
131 )
132 self.max_seq_len_cached = None
133
134 def _set_cos_sin_cache(self, seq_len, device, dtype):
135 self.max_seq_len_cached = seq_len
136 t = torch.arange(
137 self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
138 )
139
140 freqs = torch.outer(t, self.inv_freq.to(t.device))
141 # Different from paper, but it uses a different permutation in order to obtain the same calculation
142 emb = torch.cat((freqs, freqs), dim=-1)
143 self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
144 self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
145
146 def forward(self, x, seq_len=None):
147 # x: [bs, num_attention_heads, seq_len, head_size]
148 if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
149 self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
150
151 return (
152 self.cos_cached[:seq_len].to(dtype=x.dtype),
153 self.sin_cached[:seq_len].to(dtype=x.dtype),
154 )
155
156
157 # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3
158 class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
159 """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
160
161 def __init__(
162 self,
163 dim,
164 max_position_embeddings=2048,
165 base=10000,
166 device=None,
167 scaling_factor=1.0,
168 ):
169 self.scaling_factor = scaling_factor
170 super().__init__(dim, max_position_embeddings, base, device)
171
172 def _set_cos_sin_cache(self, seq_len, device, dtype):
173 self.max_seq_len_cached = seq_len
174 t = torch.arange(
175 self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
176 )
177 t = t / self.scaling_factor
178
179 freqs = torch.outer(t, self.inv_freq)
180 # Different from paper, but it uses a different permutation in order to obtain the same calculation
181 emb = torch.cat((freqs, freqs), dim=-1)
182 self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
183 self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
184
185
186 # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3
187 class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
188 """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
189
190 def __init__(
191 self,
192 dim,
193 max_position_embeddings=2048,
194 base=10000,
195 device=None,
196 scaling_factor=1.0,
197 ):
198 self.scaling_factor = scaling_factor
199 super().__init__(dim, max_position_embeddings, base, device)
200
201 def _set_cos_sin_cache(self, seq_len, device, dtype):
202 self.max_seq_len_cached = seq_len
203
204 if seq_len > self.max_position_embeddings:
205 base = self.base * (
206 (self.scaling_factor * seq_len / self.max_position_embeddings)
207 - (self.scaling_factor - 1)
208 ) ** (self.dim / (self.dim - 2))
209 inv_freq = 1.0 / (
210 base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
211 )
212 self.register_buffer("inv_freq", inv_freq, persistent=False)
213
214 t = torch.arange(
215 self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
216 )
217
218 freqs = torch.outer(t, self.inv_freq)
219 # Different from paper, but it uses a different permutation in order to obtain the same calculation
220 emb = torch.cat((freqs, freqs), dim=-1)
221 self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
222 self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
223
224
225 # Inverse dim formula to find dim based on number of rotations
226 def yarn_find_correction_dim(
227 num_rotations, dim, base=10000, max_position_embeddings=2048
228 ):
229 return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
230 2 * math.log(base)
231 )
232
233
234 # Find dim range bounds based on rotations
235 def yarn_find_correction_range(
236 low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
237 ):
238 low = math.floor(
239 yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
240 )
241 high = math.ceil(
242 yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
243 )
244 return max(low, 0), min(high, dim - 1) # Clamp values just in case
245
246
247 def yarn_get_mscale(scale=1, mscale=1):
248 if scale <= 1:
249 return 1.0
250 return 0.1 * mscale * math.log(scale) + 1.0
251
252
253 def yarn_linear_ramp_mask(min, max, dim):
254 if min == max:
255 max += 0.001 # Prevent singularity
256
257 linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
258 ramp_func = torch.clamp(linear_func, 0, 1)
259 return ramp_func
260
261
262 class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):
263
264 def __init__(
265 self,
266 dim,
267 max_position_embeddings=2048,
268 base=10000,
269 device=None,
270 scaling_factor=1.0,
271 original_max_position_embeddings=4096,
272 beta_fast=32,
273 beta_slow=1,
274 mscale=1,
275 mscale_all_dim=0,
276 ):
277 self.scaling_factor = scaling_factor
278 self.original_max_position_embeddings = original_max_position_embeddings
279 self.beta_fast = beta_fast
280 self.beta_slow = beta_slow
281 self.mscale = mscale
282 self.mscale_all_dim = mscale_all_dim
283 super().__init__(dim, max_position_embeddings, base, device)
284
285 def _set_cos_sin_cache(self, seq_len, device, dtype):
286 self.max_seq_len_cached = seq_len
287 dim = self.dim
288
289 freq_extra = 1.0 / (
290 self.base
291 ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
292 )
293 freq_inter = 1.0 / (
294 self.scaling_factor
295 * self.base
296 ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
297 )
298
299 low, high = yarn_find_correction_range(
300 self.beta_fast,
301 self.beta_slow,
302 dim,
303 self.base,
304 self.original_max_position_embeddings,
305 )
306 inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
307 device=device, dtype=torch.float32
308 )
309 inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
310 self.register_buffer("inv_freq", inv_freq, persistent=False)
311
312 t = torch.arange(seq_len, device=device, dtype=torch.float32)
313
314 freqs = torch.outer(t, inv_freq)
315
316 _mscale = float(
317 yarn_get_mscale(self.scaling_factor, self.mscale)
318 / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
319 )
320
321 emb = torch.cat((freqs, freqs), dim=-1)
322 self.register_buffer(
323 "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
324 )
325 self.register_buffer(
326 "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
327 )
328
329
330 # Copied from transformers.models.llama.modeling_llama.rotate_half
331 def rotate_half(x):
332 """Rotates half the hidden dims of the input."""
333 x1 = x[..., : x.shape[-1] // 2]
334 x2 = x[..., x.shape[-1] // 2 :]
335 return torch.cat((-x2, x1), dim=-1)
336
337
338 # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
339 def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
340 """Applies Rotary Position Embedding to the query and key tensors.
341
342 Args:
343 q (`torch.Tensor`): The query tensor.
344 k (`torch.Tensor`): The key tensor.
345 cos (`torch.Tensor`): The cosine part of the rotary embedding.
346 sin (`torch.Tensor`): The sine part of the rotary embedding.
347 position_ids (`torch.Tensor`):
348 The position indices of the tokens corresponding to the query and key tensors. For example, this can be
349 used to pass offsetted position ids when working with a KV-cache.
350 unsqueeze_dim (`int`, *optional*, defaults to 1):
351 The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
352 sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
353 that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
354 k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
355 cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
356 the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
357 Returns:
358 `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
359 """
360 cos = cos[position_ids].unsqueeze(unsqueeze_dim)
361 sin = sin[position_ids].unsqueeze(unsqueeze_dim)
362
363 b, h, s, d = q.shape
364 q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
365
366 b, h, s, d = k.shape
367 k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
368
369 q_embed = (q * cos) + (rotate_half(q) * sin)
370 k_embed = (k * cos) + (rotate_half(k) * sin)
371 return q_embed, k_embed
372
373
374 class DeepseekV3MLP(nn.Module):
375 def __init__(self, config, hidden_size=None, intermediate_size=None):
376 super().__init__()
377 self.config = config
378 self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
379 self.intermediate_size = (
380 config.intermediate_size if intermediate_size is None else intermediate_size
381 )
382
383 self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
384 self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
385 self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
386 self.act_fn = ACT2FN[config.hidden_act]
387
388 def forward(self, x):
389 down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
390 return down_proj
391
392
393 class MoEGate(nn.Module):
394 def __init__(self, config):
395 super().__init__()
396 self.config = config
397 self.top_k = config.num_experts_per_tok
398 self.n_routed_experts = config.n_routed_experts
399 self.routed_scaling_factor = config.routed_scaling_factor
400 self.scoring_func = config.scoring_func
401 self.topk_method = config.topk_method
402 self.n_group = config.n_group
403 self.topk_group = config.topk_group
404
405 # topk selection algorithm
406 self.norm_topk_prob = config.norm_topk_prob
407 self.gating_dim = config.hidden_size
408 self.weight = nn.Parameter(
409 torch.empty((self.n_routed_experts, self.gating_dim))
410 )
411 if self.topk_method == "noaux_tc":
412 self.e_score_correction_bias = nn.Parameter(
413 torch.empty((self.n_routed_experts))
414 )
415 self.reset_parameters()
416
417 def reset_parameters(self) -> None:
418 import torch.nn.init as init
419
420 init.kaiming_uniform_(self.weight, a=math.sqrt(5))
421
422 def forward(self, hidden_states):
423 bsz, seq_len, h = hidden_states.shape
424 ### compute gating score
425 hidden_states = hidden_states.view(-1, h)
426 logits = F.linear(
427 hidden_states.type(torch.float32), self.weight.type(torch.float32), None
428 )
429 if self.scoring_func == "sigmoid":
430 scores = logits.sigmoid()
431 else:
432 raise NotImplementedError(
433 f"insupportable scoring function for MoE gating: {self.scoring_func}"
434 )
435
436 ### select top-k experts
437 if self.topk_method == "noaux_tc":
438 assert not self.training
439 scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
440 group_scores = (
441 scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1)
442 ) # [n, n_group]
443 group_idx = torch.topk(
444 group_scores, k=self.topk_group, dim=-1, sorted=False
445 )[
446 1
447 ] # [n, top_k_group]
448 group_mask = torch.zeros_like(group_scores) # [n, n_group]
449 group_mask.scatter_(1, group_idx, 1) # [n, n_group]
450 score_mask = (
451 group_mask.unsqueeze(-1)
452 .expand(
453 bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
454 )
455 .reshape(bsz * seq_len, -1)
456 ) # [n, e]
457 tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
458 _, topk_idx = torch.topk(
459 tmp_scores, k=self.top_k, dim=-1, sorted=False
460 )
461 topk_weight = scores.gather(1, topk_idx)
462 else:
463 raise NotImplementedError(
464 f"insupportable TopK function for MoE gating: {self.topk_method}"
465 )
466
467 ### norm gate to sum 1
468 if self.top_k > 1 and self.norm_topk_prob:
469 denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
470 topk_weight = topk_weight / denominator
471 topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
472
473 return topk_idx, topk_weight
474
475 class DeepseekV3MoE(nn.Module):
476 """
477 A mixed expert module containing shared experts.
478 """
479
480 def __init__(self, config):
481 super().__init__()
482 self.config = config
483 self.num_experts_per_tok = config.num_experts_per_tok
484
485 if hasattr(config, "ep_size") and config.ep_size > 1:
486 assert config.ep_size == dist.get_world_size()
487 self.ep_size = config.ep_size
488 self.experts_per_rank = config.n_routed_experts // config.ep_size
489 self.ep_rank = dist.get_rank()
490 self.experts = nn.ModuleList(
491 [
492 (
493 DeepseekV3MLP(
494 config, intermediate_size=config.moe_intermediate_size
495 )
496 if i >= self.ep_rank * self.experts_per_rank
497 and i < (self.ep_rank + 1) * self.experts_per_rank
498 else None
499 )
500 for i in range(config.n_routed_experts)
501 ]
502 )
503 else:
504 self.ep_size = 1
505 self.experts_per_rank = config.n_routed_experts
506 self.ep_rank = 0
507 self.experts = nn.ModuleList(
508 [
509 DeepseekV3MLP(
510 config, intermediate_size=config.moe_intermediate_size
511 )
512 for i in range(config.n_routed_experts)
513 ]
514 )
515 self.gate = MoEGate(config)
516 if config.n_shared_experts is not None:
517 intermediate_size = config.moe_intermediate_size * config.n_shared_experts
518 self.shared_experts = DeepseekV3MLP(
519 config=config, intermediate_size=intermediate_size
520 )
521
522 def forward(self, hidden_states):
523 identity = hidden_states
524 orig_shape = hidden_states.shape
525 topk_idx, topk_weight = self.gate(hidden_states)
526 hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
527 flat_topk_idx = topk_idx.view(-1)
528 if not self.training:
529 y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
530 if self.config.n_shared_experts is not None:
531 y = y + self.shared_experts(identity)
532 return y
533
534 @torch.no_grad()
535 def moe_infer(self, x, topk_ids, topk_weight):
536 cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
537 cnts.scatter_(1, topk_ids, 1)
538 tokens_per_expert = cnts.sum(dim=0)
539 idxs = topk_ids.view(-1).argsort()
540 sorted_tokens = x[idxs // topk_ids.shape[1]]
541 sorted_tokens_shape = sorted_tokens.shape
542 if self.ep_size > 1:
543 tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
544 tokens_per_expert_group = tokens_per_expert.new_empty(
545 tokens_per_expert.shape[0]
546 )
547 dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
548 output_splits = (
549 tokens_per_expert_group.view(self.ep_size, -1)
550 .sum(1)
551 .cpu()
552 .numpy()
553 .tolist()
554 )
555 gathered_tokens = sorted_tokens.new_empty(
556 tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
557 )
558 input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
559 dist.all_to_all(
560 list(gathered_tokens.split(output_splits)),
561 list(sorted_tokens.split(input_split_sizes)),
562 )
563 tokens_per_expert_post_gather = tokens_per_expert_group.view(
564 self.ep_size, self.experts_per_rank
565 ).sum(dim=0)
566 gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
567 s = 0
568 for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
569 gatherd_idxs[s : s + k] = i % self.experts_per_rank
570 s += k
571 gatherd_idxs = gatherd_idxs.argsort()
572 sorted_tokens = gathered_tokens[gatherd_idxs]
573 tokens_per_expert = tokens_per_expert_post_gather
574 tokens_per_expert = tokens_per_expert.cpu().numpy()
575
576 outputs = []
577 start_idx = 0
578 for i, num_tokens in enumerate(tokens_per_expert):
579 end_idx = start_idx + num_tokens
580 if num_tokens == 0:
581 continue
582 expert = self.experts[i + self.ep_rank * self.experts_per_rank]
583 tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
584 expert_out = expert(tokens_for_this_expert)
585 outputs.append(expert_out)
586 start_idx = end_idx
587
588 outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
589 if self.ep_size > 1:
590 new_x = torch.empty_like(outs)
591 new_x[gatherd_idxs] = outs
592 gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
593 dist.all_to_all(
594 list(gathered_tokens.split(input_split_sizes)),
595 list(new_x.split(output_splits)),
596 )
597 outs = gathered_tokens
598
599 new_x = torch.empty_like(outs)
600 new_x[idxs] = outs
601 final_out = (
602 new_x.view(*topk_ids.shape, -1)
603 .type(topk_weight.dtype)
604 .mul_(topk_weight.unsqueeze(dim=-1))
605 .sum(dim=1)
606 .type(new_x.dtype)
607 )
608 return final_out
609
610
611 # Copied from transformers.models.llama.modeling_llama.repeat_kv
612 def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
613 """
614 This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
615 num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
616 """
617 batch, num_key_value_heads, slen, head_dim = hidden_states.shape
618 if n_rep == 1:
619 return hidden_states
620 hidden_states = hidden_states[:, :, None, :, :].expand(
621 batch, num_key_value_heads, n_rep, slen, head_dim
622 )
623 return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
624
625
626 # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3
627 class DeepseekV3Attention(nn.Module):
628 """Multi-headed attention from 'Attention Is All You Need' paper"""
629
630 def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None):
631 super().__init__()
632 self.config = config
633 self.layer_idx = layer_idx
634 if layer_idx is None:
635 logger.warning_once(
636 f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
637 "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
638 "when creating this class."
639 )
640
641 self.attention_dropout = config.attention_dropout
642 self.hidden_size = config.hidden_size
643 self.num_heads = config.num_attention_heads
644
645 self.max_position_embeddings = config.max_position_embeddings
646 self.rope_theta = config.rope_theta
647 self.q_lora_rank = config.q_lora_rank
648 self.qk_rope_head_dim = config.qk_rope_head_dim
649 self.kv_lora_rank = config.kv_lora_rank
650 self.v_head_dim = config.v_head_dim
651 self.qk_nope_head_dim = config.qk_nope_head_dim
652 self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
653
654 self.is_causal = True
655
656 if self.q_lora_rank is None:
657 self.q_proj = nn.Linear(
658 self.hidden_size, self.num_heads * self.q_head_dim, bias=False
659 )
660 else:
661 self.q_a_proj = nn.Linear(
662 self.hidden_size, config.q_lora_rank, bias=config.attention_bias
663 )
664 self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
665 self.q_b_proj = nn.Linear(
666 config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
667 )
668
669 self.kv_a_proj_with_mqa = nn.Linear(
670 self.hidden_size,
671 config.kv_lora_rank + config.qk_rope_head_dim,
672 bias=config.attention_bias,
673 )
674 self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
675 self.kv_b_proj = nn.Linear(
676 config.kv_lora_rank,
677 self.num_heads
678 * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
679 bias=False,
680 )
681
682 self.o_proj = nn.Linear(
683 self.num_heads * self.v_head_dim,
684 self.hidden_size,
685 bias=config.attention_bias,
686 )
687 self._init_rope()
688
689 self.softmax_scale = self.q_head_dim ** (-0.5)
690 if self.config.rope_scaling is not None:
691 mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
692 scaling_factor = self.config.rope_scaling["factor"]
693 if mscale_all_dim:
694 mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
695 self.softmax_scale = self.softmax_scale * mscale * mscale
696
697 def _init_rope(self):
698 if self.config.rope_scaling is None:
699 self.rotary_emb = DeepseekV3RotaryEmbedding(
700 self.qk_rope_head_dim,
701 max_position_embeddings=self.max_position_embeddings,
702 base=self.rope_theta,
703 )
704 else:
705 scaling_type = self.config.rope_scaling["type"]
706 scaling_factor = self.config.rope_scaling["factor"]
707 if scaling_type == "linear":
708 self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(
709 self.qk_rope_head_dim,
710 max_position_embeddings=self.max_position_embeddings,
711 scaling_factor=scaling_factor,
712 base=self.rope_theta,
713 )
714 elif scaling_type == "dynamic":
715 self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding(
716 self.qk_rope_head_dim,
717 max_position_embeddings=self.max_position_embeddings,
718 scaling_factor=scaling_factor,
719 base=self.rope_theta,
720 )
721 elif scaling_type == "yarn":
722 kwargs = {
723 key: self.config.rope_scaling[key]
724 for key in [
725 "original_max_position_embeddings",
726 "beta_fast",
727 "beta_slow",
728 "mscale",
729 "mscale_all_dim",
730 ]
731 if key in self.config.rope_scaling
732 }
733 self.rotary_emb = DeepseekV3YarnRotaryEmbedding(
734 self.qk_rope_head_dim,
735 max_position_embeddings=self.max_position_embeddings,
736 scaling_factor=scaling_factor,
737 base=self.rope_theta,
738 **kwargs,
739 )
740 else:
741 raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
742
743 def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
744 return (
745 tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim)
746 .transpose(1, 2)
747 .contiguous()
748 )
749
750 def forward(
751 self,
752 hidden_states: torch.Tensor,
753 attention_mask: Optional[torch.Tensor] = None,
754 position_ids: Optional[torch.LongTensor] = None,
755 past_key_value: Optional[Cache] = None,
756 output_attentions: bool = False,
757 use_cache: bool = False,
758 **kwargs,
759 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
760 if "padding_mask" in kwargs:
761 warnings.warn(
762 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
763 )
764 bsz, q_len, _ = hidden_states.size()
765
766 if self.q_lora_rank is None:
767 q = self.q_proj(hidden_states)
768 else:
769 q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
770 q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
771 q_nope, q_pe = torch.split(
772 q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
773 )
774
775 compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
776 compressed_kv, k_pe = torch.split(
777 compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
778 )
779 k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
780 kv = (
781 self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
782 .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
783 .transpose(1, 2)
784 )
785
786 k_nope, value_states = torch.split(
787 kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
788 )
789 kv_seq_len = value_states.shape[-2]
790 if past_key_value is not None:
791 if self.layer_idx is None:
792 raise ValueError(
793 f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
794 "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
795 "with a layer index."
796 )
797 kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
798 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
799
800 q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
801
802 query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
803 query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
804 query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
805
806 key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
807 key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
808 key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
809 if past_key_value is not None:
810 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
811 key_states, value_states = past_key_value.update(
812 key_states, value_states, self.layer_idx, cache_kwargs
813 )
814
815 attn_weights = (
816 torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
817 )
818
819 if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
820 raise ValueError(
821 f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
822 f" {attn_weights.size()}"
823 )
824 assert attention_mask is not None
825 if attention_mask is not None:
826 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
827 raise ValueError(
828 f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
829 )
830 attn_weights = attn_weights + attention_mask
831
832 # upcast attention to fp32
833 attn_weights = nn.functional.softmax(
834 attn_weights, dim=-1, dtype=torch.float32
835 ).to(query_states.dtype)
836 attn_weights = nn.functional.dropout(
837 attn_weights, p=self.attention_dropout, training=self.training
838 )
839 attn_output = torch.matmul(attn_weights, value_states)
840
841 if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
842 raise ValueError(
843 f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
844 f" {attn_output.size()}"
845 )
846
847 attn_output = attn_output.transpose(1, 2).contiguous()
848
849 attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
850
851 attn_output = self.o_proj(attn_output)
852
853 if not output_attentions:
854 attn_weights = None
855
856 return attn_output, attn_weights, past_key_value
857
858
859 # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3
860 class DeepseekV3FlashAttention2(DeepseekV3Attention):
861 """
862 DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays
863 untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
864 flash attention and deal with padding tokens in case the input contains any of them.
865 """
866
867 def __init__(self, *args, **kwargs):
868 super().__init__(*args, **kwargs)
869
870 # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
871 # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
872 # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
873 self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
874
875 def forward(
876 self,
877 hidden_states: torch.Tensor,
878 attention_mask: Optional[torch.LongTensor] = None,
879 position_ids: Optional[torch.LongTensor] = None,
880 past_key_value: Optional[Cache] = None,
881 output_attentions: bool = False,
882 use_cache: bool = False,
883 **kwargs,
884 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
885 # DeepseekV3FlashAttention2 attention does not support output_attentions
886 if "padding_mask" in kwargs:
887 warnings.warn(
888 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
889 )
890
891 # overwrite attention_mask with padding_mask
892 attention_mask = kwargs.pop("padding_mask")
893
894 output_attentions = False
895
896 bsz, q_len, _ = hidden_states.size()
897
898 if self.q_lora_rank is None:
899 q = self.q_proj(hidden_states)
900 else:
901 q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
902 q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
903 q_nope, q_pe = torch.split(
904 q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
905 )
906
907 # Flash attention requires the input to have the shape
908 # batch_size x seq_length x head_dim x hidden_dim
909 # therefore we just need to keep the original shape
910 compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
911 compressed_kv, k_pe = torch.split(
912 compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
913 )
914 k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
915 kv = (
916 self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
917 .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
918 .transpose(1, 2)
919 )
920
921 k_nope, value_states = torch.split(
922 kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
923 )
924 kv_seq_len = value_states.shape[-2]
925
926 kv_seq_len = value_states.shape[-2]
927 if past_key_value is not None:
928 kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
929
930 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
931 q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
932
933 query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
934 query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
935 query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
936
937 key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
938 key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
939 key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
940
941 if self.q_head_dim != self.v_head_dim:
942 value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
943
944 if past_key_value is not None:
945 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
946 key_states, value_states = past_key_value.update(
947 key_states, value_states, self.layer_idx, cache_kwargs
948 )
949
950 # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
951 # to be able to avoid many of these transpose/reshape/view.
952 query_states = query_states.transpose(1, 2)
953 key_states = key_states.transpose(1, 2)
954 value_states = value_states.transpose(1, 2)
955
956 dropout_rate = self.attention_dropout if self.training else 0.0
957
958 # In PEFT, usually we cast the layer norms in float32 for training stability reasons
959 # therefore the input hidden states gets silently casted in float32. Hence, we need
960 # cast them back in the correct dtype just to be sure everything works as expected.
961 # This might slowdown training & inference so it is recommended to not cast the LayerNorms
962 # in fp32. (DeepseekV3RMSNorm handles it correctly)
963
964 input_dtype = query_states.dtype
965 if input_dtype == torch.float32:
966 # Handle the case where the model is quantized
967 if hasattr(self.config, "_pre_quantization_dtype"):
968 target_dtype = self.config._pre_quantization_dtype
969 elif torch.is_autocast_enabled():
970 target_dtype = torch.get_autocast_gpu_dtype()
971 else:
972 target_dtype = (
973 self.q_proj.weight.dtype
974 if self.q_lora_rank is None
975 else self.q_a_proj.weight.dtype
976 )
977
978 logger.warning_once(
979 f"The input hidden states seems to be silently casted in float32, this might be related to"
980 f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
981 f" {target_dtype}."
982 )
983
984 query_states = query_states.to(target_dtype)
985 key_states = key_states.to(target_dtype)
986 value_states = value_states.to(target_dtype)
987
988 attn_output = self._flash_attention_forward(
989 query_states,
990 key_states,
991 value_states,
992 attention_mask,
993 q_len,
994 dropout=dropout_rate,
995 softmax_scale=self.softmax_scale,
996 )
997 if self.q_head_dim != self.v_head_dim:
998 attn_output = attn_output[:, :, :, : self.v_head_dim]
999
1000 attn_output = attn_output.reshape(
1001 bsz, q_len, self.num_heads * self.v_head_dim
1002 ).contiguous()
1003 attn_output = self.o_proj(attn_output)
1004
1005 if not output_attentions:
1006 attn_weights = None
1007
1008 return attn_output, attn_weights, past_key_value
1009
1010 def _flash_attention_forward(
1011 self,
1012 query_states,
1013 key_states,
1014 value_states,
1015 attention_mask,
1016 query_length,
1017 dropout=0.0,
1018 softmax_scale=None,
1019 ):
1020 """
1021 Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1022 first unpad the input, then computes the attention scores and pad the final attention scores.
1023
1024 Args:
1025 query_states (`torch.Tensor`):
1026 Input query states to be passed to Flash Attention API
1027 key_states (`torch.Tensor`):
1028 Input key states to be passed to Flash Attention API
1029 value_states (`torch.Tensor`):
1030 Input value states to be passed to Flash Attention API
1031 attention_mask (`torch.Tensor`):
1032 The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
1033 position of padding tokens and 1 for the position of non-padding tokens.
1034 dropout (`int`, *optional*):
1035 Attention dropout
1036 softmax_scale (`float`, *optional*):
1037 The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1038 """
1039 if not self._flash_attn_uses_top_left_mask:
1040 causal = self.is_causal
1041 else:
1042 # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__.
1043 causal = self.is_causal and query_length != 1
1044
1045 # Contains at least one padding token in the sequence
1046 if attention_mask is not None:
1047 batch_size = query_states.shape[0]
1048 (
1049 query_states,
1050 key_states,
1051 value_states,
1052 indices_q,
1053 cu_seq_lens,
1054 max_seq_lens,
1055 ) = self._upad_input(
1056 query_states, key_states, value_states, attention_mask, query_length
1057 )
1058
1059 cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1060 max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1061
1062 attn_output_unpad = flash_attn_varlen_func(
1063 query_states,
1064 key_states,
1065 value_states,
1066 cu_seqlens_q=cu_seqlens_q,
1067 cu_seqlens_k=cu_seqlens_k,
1068 max_seqlen_q=max_seqlen_in_batch_q,
1069 max_seqlen_k=max_seqlen_in_batch_k,
1070 dropout_p=dropout,
1071 softmax_scale=softmax_scale,
1072 causal=causal,
1073 )
1074
1075 attn_output = pad_input(
1076 attn_output_unpad, indices_q, batch_size, query_length
1077 )
1078 else:
1079 attn_output = flash_attn_func(
1080 query_states,
1081 key_states,
1082 value_states,
1083 dropout,
1084 softmax_scale=softmax_scale,
1085 causal=causal,
1086 )
1087
1088 return attn_output
1089
1090 def _upad_input(
1091 self, query_layer, key_layer, value_layer, attention_mask, query_length
1092 ):
1093 indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
1094 batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1095
1096 key_layer = index_first_axis(
1097 key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1098 indices_k,
1099 )
1100 value_layer = index_first_axis(
1101 value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1102 indices_k,
1103 )
1104 if query_length == kv_seq_len:
1105 query_layer = index_first_axis(
1106 query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
1107 indices_k,
1108 )
1109 cu_seqlens_q = cu_seqlens_k
1110 max_seqlen_in_batch_q = max_seqlen_in_batch_k
1111 indices_q = indices_k
1112 elif query_length == 1:
1113 max_seqlen_in_batch_q = 1
1114 cu_seqlens_q = torch.arange(
1115 batch_size + 1, dtype=torch.int32, device=query_layer.device
1116 ) # There is a memcpy here, that is very bad.
1117 indices_q = cu_seqlens_q[:-1]
1118 query_layer = query_layer.squeeze(1)
1119 else:
1120 # The -q_len: slice assumes left padding.
1121 attention_mask = attention_mask[:, -query_length:]
1122 query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
1123 query_layer, attention_mask
1124 )
1125
1126 return (
1127 query_layer,
1128 key_layer,
1129 value_layer,
1130 indices_q,
1131 (cu_seqlens_q, cu_seqlens_k),
1132 (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1133 )
1134
1135
1136 ATTENTION_CLASSES = {
1137 "eager": DeepseekV3Attention,
1138 "flash_attention_2": DeepseekV3FlashAttention2,
1139 }
1140
1141
1142 class DeepseekV3DecoderLayer(nn.Module):
1143 def __init__(self, config: DeepseekV3Config, layer_idx: int):
1144 super().__init__()
1145 self.hidden_size = config.hidden_size
1146
1147 self.self_attn = ATTENTION_CLASSES[config._attn_implementation](
1148 config=config, layer_idx=layer_idx
1149 )
1150
1151 self.mlp = (
1152 DeepseekV3MoE(config)
1153 if (
1154 config.n_routed_experts is not None
1155 and layer_idx >= config.first_k_dense_replace
1156 and layer_idx % config.moe_layer_freq == 0
1157 )
1158 else DeepseekV3MLP(config)
1159 )
1160 self.input_layernorm = DeepseekV3RMSNorm(
1161 config.hidden_size, eps=config.rms_norm_eps
1162 )
1163 self.post_attention_layernorm = DeepseekV3RMSNorm(
1164 config.hidden_size, eps=config.rms_norm_eps
1165 )
1166
1167 def forward(
1168 self,
1169 hidden_states: torch.Tensor,
1170 attention_mask: Optional[torch.Tensor] = None,
1171 position_ids: Optional[torch.LongTensor] = None,
1172 past_key_value: Optional[Tuple[torch.Tensor]] = None,
1173 output_attentions: Optional[bool] = False,
1174 use_cache: Optional[bool] = False,
1175 **kwargs,
1176 ) -> Tuple[
1177 torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
1178 ]:
1179 """
1180 Args:
1181 hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1182 attention_mask (`torch.FloatTensor`, *optional*):
1183 attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1184 query_sequence_length, key_sequence_length)` if default attention is used.
1185 output_attentions (`bool`, *optional*):
1186 Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1187 returned tensors for more detail.
1188 use_cache (`bool`, *optional*):
1189 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1190 (see `past_key_values`).
1191 past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1192 """
1193 if "padding_mask" in kwargs:
1194 warnings.warn(
1195 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1196 )
1197 residual = hidden_states
1198
1199 hidden_states = self.input_layernorm(hidden_states)
1200
1201 # Self Attention
1202 hidden_states, self_attn_weights, present_key_value = self.self_attn(
1203 hidden_states=hidden_states,
1204 attention_mask=attention_mask,
1205 position_ids=position_ids,
1206 past_key_value=past_key_value,
1207 output_attentions=output_attentions,
1208 use_cache=use_cache,
1209 **kwargs,
1210 )
1211 hidden_states = residual + hidden_states
1212
1213 # Fully Connected
1214 residual = hidden_states
1215 hidden_states = self.post_attention_layernorm(hidden_states)
1216 hidden_states = self.mlp(hidden_states)
1217 hidden_states = residual + hidden_states
1218
1219 outputs = (hidden_states,)
1220
1221 if output_attentions:
1222 outputs += (self_attn_weights,)
1223
1224 if use_cache:
1225 outputs += (present_key_value,)
1226
1227 return outputs
1228
1229
1230 DeepseekV3_START_DOCSTRING = r"""
1231 This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1232 library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1233 etc.)
1234
1235 This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1236 Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1237 and behavior.
1238
1239 Parameters:
1240 config ([`DeepseekV3Config`]):
1241 Model configuration class with all the parameters of the model. Initializing with a config file does not
1242 load the weights associated with the model, only the configuration. Check out the
1243 [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1244 """
1245
1246
1247 @add_start_docstrings(
1248 "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1249 DeepseekV3_START_DOCSTRING,
1250 )
1251 class DeepseekV3PreTrainedModel(PreTrainedModel):
1252 config_class = DeepseekV3Config
1253 base_model_prefix = "model"
1254 supports_gradient_checkpointing = True
1255 _no_split_modules = ["DeepseekV3DecoderLayer"]
1256 _skip_keys_device_placement = "past_key_values"
1257 _supports_flash_attn_2 = True
1258 _supports_cache_class = True
1259
1260 def _init_weights(self, module):
1261 std = self.config.initializer_range
1262 if isinstance(module, nn.Linear):
1263 module.weight.data.normal_(mean=0.0, std=std)
1264 if module.bias is not None:
1265 module.bias.data.zero_()
1266 elif isinstance(module, nn.Embedding):
1267 module.weight.data.normal_(mean=0.0, std=std)
1268 if module.padding_idx is not None:
1269 module.weight.data[module.padding_idx].zero_()
1270
1271
1272 DeepseekV3_INPUTS_DOCSTRING = r"""
1273 Args:
1274 input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1275 Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1276 it.
1277
1278 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1279 [`PreTrainedTokenizer.__call__`] for details.
1280
1281 [What are input IDs?](../glossary#input-ids)
1282 attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1283 Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1284
1285 - 1 for tokens that are **not masked**,
1286 - 0 for tokens that are **masked**.
1287
1288 [What are attention masks?](../glossary#attention-mask)
1289
1290 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1291 [`PreTrainedTokenizer.__call__`] for details.
1292
1293 If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1294 `past_key_values`).
1295
1296 If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1297 and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1298 information on the default strategy.
1299
1300 - 1 indicates the head is **not masked**,
1301 - 0 indicates the head is **masked**.
1302 position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1303 Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1304 config.n_positions - 1]`.
1305
1306 [What are position IDs?](../glossary#position-ids)
1307 past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1308 Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1309 blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1310 returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1311
1312 Two formats are allowed:
1313 - a [`~cache_utils.Cache`] instance;
1314 - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1315 shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1316 cache format.
1317
1318 The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1319 legacy cache format will be returned.
1320
1321 If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1322 have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1323 of shape `(batch_size, sequence_length)`.
1324 inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1325 Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1326 is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1327 model's internal embedding lookup matrix.
1328 use_cache (`bool`, *optional*):
1329 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1330 `past_key_values`).
1331 output_attentions (`bool`, *optional*):
1332 Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1333 tensors for more detail.
1334 output_hidden_states (`bool`, *optional*):
1335 Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1336 more detail.
1337 return_dict (`bool`, *optional*):
1338 Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1339 """
1340
1341
1342 @add_start_docstrings(
1343 "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1344 DeepseekV3_START_DOCSTRING,
1345 )
1346 class DeepseekV3Model(DeepseekV3PreTrainedModel):
1347 """
1348 Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]
1349
1350 Args:
1351 config: DeepseekV3Config
1352 """
1353
1354 def __init__(self, config: DeepseekV3Config):
1355 super().__init__(config)
1356 self.padding_idx = config.pad_token_id
1357 self.vocab_size = config.vocab_size
1358
1359 self.embed_tokens = nn.Embedding(
1360 config.vocab_size, config.hidden_size, self.padding_idx
1361 )
1362 self.layers = nn.ModuleList(
1363 [
1364 DeepseekV3DecoderLayer(config, layer_idx)
1365 for layer_idx in range(config.num_hidden_layers)
1366 ]
1367 )
1368 self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1369 self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1370
1371 self.gradient_checkpointing = False
1372 # Initialize weights and apply final processing
1373 self.post_init()
1374
1375 def get_input_embeddings(self):
1376 return self.embed_tokens
1377
1378 def set_input_embeddings(self, value):
1379 self.embed_tokens = value
1380
1381 @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1382 def forward(
1383 self,
1384 input_ids: torch.LongTensor = None,
1385 attention_mask: Optional[torch.Tensor] = None,
1386 position_ids: Optional[torch.LongTensor] = None,
1387 past_key_values: Optional[List[torch.FloatTensor]] = None,
1388 inputs_embeds: Optional[torch.FloatTensor] = None,
1389 use_cache: Optional[bool] = None,
1390 output_attentions: Optional[bool] = None,
1391 output_hidden_states: Optional[bool] = None,
1392 return_dict: Optional[bool] = None,
1393 ) -> Union[Tuple, BaseModelOutputWithPast]:
1394 output_attentions = (
1395 output_attentions
1396 if output_attentions is not None
1397 else self.config.output_attentions
1398 )
1399 output_hidden_states = (
1400 output_hidden_states
1401 if output_hidden_states is not None
1402 else self.config.output_hidden_states
1403 )
1404 use_cache = use_cache if use_cache is not None else self.config.use_cache
1405
1406 return_dict = (
1407 return_dict if return_dict is not None else self.config.use_return_dict
1408 )
1409
1410 # retrieve input_ids and inputs_embeds
1411 if input_ids is not None and inputs_embeds is not None:
1412 raise ValueError(
1413 "You cannot specify both input_ids and inputs_embeds at the same time"
1414 )
1415 elif input_ids is not None:
1416 batch_size, seq_length = input_ids.shape[:2]
1417 elif inputs_embeds is not None:
1418 batch_size, seq_length = inputs_embeds.shape[:2]
1419 else:
1420 raise ValueError("You have to specify either input_ids or inputs_embeds")
1421
1422 past_key_values_length = 0
1423 if use_cache:
1424 use_legacy_cache = not isinstance(past_key_values, Cache)
1425 if use_legacy_cache:
1426 past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1427 past_key_values_length = past_key_values.get_usable_length(seq_length)
1428
1429 if position_ids is None:
1430 device = input_ids.device if input_ids is not None else inputs_embeds.device
1431 position_ids = torch.arange(
1432 past_key_values_length,
1433 seq_length + past_key_values_length,
1434 dtype=torch.long,
1435 device=device,
1436 )
1437 position_ids = position_ids.unsqueeze(0)
1438
1439 if inputs_embeds is None:
1440 inputs_embeds = self.embed_tokens(input_ids)
1441
1442 if self._use_flash_attention_2:
1443 # 2d mask is passed through the layers
1444 attention_mask = (
1445 attention_mask
1446 if (attention_mask is not None and 0 in attention_mask)
1447 else None
1448 )
1449 else:
1450 # 4d mask is passed through the layers
1451 attention_mask = _prepare_4d_causal_attention_mask(
1452 attention_mask,
1453 (batch_size, seq_length),
1454 inputs_embeds,
1455 past_key_values_length,
1456 )
1457
1458 # embed positions
1459 hidden_states = inputs_embeds
1460
1461 # decoder layers
1462 all_hidden_states = () if output_hidden_states else None
1463 all_self_attns = () if output_attentions else None
1464 next_decoder_cache = None
1465
1466 for decoder_layer in self.layers:
1467 if output_hidden_states:
1468 all_hidden_states += (hidden_states,)
1469
1470 layer_outputs = decoder_layer(
1471 hidden_states,
1472 attention_mask=attention_mask,
1473 position_ids=position_ids,
1474 past_key_value=past_key_values,
1475 output_attentions=output_attentions,
1476 use_cache=use_cache,
1477 )
1478
1479 hidden_states = layer_outputs[0]
1480
1481 if use_cache:
1482 next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1483
1484 if output_attentions:
1485 all_self_attns += (layer_outputs[1],)
1486
1487 hidden_states = self.norm(hidden_states)
1488
1489 # add hidden states from the last decoder layer
1490 if output_hidden_states:
1491 all_hidden_states += (hidden_states,)
1492
1493 next_cache = None
1494 if use_cache:
1495 next_cache = (
1496 next_decoder_cache.to_legacy_cache()
1497 if use_legacy_cache
1498 else next_decoder_cache
1499 )
1500 if not return_dict:
1501 return tuple(
1502 v
1503 for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1504 if v is not None
1505 )
1506 return BaseModelOutputWithPast(
1507 last_hidden_state=hidden_states,
1508 past_key_values=next_cache,
1509 hidden_states=all_hidden_states,
1510 attentions=all_self_attns,
1511 )
1512
1513
1514 class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
1515 _tied_weights_keys = ["lm_head.weight"]
1516
1517 def __init__(self, config):
1518 super().__init__(config)
1519 self.model = DeepseekV3Model(config)
1520 self.vocab_size = config.vocab_size
1521 self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1522
1523 # Initialize weights and apply final processing
1524 self.post_init()
1525
1526 def get_input_embeddings(self):
1527 return self.model.embed_tokens
1528
1529 def set_input_embeddings(self, value):
1530 self.model.embed_tokens = value
1531
1532 def get_output_embeddings(self):
1533 return self.lm_head
1534
1535 def set_output_embeddings(self, new_embeddings):
1536 self.lm_head = new_embeddings
1537
1538 def set_decoder(self, decoder):
1539 self.model = decoder
1540
1541 def get_decoder(self):
1542 return self.model
1543
1544 @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1545 @replace_return_docstrings(
1546 output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1547 )
1548 def forward(
1549 self,
1550 input_ids: torch.LongTensor = None,
1551 attention_mask: Optional[torch.Tensor] = None,
1552 position_ids: Optional[torch.LongTensor] = None,
1553 past_key_values: Optional[List[torch.FloatTensor]] = None,
1554 inputs_embeds: Optional[torch.FloatTensor] = None,
1555 labels: Optional[torch.LongTensor] = None,
1556 use_cache: Optional[bool] = None,
1557 output_attentions: Optional[bool] = None,
1558 output_hidden_states: Optional[bool] = None,
1559 return_dict: Optional[bool] = None,
1560 ) -> Union[Tuple, CausalLMOutputWithPast]:
1561 r"""
1562 Args:
1563 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1564 Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
1565 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1566 (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
1567
1568 Returns:
1569
1570 Example:
1571
1572 ```python
1573 >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
1574
1575 >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1576 >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1577
1578 >>> prompt = "Hey, are you conscious? Can you talk to me?"
1579 >>> inputs = tokenizer(prompt, return_tensors="pt")
1580
1581 >>> # Generate
1582 >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1583 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1584 "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1585 ```"""
1586 output_attentions = (
1587 output_attentions
1588 if output_attentions is not None
1589 else self.config.output_attentions
1590 )
1591 output_hidden_states = (
1592 output_hidden_states
1593 if output_hidden_states is not None
1594 else self.config.output_hidden_states
1595 )
1596 return_dict = (
1597 return_dict if return_dict is not None else self.config.use_return_dict
1598 )
1599
1600 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1601 outputs = self.model(
1602 input_ids=input_ids,
1603 attention_mask=attention_mask,
1604 position_ids=position_ids,
1605 past_key_values=past_key_values,
1606 inputs_embeds=inputs_embeds,
1607 use_cache=use_cache,
1608 output_attentions=output_attentions,
1609 output_hidden_states=output_hidden_states,
1610 return_dict=return_dict,
1611 )
1612
1613 hidden_states = outputs[0]
1614 logits = self.lm_head(hidden_states)
1615 logits = logits.float()
1616
1617 loss = None
1618 if labels is not None:
1619 # Shift so that tokens < n predict n
1620 shift_logits = logits[..., :-1, :].contiguous()
1621 shift_labels = labels[..., 1:].contiguous()
1622 # Flatten the tokens
1623 loss_fct = CrossEntropyLoss()
1624 shift_logits = shift_logits.view(-1, self.config.vocab_size)
1625 shift_labels = shift_labels.view(-1)
1626 # Enable model parallelism
1627 shift_labels = shift_labels.to(shift_logits.device)
1628 loss = loss_fct(shift_logits, shift_labels)
1629
1630 if not return_dict:
1631 output = (logits,) + outputs[1:]
1632 return (loss,) + output if loss is not None else output
1633
1634 return CausalLMOutputWithPast(
1635 loss=loss,
1636 logits=logits,
1637 past_key_values=outputs.past_key_values,
1638 hidden_states=outputs.hidden_states,
1639 attentions=outputs.attentions,
1640 )
1641
1642 def prepare_inputs_for_generation(
1643 self,
1644 input_ids,
1645 past_key_values=None,
1646 attention_mask=None,
1647 inputs_embeds=None,
1648 **kwargs,
1649 ):
1650 if past_key_values is not None:
1651 if isinstance(past_key_values, Cache):
1652 cache_length = past_key_values.get_seq_length()
1653 past_length = past_key_values.seen_tokens
1654 max_cache_length = past_key_values.get_max_length()
1655 else:
1656 cache_length = past_length = past_key_values[0][0].shape[2]
1657 max_cache_length = None
1658
1659 # Keep only the unprocessed tokens:
1660 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1661 # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1662 # input)
1663 if (
1664 attention_mask is not None
1665 and attention_mask.shape[1] > input_ids.shape[1]
1666 ):
1667 input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1668 # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1669 # input_ids based on the past_length.
1670 elif past_length < input_ids.shape[1]:
1671 input_ids = input_ids[:, past_length:]
1672 # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1673
1674 # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1675 if (
1676 max_cache_length is not None
1677 and attention_mask is not None
1678 and cache_length + input_ids.shape[1] > max_cache_length
1679 ):
1680 attention_mask = attention_mask[:, -max_cache_length:]
1681
1682 position_ids = kwargs.get("position_ids", None)
1683 if attention_mask is not None and position_ids is None:
1684 # create position_ids on the fly for batch generation
1685 position_ids = attention_mask.long().cumsum(-1) - 1
1686 position_ids.masked_fill_(attention_mask == 0, 1)
1687 if past_key_values:
1688 position_ids = position_ids[:, -input_ids.shape[1] :]
1689
1690 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1691 if inputs_embeds is not None and past_key_values is None:
1692 model_inputs = {"inputs_embeds": inputs_embeds}
1693 else:
1694 model_inputs = {"input_ids": input_ids}
1695
1696 model_inputs.update(
1697 {
1698 "position_ids": position_ids,
1699 "past_key_values": past_key_values,
1700 "use_cache": kwargs.get("use_cache"),
1701 "attention_mask": attention_mask,
1702 }
1703 )
1704 return model_inputs
1705
1706 @staticmethod
1707 def _reorder_cache(past_key_values, beam_idx):
1708 reordered_past = ()
1709 for layer_past in past_key_values:
1710 reordered_past += (
1711 tuple(
1712 past_state.index_select(0, beam_idx.to(past_state.device))
1713 for past_state in layer_past
1714 ),
1715 )
1716 return reordered_past
1717
1718
1719 @add_start_docstrings(
1720 """
1721 The DeepseekV3 Model transformer with a sequence classification head on top (linear layer).
1722
1723 [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1724 (e.g. GPT-2) do.
1725
1726 Since it does classification on the last token, it requires to know the position of the last token. If a
1727 `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1728 no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1729 padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1730 each row of the batch).
1731 """,
1732 DeepseekV3_START_DOCSTRING,
1733 )
1734 class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
1735 def __init__(self, config):
1736 super().__init__(config)
1737 self.num_labels = config.num_labels
1738 self.model = DeepseekV3Model(config)
1739 self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1740
1741 # Initialize weights and apply final processing
1742 self.post_init()
1743
1744 def get_input_embeddings(self):
1745 return self.model.embed_tokens
1746
1747 def set_input_embeddings(self, value):
1748 self.model.embed_tokens = value
1749
1750 @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1751 def forward(
1752 self,
1753 input_ids: torch.LongTensor = None,
1754 attention_mask: Optional[torch.Tensor] = None,
1755 position_ids: Optional[torch.LongTensor] = None,
1756 past_key_values: Optional[List[torch.FloatTensor]] = None,
1757 inputs_embeds: Optional[torch.FloatTensor] = None,
1758 labels: Optional[torch.LongTensor] = None,
1759 use_cache: Optional[bool] = None,
1760 output_attentions: Optional[bool] = None,
1761 output_hidden_states: Optional[bool] = None,
1762 return_dict: Optional[bool] = None,
1763 ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1764 r"""
1765 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1766 Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,
1767 config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1768 `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1769 """
1770 return_dict = (
1771 return_dict if return_dict is not None else self.config.use_return_dict
1772 )
1773
1774 transformer_outputs = self.model(
1775 input_ids,
1776 attention_mask=attention_mask,
1777 position_ids=position_ids,
1778 past_key_values=past_key_values,
1779 inputs_embeds=inputs_embeds,
1780 use_cache=use_cache,
1781 output_attentions=output_attentions,
1782 output_hidden_states=output_hidden_states,
1783 return_dict=return_dict,
1784 )
1785 hidden_states = transformer_outputs[0]
1786 logits = self.score(hidden_states)
1787
1788 if input_ids is not None:
1789 batch_size = input_ids.shape[0]
1790 else:
1791 batch_size = inputs_embeds.shape[0]
1792
1793 if self.config.pad_token_id is None and batch_size != 1:
1794 raise ValueError(
1795 "Cannot handle batch sizes > 1 if no padding token is defined."
1796 )
1797 if self.config.pad_token_id is None:
1798 sequence_lengths = -1
1799 else:
1800 if input_ids is not None:
1801 sequence_lengths = (
1802 torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1803 ).to(logits.device)
1804 else:
1805 sequence_lengths = -1
1806
1807 pooled_logits = logits[
1808 torch.arange(batch_size, device=logits.device), sequence_lengths
1809 ]
1810
1811 loss = None
1812 if labels is not None:
1813 labels = labels.to(logits.device)
1814 if self.config.problem_type is None:
1815 if self.num_labels == 1:
1816 self.config.problem_type = "regression"
1817 elif self.num_labels > 1 and (
1818 labels.dtype == torch.long or labels.dtype == torch.int
1819 ):
1820 self.config.problem_type = "single_label_classification"
1821 else:
1822 self.config.problem_type = "multi_label_classification"
1823
1824 if self.config.problem_type == "regression":
1825 loss_fct = MSELoss()
1826 if self.num_labels == 1:
1827 loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1828 else:
1829 loss = loss_fct(pooled_logits, labels)
1830 elif self.config.problem_type == "single_label_classification":
1831 loss_fct = CrossEntropyLoss()
1832 loss = loss_fct(
1833 pooled_logits.view(-1, self.num_labels), labels.view(-1)
1834 )
1835 elif self.config.problem_type == "multi_label_classification":
1836 loss_fct = BCEWithLogitsLoss()
1837 loss = loss_fct(pooled_logits, labels)
1838 if not return_dict:
1839 output = (pooled_logits,) + transformer_outputs[1:]
1840 return ((loss,) + output) if loss is not None else output
1841
1842 return SequenceClassifierOutputWithPast(
1843 loss=loss,
1844 logits=pooled_logits,
1845 past_key_values=transformer_outputs.past_key_values,
1846 hidden_states=transformer_outputs.hidden_states,
1847 attentions=transformer_outputs.attentions,
1848 )
1849