modeling_deepseekv2.py
80.3 KB · 1993 lines · python Raw
1 # coding=utf-8
2 # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
3 #
4 # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 # and OPT implementations in this library. It has been modified from its
6 # original forms to accommodate minor architectural differences compared
7 # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 #
9 # Licensed under the Apache License, Version 2.0 (the "License");
10 # you may not use this file except in compliance with the License.
11 # You may obtain a copy of the License at
12 #
13 # http://www.apache.org/licenses/LICENSE-2.0
14 #
15 # Unless required by applicable law or agreed to in writing, software
16 # distributed under the License is distributed on an "AS IS" BASIS,
17 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 # See the License for the specific language governing permissions and
19 # limitations under the License.
20 """ PyTorch DeepSeek model and compatible with both DeepSeekV2 and DeepSeekV3"""
21 import math
22 import warnings
23 from typing import List, Optional, Tuple, Union
24 import numpy as np
25
26 import torch
27 import torch.nn.functional as F
28 import torch.utils.checkpoint
29 import torch.distributed as dist
30 from einops import repeat
31 from torch import nn
32 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
34 from transformers.activations import ACT2FN
35 from transformers.cache_utils import Cache, DynamicCache
36 from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
37 from transformers.models.llama.modeling_llama import (
38 LlamaAttention,
39 LlamaFlashAttention2
40 )
41 from transformers.modeling_outputs import (
42 BaseModelOutputWithPast,
43 CausalLMOutputWithPast,
44 SequenceClassifierOutputWithPast,
45 )
46 from transformers.modeling_utils import PreTrainedModel
47 from transformers.pytorch_utils import (
48 ALL_LAYERNORM_LAYERS,
49 is_torch_greater_or_equal_than_1_13,
50 )
51 from transformers.utils import (
52 add_start_docstrings,
53 add_start_docstrings_to_model_forward,
54 is_flash_attn_2_available,
55 is_flash_attn_greater_or_equal_2_10,
56 logging,
57 replace_return_docstrings,
58 )
59 from transformers.utils.import_utils import is_torch_fx_available
60
61 from .configuration_deepseek_v2 import DeepseekV2Config
62
63 if is_flash_attn_2_available():
64 from flash_attn import flash_attn_func, flash_attn_varlen_func
65 from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
66
67 # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
68 # It means that the function will not be traced through and simply appear as a node in the graph.
69 if is_torch_fx_available():
70 if not is_torch_greater_or_equal_than_1_13:
71 import torch.fx
72
73 _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
74
75 logger = logging.get_logger(__name__)
76
77 _CONFIG_FOR_DOC = "DeepseekV2Config"
78
79
80 def _get_unpad_data(attention_mask):
81 seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
82 indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
83 max_seqlen_in_batch = seqlens_in_batch.max().item()
84 cu_seqlens = F.pad(
85 torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
86 )
87 return (
88 indices,
89 cu_seqlens,
90 max_seqlen_in_batch,
91 )
92
93
94 class DeepseekV2RMSNorm(nn.Module):
95 def __init__(self, hidden_size, eps=1e-6):
96 """
97 DeepseekV2RMSNorm is equivalent to T5LayerNorm
98 """
99 super().__init__()
100 self.weight = nn.Parameter(torch.ones(hidden_size))
101 self.variance_epsilon = eps
102
103 def forward(self, hidden_states):
104 input_dtype = hidden_states.dtype
105 hidden_states = hidden_states.to(torch.float32)
106 variance = hidden_states.pow(2).mean(-1, keepdim=True)
107 hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
108 return self.weight * hidden_states.to(input_dtype)
109
110
111 ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm)
112
113
114
115
116 class DeepseekV2RotaryEmbedding(nn.Module):
117 def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
118 super().__init__()
119
120 self.dim = dim
121 self.max_position_embeddings = max_position_embeddings
122 self.base = base
123 inv_freq = 1.0 / (
124 self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
125 )
126 self.register_buffer("inv_freq", inv_freq, persistent=False)
127
128 # Build here to make `torch.jit.trace` work.
129 self._set_cos_sin_cache(
130 seq_len=max_position_embeddings,
131 device=self.inv_freq.device,
132 dtype=torch.get_default_dtype(),
133 )
134 self.max_seq_len_cached = None
135
136 def _set_cos_sin_cache(self, seq_len, device, dtype):
137 self.max_seq_len_cached = seq_len
138 t = torch.arange(
139 self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
140 )
141
142 freqs = torch.outer(t, self.inv_freq.to(t.device))
143 # Different from paper, but it uses a different permutation in order to obtain the same calculation
144 emb = torch.cat((freqs, freqs), dim=-1)
145 self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
146 self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
147
148 def forward(self, x, seq_len=None):
149 # x: [bs, num_attention_heads, seq_len, head_size]
150 if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
151 self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
152
153 return (
154 self.cos_cached[:seq_len].to(dtype=x.dtype),
155 self.sin_cached[:seq_len].to(dtype=x.dtype),
156 )
157
158
159 # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2
160 class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
161 """DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
162
163 def __init__(
164 self,
165 dim,
166 max_position_embeddings=2048,
167 base=10000,
168 device=None,
169 scaling_factor=1.0,
170 ):
171 self.scaling_factor = scaling_factor
172 super().__init__(dim, max_position_embeddings, base, device)
173
174 def _set_cos_sin_cache(self, seq_len, device, dtype):
175 self.max_seq_len_cached = seq_len
176 t = torch.arange(
177 self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
178 )
179 t = t / self.scaling_factor
180
181 freqs = torch.outer(t, self.inv_freq)
182 # Different from paper, but it uses a different permutation in order to obtain the same calculation
183 emb = torch.cat((freqs, freqs), dim=-1)
184 self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
185 self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
186
187
188 # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2
189 class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
190 """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
191
192 def __init__(
193 self,
194 dim,
195 max_position_embeddings=2048,
196 base=10000,
197 device=None,
198 scaling_factor=1.0,
199 ):
200 self.scaling_factor = scaling_factor
201 super().__init__(dim, max_position_embeddings, base, device)
202
203 def _set_cos_sin_cache(self, seq_len, device, dtype):
204 self.max_seq_len_cached = seq_len
205
206 if seq_len > self.max_position_embeddings:
207 base = self.base * (
208 (self.scaling_factor * seq_len / self.max_position_embeddings)
209 - (self.scaling_factor - 1)
210 ) ** (self.dim / (self.dim - 2))
211 inv_freq = 1.0 / (
212 base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
213 )
214 self.register_buffer("inv_freq", inv_freq, persistent=False)
215
216 t = torch.arange(
217 self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
218 )
219
220 freqs = torch.outer(t, self.inv_freq)
221 # Different from paper, but it uses a different permutation in order to obtain the same calculation
222 emb = torch.cat((freqs, freqs), dim=-1)
223 self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
224 self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
225
226
227 # Inverse dim formula to find dim based on number of rotations
228 def yarn_find_correction_dim(
229 num_rotations, dim, base=10000, max_position_embeddings=2048
230 ):
231 return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
232 2 * math.log(base)
233 )
234
235
236 # Find dim range bounds based on rotations
237 def yarn_find_correction_range(
238 low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
239 ):
240 low = math.floor(
241 yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
242 )
243 high = math.ceil(
244 yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
245 )
246 return max(low, 0), min(high, dim - 1) # Clamp values just in case
247
248
249 def yarn_get_mscale(scale=1, mscale=1):
250 if scale <= 1:
251 return 1.0
252 return 0.1 * mscale * math.log(scale) + 1.0
253
254
255 def yarn_linear_ramp_mask(min, max, dim):
256 if min == max:
257 max += 0.001 # Prevent singularity
258
259 linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
260 ramp_func = torch.clamp(linear_func, 0, 1)
261 return ramp_func
262
263
264 class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding):
265
266 def __init__(
267 self,
268 dim,
269 max_position_embeddings=2048,
270 base=10000,
271 device=None,
272 scaling_factor=1.0,
273 original_max_position_embeddings=4096,
274 beta_fast=32,
275 beta_slow=1,
276 mscale=1,
277 mscale_all_dim=0,
278 ):
279 self.scaling_factor = scaling_factor
280 self.original_max_position_embeddings = original_max_position_embeddings
281 self.beta_fast = beta_fast
282 self.beta_slow = beta_slow
283 self.mscale = mscale
284 self.mscale_all_dim = mscale_all_dim
285 super().__init__(dim, max_position_embeddings, base, device)
286
287 def _set_cos_sin_cache(self, seq_len, device, dtype):
288 self.max_seq_len_cached = seq_len
289 dim = self.dim
290
291 freq_extra = 1.0 / (
292 self.base
293 ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
294 )
295 freq_inter = 1.0 / (
296 self.scaling_factor
297 * self.base
298 ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
299 )
300
301 low, high = yarn_find_correction_range(
302 self.beta_fast,
303 self.beta_slow,
304 dim,
305 self.base,
306 self.original_max_position_embeddings,
307 )
308 inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
309 device=device, dtype=torch.float32
310 )
311 inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
312 self.register_buffer("inv_freq", inv_freq, persistent=False)
313
314 t = torch.arange(seq_len, device=device, dtype=torch.float32)
315
316 freqs = torch.outer(t, inv_freq)
317
318 _mscale = float(
319 yarn_get_mscale(self.scaling_factor, self.mscale)
320 / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
321 )
322
323 emb = torch.cat((freqs, freqs), dim=-1)
324 self.register_buffer(
325 "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
326 )
327 self.register_buffer(
328 "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
329 )
330
331
332 # Copied from transformers.models.llama.modeling_llama.rotate_half
333 def rotate_half(x):
334 """Rotates half the hidden dims of the input."""
335 x1 = x[..., : x.shape[-1] // 2]
336 x2 = x[..., x.shape[-1] // 2 :]
337 return torch.cat((-x2, x1), dim=-1)
338
339
340 # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
341 def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
342 """Applies Rotary Position Embedding to the query and key tensors.
343
344 Args:
345 q (`torch.Tensor`): The query tensor.
346 k (`torch.Tensor`): The key tensor.
347 cos (`torch.Tensor`): The cosine part of the rotary embedding.
348 sin (`torch.Tensor`): The sine part of the rotary embedding.
349 position_ids (`torch.Tensor`):
350 The position indices of the tokens corresponding to the query and key tensors. For example, this can be
351 used to pass offsetted position ids when working with a KV-cache.
352 unsqueeze_dim (`int`, *optional*, defaults to 1):
353 The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
354 sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
355 that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
356 k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
357 cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
358 the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
359 Returns:
360 `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
361 """
362 cos = cos[position_ids].unsqueeze(unsqueeze_dim)
363 sin = sin[position_ids].unsqueeze(unsqueeze_dim)
364
365
366 # print()
367
368 b, h, s, d = q.shape
369 q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
370
371 b, h, s, d = k.shape
372 k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
373
374 q_embed = (q * cos) + (rotate_half(q) * sin)
375 k_embed = (k * cos) + (rotate_half(k) * sin)
376
377
378 return q_embed, k_embed
379
380
381 class DeepseekV2MLP(nn.Module):
382 def __init__(self, config, hidden_size=None, intermediate_size=None):
383 super().__init__()
384 self.config = config
385 self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
386 self.intermediate_size = (
387 config.intermediate_size if intermediate_size is None else intermediate_size
388 )
389
390 self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
391 self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
392 self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
393 self.act_fn = ACT2FN[config.hidden_act]
394
395 def forward(self, x):
396 down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
397 return down_proj
398
399
400 class MoEGate(nn.Module):
401 def __init__(self, config):
402 super().__init__()
403 self.config = config
404 self.top_k = config.num_experts_per_tok
405 self.n_routed_experts = config.n_routed_experts
406 self.routed_scaling_factor = config.routed_scaling_factor
407 self.scoring_func = config.scoring_func
408 self.alpha = config.aux_loss_alpha
409 self.seq_aux = config.seq_aux
410 self.topk_method = config.topk_method
411 self.n_group = config.n_group
412 self.topk_group = config.topk_group
413
414 # topk selection algorithm
415 self.norm_topk_prob = config.norm_topk_prob
416 self.gating_dim = config.hidden_size
417 self.weight = nn.Parameter(
418 torch.empty((self.n_routed_experts, self.gating_dim))
419 )
420 if self.topk_method == "noaux_tc":
421 self.e_score_correction_bias = nn.Parameter(
422 torch.empty((self.n_routed_experts))
423 )
424 self.reset_parameters()
425
426 def reset_parameters(self) -> None:
427 import torch.nn.init as init
428
429 init.kaiming_uniform_(self.weight, a=math.sqrt(5))
430
431 def forward(self, hidden_states):
432 bsz, seq_len, h = hidden_states.shape
433 ### compute gating score
434 hidden_states = hidden_states.view(-1, h)
435 logits = F.linear(
436 hidden_states.type(torch.float32), self.weight.type(torch.float32), None
437 )
438 if self.scoring_func == "softmax":
439 scores = logits.softmax(dim=-1, dtype=torch.float32)
440 elif self.scoring_func == "sigmoid":
441 scores = logits.sigmoid()
442 else:
443 raise NotImplementedError(
444 f"insupportable scoring function for MoE gating: {self.scoring_func}"
445 )
446
447 ### select top-k experts
448 if self.topk_method == "greedy":
449 topk_weight, topk_idx = torch.topk(
450 scores, k=self.top_k, dim=-1, sorted=False
451 )
452 elif self.topk_method == "group_limited_greedy":
453 group_scores = (
454 scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values
455 ) # [n, n_group]
456 group_idx = torch.topk(
457 group_scores, k=self.topk_group, dim=-1, sorted=False
458 )[
459 1
460 ] # [n, top_k_group]
461 group_mask = torch.zeros_like(group_scores) # [n, n_group]
462 group_mask.scatter_(1, group_idx, 1) # [n, n_group]
463 score_mask = (
464 group_mask.unsqueeze(-1)
465 .expand(
466 bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
467 )
468 .reshape(bsz * seq_len, -1)
469 ) # [n, e]
470 tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
471 topk_weight, topk_idx = torch.topk(
472 tmp_scores, k=self.top_k, dim=-1, sorted=False
473 )
474 elif self.topk_method == "noaux_tc":
475 assert not self.training
476 scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
477 group_scores = (
478 scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1)
479 ) # [n, n_group]
480 group_idx = torch.topk(
481 group_scores, k=self.topk_group, dim=-1, sorted=False
482 )[
483 1
484 ] # [n, top_k_group]
485 group_mask = torch.zeros_like(group_scores) # [n, n_group]
486 group_mask.scatter_(1, group_idx, 1) # [n, n_group]
487 score_mask = (
488 group_mask.unsqueeze(-1)
489 .expand(
490 bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
491 )
492 .reshape(bsz * seq_len, -1)
493 ) # [n, e]
494 tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
495 _, topk_idx = torch.topk(
496 tmp_scores, k=self.top_k, dim=-1, sorted=False
497 )
498 topk_weight = scores.gather(1, topk_idx)
499
500 ### norm gate to sum 1
501 if self.top_k > 1 and self.norm_topk_prob:
502 denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
503 topk_weight = topk_weight / denominator * self.routed_scaling_factor
504 else:
505 topk_weight = topk_weight * self.routed_scaling_factor
506 ### expert-level computation auxiliary loss
507 if self.training and self.alpha > 0.0:
508 scores_for_aux = scores
509 aux_topk = self.top_k
510 # always compute aux loss based on the naive greedy topk method
511 topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
512 if self.seq_aux:
513 scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
514 ce = torch.zeros(
515 bsz, self.n_routed_experts, device=hidden_states.device
516 )
517 ce.scatter_add_(
518 1,
519 topk_idx_for_aux_loss,
520 torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
521 ).div_(seq_len * aux_topk / self.n_routed_experts)
522 aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(
523 dim=1
524 ).mean() * self.alpha
525 else:
526 mask_ce = F.one_hot(
527 topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts
528 )
529 ce = mask_ce.float().mean(0)
530 Pi = scores_for_aux.mean(0)
531 fi = ce * self.n_routed_experts
532 aux_loss = (Pi * fi).sum() * self.alpha
533 else:
534 aux_loss = None
535 return topk_idx, topk_weight, aux_loss
536
537
538 class AddAuxiliaryLoss(torch.autograd.Function):
539 """
540 The trick function of adding auxiliary (aux) loss,
541 which includes the gradient of the aux loss during backpropagation.
542 """
543
544 @staticmethod
545 def forward(ctx, x, loss):
546 assert loss.numel() == 1
547 ctx.dtype = loss.dtype
548 ctx.required_aux_loss = loss.requires_grad
549 return x
550
551 @staticmethod
552 def backward(ctx, grad_output):
553 grad_loss = None
554 if ctx.required_aux_loss:
555 grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
556 return grad_output, grad_loss
557
558
559 class DeepseekV2MoE(nn.Module):
560 """
561 A mixed expert module containing shared experts.
562 """
563
564 def __init__(self, config):
565 super().__init__()
566 self.config = config
567 self.num_experts_per_tok = config.num_experts_per_tok
568
569 if hasattr(config, "ep_size") and config.ep_size > 1:
570 assert config.ep_size == dist.get_world_size()
571 self.ep_size = config.ep_size
572 self.experts_per_rank = config.n_routed_experts // config.ep_size
573 self.ep_rank = dist.get_rank()
574 self.experts = nn.ModuleList(
575 [
576 (
577 DeepseekV2MLP(
578 config, intermediate_size=config.moe_intermediate_size
579 )
580 if i >= self.ep_rank * self.experts_per_rank
581 and i < (self.ep_rank + 1) * self.experts_per_rank
582 else None
583 )
584 for i in range(config.n_routed_experts)
585 ]
586 )
587 else:
588 self.ep_size = 1
589 self.experts_per_rank = config.n_routed_experts
590 self.ep_rank = 0
591 self.experts = nn.ModuleList(
592 [
593 DeepseekV2MLP(
594 config, intermediate_size=config.moe_intermediate_size
595 )
596 for i in range(config.n_routed_experts)
597 ]
598 )
599 self.gate = MoEGate(config)
600 if config.n_shared_experts is not None:
601 intermediate_size = config.moe_intermediate_size * config.n_shared_experts
602 self.shared_experts = DeepseekV2MLP(
603 config=config, intermediate_size=intermediate_size
604 )
605
606 def forward(self, hidden_states):
607 identity = hidden_states
608 orig_shape = hidden_states.shape
609 topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
610 hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
611 flat_topk_idx = topk_idx.view(-1)
612 if self.training:
613 hidden_states = hidden_states.repeat_interleave(
614 self.num_experts_per_tok, dim=0
615 )
616 y = torch.empty_like(hidden_states)
617 for i, expert in enumerate(self.experts):
618 y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
619 y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
620 y = y.to(hidden_states.dtype).view(*orig_shape)
621 y = AddAuxiliaryLoss.apply(y, aux_loss)
622 else:
623 y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
624 if self.config.n_shared_experts is not None:
625 y = y + self.shared_experts(identity)
626 return y
627
628 @torch.no_grad()
629 def moe_infer(self, x, topk_ids, topk_weight):
630 cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
631 cnts.scatter_(1, topk_ids, 1)
632 tokens_per_expert = cnts.sum(dim=0)
633 idxs = topk_ids.view(-1).argsort()
634 sorted_tokens = x[idxs // topk_ids.shape[1]]
635 sorted_tokens_shape = sorted_tokens.shape
636 if self.ep_size > 1:
637 tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
638 tokens_per_expert_group = tokens_per_expert.new_empty(
639 tokens_per_expert.shape[0]
640 )
641 dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
642 output_splits = (
643 tokens_per_expert_group.view(self.ep_size, -1)
644 .sum(1)
645 .cpu()
646 .numpy()
647 .tolist()
648 )
649 gathered_tokens = sorted_tokens.new_empty(
650 tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
651 )
652 input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
653 dist.all_to_all(
654 list(gathered_tokens.split(output_splits)),
655 list(sorted_tokens.split(input_split_sizes)),
656 )
657 tokens_per_expert_post_gather = tokens_per_expert_group.view(
658 self.ep_size, self.experts_per_rank
659 ).sum(dim=0)
660 gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
661 s = 0
662 for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
663 gatherd_idxs[s : s + k] = i % self.experts_per_rank
664 s += k
665 gatherd_idxs = gatherd_idxs.argsort()
666 sorted_tokens = gathered_tokens[gatherd_idxs]
667 tokens_per_expert = tokens_per_expert_post_gather
668 tokens_per_expert = tokens_per_expert.cpu().numpy()
669
670 outputs = []
671 start_idx = 0
672 for i, num_tokens in enumerate(tokens_per_expert):
673 end_idx = start_idx + num_tokens
674 if num_tokens == 0:
675 continue
676 expert = self.experts[i + self.ep_rank * self.experts_per_rank]
677 tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
678 expert_out = expert(tokens_for_this_expert)
679 outputs.append(expert_out)
680 start_idx = end_idx
681
682 outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
683 if self.ep_size > 1:
684 new_x = torch.empty_like(outs)
685 new_x[gatherd_idxs] = outs
686 gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
687 dist.all_to_all(
688 list(gathered_tokens.split(input_split_sizes)),
689 list(new_x.split(output_splits)),
690 )
691 outs = gathered_tokens
692
693 new_x = torch.empty_like(outs)
694 new_x[idxs] = outs
695 final_out = (
696 new_x.view(*topk_ids.shape, -1)
697 .type(topk_weight.dtype)
698 .mul_(topk_weight.unsqueeze(dim=-1))
699 .sum(dim=1)
700 .type(new_x.dtype)
701 )
702 return final_out
703
704
705 # Copied from transformers.models.llama.modeling_llama.repeat_kv
706 def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
707 """
708 This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
709 num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
710 """
711 batch, num_key_value_heads, slen, head_dim = hidden_states.shape
712 if n_rep == 1:
713 return hidden_states
714 hidden_states = hidden_states[:, :, None, :, :].expand(
715 batch, num_key_value_heads, n_rep, slen, head_dim
716 )
717 return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
718
719
720 # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2
721 class DeepseekV2Attention(nn.Module):
722 """Multi-headed attention from 'Attention Is All You Need' paper"""
723
724 def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None):
725 super().__init__()
726 self.config = config
727 self.layer_idx = layer_idx
728 if layer_idx is None:
729 logger.warning_once(
730 f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
731 "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
732 "when creating this class."
733 )
734
735 self.attention_dropout = config.attention_dropout
736 self.hidden_size = config.hidden_size
737 self.num_heads = config.num_attention_heads
738
739 self.max_position_embeddings = config.max_position_embeddings
740 self.rope_theta = config.rope_theta
741 self.q_lora_rank = config.q_lora_rank
742 self.qk_rope_head_dim = config.qk_rope_head_dim
743 self.kv_lora_rank = config.kv_lora_rank
744 self.v_head_dim = config.v_head_dim
745 self.qk_nope_head_dim = config.qk_nope_head_dim
746 self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
747
748 self.is_causal = True
749
750 if self.q_lora_rank is None:
751 self.q_proj = nn.Linear(
752 self.hidden_size, self.num_heads * self.q_head_dim, bias=False
753 )
754 else:
755 self.q_a_proj = nn.Linear(
756 self.hidden_size, config.q_lora_rank, bias=config.attention_bias
757 )
758 self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)
759 self.q_b_proj = nn.Linear(
760 config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
761 )
762 # config.kv_lora_rank + config.qk_rope_head_dim,
763 self.kv_a_proj_with_mqa = nn.Linear(
764 self.hidden_size,
765 config.kv_lora_rank + config.qk_rope_head_dim,
766 bias=config.attention_bias,
767 )
768 self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)
769 self.kv_b_proj = nn.Linear(
770 config.kv_lora_rank,
771 self.num_heads
772 * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
773 bias=False,
774 )
775
776 self.o_proj = nn.Linear(
777 self.num_heads * self.v_head_dim,
778 self.hidden_size,
779 bias=config.attention_bias,
780 )
781 self._init_rope()
782
783 self.softmax_scale = self.q_head_dim ** (-0.5)
784 if self.config.rope_scaling is not None:
785 mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
786 scaling_factor = self.config.rope_scaling["factor"]
787 if mscale_all_dim:
788 mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
789 self.softmax_scale = self.softmax_scale * mscale * mscale
790
791 def _init_rope(self):
792 if self.config.rope_scaling is None:
793 self.rotary_emb = DeepseekV2RotaryEmbedding(
794 self.qk_rope_head_dim,
795 max_position_embeddings=self.max_position_embeddings,
796 base=self.rope_theta,
797 )
798 # self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding(
799 # self.qk_rope_head_dim,
800 # max_position_embeddings=self.max_position_embeddings,
801 # scaling_factor=scaling_factor,
802 # base=self.rope_theta,
803 # )
804 else:
805 scaling_type = self.config.rope_scaling["type"]
806 scaling_factor = self.config.rope_scaling["factor"]
807 if scaling_type == "linear":
808 self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding(
809 self.qk_rope_head_dim,
810 max_position_embeddings=self.max_position_embeddings,
811 scaling_factor=scaling_factor,
812 base=self.rope_theta,
813 )
814 elif scaling_type == "dynamic":
815 self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding(
816 self.qk_rope_head_dim,
817 max_position_embeddings=self.max_position_embeddings,
818 scaling_factor=scaling_factor,
819 base=self.rope_theta,
820 )
821 elif scaling_type == "yarn":
822 kwargs = {
823 key: self.config.rope_scaling[key]
824 for key in [
825 "original_max_position_embeddings",
826 "beta_fast",
827 "beta_slow",
828 "mscale",
829 "mscale_all_dim",
830 ]
831 if key in self.config.rope_scaling
832 }
833 self.rotary_emb = DeepseekV2YarnRotaryEmbedding(
834 self.qk_rope_head_dim,
835 max_position_embeddings=self.max_position_embeddings,
836 scaling_factor=scaling_factor,
837 base=self.rope_theta,
838 **kwargs,
839 )
840 else:
841 raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
842
843 def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
844 return (
845 tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim)
846 .transpose(1, 2)
847 .contiguous()
848 )
849
850 def forward(
851 self,
852 hidden_states: torch.Tensor,
853 attention_mask: Optional[torch.Tensor] = None,
854 position_ids: Optional[torch.LongTensor] = None,
855 past_key_value: Optional[Cache] = None,
856 output_attentions: bool = False,
857 use_cache: bool = False,
858 **kwargs,
859 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
860 if "padding_mask" in kwargs:
861 warnings.warn(
862 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
863 )
864 bsz, q_len, _ = hidden_states.size()
865
866 if self.q_lora_rank is None:
867 q = self.q_proj(hidden_states)
868 else:
869 q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
870 q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
871
872
873 q_nope, q_pe = torch.split(
874 q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
875 )
876
877 compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
878 compressed_kv, k_pe = torch.split(
879 compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
880 )
881 compressed_kv = self.kv_a_layernorm(compressed_kv)
882 k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
883
884 kv_seq_len = k_pe.shape[-2]
885 if past_key_value is not None:
886 if self.layer_idx is None:
887 raise ValueError(
888 f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
889 "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
890 "with a layer index."
891 )
892 kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
893
894 cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len)
895 q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
896
897 if past_key_value is not None:
898 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
899 compressed_kv = compressed_kv.unsqueeze(1)
900 k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
901 compressed_kv = compressed_kv.squeeze(1)
902
903 kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
904 q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :]
905 out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :]
906
907 q_nope = torch.matmul(q_nope, q_absorb)
908 attn_weights = (torch.matmul(q_pe, k_pe.mT) +
909 torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale
910 if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
911 raise ValueError(
912 f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
913 f" {attn_weights.size()}"
914 )
915 assert attention_mask is not None
916 if attention_mask is not None:
917 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
918 raise ValueError(
919 f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
920 )
921 attn_weights = attn_weights + attention_mask
922
923 # upcast attention to fp32
924 attn_weights = nn.functional.softmax(
925 attn_weights, dim=-1, dtype=torch.float32
926 ).to(q_pe.dtype)
927 attn_weights = nn.functional.dropout(
928 attn_weights, p=self.attention_dropout, training=self.training
929 )
930 attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
931
932 attn_output = torch.matmul(attn_output, out_absorb.mT)
933
934 if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
935 raise ValueError(
936 f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
937 f" {attn_output.size()}"
938 )
939
940 attn_output = attn_output.transpose(1, 2).contiguous()
941
942 attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
943
944 attn_output = self.o_proj(attn_output)
945
946 if not output_attentions:
947 attn_weights = None
948
949 return attn_output, attn_weights, past_key_value
950
951
952 # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2
953 class DeepseekV2FlashAttention2(DeepseekV2Attention):
954 """
955 DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays
956 untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
957 flash attention and deal with padding tokens in case the input contains any of them.
958 """
959
960 def __init__(self, *args, **kwargs):
961 super().__init__(*args, **kwargs)
962
963 # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
964 # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
965 # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
966 self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
967
968 def forward(
969 self,
970 hidden_states: torch.Tensor,
971 attention_mask: Optional[torch.LongTensor] = None,
972 position_ids: Optional[torch.LongTensor] = None,
973 past_key_value: Optional[Cache] = None,
974 output_attentions: bool = False,
975 use_cache: bool = False,
976 **kwargs,
977 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
978 # DeepseekV2FlashAttention2 attention does not support output_attentions
979 if "padding_mask" in kwargs:
980 warnings.warn(
981 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
982 )
983
984 # overwrite attention_mask with padding_mask
985 attention_mask = kwargs.pop("padding_mask")
986
987 output_attentions = False
988
989 bsz, q_len, _ = hidden_states.size()
990
991 if self.q_lora_rank is None:
992 q = self.q_proj(hidden_states)
993 else:
994 q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
995 q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
996 q_nope, q_pe = torch.split(
997 q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
998 )
999
1000 # Flash attention requires the input to have the shape
1001 # batch_size x seq_length x head_dim x hidden_dim
1002 # therefore we just need to keep the original shape
1003 compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
1004 compressed_kv, k_pe = torch.split(
1005 compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
1006 )
1007 k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
1008 kv = (
1009 self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
1010 .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
1011 .transpose(1, 2)
1012 )
1013
1014 k_nope, value_states = torch.split(
1015 kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
1016 )
1017 kv_seq_len = value_states.shape[-2]
1018
1019 kv_seq_len = value_states.shape[-2]
1020 if past_key_value is not None:
1021 kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1022
1023 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1024 q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
1025
1026 query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
1027 query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
1028 query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
1029
1030 key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
1031 key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
1032 key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
1033
1034 if self.q_head_dim != self.v_head_dim:
1035 value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
1036
1037 # TODO: support compressed_kv for kv_cache (instead of key_states, value_states) in flash_attention version
1038 if past_key_value is not None:
1039 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
1040 key_states, value_states = past_key_value.update(
1041 key_states, value_states, self.layer_idx, cache_kwargs
1042 )
1043
1044 # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
1045 # to be able to avoid many of these transpose/reshape/view.
1046 query_states = query_states.transpose(1, 2)
1047 key_states = key_states.transpose(1, 2)
1048 value_states = value_states.transpose(1, 2)
1049
1050 dropout_rate = self.attention_dropout if self.training else 0.0
1051
1052 # In PEFT, usually we cast the layer norms in float32 for training stability reasons
1053 # therefore the input hidden states gets silently casted in float32. Hence, we need
1054 # cast them back in the correct dtype just to be sure everything works as expected.
1055 # This might slowdown training & inference so it is recommended to not cast the LayerNorms
1056 # in fp32. (DeepseekV2RMSNorm handles it correctly)
1057
1058 input_dtype = query_states.dtype
1059 if input_dtype == torch.float32:
1060 # Handle the case where the model is quantized
1061 if hasattr(self.config, "_pre_quantization_dtype"):
1062 target_dtype = self.config._pre_quantization_dtype
1063 elif torch.is_autocast_enabled():
1064 target_dtype = torch.get_autocast_gpu_dtype()
1065 else:
1066 target_dtype = (
1067 self.q_proj.weight.dtype
1068 if self.q_lora_rank is None
1069 else self.q_a_proj.weight.dtype
1070 )
1071
1072 logger.warning_once(
1073 f"The input hidden states seems to be silently casted in float32, this might be related to"
1074 f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
1075 f" {target_dtype}."
1076 )
1077
1078 query_states = query_states.to(target_dtype)
1079 key_states = key_states.to(target_dtype)
1080 value_states = value_states.to(target_dtype)
1081
1082 attn_output = self._flash_attention_forward(
1083 query_states,
1084 key_states,
1085 value_states,
1086 attention_mask,
1087 q_len,
1088 dropout=dropout_rate,
1089 softmax_scale=self.softmax_scale,
1090 )
1091 if self.q_head_dim != self.v_head_dim:
1092 attn_output = attn_output[:, :, :, : self.v_head_dim]
1093
1094 attn_output = attn_output.reshape(
1095 bsz, q_len, self.num_heads * self.v_head_dim
1096 ).contiguous()
1097 attn_output = self.o_proj(attn_output)
1098
1099 if not output_attentions:
1100 attn_weights = None
1101
1102 return attn_output, attn_weights, past_key_value
1103
1104 def _flash_attention_forward(
1105 self,
1106 query_states,
1107 key_states,
1108 value_states,
1109 attention_mask,
1110 query_length,
1111 dropout=0.0,
1112 softmax_scale=None,
1113 ):
1114 """
1115 Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1116 first unpad the input, then computes the attention scores and pad the final attention scores.
1117
1118 Args:
1119 query_states (`torch.Tensor`):
1120 Input query states to be passed to Flash Attention API
1121 key_states (`torch.Tensor`):
1122 Input key states to be passed to Flash Attention API
1123 value_states (`torch.Tensor`):
1124 Input value states to be passed to Flash Attention API
1125 attention_mask (`torch.Tensor`):
1126 The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
1127 position of padding tokens and 1 for the position of non-padding tokens.
1128 dropout (`int`, *optional*):
1129 Attention dropout
1130 softmax_scale (`float`, *optional*):
1131 The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1132 """
1133 if not self._flash_attn_uses_top_left_mask:
1134 causal = self.is_causal
1135 else:
1136 # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__.
1137 causal = self.is_causal and query_length != 1
1138
1139 # Contains at least one padding token in the sequence
1140 if attention_mask is not None:
1141 batch_size = query_states.shape[0]
1142 (
1143 query_states,
1144 key_states,
1145 value_states,
1146 indices_q,
1147 cu_seq_lens,
1148 max_seq_lens,
1149 ) = self._upad_input(
1150 query_states, key_states, value_states, attention_mask, query_length
1151 )
1152
1153 cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1154 max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1155
1156 attn_output_unpad = flash_attn_varlen_func(
1157 query_states,
1158 key_states,
1159 value_states,
1160 cu_seqlens_q=cu_seqlens_q,
1161 cu_seqlens_k=cu_seqlens_k,
1162 max_seqlen_q=max_seqlen_in_batch_q,
1163 max_seqlen_k=max_seqlen_in_batch_k,
1164 dropout_p=dropout,
1165 softmax_scale=softmax_scale,
1166 causal=causal,
1167 )
1168
1169 attn_output = pad_input(
1170 attn_output_unpad, indices_q, batch_size, query_length
1171 )
1172 else:
1173 attn_output = flash_attn_func(
1174 query_states,
1175 key_states,
1176 value_states,
1177 dropout,
1178 softmax_scale=softmax_scale,
1179 causal=causal,
1180 )
1181
1182 return attn_output
1183
1184 def _upad_input(
1185 self, query_layer, key_layer, value_layer, attention_mask, query_length
1186 ):
1187 indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
1188 batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1189
1190 key_layer = index_first_axis(
1191 key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1192 indices_k,
1193 )
1194 value_layer = index_first_axis(
1195 value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1196 indices_k,
1197 )
1198 if query_length == kv_seq_len:
1199 query_layer = index_first_axis(
1200 query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
1201 indices_k,
1202 )
1203 cu_seqlens_q = cu_seqlens_k
1204 max_seqlen_in_batch_q = max_seqlen_in_batch_k
1205 indices_q = indices_k
1206 elif query_length == 1:
1207 max_seqlen_in_batch_q = 1
1208 cu_seqlens_q = torch.arange(
1209 batch_size + 1, dtype=torch.int32, device=query_layer.device
1210 ) # There is a memcpy here, that is very bad.
1211 indices_q = cu_seqlens_q[:-1]
1212 query_layer = query_layer.squeeze(1)
1213 else:
1214 # The -q_len: slice assumes left padding.
1215 attention_mask = attention_mask[:, -query_length:]
1216 query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
1217 query_layer, attention_mask
1218 )
1219
1220 return (
1221 query_layer,
1222 key_layer,
1223 value_layer,
1224 indices_q,
1225 (cu_seqlens_q, cu_seqlens_k),
1226 (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1227 )
1228
1229
1230 ATTENTION_CLASSES = {
1231 "eager": DeepseekV2Attention,
1232 "flash_attention_2": DeepseekV2FlashAttention2,
1233
1234 "mla_eager": DeepseekV2Attention,
1235 "mla_flash_attention_2": DeepseekV2FlashAttention2,
1236
1237 "mha_eager": LlamaAttention,
1238 "mha_flash_attention_2": LlamaFlashAttention2
1239 }
1240
1241
1242 class DeepseekV2DecoderLayer(nn.Module):
1243 def __init__(self, config: DeepseekV2Config, layer_idx: int):
1244 super().__init__()
1245 self.hidden_size = config.hidden_size
1246
1247
1248 if config.use_mla:
1249 attn_implementation = "mla_" + config._attn_implementation
1250 else:
1251 attn_implementation = "mha_" + config._attn_implementation
1252
1253 self.self_attn = ATTENTION_CLASSES[attn_implementation](
1254 config=config, layer_idx=layer_idx
1255 )
1256
1257 self.mlp = (
1258 DeepseekV2MoE(config)
1259 if (
1260 config.n_routed_experts is not None
1261 and layer_idx >= config.first_k_dense_replace
1262 and layer_idx % config.moe_layer_freq == 0
1263 )
1264 else DeepseekV2MLP(config)
1265 )
1266 self.input_layernorm = DeepseekV2RMSNorm(
1267 config.hidden_size, eps=config.rms_norm_eps
1268 )
1269 self.post_attention_layernorm = DeepseekV2RMSNorm(
1270 config.hidden_size, eps=config.rms_norm_eps
1271 )
1272
1273 def forward(
1274 self,
1275 hidden_states: torch.Tensor,
1276 attention_mask: Optional[torch.Tensor] = None,
1277 position_ids: Optional[torch.LongTensor] = None,
1278 past_key_value: Optional[Tuple[torch.Tensor]] = None,
1279 output_attentions: Optional[bool] = False,
1280 use_cache: Optional[bool] = False,
1281 **kwargs,
1282 ) -> Tuple[
1283 torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
1284 ]:
1285 """
1286 Args:
1287 hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1288 attention_mask (`torch.FloatTensor`, *optional*):
1289 attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1290 query_sequence_length, key_sequence_length)` if default attention is used.
1291 output_attentions (`bool`, *optional*):
1292 Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1293 returned tensors for more detail.
1294 use_cache (`bool`, *optional*):
1295 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1296 (see `past_key_values`).
1297 past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1298 """
1299 if "padding_mask" in kwargs:
1300 warnings.warn(
1301 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1302 )
1303 residual = hidden_states
1304
1305 hidden_states = self.input_layernorm(hidden_states)
1306
1307 # Self Attention
1308 hidden_states, self_attn_weights, present_key_value = self.self_attn(
1309 hidden_states=hidden_states,
1310 attention_mask=attention_mask,
1311 position_ids=position_ids,
1312 past_key_value=past_key_value,
1313 output_attentions=output_attentions,
1314 use_cache=use_cache,
1315 **kwargs,
1316 )
1317 hidden_states = residual + hidden_states
1318
1319 # Fully Connected
1320 residual = hidden_states
1321 hidden_states = self.post_attention_layernorm(hidden_states)
1322 hidden_states = self.mlp(hidden_states)
1323 hidden_states = residual + hidden_states
1324
1325 outputs = (hidden_states,)
1326
1327 if output_attentions:
1328 outputs += (self_attn_weights,)
1329
1330 if use_cache:
1331 outputs += (present_key_value,)
1332
1333 return outputs
1334
1335
1336 DeepseekV2_START_DOCSTRING = r"""
1337 This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1338 library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1339 etc.)
1340
1341 This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1342 Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1343 and behavior.
1344
1345 Parameters:
1346 config ([`DeepseekV2Config`]):
1347 Model configuration class with all the parameters of the model. Initializing with a config file does not
1348 load the weights associated with the model, only the configuration. Check out the
1349 [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1350 """
1351
1352
1353 @add_start_docstrings(
1354 "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.",
1355 DeepseekV2_START_DOCSTRING,
1356 )
1357 class DeepseekV2PreTrainedModel(PreTrainedModel):
1358 config_class = DeepseekV2Config
1359 base_model_prefix = "model"
1360 supports_gradient_checkpointing = True
1361 _no_split_modules = ["DeepseekV2DecoderLayer"]
1362 _skip_keys_device_placement = "past_key_values"
1363 _supports_flash_attn_2 = True
1364 _supports_cache_class = True
1365
1366 def _init_weights(self, module):
1367 std = self.config.initializer_range
1368 if isinstance(module, nn.Linear):
1369 module.weight.data.normal_(mean=0.0, std=std)
1370 if module.bias is not None:
1371 module.bias.data.zero_()
1372 elif isinstance(module, nn.Embedding):
1373 module.weight.data.normal_(mean=0.0, std=std)
1374 if module.padding_idx is not None:
1375 module.weight.data[module.padding_idx].zero_()
1376
1377
1378 DeepseekV2_INPUTS_DOCSTRING = r"""
1379 Args:
1380 input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1381 Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1382 it.
1383
1384 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1385 [`PreTrainedTokenizer.__call__`] for details.
1386
1387 [What are input IDs?](../glossary#input-ids)
1388 attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1389 Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1390
1391 - 1 for tokens that are **not masked**,
1392 - 0 for tokens that are **masked**.
1393
1394 [What are attention masks?](../glossary#attention-mask)
1395
1396 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1397 [`PreTrainedTokenizer.__call__`] for details.
1398
1399 If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1400 `past_key_values`).
1401
1402 If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1403 and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1404 information on the default strategy.
1405
1406 - 1 indicates the head is **not masked**,
1407 - 0 indicates the head is **masked**.
1408 position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1409 Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1410 config.n_positions - 1]`.
1411
1412 [What are position IDs?](../glossary#position-ids)
1413 past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1414 Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1415 blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1416 returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1417
1418 Two formats are allowed:
1419 - a [`~cache_utils.Cache`] instance;
1420 - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1421 shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1422 cache format.
1423
1424 The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1425 legacy cache format will be returned.
1426
1427 If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1428 have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1429 of shape `(batch_size, sequence_length)`.
1430 inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1431 Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1432 is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1433 model's internal embedding lookup matrix.
1434 use_cache (`bool`, *optional*):
1435 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1436 `past_key_values`).
1437 output_attentions (`bool`, *optional*):
1438 Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1439 tensors for more detail.
1440 output_hidden_states (`bool`, *optional*):
1441 Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1442 more detail.
1443 return_dict (`bool`, *optional*):
1444 Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1445 """
1446
1447
1448 @add_start_docstrings(
1449 "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.",
1450 DeepseekV2_START_DOCSTRING,
1451 )
1452 class DeepseekV2Model(DeepseekV2PreTrainedModel):
1453 """
1454 Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`]
1455
1456 Args:
1457 config: DeepseekV2Config
1458 """
1459
1460 def __init__(self, config: DeepseekV2Config):
1461 super().__init__(config)
1462 self.padding_idx = config.pad_token_id
1463 self.vocab_size = config.vocab_size
1464
1465 self.embed_tokens = nn.Embedding(
1466 config.vocab_size, config.hidden_size, self.padding_idx
1467 )
1468 self.layers = nn.ModuleList(
1469 [
1470 DeepseekV2DecoderLayer(config, layer_idx)
1471 for layer_idx in range(config.num_hidden_layers)
1472 ]
1473 )
1474 # print(config._attn_implementation)
1475 self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1476 self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1477
1478 self.gradient_checkpointing = False
1479 # Initialize weights and apply final processing
1480 self.post_init()
1481
1482 def get_input_embeddings(self):
1483 return self.embed_tokens
1484
1485 def set_input_embeddings(self, value):
1486 self.embed_tokens = value
1487
1488 @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
1489 def forward(
1490 self,
1491 input_ids: torch.LongTensor = None,
1492 attention_mask: Optional[torch.Tensor] = None,
1493 position_ids: Optional[torch.LongTensor] = None,
1494 past_key_values: Optional[List[torch.FloatTensor]] = None,
1495 inputs_embeds: Optional[torch.FloatTensor] = None,
1496 use_cache: Optional[bool] = None,
1497 output_attentions: Optional[bool] = None,
1498 output_hidden_states: Optional[bool] = None,
1499 return_dict: Optional[bool] = None,
1500 cache_position: Optional[torch.LongTensor] = None
1501 ) -> Union[Tuple, BaseModelOutputWithPast]:
1502 output_attentions = (
1503 output_attentions
1504 if output_attentions is not None
1505 else self.config.output_attentions
1506 )
1507 output_hidden_states = (
1508 output_hidden_states
1509 if output_hidden_states is not None
1510 else self.config.output_hidden_states
1511 )
1512 use_cache = use_cache if use_cache is not None else self.config.use_cache
1513
1514 return_dict = (
1515 return_dict if return_dict is not None else self.config.use_return_dict
1516 )
1517
1518 # retrieve input_ids and inputs_embeds
1519 if input_ids is not None and inputs_embeds is not None:
1520 raise ValueError(
1521 "You cannot specify both input_ids and inputs_embeds at the same time"
1522 )
1523 elif input_ids is not None:
1524 batch_size, seq_length = input_ids.shape[:2]
1525 elif inputs_embeds is not None:
1526 batch_size, seq_length = inputs_embeds.shape[:2]
1527 else:
1528 raise ValueError("You have to specify either input_ids or inputs_embeds")
1529
1530 if self.gradient_checkpointing and self.training:
1531 if use_cache:
1532 logger.warning_once(
1533 "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers."
1534 )
1535 use_cache = False
1536
1537 past_key_values_length = 0
1538 if use_cache:
1539 use_legacy_cache = not isinstance(past_key_values, Cache)
1540 if use_legacy_cache:
1541 past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1542 past_key_values_length = past_key_values.get_usable_length(seq_length)
1543
1544 if position_ids is None:
1545 device = input_ids.device if input_ids is not None else inputs_embeds.device
1546 position_ids = torch.arange(
1547 past_key_values_length,
1548 seq_length + past_key_values_length,
1549 dtype=torch.long,
1550 device=device,
1551 )
1552 position_ids = position_ids.unsqueeze(0)
1553
1554 if inputs_embeds is None:
1555 inputs_embeds = self.embed_tokens(input_ids)
1556
1557 if self._use_flash_attention_2:
1558 # 2d mask is passed through the layers
1559 attention_mask = (
1560 attention_mask
1561 if (attention_mask is not None and 0 in attention_mask)
1562 else None
1563 )
1564 else:
1565 # 4d mask is passed through the layers
1566 attention_mask = _prepare_4d_causal_attention_mask(
1567 attention_mask,
1568 (batch_size, seq_length),
1569 inputs_embeds,
1570 past_key_values_length,
1571 )
1572
1573 # embed positions
1574 hidden_states = inputs_embeds
1575
1576 # decoder layers
1577 all_hidden_states = () if output_hidden_states else None
1578 all_self_attns = () if output_attentions else None
1579 next_decoder_cache = None
1580
1581 for decoder_layer in self.layers:
1582 if output_hidden_states:
1583 all_hidden_states += (hidden_states,)
1584
1585 if self.gradient_checkpointing and self.training:
1586 layer_outputs = self._gradient_checkpointing_func(
1587 decoder_layer.__call__,
1588 hidden_states,
1589 attention_mask,
1590 position_ids,
1591 past_key_values,
1592 output_attentions,
1593 use_cache,
1594 )
1595 else:
1596 layer_outputs = decoder_layer(
1597 hidden_states,
1598 attention_mask=attention_mask,
1599 position_ids=position_ids,
1600 past_key_value=past_key_values,
1601 output_attentions=output_attentions,
1602 use_cache=use_cache,
1603 )
1604
1605 hidden_states = layer_outputs[0]
1606
1607 if use_cache:
1608 next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1609
1610 if output_attentions:
1611 all_self_attns += (layer_outputs[1],)
1612
1613 hidden_states = self.norm(hidden_states)
1614
1615 # add hidden states from the last decoder layer
1616 if output_hidden_states:
1617 all_hidden_states += (hidden_states,)
1618
1619 next_cache = None
1620 if use_cache:
1621 next_cache = (
1622 next_decoder_cache.to_legacy_cache()
1623 if use_legacy_cache
1624 else next_decoder_cache
1625 )
1626 if not return_dict:
1627 return tuple(
1628 v
1629 for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1630 if v is not None
1631 )
1632 return BaseModelOutputWithPast(
1633 last_hidden_state=hidden_states,
1634 past_key_values=next_cache,
1635 hidden_states=all_hidden_states,
1636 attentions=all_self_attns,
1637 )
1638
1639
1640 class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
1641 _tied_weights_keys = ["lm_head.weight"]
1642
1643 def __init__(self, config):
1644 super().__init__(config)
1645 self.model = DeepseekV2Model(config)
1646 self.vocab_size = config.vocab_size
1647 self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1648
1649 # Initialize weights and apply final processing
1650 self.post_init()
1651
1652 def get_input_embeddings(self):
1653 return self.model.embed_tokens
1654
1655 def set_input_embeddings(self, value):
1656 self.model.embed_tokens = value
1657
1658 def get_output_embeddings(self):
1659 return self.lm_head
1660
1661 def set_output_embeddings(self, new_embeddings):
1662 self.lm_head = new_embeddings
1663
1664 def set_decoder(self, decoder):
1665 self.model = decoder
1666
1667 def get_decoder(self):
1668 return self.model
1669
1670 @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
1671 @replace_return_docstrings(
1672 output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1673 )
1674 def forward(
1675 self,
1676 input_ids: torch.LongTensor = None,
1677 attention_mask: Optional[torch.Tensor] = None,
1678 position_ids: Optional[torch.LongTensor] = None,
1679 past_key_values: Optional[List[torch.FloatTensor]] = None,
1680 inputs_embeds: Optional[torch.FloatTensor] = None,
1681 labels: Optional[torch.LongTensor] = None,
1682 use_cache: Optional[bool] = None,
1683 output_attentions: Optional[bool] = None,
1684 output_hidden_states: Optional[bool] = None,
1685 return_dict: Optional[bool] = None,
1686 cache_position: Optional[torch.LongTensor] = None
1687 ) -> Union[Tuple, CausalLMOutputWithPast]:
1688 r"""
1689 Args:
1690 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1691 Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
1692 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1693 (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
1694
1695 Returns:
1696
1697 Example:
1698
1699 ```python
1700 >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM
1701
1702 >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1703 >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1704
1705 >>> prompt = "Hey, are you conscious? Can you talk to me?"
1706 >>> inputs = tokenizer(prompt, return_tensors="pt")
1707
1708 >>> # Generate
1709 >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1710 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1711 "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1712 ```"""
1713 output_attentions = (
1714 output_attentions
1715 if output_attentions is not None
1716 else self.config.output_attentions
1717 )
1718 output_hidden_states = (
1719 output_hidden_states
1720 if output_hidden_states is not None
1721 else self.config.output_hidden_states
1722 )
1723 return_dict = (
1724 return_dict if return_dict is not None else self.config.use_return_dict
1725 )
1726
1727 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1728 outputs = self.model(
1729 input_ids=input_ids,
1730 attention_mask=attention_mask,
1731 position_ids=position_ids,
1732 past_key_values=past_key_values,
1733 inputs_embeds=inputs_embeds,
1734 use_cache=use_cache,
1735 output_attentions=output_attentions,
1736 output_hidden_states=output_hidden_states,
1737 return_dict=return_dict,
1738 cache_position=cache_position
1739 )
1740
1741 hidden_states = outputs[0]
1742 logits = self.lm_head(hidden_states)
1743 logits = logits.float()
1744
1745 loss = None
1746 if labels is not None:
1747 # Shift so that tokens < n predict n
1748 shift_logits = logits[..., :-1, :].contiguous()
1749 shift_labels = labels[..., 1:].contiguous()
1750 # Flatten the tokens
1751 loss_fct = CrossEntropyLoss()
1752 shift_logits = shift_logits.view(-1, self.config.vocab_size)
1753 shift_labels = shift_labels.view(-1)
1754 # Enable model parallelism
1755 shift_labels = shift_labels.to(shift_logits.device)
1756 loss = loss_fct(shift_logits, shift_labels)
1757
1758 if not return_dict:
1759 output = (logits,) + outputs[1:]
1760 return (loss,) + output if loss is not None else output
1761
1762 return CausalLMOutputWithPast(
1763 loss=loss,
1764 logits=logits,
1765 past_key_values=outputs.past_key_values,
1766 hidden_states=outputs.hidden_states,
1767 attentions=outputs.attentions,
1768 )
1769
1770 def prepare_inputs_for_generation(
1771 self,
1772 input_ids,
1773 past_key_values=None,
1774 attention_mask=None,
1775 inputs_embeds=None,
1776 **kwargs,
1777 ):
1778 past_length = 0
1779 if past_key_values is not None:
1780 if isinstance(past_key_values, Cache):
1781 cache_length = past_key_values.get_seq_length()
1782 past_length = past_key_values.seen_tokens
1783 max_cache_length = past_key_values.get_max_length()
1784 else:
1785 cache_length = past_length = past_key_values[0][0].shape[2]
1786 max_cache_length = None
1787
1788 # Keep only the unprocessed tokens:
1789 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1790 # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1791 # input)
1792 if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1793 input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
1794 # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1795 # input_ids based on the past_length.
1796 elif past_length < input_ids.shape[1]:
1797 input_ids = input_ids[:, past_length:]
1798 # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1799
1800 # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1801 if (
1802 max_cache_length is not None
1803 and attention_mask is not None
1804 and cache_length + input_ids.shape[1] > max_cache_length
1805 ):
1806 attention_mask = attention_mask[:, -max_cache_length:]
1807
1808 position_ids = kwargs.get("position_ids", None)
1809 if attention_mask is not None and position_ids is None:
1810 # create position_ids on the fly for batch generation
1811 position_ids = attention_mask.long().cumsum(-1) - 1
1812 position_ids.masked_fill_(attention_mask == 0, 1)
1813 if past_key_values:
1814 position_ids = position_ids[:, -input_ids.shape[1]:]
1815
1816 if self.generation_config.cache_implementation == "static":
1817 # generation with static cache
1818 cache_position = kwargs.get("cache_position", None)
1819 if cache_position is None:
1820 past_length = 0
1821 else:
1822 past_length = cache_position[-1] + 1
1823 input_ids = input_ids[:, past_length:]
1824 position_ids = position_ids[:, past_length:]
1825
1826 # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
1827 # same goes for position ids. Could also help with continued generation.
1828 cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
1829
1830 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1831 if inputs_embeds is not None and past_key_values is None:
1832 model_inputs = {"inputs_embeds": inputs_embeds}
1833 else:
1834 # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1835 # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1836 # TODO: use `next_tokens` directly instead.
1837 model_inputs = {"input_ids": input_ids.contiguous()}
1838
1839 model_inputs.update(
1840 {
1841 "position_ids": position_ids.contiguous(),
1842 "cache_position": cache_position,
1843 "past_key_values": past_key_values,
1844 "use_cache": kwargs.get("use_cache"),
1845 "attention_mask": attention_mask,
1846 }
1847 )
1848 return model_inputs
1849
1850 @staticmethod
1851 def _reorder_cache(past_key_values, beam_idx):
1852 reordered_past = ()
1853 for layer_past in past_key_values:
1854 reordered_past += (
1855 tuple(
1856 past_state.index_select(0, beam_idx.to(past_state.device))
1857 for past_state in layer_past
1858 ),
1859 )
1860 return reordered_past
1861
1862
1863 @add_start_docstrings(
1864 """
1865 The DeepseekV2 Model transformer with a sequence classification head on top (linear layer).
1866
1867 [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1868 (e.g. GPT-2) do.
1869
1870 Since it does classification on the last token, it requires to know the position of the last token. If a
1871 `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1872 no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1873 padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1874 each row of the batch).
1875 """,
1876 DeepseekV2_START_DOCSTRING,
1877 )
1878 class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
1879 def __init__(self, config):
1880 super().__init__(config)
1881 self.num_labels = config.num_labels
1882 self.model = DeepseekV2Model(config)
1883 self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1884
1885 # Initialize weights and apply final processing
1886 self.post_init()
1887
1888 def get_input_embeddings(self):
1889 return self.model.embed_tokens
1890
1891 def set_input_embeddings(self, value):
1892 self.model.embed_tokens = value
1893
1894 @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
1895 def forward(
1896 self,
1897 input_ids: torch.LongTensor = None,
1898 attention_mask: Optional[torch.Tensor] = None,
1899 position_ids: Optional[torch.LongTensor] = None,
1900 past_key_values: Optional[List[torch.FloatTensor]] = None,
1901 inputs_embeds: Optional[torch.FloatTensor] = None,
1902 labels: Optional[torch.LongTensor] = None,
1903 use_cache: Optional[bool] = None,
1904 output_attentions: Optional[bool] = None,
1905 output_hidden_states: Optional[bool] = None,
1906 return_dict: Optional[bool] = None,
1907 ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1908 r"""
1909 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1910 Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,
1911 config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1912 `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1913 """
1914 return_dict = (
1915 return_dict if return_dict is not None else self.config.use_return_dict
1916 )
1917
1918 transformer_outputs = self.model(
1919 input_ids,
1920 attention_mask=attention_mask,
1921 position_ids=position_ids,
1922 past_key_values=past_key_values,
1923 inputs_embeds=inputs_embeds,
1924 use_cache=use_cache,
1925 output_attentions=output_attentions,
1926 output_hidden_states=output_hidden_states,
1927 return_dict=return_dict,
1928 )
1929 hidden_states = transformer_outputs[0]
1930 logits = self.score(hidden_states)
1931
1932 if input_ids is not None:
1933 batch_size = input_ids.shape[0]
1934 else:
1935 batch_size = inputs_embeds.shape[0]
1936
1937 if self.config.pad_token_id is None and batch_size != 1:
1938 raise ValueError(
1939 "Cannot handle batch sizes > 1 if no padding token is defined."
1940 )
1941 if self.config.pad_token_id is None:
1942 sequence_lengths = -1
1943 else:
1944 if input_ids is not None:
1945 sequence_lengths = (
1946 torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1947 ).to(logits.device)
1948 else:
1949 sequence_lengths = -1
1950
1951 pooled_logits = logits[
1952 torch.arange(batch_size, device=logits.device), sequence_lengths
1953 ]
1954
1955 loss = None
1956 if labels is not None:
1957 labels = labels.to(logits.device)
1958 if self.config.problem_type is None:
1959 if self.num_labels == 1:
1960 self.config.problem_type = "regression"
1961 elif self.num_labels > 1 and (
1962 labels.dtype == torch.long or labels.dtype == torch.int
1963 ):
1964 self.config.problem_type = "single_label_classification"
1965 else:
1966 self.config.problem_type = "multi_label_classification"
1967
1968 if self.config.problem_type == "regression":
1969 loss_fct = MSELoss()
1970 if self.num_labels == 1:
1971 loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1972 else:
1973 loss = loss_fct(pooled_logits, labels)
1974 elif self.config.problem_type == "single_label_classification":
1975 loss_fct = CrossEntropyLoss()
1976 loss = loss_fct(
1977 pooled_logits.view(-1, self.num_labels), labels.view(-1)
1978 )
1979 elif self.config.problem_type == "multi_label_classification":
1980 loss_fct = BCEWithLogitsLoss()
1981 loss = loss_fct(pooled_logits, labels)
1982 if not return_dict:
1983 output = (pooled_logits,) + transformer_outputs[1:]
1984 return ((loss,) + output) if loss is not None else output
1985
1986 return SequenceClassifierOutputWithPast(
1987 loss=loss,
1988 logits=pooled_logits,
1989 past_key_values=transformer_outputs.past_key_values,
1990 hidden_states=transformer_outputs.hidden_states,
1991 attentions=transformer_outputs.attentions,
1992 )
1993