modeling_deepseek.py
76.7 KB · 1809 lines · python Raw
1 # coding=utf-8
2 # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
3 #
4 # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 # and OPT implementations in this library. It has been modified from its
6 # original forms to accommodate minor architectural differences compared
7 # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 #
9 # Licensed under the Apache License, Version 2.0 (the "License");
10 # you may not use this file except in compliance with the License.
11 # You may obtain a copy of the License at
12 #
13 # http://www.apache.org/licenses/LICENSE-2.0
14 #
15 # Unless required by applicable law or agreed to in writing, software
16 # distributed under the License is distributed on an "AS IS" BASIS,
17 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 # See the License for the specific language governing permissions and
19 # limitations under the License.
20 """ PyTorch DeepSeek model."""
21 import math
22 import warnings
23 from typing import List, Optional, Tuple, Union
24
25 import numpy as np
26 import torch
27 import torch.distributed as dist
28 import torch.nn.functional as F
29 import torch.utils.checkpoint
30 from torch import nn
31 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
32 from transformers.activations import ACT2FN
33 from transformers.cache_utils import Cache, DynamicCache
34 from transformers.modeling_attn_mask_utils import \
35 _prepare_4d_causal_attention_mask
36 from transformers.modeling_outputs import (BaseModelOutputWithPast,
37 CausalLMOutputWithPast,
38 SequenceClassifierOutputWithPast)
39 from transformers.modeling_utils import PreTrainedModel
40 from transformers.pytorch_utils import (ALL_LAYERNORM_LAYERS,
41 is_torch_greater_or_equal_than_1_13)
42 from transformers.utils import (add_start_docstrings,
43 add_start_docstrings_to_model_forward,
44 is_flash_attn_2_available,
45 is_flash_attn_greater_or_equal_2_10, logging,
46 replace_return_docstrings)
47 from transformers.utils.import_utils import is_torch_fx_available
48
49 from .configuration_deepseek import DeepseekV3Config
50
51 if is_flash_attn_2_available():
52 from flash_attn import flash_attn_func, flash_attn_varlen_func
53 from flash_attn.bert_padding import pad_input # noqa
54 from flash_attn.bert_padding import index_first_axis, unpad_input
55
56 # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
57 # It means that the function will not be traced through and simply appear as a node in the graph.
58 if is_torch_fx_available():
59 if not is_torch_greater_or_equal_than_1_13:
60 import torch.fx
61
62 _prepare_4d_causal_attention_mask = torch.fx.wrap(
63 _prepare_4d_causal_attention_mask)
64
65 logger = logging.get_logger(__name__)
66
67 _CONFIG_FOR_DOC = "DeepseekV3Config"
68
69
70 def _get_unpad_data(attention_mask):
71 seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
72 indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
73 max_seqlen_in_batch = seqlens_in_batch.max().item()
74 cu_seqlens = F.pad(
75 torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
76 return (
77 indices,
78 cu_seqlens,
79 max_seqlen_in_batch,
80 )
81
82
83 # code modified from transformers 4.48.3 to amend breaks in newer transformers versions
84 def get_usable_length(past_key_value,
85 new_seq_length: int,
86 layer_idx: Optional[int] = 0) -> int:
87 max_length = past_key_value.get_max_cache_shape()
88 previous_seq_length = past_key_value.get_seq_length(layer_idx)
89 if max_length is not None and max_length > 0 and previous_seq_length + new_seq_length > max_length:
90 return max_length - new_seq_length
91 return previous_seq_length
92
93
94 class DeepseekV3RMSNorm(nn.Module):
95
96 def __init__(self, hidden_size, eps=1e-6):
97 """
98 DeepseekV3RMSNorm is equivalent to T5LayerNorm
99 """
100 super().__init__()
101 self.weight = nn.Parameter(torch.ones(hidden_size))
102 self.variance_epsilon = eps
103
104 def forward(self, hidden_states):
105 input_dtype = hidden_states.dtype
106 hidden_states = hidden_states.to(torch.float32)
107 variance = hidden_states.pow(2).mean(-1, keepdim=True)
108 hidden_states = hidden_states * torch.rsqrt(variance +
109 self.variance_epsilon)
110 return self.weight * hidden_states.to(input_dtype)
111
112
113 ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm)
114
115
116 class DeepseekV3RotaryEmbedding(nn.Module):
117
118 def __init__(self,
119 dim,
120 max_position_embeddings=2048,
121 base=10000,
122 device=None):
123 super().__init__()
124
125 self.dim = dim
126 self.max_position_embeddings = max_position_embeddings
127 self.base = base
128 inv_freq = 1.0 / (self.base**(
129 torch.arange(0, self.dim, 2).float().to(device) / self.dim))
130 self.register_buffer("inv_freq", inv_freq, persistent=False)
131
132 # Build here to make `torch.jit.trace` work.
133 self._set_cos_sin_cache(
134 seq_len=max_position_embeddings,
135 device=self.inv_freq.device,
136 dtype=torch.get_default_dtype(),
137 )
138 self.max_seq_len_cached = None
139
140 def _set_cos_sin_cache(self, seq_len, device, dtype):
141 self.max_seq_len_cached = seq_len
142 t = torch.arange(self.max_seq_len_cached,
143 device=device,
144 dtype=self.inv_freq.dtype)
145
146 freqs = torch.outer(t, self.inv_freq.to(t.device))
147 # Different from paper, but it uses a different permutation in order to obtain the same calculation
148 emb = torch.cat((freqs, freqs), dim=-1)
149 self.register_buffer("cos_cached",
150 emb.cos().to(dtype),
151 persistent=False)
152 self.register_buffer("sin_cached",
153 emb.sin().to(dtype),
154 persistent=False)
155
156 def forward(self, x, seq_len=None):
157 # x: [bs, num_attention_heads, seq_len, head_size]
158 if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
159 self._set_cos_sin_cache(seq_len=seq_len,
160 device=x.device,
161 dtype=x.dtype)
162
163 return (
164 self.cos_cached[:seq_len].to(dtype=x.dtype),
165 self.sin_cached[:seq_len].to(dtype=x.dtype),
166 )
167
168
169 # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3
170 class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
171 """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
172
173 def __init__(
174 self,
175 dim,
176 max_position_embeddings=2048,
177 base=10000,
178 device=None,
179 scaling_factor=1.0,
180 ):
181 self.scaling_factor = scaling_factor
182 super().__init__(dim, max_position_embeddings, base, device)
183
184 def _set_cos_sin_cache(self, seq_len, device, dtype):
185 self.max_seq_len_cached = seq_len
186 t = torch.arange(self.max_seq_len_cached,
187 device=device,
188 dtype=self.inv_freq.dtype)
189 t = t / self.scaling_factor
190
191 freqs = torch.outer(t, self.inv_freq)
192 # Different from paper, but it uses a different permutation in order to obtain the same calculation
193 emb = torch.cat((freqs, freqs), dim=-1)
194 self.register_buffer("cos_cached",
195 emb.cos().to(dtype),
196 persistent=False)
197 self.register_buffer("sin_cached",
198 emb.sin().to(dtype),
199 persistent=False)
200
201
202 # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3
203 class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):
204 """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
205
206 def __init__(
207 self,
208 dim,
209 max_position_embeddings=2048,
210 base=10000,
211 device=None,
212 scaling_factor=1.0,
213 ):
214 self.scaling_factor = scaling_factor
215 super().__init__(dim, max_position_embeddings, base, device)
216
217 def _set_cos_sin_cache(self, seq_len, device, dtype):
218 self.max_seq_len_cached = seq_len
219
220 if seq_len > self.max_position_embeddings:
221 base = self.base * ((self.scaling_factor * seq_len /
222 self.max_position_embeddings) -
223 (self.scaling_factor - 1))**(self.dim /
224 (self.dim - 2))
225 inv_freq = 1.0 / (base**(
226 torch.arange(0, self.dim, 2).float().to(device) / self.dim))
227 self.register_buffer("inv_freq", inv_freq, persistent=False)
228
229 t = torch.arange(self.max_seq_len_cached,
230 device=device,
231 dtype=self.inv_freq.dtype)
232
233 freqs = torch.outer(t, self.inv_freq)
234 # Different from paper, but it uses a different permutation in order to obtain the same calculation
235 emb = torch.cat((freqs, freqs), dim=-1)
236 self.register_buffer("cos_cached",
237 emb.cos().to(dtype),
238 persistent=False)
239 self.register_buffer("sin_cached",
240 emb.sin().to(dtype),
241 persistent=False)
242
243
244 # Inverse dim formula to find dim based on number of rotations
245 def yarn_find_correction_dim(num_rotations,
246 dim,
247 base=10000,
248 max_position_embeddings=2048):
249 return (dim * math.log(max_position_embeddings /
250 (num_rotations * 2 * math.pi))) / (2 *
251 math.log(base))
252
253
254 # Find dim range bounds based on rotations
255 def yarn_find_correction_range(low_rot,
256 high_rot,
257 dim,
258 base=10000,
259 max_position_embeddings=2048):
260 low = math.floor(
261 yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
262 high = math.ceil(
263 yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
264 return max(low, 0), min(high, dim - 1) # Clamp values just in case
265
266
267 def yarn_get_mscale(scale=1, mscale=1):
268 if scale <= 1:
269 return 1.0
270 return 0.1 * mscale * math.log(scale) + 1.0
271
272
273 def yarn_linear_ramp_mask(min, max, dim):
274 if min == max:
275 max += 0.001 # Prevent singularity
276
277 linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
278 ramp_func = torch.clamp(linear_func, 0, 1)
279 return ramp_func
280
281
282 class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):
283
284 def __init__(
285 self,
286 dim,
287 max_position_embeddings=2048,
288 base=10000,
289 device=None,
290 scaling_factor=1.0,
291 original_max_position_embeddings=4096,
292 beta_fast=32,
293 beta_slow=1,
294 mscale=1,
295 mscale_all_dim=0,
296 ):
297 self.scaling_factor = scaling_factor
298 self.original_max_position_embeddings = original_max_position_embeddings
299 self.beta_fast = beta_fast
300 self.beta_slow = beta_slow
301 self.mscale = mscale
302 self.mscale_all_dim = mscale_all_dim
303 super().__init__(dim, max_position_embeddings, base, device)
304
305 def _set_cos_sin_cache(self, seq_len, device, dtype):
306 self.max_seq_len_cached = seq_len
307 dim = self.dim
308
309 freq_extra = 1.0 / (self.base**(
310 torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
311 freq_inter = 1.0 / (self.scaling_factor * self.base**(
312 torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
313
314 low, high = yarn_find_correction_range(
315 self.beta_fast,
316 self.beta_slow,
317 dim,
318 self.base,
319 self.original_max_position_embeddings,
320 )
321 inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
322 device=device, dtype=torch.float32)
323 inv_freq = freq_inter * (1 -
324 inv_freq_mask) + freq_extra * inv_freq_mask
325 self.register_buffer("inv_freq", inv_freq, persistent=False)
326
327 t = torch.arange(seq_len, device=device, dtype=torch.float32)
328
329 freqs = torch.outer(t, inv_freq)
330
331 _mscale = float(
332 yarn_get_mscale(self.scaling_factor, self.mscale) /
333 yarn_get_mscale(self.scaling_factor, self.mscale_all_dim))
334
335 emb = torch.cat((freqs, freqs), dim=-1)
336 self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype),
337 persistent=False)
338 self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype),
339 persistent=False)
340
341
342 # Copied from transformers.models.llama.modeling_llama.rotate_half
343 def rotate_half(x):
344 """Rotates half the hidden dims of the input."""
345 x1 = x[..., :x.shape[-1] // 2]
346 x2 = x[..., x.shape[-1] // 2:]
347 return torch.cat((-x2, x1), dim=-1)
348
349
350 # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
351 def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
352 """Applies Rotary Position Embedding to the query and key tensors.
353
354 Args:
355 q (`torch.Tensor`): The query tensor.
356 k (`torch.Tensor`): The key tensor.
357 cos (`torch.Tensor`): The cosine part of the rotary embedding.
358 sin (`torch.Tensor`): The sine part of the rotary embedding.
359 position_ids (`torch.Tensor`):
360 The position indices of the tokens corresponding to the query and key tensors. For example, this can be
361 used to pass offsetted position ids when working with a KV-cache.
362 unsqueeze_dim (`int`, *optional*, defaults to 1):
363 The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
364 sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
365 that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
366 k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
367 cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
368 the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
369 Returns:
370 `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
371 """
372 cos = cos[position_ids].unsqueeze(unsqueeze_dim)
373 sin = sin[position_ids].unsqueeze(unsqueeze_dim)
374
375 b, h, s, d = q.shape
376 q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
377
378 b, h, s, d = k.shape
379 k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
380
381 q_embed = (q * cos) + (rotate_half(q) * sin)
382 k_embed = (k * cos) + (rotate_half(k) * sin)
383 return q_embed, k_embed
384
385
386 class DeepseekV3MLP(nn.Module):
387
388 def __init__(self, config, hidden_size=None, intermediate_size=None):
389 super().__init__()
390 self.config = config
391 self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
392 self.intermediate_size = (config.intermediate_size if intermediate_size
393 is None else intermediate_size)
394
395 self.gate_proj = nn.Linear(self.hidden_size,
396 self.intermediate_size,
397 bias=False)
398 self.up_proj = nn.Linear(self.hidden_size,
399 self.intermediate_size,
400 bias=False)
401 self.down_proj = nn.Linear(self.intermediate_size,
402 self.hidden_size,
403 bias=False)
404 self.act_fn = ACT2FN[config.hidden_act]
405
406 def forward(self, x):
407 down_proj = self.down_proj(
408 self.act_fn(self.gate_proj(x)) * self.up_proj(x))
409 return down_proj
410
411
412 class MoEGate(nn.Module):
413
414 def __init__(self, config):
415 super().__init__()
416 self.config = config
417 self.top_k = config.num_experts_per_tok
418 self.n_routed_experts = config.n_routed_experts
419 self.routed_scaling_factor = config.routed_scaling_factor
420 self.scoring_func = config.scoring_func
421 self.seq_aux = config.seq_aux
422 self.topk_method = config.topk_method
423 self.n_group = config.n_group
424 self.topk_group = config.topk_group
425
426 # topk selection algorithm
427 self.norm_topk_prob = config.norm_topk_prob
428 self.gating_dim = config.hidden_size
429 self.weight = nn.Parameter(
430 torch.empty((self.n_routed_experts, self.gating_dim)))
431 if self.topk_method == "noaux_tc":
432 self.e_score_correction_bias = nn.Parameter(
433 torch.empty((self.n_routed_experts)))
434 self.reset_parameters()
435
436 def reset_parameters(self) -> None:
437 import torch.nn.init as init
438
439 init.kaiming_uniform_(self.weight, a=math.sqrt(5))
440
441 def forward(self, hidden_states):
442 bsz, seq_len, h = hidden_states.shape
443 ### compute gating score
444 hidden_states = hidden_states.view(-1, h)
445 logits = F.linear(hidden_states.type(torch.float32),
446 self.weight.type(torch.float32), None)
447 if self.scoring_func == "sigmoid":
448 scores = logits.sigmoid()
449 else:
450 raise NotImplementedError(
451 f"insupportable scoring function for MoE gating: {self.scoring_func}"
452 )
453
454 ### select top-k experts
455 if self.topk_method == "noaux_tc":
456 assert not self.training
457 scores_for_choice = scores.view(
458 bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
459 group_scores = (scores_for_choice.view(
460 bsz * seq_len, self.n_group,
461 -1).topk(2, dim=-1)[0].sum(dim=-1)) # [n, n_group]
462 group_idx = torch.topk(group_scores,
463 k=self.topk_group,
464 dim=-1,
465 sorted=False)[1] # [n, top_k_group]
466 group_mask = torch.zeros_like(group_scores) # [n, n_group]
467 group_mask.scatter_(1, group_idx, 1) # [n, n_group]
468 score_mask = (group_mask.unsqueeze(-1).expand(
469 bsz * seq_len, self.n_group,
470 self.n_routed_experts // self.n_group).reshape(
471 bsz * seq_len, -1)) # [n, e]
472 tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(),
473 0.0) # [n, e]
474 _, topk_idx = torch.topk(tmp_scores,
475 k=self.top_k,
476 dim=-1,
477 sorted=False)
478 topk_weight = scores.gather(1, topk_idx)
479 else:
480 raise NotImplementedError(
481 f"insupportable TopK function for MoE gating: {self.topk_method}"
482 )
483
484 ### norm gate to sum 1
485 if self.top_k > 1 and self.norm_topk_prob:
486 denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
487 topk_weight = topk_weight / denominator
488 topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
489
490 return topk_idx, topk_weight
491
492
493 class DeepseekV3MoE(nn.Module):
494 """
495 A mixed expert module containing shared experts.
496 """
497
498 def __init__(self, config):
499 super().__init__()
500 self.config = config
501 self.num_experts_per_tok = config.num_experts_per_tok
502
503 if hasattr(config, "ep_size") and config.ep_size > 1:
504 assert config.ep_size == dist.get_world_size()
505 self.ep_size = config.ep_size
506 self.experts_per_rank = config.n_routed_experts // config.ep_size
507 self.ep_rank = dist.get_rank()
508 self.experts = nn.ModuleList([
509 (DeepseekV3MLP(config,
510 intermediate_size=config.moe_intermediate_size)
511 if i >= self.ep_rank * self.experts_per_rank
512 and i < (self.ep_rank + 1) * self.experts_per_rank else None)
513 for i in range(config.n_routed_experts)
514 ])
515 else:
516 self.ep_size = 1
517 self.experts_per_rank = config.n_routed_experts
518 self.ep_rank = 0
519 self.experts = nn.ModuleList([
520 DeepseekV3MLP(config,
521 intermediate_size=config.moe_intermediate_size)
522 for i in range(config.n_routed_experts)
523 ])
524 self.gate = MoEGate(config)
525 if config.n_shared_experts is not None:
526 intermediate_size = config.moe_intermediate_size * config.n_shared_experts
527 self.shared_experts = DeepseekV3MLP(
528 config=config, intermediate_size=intermediate_size)
529
530 def forward(self, hidden_states):
531 identity = hidden_states
532 orig_shape = hidden_states.shape
533 topk_idx, topk_weight = self.gate(hidden_states)
534 hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
535 flat_topk_idx = topk_idx.view(-1)
536 if not self.training:
537 y = self.moe_infer(hidden_states, topk_idx,
538 topk_weight).view(*orig_shape)
539 if self.config.n_shared_experts is not None:
540 y = y + self.shared_experts(identity)
541 return y
542
543 @torch.no_grad()
544 def moe_infer(self, x, topk_ids, topk_weight):
545 cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
546 cnts.scatter_(1, topk_ids, 1)
547 tokens_per_expert = cnts.sum(dim=0)
548 idxs = topk_ids.view(-1).argsort()
549 sorted_tokens = x[idxs // topk_ids.shape[1]]
550 sorted_tokens_shape = sorted_tokens.shape
551 if self.ep_size > 1:
552 tokens_per_ep_rank = tokens_per_expert.view(self.ep_size,
553 -1).sum(dim=1)
554 tokens_per_expert_group = tokens_per_expert.new_empty(
555 tokens_per_expert.shape[0])
556 dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
557 output_splits = (tokens_per_expert_group.view(
558 self.ep_size, -1).sum(1).cpu().numpy().tolist())
559 gathered_tokens = sorted_tokens.new_empty(
560 tokens_per_expert_group.sum(dim=0).cpu().item(),
561 sorted_tokens.shape[1])
562 input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
563 dist.all_to_all(
564 list(gathered_tokens.split(output_splits)),
565 list(sorted_tokens.split(input_split_sizes)),
566 )
567 tokens_per_expert_post_gather = tokens_per_expert_group.view(
568 self.ep_size, self.experts_per_rank).sum(dim=0)
569 gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0], ),
570 dtype=np.int32)
571 s = 0
572 for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
573 gatherd_idxs[s:s + k] = i % self.experts_per_rank
574 s += k
575 gatherd_idxs = gatherd_idxs.argsort()
576 sorted_tokens = gathered_tokens[gatherd_idxs]
577 tokens_per_expert = tokens_per_expert_post_gather
578 tokens_per_expert = tokens_per_expert.cpu().numpy()
579
580 outputs = []
581 start_idx = 0
582 for i, num_tokens in enumerate(tokens_per_expert):
583 end_idx = start_idx + num_tokens
584 if num_tokens == 0:
585 continue
586 expert = self.experts[i + self.ep_rank * self.experts_per_rank]
587 tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
588 expert_out = expert(tokens_for_this_expert)
589 outputs.append(expert_out)
590 start_idx = end_idx
591
592 outs = torch.cat(outputs,
593 dim=0) if len(outputs) else sorted_tokens.new_empty(0)
594 if self.ep_size > 1:
595 new_x = torch.empty_like(outs)
596 new_x[gatherd_idxs] = outs
597 gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
598 dist.all_to_all(
599 list(gathered_tokens.split(input_split_sizes)),
600 list(new_x.split(output_splits)),
601 )
602 outs = gathered_tokens
603
604 new_x = torch.empty_like(outs)
605 new_x[idxs] = outs
606 final_out = (new_x.view(
607 *topk_ids.shape, -1).type(topk_weight.dtype).mul_(
608 topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
609 return final_out
610
611
612 # Copied from transformers.models.llama.modeling_llama.repeat_kv
613 def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
614 """
615 This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
616 num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
617 """
618 batch, num_key_value_heads, slen, head_dim = hidden_states.shape
619 if n_rep == 1:
620 return hidden_states
621 hidden_states = hidden_states[:, :,
622 None, :, :].expand(batch,
623 num_key_value_heads,
624 n_rep, slen, head_dim)
625 return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
626 head_dim)
627
628
629 # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3
630 class DeepseekV3Attention(nn.Module):
631 """Multi-headed attention from 'Attention Is All You Need' paper"""
632
633 def __init__(self,
634 config: DeepseekV3Config,
635 layer_idx: Optional[int] = None):
636 super().__init__()
637 self.config = config
638 self.layer_idx = layer_idx
639 if layer_idx is None:
640 logger.warning_once(
641 f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
642 "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
643 "when creating this class.")
644
645 self.attention_dropout = config.attention_dropout
646 self.hidden_size = config.hidden_size
647 self.num_heads = config.num_attention_heads
648
649 self.max_position_embeddings = config.max_position_embeddings
650 self.rope_theta = config.rope_theta
651 self.q_lora_rank = config.q_lora_rank
652 self.qk_rope_head_dim = config.qk_rope_head_dim
653 self.kv_lora_rank = config.kv_lora_rank
654 self.v_head_dim = config.v_head_dim
655 self.qk_nope_head_dim = config.qk_nope_head_dim
656 self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
657
658 self.is_causal = True
659
660 if self.q_lora_rank is None:
661 self.q_proj = nn.Linear(self.hidden_size,
662 self.num_heads * self.q_head_dim,
663 bias=False)
664 else:
665 self.q_a_proj = nn.Linear(self.hidden_size,
666 config.q_lora_rank,
667 bias=config.attention_bias)
668 self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
669 self.q_b_proj = nn.Linear(config.q_lora_rank,
670 self.num_heads * self.q_head_dim,
671 bias=False)
672
673 self.kv_a_proj_with_mqa = nn.Linear(
674 self.hidden_size,
675 config.kv_lora_rank + config.qk_rope_head_dim,
676 bias=config.attention_bias,
677 )
678 self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
679 self.kv_b_proj = nn.Linear(
680 config.kv_lora_rank,
681 self.num_heads *
682 (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
683 bias=False,
684 )
685
686 self.o_proj = nn.Linear(
687 self.num_heads * self.v_head_dim,
688 self.hidden_size,
689 bias=config.attention_bias,
690 )
691 self._init_rope()
692
693 self.softmax_scale = self.q_head_dim**(-0.5)
694 if self.config.rope_scaling is not None:
695 mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
696 scaling_factor = self.config.rope_scaling["factor"]
697 if mscale_all_dim:
698 mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
699 self.softmax_scale = self.softmax_scale * mscale * mscale
700
701 def _init_rope(self):
702 if self.config.rope_scaling is None:
703 self.rotary_emb = DeepseekV3RotaryEmbedding(
704 self.qk_rope_head_dim,
705 max_position_embeddings=self.max_position_embeddings,
706 base=self.rope_theta,
707 )
708 else:
709 scaling_type = self.config.rope_scaling["type"]
710 scaling_factor = self.config.rope_scaling["factor"]
711 if scaling_type == "linear":
712 self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(
713 self.qk_rope_head_dim,
714 max_position_embeddings=self.max_position_embeddings,
715 scaling_factor=scaling_factor,
716 base=self.rope_theta,
717 )
718 elif scaling_type == "dynamic":
719 self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding(
720 self.qk_rope_head_dim,
721 max_position_embeddings=self.max_position_embeddings,
722 scaling_factor=scaling_factor,
723 base=self.rope_theta,
724 )
725 elif scaling_type == "yarn":
726 kwargs = {
727 key: self.config.rope_scaling[key]
728 for key in [
729 "original_max_position_embeddings",
730 "beta_fast",
731 "beta_slow",
732 "mscale",
733 "mscale_all_dim",
734 ] if key in self.config.rope_scaling
735 }
736 self.rotary_emb = DeepseekV3YarnRotaryEmbedding(
737 self.qk_rope_head_dim,
738 max_position_embeddings=self.max_position_embeddings,
739 scaling_factor=scaling_factor,
740 base=self.rope_theta,
741 **kwargs,
742 )
743 else:
744 raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
745
746 def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
747 return (tensor.view(bsz, seq_len, self.num_heads,
748 self.v_head_dim).transpose(1, 2).contiguous())
749
750 def forward(
751 self,
752 hidden_states: torch.Tensor,
753 attention_mask: Optional[torch.Tensor] = None,
754 position_ids: Optional[torch.LongTensor] = None,
755 past_key_value: Optional[Cache] = None,
756 output_attentions: bool = False,
757 use_cache: bool = False,
758 **kwargs,
759 ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
760 Optional[Tuple[torch.Tensor]]]:
761 if "padding_mask" in kwargs:
762 warnings.warn(
763 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
764 )
765 bsz, q_len, _ = hidden_states.size()
766
767 if self.q_lora_rank is None:
768 q = self.q_proj(hidden_states)
769 else:
770 q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
771 q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
772 q_nope, q_pe = torch.split(
773 q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
774
775 compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
776 compressed_kv, k_pe = torch.split(
777 compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
778 k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
779 kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(
780 bsz, q_len, self.num_heads,
781 self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))
782
783 k_nope, value_states = torch.split(
784 kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
785 kv_seq_len = value_states.shape[-2]
786 if past_key_value is not None:
787 if self.layer_idx is None:
788 raise ValueError(
789 f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
790 "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
791 "with a layer index.")
792 kv_seq_len += get_usable_length(past_key_value, kv_seq_len,
793 self.layer_idx)
794 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
795
796 q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
797
798 query_states = k_pe.new_empty(bsz, self.num_heads, q_len,
799 self.q_head_dim)
800 query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
801 query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
802
803 key_states = k_pe.new_empty(bsz, self.num_heads, q_len,
804 self.q_head_dim)
805 key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
806 key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
807 if past_key_value is not None:
808 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
809 key_states, value_states = past_key_value.update(
810 key_states, value_states, self.layer_idx, cache_kwargs)
811
812 attn_weights = (
813 torch.matmul(query_states, key_states.transpose(2, 3)) *
814 self.softmax_scale)
815
816 if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
817 raise ValueError(
818 f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
819 f" {attn_weights.size()}")
820 assert attention_mask is not None
821 if attention_mask is not None:
822 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
823 raise ValueError(
824 f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
825 )
826 attn_weights = attn_weights + attention_mask
827
828 # upcast attention to fp32
829 attn_weights = nn.functional.softmax(attn_weights,
830 dim=-1,
831 dtype=torch.float32).to(
832 query_states.dtype)
833 attn_weights = nn.functional.dropout(attn_weights,
834 p=self.attention_dropout,
835 training=self.training)
836 attn_output = torch.matmul(attn_weights, value_states)
837
838 if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
839 raise ValueError(
840 f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
841 f" {attn_output.size()}")
842
843 attn_output = attn_output.transpose(1, 2).contiguous()
844
845 attn_output = attn_output.reshape(bsz, q_len,
846 self.num_heads * self.v_head_dim)
847
848 attn_output = self.o_proj(attn_output)
849
850 if not output_attentions:
851 attn_weights = None
852
853 return attn_output, attn_weights, past_key_value
854
855
856 # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3
857 class DeepseekV3FlashAttention2(DeepseekV3Attention):
858 """
859 DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays
860 untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
861 flash attention and deal with padding tokens in case the input contains any of them.
862 """
863
864 def __init__(self, *args, **kwargs):
865 super().__init__(*args, **kwargs)
866
867 # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
868 # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
869 # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
870 self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10(
871 )
872
873 def forward(
874 self,
875 hidden_states: torch.Tensor,
876 attention_mask: Optional[torch.LongTensor] = None,
877 position_ids: Optional[torch.LongTensor] = None,
878 past_key_value: Optional[Cache] = None,
879 output_attentions: bool = False,
880 use_cache: bool = False,
881 **kwargs,
882 ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
883 Optional[Tuple[torch.Tensor]]]:
884 # DeepseekV3FlashAttention2 attention does not support output_attentions
885 if "padding_mask" in kwargs:
886 warnings.warn(
887 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
888 )
889
890 # overwrite attention_mask with padding_mask
891 attention_mask = kwargs.pop("padding_mask")
892
893 output_attentions = False
894
895 bsz, q_len, _ = hidden_states.size()
896
897 if self.q_lora_rank is None:
898 q = self.q_proj(hidden_states)
899 else:
900 q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
901 q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
902 q_nope, q_pe = torch.split(
903 q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
904
905 # Flash attention requires the input to have the shape
906 # batch_size x seq_length x head_dim x hidden_dim
907 # therefore we just need to keep the original shape
908 compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
909 compressed_kv, k_pe = torch.split(
910 compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
911 k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
912 kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(
913 bsz, q_len, self.num_heads,
914 self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))
915
916 k_nope, value_states = torch.split(
917 kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
918 kv_seq_len = value_states.shape[-2]
919
920 kv_seq_len = value_states.shape[-2]
921 if past_key_value is not None:
922 kv_seq_len += get_usable_length(past_key_value, kv_seq_len,
923 self.layer_idx)
924
925 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
926 q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
927
928 query_states = k_pe.new_empty(bsz, self.num_heads, q_len,
929 self.q_head_dim)
930 query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
931 query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
932
933 key_states = k_pe.new_empty(bsz, self.num_heads, q_len,
934 self.q_head_dim)
935 key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
936 key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
937
938 if self.q_head_dim != self.v_head_dim:
939 value_states = F.pad(value_states,
940 [0, self.q_head_dim - self.v_head_dim])
941
942 if past_key_value is not None:
943 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
944 key_states, value_states = past_key_value.update(
945 key_states, value_states, self.layer_idx, cache_kwargs)
946
947 # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
948 # to be able to avoid many of these transpose/reshape/view.
949 query_states = query_states.transpose(1, 2)
950 key_states = key_states.transpose(1, 2)
951 value_states = value_states.transpose(1, 2)
952
953 dropout_rate = self.attention_dropout if self.training else 0.0
954
955 # In PEFT, usually we cast the layer norms in float32 for training stability reasons
956 # therefore the input hidden states gets silently casted in float32. Hence, we need
957 # cast them back in the correct dtype just to be sure everything works as expected.
958 # This might slowdown training & inference so it is recommended to not cast the LayerNorms
959 # in fp32. (DeepseekV3RMSNorm handles it correctly)
960
961 input_dtype = query_states.dtype
962 if input_dtype == torch.float32:
963 # Handle the case where the model is quantized
964 if hasattr(self.config, "_pre_quantization_dtype"):
965 target_dtype = self.config._pre_quantization_dtype
966 elif torch.is_autocast_enabled():
967 target_dtype = torch.get_autocast_gpu_dtype()
968 else:
969 target_dtype = (self.q_proj.weight.dtype if self.q_lora_rank
970 is None else self.q_a_proj.weight.dtype)
971
972 logger.warning_once(
973 f"The input hidden states seems to be silently casted in float32, this might be related to"
974 f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
975 f" {target_dtype}.")
976
977 query_states = query_states.to(target_dtype)
978 key_states = key_states.to(target_dtype)
979 value_states = value_states.to(target_dtype)
980
981 attn_output = self._flash_attention_forward(
982 query_states,
983 key_states,
984 value_states,
985 attention_mask,
986 q_len,
987 dropout=dropout_rate,
988 softmax_scale=self.softmax_scale,
989 )
990 if self.q_head_dim != self.v_head_dim:
991 attn_output = attn_output[:, :, :, :self.v_head_dim]
992
993 attn_output = attn_output.reshape(bsz, q_len, self.num_heads *
994 self.v_head_dim).contiguous()
995 attn_output = self.o_proj(attn_output)
996
997 if not output_attentions:
998 attn_weights = None
999
1000 return attn_output, attn_weights, past_key_value
1001
1002 def _flash_attention_forward(
1003 self,
1004 query_states,
1005 key_states,
1006 value_states,
1007 attention_mask,
1008 query_length,
1009 dropout=0.0,
1010 softmax_scale=None,
1011 ):
1012 """
1013 Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1014 first unpad the input, then computes the attention scores and pad the final attention scores.
1015
1016 Args:
1017 query_states (`torch.Tensor`):
1018 Input query states to be passed to Flash Attention API
1019 key_states (`torch.Tensor`):
1020 Input key states to be passed to Flash Attention API
1021 value_states (`torch.Tensor`):
1022 Input value states to be passed to Flash Attention API
1023 attention_mask (`torch.Tensor`):
1024 The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
1025 position of padding tokens and 1 for the position of non-padding tokens.
1026 dropout (`int`, *optional*):
1027 Attention dropout
1028 softmax_scale (`float`, *optional*):
1029 The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1030 """
1031 if not self._flash_attn_uses_top_left_mask:
1032 causal = self.is_causal
1033 else:
1034 # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__.
1035 causal = self.is_causal and query_length != 1
1036
1037 # Contains at least one padding token in the sequence
1038 if attention_mask is not None:
1039 batch_size = query_states.shape[0]
1040 (
1041 query_states,
1042 key_states,
1043 value_states,
1044 indices_q,
1045 cu_seq_lens,
1046 max_seq_lens,
1047 ) = self._upad_input(query_states, key_states, value_states,
1048 attention_mask, query_length)
1049
1050 cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1051 max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1052
1053 attn_output_unpad = flash_attn_varlen_func(
1054 query_states,
1055 key_states,
1056 value_states,
1057 cu_seqlens_q=cu_seqlens_q,
1058 cu_seqlens_k=cu_seqlens_k,
1059 max_seqlen_q=max_seqlen_in_batch_q,
1060 max_seqlen_k=max_seqlen_in_batch_k,
1061 dropout_p=dropout,
1062 softmax_scale=softmax_scale,
1063 causal=causal,
1064 )
1065
1066 attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
1067 query_length)
1068 else:
1069 attn_output = flash_attn_func(
1070 query_states,
1071 key_states,
1072 value_states,
1073 dropout,
1074 softmax_scale=softmax_scale,
1075 causal=causal,
1076 )
1077
1078 return attn_output
1079
1080 def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
1081 query_length):
1082 indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
1083 attention_mask)
1084 batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1085
1086 key_layer = index_first_axis(
1087 key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
1088 head_dim),
1089 indices_k,
1090 )
1091 value_layer = index_first_axis(
1092 value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
1093 head_dim),
1094 indices_k,
1095 )
1096 if query_length == kv_seq_len:
1097 query_layer = index_first_axis(
1098 query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
1099 head_dim),
1100 indices_k,
1101 )
1102 cu_seqlens_q = cu_seqlens_k
1103 max_seqlen_in_batch_q = max_seqlen_in_batch_k
1104 indices_q = indices_k
1105 elif query_length == 1:
1106 max_seqlen_in_batch_q = 1
1107 cu_seqlens_q = torch.arange(
1108 batch_size + 1, dtype=torch.int32, device=query_layer.device
1109 ) # There is a memcpy here, that is very bad.
1110 indices_q = cu_seqlens_q[:-1]
1111 query_layer = query_layer.squeeze(1)
1112 else:
1113 # The -q_len: slice assumes left padding.
1114 attention_mask = attention_mask[:, -query_length:]
1115 query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
1116 query_layer, attention_mask)
1117
1118 return (
1119 query_layer,
1120 key_layer,
1121 value_layer,
1122 indices_q,
1123 (cu_seqlens_q, cu_seqlens_k),
1124 (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1125 )
1126
1127
1128 ATTENTION_CLASSES = {
1129 "eager": DeepseekV3Attention,
1130 "flash_attention_2": DeepseekV3FlashAttention2,
1131 }
1132
1133
1134 class DeepseekV3DecoderLayer(nn.Module):
1135
1136 def __init__(self, config: DeepseekV3Config, layer_idx: int):
1137 super().__init__()
1138 self.hidden_size = config.hidden_size
1139
1140 self.self_attn = ATTENTION_CLASSES[config._attn_implementation](
1141 config=config, layer_idx=layer_idx)
1142
1143 self.mlp = (DeepseekV3MoE(config) if
1144 (config.n_routed_experts is not None
1145 and layer_idx >= config.first_k_dense_replace
1146 and layer_idx % config.moe_layer_freq == 0) else
1147 DeepseekV3MLP(config))
1148 self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size,
1149 eps=config.rms_norm_eps)
1150 self.post_attention_layernorm = DeepseekV3RMSNorm(
1151 config.hidden_size, eps=config.rms_norm_eps)
1152
1153 def forward(
1154 self,
1155 hidden_states: torch.Tensor,
1156 attention_mask: Optional[torch.Tensor] = None,
1157 position_ids: Optional[torch.LongTensor] = None,
1158 past_key_value: Optional[Tuple[torch.Tensor]] = None,
1159 output_attentions: Optional[bool] = False,
1160 use_cache: Optional[bool] = False,
1161 **kwargs,
1162 ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
1163 torch.FloatTensor]]]:
1164 """
1165 Args:
1166 hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1167 attention_mask (`torch.FloatTensor`, *optional*):
1168 attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1169 query_sequence_length, key_sequence_length)` if default attention is used.
1170 output_attentions (`bool`, *optional*):
1171 Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1172 returned tensors for more detail.
1173 use_cache (`bool`, *optional*):
1174 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1175 (see `past_key_values`).
1176 past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1177 """
1178 if "padding_mask" in kwargs:
1179 warnings.warn(
1180 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1181 )
1182 residual = hidden_states
1183
1184 hidden_states = self.input_layernorm(hidden_states)
1185
1186 # Self Attention
1187 hidden_states, self_attn_weights, present_key_value = self.self_attn(
1188 hidden_states=hidden_states,
1189 attention_mask=attention_mask,
1190 position_ids=position_ids,
1191 past_key_value=past_key_value,
1192 output_attentions=output_attentions,
1193 use_cache=use_cache,
1194 **kwargs,
1195 )
1196 hidden_states = residual + hidden_states
1197
1198 # Fully Connected
1199 residual = hidden_states
1200 hidden_states = self.post_attention_layernorm(hidden_states)
1201 hidden_states = self.mlp(hidden_states)
1202 hidden_states = residual + hidden_states
1203
1204 outputs = (hidden_states, )
1205
1206 if output_attentions:
1207 outputs += (self_attn_weights, )
1208
1209 if use_cache:
1210 outputs += (present_key_value, )
1211
1212 return outputs
1213
1214
1215 DeepseekV3_START_DOCSTRING = r"""
1216 This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1217 library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1218 etc.)
1219
1220 This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1221 Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1222 and behavior.
1223
1224 Parameters:
1225 config ([`DeepseekV3Config`]):
1226 Model configuration class with all the parameters of the model. Initializing with a config file does not
1227 load the weights associated with the model, only the configuration. Check out the
1228 [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1229 """
1230
1231
1232 @add_start_docstrings(
1233 "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1234 DeepseekV3_START_DOCSTRING,
1235 )
1236 class DeepseekV3PreTrainedModel(PreTrainedModel):
1237 config_class = DeepseekV3Config
1238 base_model_prefix = "model"
1239 supports_gradient_checkpointing = True
1240 _no_split_modules = ["DeepseekV3DecoderLayer"]
1241 _skip_keys_device_placement = "past_key_values"
1242 _supports_flash_attn_2 = True
1243 _supports_cache_class = True
1244
1245 def _init_weights(self, module):
1246 std = self.config.initializer_range
1247 if isinstance(module, nn.Linear):
1248 module.weight.data.normal_(mean=0.0, std=std)
1249 if module.bias is not None:
1250 module.bias.data.zero_()
1251 elif isinstance(module, nn.Embedding):
1252 module.weight.data.normal_(mean=0.0, std=std)
1253 if module.padding_idx is not None:
1254 module.weight.data[module.padding_idx].zero_()
1255
1256
1257 DeepseekV3_INPUTS_DOCSTRING = r"""
1258 Args:
1259 input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1260 Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1261 it.
1262
1263 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1264 [`PreTrainedTokenizer.__call__`] for details.
1265
1266 [What are input IDs?](../glossary#input-ids)
1267 attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1268 Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1269
1270 - 1 for tokens that are **not masked**,
1271 - 0 for tokens that are **masked**.
1272
1273 [What are attention masks?](../glossary#attention-mask)
1274
1275 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1276 [`PreTrainedTokenizer.__call__`] for details.
1277
1278 If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1279 `past_key_values`).
1280
1281 If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1282 and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1283 information on the default strategy.
1284
1285 - 1 indicates the head is **not masked**,
1286 - 0 indicates the head is **masked**.
1287 position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1288 Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1289 config.n_positions - 1]`.
1290
1291 [What are position IDs?](../glossary#position-ids)
1292 past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1293 Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1294 blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1295 returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1296
1297 Two formats are allowed:
1298 - a [`~cache_utils.Cache`] instance;
1299 - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1300 shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1301 cache format.
1302
1303 The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1304 legacy cache format will be returned.
1305
1306 If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1307 have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1308 of shape `(batch_size, sequence_length)`.
1309 inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1310 Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1311 is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1312 model's internal embedding lookup matrix.
1313 use_cache (`bool`, *optional*):
1314 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1315 `past_key_values`).
1316 output_attentions (`bool`, *optional*):
1317 Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1318 tensors for more detail.
1319 output_hidden_states (`bool`, *optional*):
1320 Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1321 more detail.
1322 return_dict (`bool`, *optional*):
1323 Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1324 """
1325
1326
1327 @add_start_docstrings(
1328 "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
1329 DeepseekV3_START_DOCSTRING,
1330 )
1331 class DeepseekV3Model(DeepseekV3PreTrainedModel):
1332 """
1333 Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]
1334
1335 Args:
1336 config: DeepseekV3Config
1337 """
1338
1339 def __init__(self, config: DeepseekV3Config):
1340 super().__init__(config)
1341 self.padding_idx = config.pad_token_id
1342 self.vocab_size = config.vocab_size
1343
1344 self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
1345 self.padding_idx)
1346 self.layers = nn.ModuleList([
1347 DeepseekV3DecoderLayer(config, layer_idx)
1348 for layer_idx in range(config.num_hidden_layers)
1349 ])
1350 self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1351 self.norm = DeepseekV3RMSNorm(config.hidden_size,
1352 eps=config.rms_norm_eps)
1353
1354 self.gradient_checkpointing = False
1355 # Initialize weights and apply final processing
1356 self.post_init()
1357
1358 def get_input_embeddings(self):
1359 return self.embed_tokens
1360
1361 def set_input_embeddings(self, value):
1362 self.embed_tokens = value
1363
1364 @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1365 def forward(
1366 self,
1367 input_ids: torch.LongTensor = None,
1368 attention_mask: Optional[torch.Tensor] = None,
1369 position_ids: Optional[torch.LongTensor] = None,
1370 past_key_values: Optional[List[torch.FloatTensor]] = None,
1371 inputs_embeds: Optional[torch.FloatTensor] = None,
1372 use_cache: Optional[bool] = None,
1373 output_attentions: Optional[bool] = None,
1374 output_hidden_states: Optional[bool] = None,
1375 return_dict: Optional[bool] = None,
1376 ) -> Union[Tuple, BaseModelOutputWithPast]:
1377 output_attentions = (output_attentions if output_attentions is not None
1378 else self.config.output_attentions)
1379 output_hidden_states = (output_hidden_states
1380 if output_hidden_states is not None else
1381 self.config.output_hidden_states)
1382 use_cache = use_cache if use_cache is not None else self.config.use_cache
1383
1384 return_dict = (return_dict if return_dict is not None else
1385 self.config.use_return_dict)
1386
1387 # retrieve input_ids and inputs_embeds
1388 if input_ids is not None and inputs_embeds is not None:
1389 raise ValueError(
1390 "You cannot specify both input_ids and inputs_embeds at the same time"
1391 )
1392 elif input_ids is not None:
1393 batch_size, seq_length = input_ids.shape[:2]
1394 elif inputs_embeds is not None:
1395 batch_size, seq_length = inputs_embeds.shape[:2]
1396 else:
1397 raise ValueError(
1398 "You have to specify either input_ids or inputs_embeds")
1399
1400 past_key_values_length = 0
1401 if use_cache:
1402 use_legacy_cache = not isinstance(past_key_values, Cache)
1403 if use_legacy_cache:
1404 past_key_values = DynamicCache.from_legacy_cache(
1405 past_key_values)
1406 past_key_values_length = get_usable_length(past_key_values,
1407 seq_length)
1408
1409 if position_ids is None:
1410 device = input_ids.device if input_ids is not None else inputs_embeds.device
1411 position_ids = torch.arange(
1412 past_key_values_length,
1413 seq_length + past_key_values_length,
1414 dtype=torch.long,
1415 device=device,
1416 )
1417 position_ids = position_ids.unsqueeze(0)
1418
1419 if inputs_embeds is None:
1420 inputs_embeds = self.embed_tokens(input_ids)
1421
1422 if self._use_flash_attention_2:
1423 # 2d mask is passed through the layers
1424 attention_mask = (attention_mask if
1425 (attention_mask is not None
1426 and 0 in attention_mask) else None)
1427 else:
1428 # 4d mask is passed through the layers
1429 attention_mask = _prepare_4d_causal_attention_mask(
1430 attention_mask,
1431 (batch_size, seq_length),
1432 inputs_embeds,
1433 past_key_values_length,
1434 )
1435
1436 # embed positions
1437 hidden_states = inputs_embeds
1438
1439 # decoder layers
1440 all_hidden_states = () if output_hidden_states else None
1441 all_self_attns = () if output_attentions else None
1442 next_decoder_cache = None
1443
1444 for decoder_layer in self.layers:
1445 if output_hidden_states:
1446 all_hidden_states += (hidden_states, )
1447
1448 layer_outputs = decoder_layer(
1449 hidden_states,
1450 attention_mask=attention_mask,
1451 position_ids=position_ids,
1452 past_key_value=past_key_values,
1453 output_attentions=output_attentions,
1454 use_cache=use_cache,
1455 )
1456
1457 hidden_states = layer_outputs[0]
1458
1459 if use_cache:
1460 next_decoder_cache = layer_outputs[
1461 2 if output_attentions else 1]
1462
1463 if output_attentions:
1464 all_self_attns += (layer_outputs[1], )
1465
1466 hidden_states = self.norm(hidden_states)
1467
1468 # add hidden states from the last decoder layer
1469 if output_hidden_states:
1470 all_hidden_states += (hidden_states, )
1471
1472 next_cache = None
1473 if use_cache:
1474 next_cache = (next_decoder_cache.to_legacy_cache()
1475 if use_legacy_cache else next_decoder_cache)
1476 if not return_dict:
1477 return tuple(
1478 v for v in
1479 [hidden_states, next_cache, all_hidden_states, all_self_attns]
1480 if v is not None)
1481 return BaseModelOutputWithPast(
1482 last_hidden_state=hidden_states,
1483 past_key_values=next_cache,
1484 hidden_states=all_hidden_states,
1485 attentions=all_self_attns,
1486 )
1487
1488
1489 class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
1490 _tied_weights_keys = ["lm_head.weight"]
1491
1492 def __init__(self, config):
1493 super().__init__(config)
1494 self.model = DeepseekV3Model(config)
1495 self.vocab_size = config.vocab_size
1496 self.lm_head = nn.Linear(config.hidden_size,
1497 config.vocab_size,
1498 bias=False)
1499
1500 # Initialize weights and apply final processing
1501 self.post_init()
1502
1503 def get_input_embeddings(self):
1504 return self.model.embed_tokens
1505
1506 def set_input_embeddings(self, value):
1507 self.model.embed_tokens = value
1508
1509 def get_output_embeddings(self):
1510 return self.lm_head
1511
1512 def set_output_embeddings(self, new_embeddings):
1513 self.lm_head = new_embeddings
1514
1515 def set_decoder(self, decoder):
1516 self.model = decoder
1517
1518 def get_decoder(self):
1519 return self.model
1520
1521 @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1522 @replace_return_docstrings(output_type=CausalLMOutputWithPast,
1523 config_class=_CONFIG_FOR_DOC)
1524 def forward(
1525 self,
1526 input_ids: torch.LongTensor = None,
1527 attention_mask: Optional[torch.Tensor] = None,
1528 position_ids: Optional[torch.LongTensor] = None,
1529 past_key_values: Optional[List[torch.FloatTensor]] = None,
1530 inputs_embeds: Optional[torch.FloatTensor] = None,
1531 labels: Optional[torch.LongTensor] = None,
1532 use_cache: Optional[bool] = None,
1533 output_attentions: Optional[bool] = None,
1534 output_hidden_states: Optional[bool] = None,
1535 return_dict: Optional[bool] = None,
1536 ) -> Union[Tuple, CausalLMOutputWithPast]:
1537 r"""
1538 Args:
1539 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1540 Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
1541 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1542 (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
1543
1544 Returns:
1545
1546 Example:
1547
1548 ```python
1549 >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
1550
1551 >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1552 >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1553
1554 >>> prompt = "Hey, are you conscious? Can you talk to me?"
1555 >>> inputs = tokenizer(prompt, return_tensors="pt")
1556
1557 >>> # Generate
1558 >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1559 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1560 "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1561 ```"""
1562 output_attentions = (output_attentions if output_attentions is not None
1563 else self.config.output_attentions)
1564 output_hidden_states = (output_hidden_states
1565 if output_hidden_states is not None else
1566 self.config.output_hidden_states)
1567 return_dict = (return_dict if return_dict is not None else
1568 self.config.use_return_dict)
1569
1570 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1571 outputs = self.model(
1572 input_ids=input_ids,
1573 attention_mask=attention_mask,
1574 position_ids=position_ids,
1575 past_key_values=past_key_values,
1576 inputs_embeds=inputs_embeds,
1577 use_cache=use_cache,
1578 output_attentions=output_attentions,
1579 output_hidden_states=output_hidden_states,
1580 return_dict=return_dict,
1581 )
1582
1583 hidden_states = outputs[0]
1584 logits = self.lm_head(hidden_states)
1585 logits = logits.float()
1586
1587 loss = None
1588 if labels is not None:
1589 # Shift so that tokens < n predict n
1590 shift_logits = logits[..., :-1, :].contiguous()
1591 shift_labels = labels[..., 1:].contiguous()
1592 # Flatten the tokens
1593 loss_fct = CrossEntropyLoss()
1594 shift_logits = shift_logits.view(-1, self.config.vocab_size)
1595 shift_labels = shift_labels.view(-1)
1596 # Enable model parallelism
1597 shift_labels = shift_labels.to(shift_logits.device)
1598 loss = loss_fct(shift_logits, shift_labels)
1599
1600 if not return_dict:
1601 output = (logits, ) + outputs[1:]
1602 return (loss, ) + output if loss is not None else output
1603
1604 return CausalLMOutputWithPast(
1605 loss=loss,
1606 logits=logits,
1607 past_key_values=outputs.past_key_values,
1608 hidden_states=outputs.hidden_states,
1609 attentions=outputs.attentions,
1610 )
1611
1612 def prepare_inputs_for_generation(
1613 self,
1614 input_ids,
1615 past_key_values=None,
1616 attention_mask=None,
1617 inputs_embeds=None,
1618 **kwargs,
1619 ):
1620 if past_key_values is not None:
1621 if isinstance(past_key_values, Cache):
1622 cache_length = past_key_values.get_seq_length()
1623 # seen_tokens 可能在某些 transformers 版本中不存在,使用 getattr 安全访问
1624 past_length = getattr(past_key_values, 'seen_tokens',
1625 cache_length)
1626 max_cache_length = past_key_values.get_max_length()
1627 else:
1628 cache_length = past_length = past_key_values[0][0].shape[2]
1629 max_cache_length = None
1630
1631 # Keep only the unprocessed tokens:
1632 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1633 # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1634 # input)
1635 if (attention_mask is not None
1636 and attention_mask.shape[1] > input_ids.shape[1]):
1637 input_ids = input_ids[:, -(attention_mask.shape[1] -
1638 past_length):]
1639 # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1640 # input_ids based on the past_length.
1641 elif past_length < input_ids.shape[1]:
1642 input_ids = input_ids[:, past_length:]
1643 # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1644
1645 # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1646 if (max_cache_length is not None and attention_mask is not None
1647 and cache_length + input_ids.shape[1] > max_cache_length):
1648 attention_mask = attention_mask[:, -max_cache_length:]
1649
1650 position_ids = kwargs.get("position_ids", None)
1651 if attention_mask is not None and position_ids is None:
1652 # create position_ids on the fly for batch generation
1653 position_ids = attention_mask.long().cumsum(-1) - 1
1654 position_ids.masked_fill_(attention_mask == 0, 1)
1655 if past_key_values:
1656 position_ids = position_ids[:, -input_ids.shape[1]:]
1657
1658 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1659 if inputs_embeds is not None and past_key_values is None:
1660 model_inputs = {"inputs_embeds": inputs_embeds}
1661 else:
1662 model_inputs = {"input_ids": input_ids}
1663
1664 model_inputs.update({
1665 "position_ids": position_ids,
1666 "past_key_values": past_key_values,
1667 "use_cache": kwargs.get("use_cache"),
1668 "attention_mask": attention_mask,
1669 })
1670 return model_inputs
1671
1672 @staticmethod
1673 def _reorder_cache(past_key_values, beam_idx):
1674 reordered_past = ()
1675 for layer_past in past_key_values:
1676 reordered_past += (tuple(
1677 past_state.index_select(0, beam_idx.to(past_state.device))
1678 for past_state in layer_past), )
1679 return reordered_past
1680
1681
1682 @add_start_docstrings(
1683 """
1684 The DeepseekV3 Model transformer with a sequence classification head on top (linear layer).
1685
1686 [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1687 (e.g. GPT-2) do.
1688
1689 Since it does classification on the last token, it requires to know the position of the last token. If a
1690 `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1691 no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1692 padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1693 each row of the batch).
1694 """,
1695 DeepseekV3_START_DOCSTRING,
1696 )
1697 class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
1698
1699 def __init__(self, config):
1700 super().__init__(config)
1701 self.num_labels = config.num_labels
1702 self.model = DeepseekV3Model(config)
1703 self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1704
1705 # Initialize weights and apply final processing
1706 self.post_init()
1707
1708 def get_input_embeddings(self):
1709 return self.model.embed_tokens
1710
1711 def set_input_embeddings(self, value):
1712 self.model.embed_tokens = value
1713
1714 @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)
1715 def forward(
1716 self,
1717 input_ids: torch.LongTensor = None,
1718 attention_mask: Optional[torch.Tensor] = None,
1719 position_ids: Optional[torch.LongTensor] = None,
1720 past_key_values: Optional[List[torch.FloatTensor]] = None,
1721 inputs_embeds: Optional[torch.FloatTensor] = None,
1722 labels: Optional[torch.LongTensor] = None,
1723 use_cache: Optional[bool] = None,
1724 output_attentions: Optional[bool] = None,
1725 output_hidden_states: Optional[bool] = None,
1726 return_dict: Optional[bool] = None,
1727 ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1728 r"""
1729 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1730 Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,
1731 config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1732 `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1733 """
1734 return_dict = (return_dict if return_dict is not None else
1735 self.config.use_return_dict)
1736
1737 transformer_outputs = self.model(
1738 input_ids,
1739 attention_mask=attention_mask,
1740 position_ids=position_ids,
1741 past_key_values=past_key_values,
1742 inputs_embeds=inputs_embeds,
1743 use_cache=use_cache,
1744 output_attentions=output_attentions,
1745 output_hidden_states=output_hidden_states,
1746 return_dict=return_dict,
1747 )
1748 hidden_states = transformer_outputs[0]
1749 logits = self.score(hidden_states)
1750
1751 if input_ids is not None:
1752 batch_size = input_ids.shape[0]
1753 else:
1754 batch_size = inputs_embeds.shape[0]
1755
1756 if self.config.pad_token_id is None and batch_size != 1:
1757 raise ValueError(
1758 "Cannot handle batch sizes > 1 if no padding token is defined."
1759 )
1760 if self.config.pad_token_id is None:
1761 sequence_lengths = -1
1762 else:
1763 if input_ids is not None:
1764 sequence_lengths = (torch.eq(
1765 input_ids, self.config.pad_token_id).int().argmax(-1) -
1766 1).to(logits.device)
1767 else:
1768 sequence_lengths = -1
1769
1770 pooled_logits = logits[torch.arange(batch_size, device=logits.device),
1771 sequence_lengths]
1772
1773 loss = None
1774 if labels is not None:
1775 labels = labels.to(logits.device)
1776 if self.config.problem_type is None:
1777 if self.num_labels == 1:
1778 self.config.problem_type = "regression"
1779 elif self.num_labels > 1 and (labels.dtype == torch.long
1780 or labels.dtype == torch.int):
1781 self.config.problem_type = "single_label_classification"
1782 else:
1783 self.config.problem_type = "multi_label_classification"
1784
1785 if self.config.problem_type == "regression":
1786 loss_fct = MSELoss()
1787 if self.num_labels == 1:
1788 loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1789 else:
1790 loss = loss_fct(pooled_logits, labels)
1791 elif self.config.problem_type == "single_label_classification":
1792 loss_fct = CrossEntropyLoss()
1793 loss = loss_fct(pooled_logits.view(-1, self.num_labels),
1794 labels.view(-1))
1795 elif self.config.problem_type == "multi_label_classification":
1796 loss_fct = BCEWithLogitsLoss()
1797 loss = loss_fct(pooled_logits, labels)
1798 if not return_dict:
1799 output = (pooled_logits, ) + transformer_outputs[1:]
1800 return ((loss, ) + output) if loss is not None else output
1801
1802 return SequenceClassifierOutputWithPast(
1803 loss=loss,
1804 logits=pooled_logits,
1805 past_key_values=transformer_outputs.past_key_values,
1806 hidden_states=transformer_outputs.hidden_states,
1807 attentions=transformer_outputs.attentions,
1808 )
1809