modeling_qwen.py
63.7 KB · 1430 lines · python Raw
1 # coding=utf-8
2 # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3 #
4 # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 # and OPT implementations in this library. It has been modified from its
6 # original forms to accommodate minor architectural differences compared
7 # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 #
9 # Licensed under the Apache License, Version 2.0 (the "License");
10 # you may not use this file except in compliance with the License.
11 # You may obtain a copy of the License at
12 #
13 # http://www.apache.org/licenses/LICENSE-2.0
14 #
15 # Unless required by applicable law or agreed to in writing, software
16 # distributed under the License is distributed on an "AS IS" BASIS,
17 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 # See the License for the specific language governing permissions and
19 # limitations under the License.
20 """ PyTorch Qwen2 model."""
21 from transformers import Qwen2Config
22 import inspect
23 import math
24 import os
25 import warnings
26 from typing import List, Optional, Tuple, Union
27
28 import torch
29 import torch.nn.functional as F
30 import torch.utils.checkpoint
31 from torch import nn
32 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
34 from transformers.activations import ACT2FN
35 from transformers.cache_utils import Cache, DynamicCache
36 from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
37 from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
38 from transformers.modeling_utils import PreTrainedModel
39 from transformers.utils import (
40 add_start_docstrings,
41 add_start_docstrings_to_model_forward,
42 is_flash_attn_2_available,
43 is_flash_attn_greater_or_equal_2_10,
44 logging,
45 replace_return_docstrings,
46 )
47
48
49 # if is_flash_attn_2_available():
50 # from flash_attn import flash_attn_func, flash_attn_varlen_func
51 # from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
52
53 # _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
54
55
56 logger = logging.get_logger(__name__)
57
58
59 _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
60 _CONFIG_FOR_DOC = "Qwen2Config"
61
62 QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [
63 "Qwen/Qwen2-7B-beta",
64 # See all Qwen2 models at https://huggingface.co/models?filter=qwen2
65 ]
66
67
68 # Copied from transformers.models.llama.modeling_llama._get_unpad_data
69 def _get_unpad_data(attention_mask):
70 seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
71 indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
72 max_seqlen_in_batch = seqlens_in_batch.max().item()
73 cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
74 return (
75 indices,
76 cu_seqlens,
77 max_seqlen_in_batch,
78 )
79
80
81 # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
82 class Qwen2RMSNorm(nn.Module):
83 def __init__(self, hidden_size, eps=1e-6):
84 """
85 Qwen2RMSNorm is equivalent to T5LayerNorm
86 """
87 super().__init__()
88 self.weight = nn.Parameter(torch.ones(hidden_size))
89 self.variance_epsilon = eps
90
91 def forward(self, hidden_states):
92 input_dtype = hidden_states.dtype
93 hidden_states = hidden_states.to(torch.float32)
94 variance = hidden_states.pow(2).mean(-1, keepdim=True)
95 hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
96 return self.weight * hidden_states.to(input_dtype)
97
98
99 # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2
100 class Qwen2RotaryEmbedding(nn.Module):
101 def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
102 super().__init__()
103
104 self.dim = dim
105 self.max_position_embeddings = max_position_embeddings
106 self.base = base
107 inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
108 self.register_buffer("inv_freq", inv_freq, persistent=False)
109
110 # Build here to make `torch.jit.trace` work.
111 self._set_cos_sin_cache(
112 seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
113 )
114
115 def _set_cos_sin_cache(self, seq_len, device, dtype):
116 self.max_seq_len_cached = seq_len
117 t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
118
119 freqs = torch.outer(t, self.inv_freq)
120 # Different from paper, but it uses a different permutation in order to obtain the same calculation
121 emb = torch.cat((freqs, freqs), dim=-1)
122 self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
123 self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
124
125 def forward(self, x, seq_len=None):
126 # x: [bs, num_attention_heads, seq_len, head_size]
127 if seq_len > self.max_seq_len_cached:
128 self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
129
130 return (
131 self.cos_cached[:seq_len].to(dtype=x.dtype),
132 self.sin_cached[:seq_len].to(dtype=x.dtype),
133 )
134
135
136 # Copied from transformers.models.llama.modeling_llama.rotate_half
137 def rotate_half(x):
138 """Rotates half the hidden dims of the input."""
139 x1 = x[..., : x.shape[-1] // 2]
140 x2 = x[..., x.shape[-1] // 2 :]
141 return torch.cat((-x2, x1), dim=-1)
142
143
144 # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
145 def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
146 """Applies Rotary Position Embedding to the query and key tensors.
147
148 Args:
149 q (`torch.Tensor`): The query tensor.
150 k (`torch.Tensor`): The key tensor.
151 cos (`torch.Tensor`): The cosine part of the rotary embedding.
152 sin (`torch.Tensor`): The sine part of the rotary embedding.
153 position_ids (`torch.Tensor`):
154 The position indices of the tokens corresponding to the query and key tensors. For example, this can be
155 used to pass offsetted position ids when working with a KV-cache.
156 unsqueeze_dim (`int`, *optional*, defaults to 1):
157 The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
158 sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
159 that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
160 k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
161 cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
162 the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
163 Returns:
164 `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
165 """
166 cos = cos[position_ids].unsqueeze(unsqueeze_dim)
167 sin = sin[position_ids].unsqueeze(unsqueeze_dim)
168 q_embed = (q * cos) + (rotate_half(q) * sin)
169 k_embed = (k * cos) + (rotate_half(k) * sin)
170 return q_embed, k_embed
171
172
173 # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
174 class Qwen2MLP(nn.Module):
175 def __init__(self, config):
176 super().__init__()
177 self.config = config
178 self.hidden_size = config.hidden_size
179 self.intermediate_size = config.intermediate_size
180 self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
181 self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
182 self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
183 self.act_fn = ACT2FN[config.hidden_act]
184
185 def forward(self, x):
186 return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
187
188
189 # Copied from transformers.models.llama.modeling_llama.repeat_kv
190 def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
191 """
192 This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
193 num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
194 """
195 batch, num_key_value_heads, slen, head_dim = hidden_states.shape
196 if n_rep == 1:
197 return hidden_states
198 hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
199 return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
200
201
202 class Qwen2Attention(nn.Module):
203 """
204 Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
205 and "Generating Long Sequences with Sparse Transformers".
206 """
207
208 def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
209 super().__init__()
210 self.config = config
211 self.layer_idx = layer_idx
212 if layer_idx is None:
213 logger.warning_once(
214 f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
215 "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
216 "when creating this class."
217 )
218
219 self.hidden_size = config.hidden_size
220 self.num_heads = config.num_attention_heads
221 self.head_dim = self.hidden_size // self.num_heads
222 self.num_key_value_heads = config.num_key_value_heads
223 self.num_key_value_groups = self.num_heads // self.num_key_value_heads
224 self.max_position_embeddings = config.max_position_embeddings
225 self.rope_theta = config.rope_theta
226 self.is_causal = True
227 self.attention_dropout = config.attention_dropout
228
229 if (self.head_dim * self.num_heads) != self.hidden_size:
230 raise ValueError(
231 f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
232 f" and `num_heads`: {self.num_heads})."
233 )
234 self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
235 self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
236 self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
237 self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
238
239 self.rotary_emb = Qwen2RotaryEmbedding(
240 self.head_dim,
241 max_position_embeddings=self.max_position_embeddings,
242 base=self.rope_theta,
243 )
244
245 def forward(
246 self,
247 hidden_states: torch.Tensor,
248 attention_mask: Optional[torch.Tensor] = None,
249 position_ids: Optional[torch.LongTensor] = None,
250 past_key_value: Optional[Cache] = None,
251 output_attentions: bool = False,
252 use_cache: bool = False,
253 **kwargs,
254 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
255 if "padding_mask" in kwargs:
256 warnings.warn(
257 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
258 )
259 bsz, q_len, _ = hidden_states.size()
260
261 query_states = self.q_proj(hidden_states)
262 key_states = self.k_proj(hidden_states)
263 value_states = self.v_proj(hidden_states)
264
265 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
266 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
267 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
268
269 kv_seq_len = key_states.shape[-2]
270 if past_key_value is not None:
271 if self.layer_idx is None:
272 raise ValueError(
273 f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
274 "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
275 "with a layer index."
276 )
277 kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
278 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
279 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
280
281 if past_key_value is not None:
282 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
283 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
284
285 # repeat k/v heads if n_kv_heads < n_heads
286 key_states = repeat_kv(key_states, self.num_key_value_groups)
287 value_states = repeat_kv(value_states, self.num_key_value_groups)
288
289 attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
290
291 if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
292 raise ValueError(
293 f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
294 f" {attn_weights.size()}"
295 )
296
297 if attention_mask is not None:
298 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
299 raise ValueError(
300 f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
301 )
302
303 attn_weights = attn_weights + attention_mask
304
305 # upcast attention to fp32
306 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
307 attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
308 attn_output = torch.matmul(attn_weights, value_states)
309
310 if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
311 raise ValueError(
312 f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
313 f" {attn_output.size()}"
314 )
315
316 attn_output = attn_output.transpose(1, 2).contiguous()
317 attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
318
319 attn_output = self.o_proj(attn_output)
320
321 if not output_attentions:
322 attn_weights = None
323
324 return attn_output, attn_weights, past_key_value
325
326
327 class Qwen2FlashAttention2(Qwen2Attention):
328 """
329 Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
330 as the weights of the module stays untouched. The only required change would be on the forward pass
331 where it needs to correctly call the public API of flash attention and deal with padding tokens
332 in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
333 config.max_window_layers layers.
334 """
335
336 # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
337 def __init__(self, *args, **kwargs):
338 super().__init__(*args, **kwargs)
339
340 # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
341 # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
342 # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
343 self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
344
345 def forward(
346 self,
347 hidden_states: torch.Tensor,
348 attention_mask: Optional[torch.Tensor] = None,
349 position_ids: Optional[torch.LongTensor] = None,
350 past_key_value: Optional[Cache] = None,
351 output_attentions: bool = False,
352 use_cache: bool = False,
353 is_causal: bool = False,
354 **kwargs,
355 ):
356 if "padding_mask" in kwargs:
357 warnings.warn(
358 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
359 )
360
361 # overwrite attention_mask with padding_mask
362 attention_mask = kwargs.pop("padding_mask")
363 bsz, q_len, _ = hidden_states.size()
364
365 query_states = self.q_proj(hidden_states)
366 key_states = self.k_proj(hidden_states)
367 value_states = self.v_proj(hidden_states)
368
369 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
370 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
371 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
372
373 kv_seq_len = key_states.shape[-2]
374 if past_key_value is not None:
375 if self.layer_idx is None:
376 raise ValueError(
377 f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
378 "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
379 "with a layer index."
380 )
381 kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
382
383 # Because the input can be padded, the absolute sequence length depends on the max position id.
384 rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
385 cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
386
387 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
388
389 use_sliding_windows = (
390 _flash_supports_window_size
391 and getattr(self.config, "sliding_window", None) is not None
392 and kv_seq_len > self.config.sliding_window
393 and self.config.use_sliding_window
394 )
395
396 if not _flash_supports_window_size:
397 logger.warning_once(
398 "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
399 " make sure to upgrade flash-attn library."
400 )
401
402 if past_key_value is not None:
403 # Activate slicing cache only if the config has a value `sliding_windows` attribute
404 cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
405 if (
406 getattr(self.config, "sliding_window", None) is not None
407 and kv_seq_len > self.config.sliding_window
408 and cache_has_contents
409 ):
410 slicing_tokens = 1 - self.config.sliding_window
411
412 past_key = past_key_value[self.layer_idx][0]
413 past_value = past_key_value[self.layer_idx][1]
414
415 past_key = past_key[:, :, slicing_tokens:, :].contiguous()
416 past_value = past_value[:, :, slicing_tokens:, :].contiguous()
417
418 if past_key.shape[-2] != self.config.sliding_window - 1:
419 raise ValueError(
420 f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
421 f" {past_key.shape}"
422 )
423
424 if attention_mask is not None:
425 attention_mask = attention_mask[:, slicing_tokens:]
426 attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
427
428 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
429 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
430
431 # repeat k/v heads if n_kv_heads < n_heads
432 key_states = repeat_kv(key_states, self.num_key_value_groups)
433 value_states = repeat_kv(value_states, self.num_key_value_groups)
434 dropout_rate = 0.0 if not self.training else self.attention_dropout
435
436 # In PEFT, usually we cast the layer norms in float32 for training stability reasons
437 # therefore the input hidden states gets silently casted in float32. Hence, we need
438 # cast them back in float16 just to be sure everything works as expected.
439 input_dtype = query_states.dtype
440 if input_dtype == torch.float32:
441 if torch.is_autocast_enabled():
442 target_dtype = torch.get_autocast_gpu_dtype()
443 # Handle the case where the model is quantized
444 elif hasattr(self.config, "_pre_quantization_dtype"):
445 target_dtype = self.config._pre_quantization_dtype
446 else:
447 target_dtype = self.q_proj.weight.dtype
448
449 logger.warning_once(
450 f"The input hidden states seems to be silently casted in float32, this might be related to"
451 f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
452 f" {target_dtype}."
453 )
454
455 query_states = query_states.to(target_dtype)
456 key_states = key_states.to(target_dtype)
457 value_states = value_states.to(target_dtype)
458
459 # Reashape to the expected shape for Flash Attention
460 query_states = query_states.transpose(1, 2)
461 key_states = key_states.transpose(1, 2)
462 value_states = value_states.transpose(1, 2)
463
464 attn_output = self._flash_attention_forward(
465 query_states,
466 key_states,
467 value_states,
468 attention_mask,
469 q_len,
470 dropout=dropout_rate,
471 use_sliding_windows=use_sliding_windows,
472 is_causal=is_causal
473 )
474
475 attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
476 attn_output = self.o_proj(attn_output)
477
478 if not output_attentions:
479 attn_weights = None
480
481 return attn_output, attn_weights, past_key_value
482
483 def _flash_attention_forward(
484 self,
485 query_states,
486 key_states,
487 value_states,
488 attention_mask,
489 query_length,
490 dropout=0.0,
491 softmax_scale=None,
492 use_sliding_windows=False,
493 is_causal=True,
494 ):
495 """
496 Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
497 first unpad the input, then computes the attention scores and pad the final attention scores.
498
499 Args:
500 query_states (`torch.Tensor`):
501 Input query states to be passed to Flash Attention API
502 key_states (`torch.Tensor`):
503 Input key states to be passed to Flash Attention API
504 value_states (`torch.Tensor`):
505 Input value states to be passed to Flash Attention API
506 attention_mask (`torch.Tensor`):
507 The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
508 position of padding tokens and 1 for the position of non-padding tokens.
509 dropout (`int`, *optional*):
510 Attention dropout
511 softmax_scale (`float`, *optional*):
512 The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
513 use_sliding_windows (`bool`, *optional*):
514 Whether to activate sliding window attention.
515 """
516 if not self._flash_attn_uses_top_left_mask:
517 causal = is_causal
518 else:
519 # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
520 causal = is_causal and query_length != 1
521
522 # Decide whether to use SWA or not by layer index.
523 if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
524 use_sliding_windows = False
525
526 # Contains at least one padding token in the sequence
527 if attention_mask is not None:
528 batch_size = query_states.shape[0]
529 query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
530 query_states, key_states, value_states, attention_mask, query_length
531 )
532
533 cu_seqlens_q, cu_seqlens_k = cu_seq_lens
534 max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
535
536 if not use_sliding_windows:
537 attn_output_unpad = flash_attn_varlen_func(
538 query_states,
539 key_states,
540 value_states,
541 cu_seqlens_q=cu_seqlens_q,
542 cu_seqlens_k=cu_seqlens_k,
543 max_seqlen_q=max_seqlen_in_batch_q,
544 max_seqlen_k=max_seqlen_in_batch_k,
545 dropout_p=dropout,
546 softmax_scale=softmax_scale,
547 causal=causal,
548 )
549 else:
550 attn_output_unpad = flash_attn_varlen_func(
551 query_states,
552 key_states,
553 value_states,
554 cu_seqlens_q=cu_seqlens_q,
555 cu_seqlens_k=cu_seqlens_k,
556 max_seqlen_q=max_seqlen_in_batch_q,
557 max_seqlen_k=max_seqlen_in_batch_k,
558 dropout_p=dropout,
559 softmax_scale=softmax_scale,
560 causal=causal,
561 window_size=(self.config.sliding_window, self.config.sliding_window),
562 )
563
564 attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
565 else:
566 if not use_sliding_windows:
567 attn_output = flash_attn_func(
568 query_states,
569 key_states,
570 value_states,
571 dropout,
572 softmax_scale=softmax_scale,
573 causal=causal,
574 )
575 else:
576 attn_output = flash_attn_func(
577 query_states,
578 key_states,
579 value_states,
580 dropout,
581 softmax_scale=softmax_scale,
582 causal=causal,
583 window_size=(self.config.sliding_window, self.config.sliding_window),
584 )
585
586 return attn_output
587
588 # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
589 def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
590 batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
591
592 # On the first iteration we need to properly re-create the padding mask
593 # by slicing it on the proper place
594 if kv_seq_len != attention_mask.shape[-1]:
595 attention_mask_num_tokens = attention_mask.shape[-1]
596 attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
597
598 indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
599
600 key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
601 value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
602
603 if query_length == kv_seq_len:
604 query_layer = index_first_axis(
605 query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
606 )
607 cu_seqlens_q = cu_seqlens_k
608 max_seqlen_in_batch_q = max_seqlen_in_batch_k
609 indices_q = indices_k
610 elif query_length == 1:
611 max_seqlen_in_batch_q = 1
612 cu_seqlens_q = torch.arange(
613 batch_size + 1, dtype=torch.int32, device=query_layer.device
614 ) # There is a memcpy here, that is very bad.
615 indices_q = cu_seqlens_q[:-1]
616 query_layer = query_layer.squeeze(1)
617 else:
618 # The -q_len: slice assumes left padding.
619 attention_mask = attention_mask[:, -query_length:]
620 query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
621
622 return (
623 query_layer,
624 key_layer,
625 value_layer,
626 indices_q,
627 (cu_seqlens_q, cu_seqlens_k),
628 (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
629 )
630
631
632 # Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2
633 class Qwen2SdpaAttention(Qwen2Attention):
634 """
635 Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
636 `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
637 SDPA API.
638 """
639
640 # Adapted from Qwen2Attention.forward
641 def forward(
642 self,
643 hidden_states: torch.Tensor,
644 attention_mask: Optional[torch.Tensor] = None,
645 position_ids: Optional[torch.LongTensor] = None,
646 past_key_value: Optional[Cache] = None,
647 output_attentions: bool = False,
648 use_cache: bool = False,
649 is_causal: bool = True,
650 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
651 if output_attentions:
652 # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
653 logger.warning_once(
654 "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
655 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
656 )
657 return super().forward(
658 hidden_states=hidden_states,
659 attention_mask=attention_mask,
660 position_ids=position_ids,
661 past_key_value=past_key_value,
662 output_attentions=output_attentions,
663 use_cache=use_cache,
664 is_causal=is_causal
665 )
666
667 bsz, q_len, _ = hidden_states.size()
668
669 query_states = self.q_proj(hidden_states)
670 key_states = self.k_proj(hidden_states)
671 value_states = self.v_proj(hidden_states)
672
673 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
674 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
675 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
676
677 kv_seq_len = key_states.shape[-2]
678 if past_key_value is not None:
679 kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
680 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
681
682 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
683
684 if past_key_value is not None:
685 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
686 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
687
688 key_states = repeat_kv(key_states, self.num_key_value_groups)
689 value_states = repeat_kv(value_states, self.num_key_value_groups)
690
691 if attention_mask is not None:
692 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
693 raise ValueError(
694 f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
695 )
696
697 # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
698 # Reference: https://github.com/pytorch/pytorch/issues/112577.
699 if query_states.device.type == "cuda" and attention_mask is not None:
700 query_states = query_states.contiguous()
701 key_states = key_states.contiguous()
702 value_states = value_states.contiguous()
703
704 attn_output = torch.nn.functional.scaled_dot_product_attention(
705 query_states,
706 key_states,
707 value_states,
708 attn_mask=attention_mask,
709 dropout_p=self.attention_dropout if self.training else 0.0,
710 # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
711 is_causal=is_causal and attention_mask is None and q_len > 1,
712 )
713
714 attn_output = attn_output.transpose(1, 2).contiguous()
715 attn_output = attn_output.view(bsz, q_len, self.hidden_size)
716
717 attn_output = self.o_proj(attn_output)
718
719 return attn_output, None, past_key_value
720
721
722 QWEN2_ATTENTION_CLASSES = {
723 "eager": Qwen2Attention,
724 "flash_attention_2": Qwen2FlashAttention2,
725 "sdpa": Qwen2SdpaAttention,
726 }
727
728
729 class Qwen2DecoderLayer(nn.Module):
730 def __init__(self, config: Qwen2Config, layer_idx: int):
731 super().__init__()
732 self.hidden_size = config.hidden_size
733
734 if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
735 logger.warning_once(
736 f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
737 "unexpected results may be encountered."
738 )
739 self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
740
741 self.mlp = Qwen2MLP(config)
742 self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
743 self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
744
745 def forward(
746 self,
747 hidden_states: torch.Tensor,
748 attention_mask: Optional[torch.Tensor] = None,
749 position_ids: Optional[torch.LongTensor] = None,
750 past_key_value: Optional[Tuple[torch.Tensor]] = None,
751 output_attentions: Optional[bool] = False,
752 use_cache: Optional[bool] = False,
753 is_causal: Optional[bool] = True,
754 **kwargs,
755 ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
756 if "padding_mask" in kwargs:
757 warnings.warn(
758 "Passing `padding_mask` is deprecated and will be removed in v4.37. "
759 "Please make sure use `attention_mask` instead.`"
760 )
761 """
762 Args:
763 hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
764 attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
765 `(batch, sequence_length)` where padding elements are indicated by 0.
766 output_attentions (`bool`, *optional*):
767 Whether or not to return the attentions tensors of all attention layers. See `attentions` under
768 returned tensors for more detail.
769 use_cache (`bool`, *optional*):
770 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
771 (see `past_key_values`).
772 past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
773 """
774
775 residual = hidden_states
776
777 hidden_states = self.input_layernorm(hidden_states)
778
779 # Self Attention
780 hidden_states, self_attn_weights, present_key_value = self.self_attn(
781 hidden_states=hidden_states,
782 attention_mask=attention_mask,
783 position_ids=position_ids,
784 past_key_value=past_key_value,
785 output_attentions=output_attentions,
786 use_cache=use_cache,
787 is_causal=is_causal,
788 )
789 hidden_states = residual + hidden_states
790
791 # Fully Connected
792 residual = hidden_states
793 hidden_states = self.post_attention_layernorm(hidden_states)
794 hidden_states = self.mlp(hidden_states)
795 hidden_states = residual + hidden_states
796
797 outputs = (hidden_states,)
798
799 if output_attentions:
800 outputs += (self_attn_weights,)
801
802 if use_cache:
803 outputs += (present_key_value,)
804
805 return outputs
806
807
808 QWEN2_START_DOCSTRING = r"""
809 This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
810 library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
811 etc.)
812
813 This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
814 Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
815 and behavior.
816
817 Parameters:
818 config ([`Qwen2Config`]):
819 Model configuration class with all the parameters of the model. Initializing with a config file does not
820 load the weights associated with the model, only the configuration. Check out the
821 [`~PreTrainedModel.from_pretrained`] method to load the model weights.
822 """
823
824
825 @add_start_docstrings(
826 "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
827 QWEN2_START_DOCSTRING,
828 )
829 class Qwen2PreTrainedModel(PreTrainedModel):
830 config_class = Qwen2Config
831 base_model_prefix = "model"
832 supports_gradient_checkpointing = True
833 _no_split_modules = ["Qwen2DecoderLayer"]
834 _skip_keys_device_placement = "past_key_values"
835 _supports_flash_attn_2 = True
836 _supports_sdpa = True
837 _supports_cache_class = True
838
839 def _init_weights(self, module):
840 std = self.config.initializer_range
841 if isinstance(module, nn.Linear):
842 module.weight.data.normal_(mean=0.0, std=std)
843 if module.bias is not None:
844 module.bias.data.zero_()
845 elif isinstance(module, nn.Embedding):
846 module.weight.data.normal_(mean=0.0, std=std)
847 if module.padding_idx is not None:
848 module.weight.data[module.padding_idx].zero_()
849
850
851 QWEN2_INPUTS_DOCSTRING = r"""
852 Args:
853 input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
854 Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
855 it.
856
857 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
858 [`PreTrainedTokenizer.__call__`] for details.
859
860 [What are input IDs?](../glossary#input-ids)
861 attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
862 Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
863
864 - 1 for tokens that are **not masked**,
865 - 0 for tokens that are **masked**.
866
867 [What are attention masks?](../glossary#attention-mask)
868
869 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
870 [`PreTrainedTokenizer.__call__`] for details.
871
872 If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
873 `past_key_values`).
874
875 If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
876 and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
877 information on the default strategy.
878
879 - 1 indicates the head is **not masked**,
880 - 0 indicates the head is **masked**.
881 position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
882 Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
883 config.n_positions - 1]`.
884
885 [What are position IDs?](../glossary#position-ids)
886 past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
887 Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
888 blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
889 returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
890
891 Two formats are allowed:
892 - a [`~cache_utils.Cache`] instance;
893 - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
894 shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
895 cache format.
896
897 The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
898 legacy cache format will be returned.
899
900 If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
901 have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
902 of shape `(batch_size, sequence_length)`.
903 inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
904 Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
905 is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
906 model's internal embedding lookup matrix.
907 use_cache (`bool`, *optional*):
908 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
909 `past_key_values`).
910 output_attentions (`bool`, *optional*):
911 Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
912 tensors for more detail.
913 output_hidden_states (`bool`, *optional*):
914 Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
915 more detail.
916 return_dict (`bool`, *optional*):
917 Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
918 """
919
920
921 @add_start_docstrings(
922 "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
923 QWEN2_START_DOCSTRING,
924 )
925 class Qwen2Model(Qwen2PreTrainedModel):
926 """
927 Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
928
929 Args:
930 config: Qwen2Config
931 """
932
933 def __init__(self, config: Qwen2Config):
934 super().__init__(config)
935 self.padding_idx = config.pad_token_id
936 self.vocab_size = config.vocab_size
937
938 self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
939 self.layers = nn.ModuleList(
940 [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
941 )
942 self._attn_implementation = config._attn_implementation
943 self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
944
945 self.gradient_checkpointing = False
946 # Initialize weights and apply final processing
947 self.post_init()
948
949 def get_input_embeddings(self):
950 return self.embed_tokens
951
952 def set_input_embeddings(self, value):
953 self.embed_tokens = value
954
955 @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
956 def forward(
957 self,
958 input_ids: torch.LongTensor = None,
959 attention_mask: Optional[torch.Tensor] = None,
960 position_ids: Optional[torch.LongTensor] = None,
961 past_key_values: Optional[List[torch.FloatTensor]] = None,
962 inputs_embeds: Optional[torch.FloatTensor] = None,
963 use_cache: Optional[bool] = None,
964 output_attentions: Optional[bool] = None,
965 output_hidden_states: Optional[bool] = None,
966 return_dict: Optional[bool] = None,
967 labels: Optional[torch.LongTensor] = None,
968 is_causal: Optional[bool] = False,
969 ) -> Union[Tuple, BaseModelOutputWithPast]:
970 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
971 output_hidden_states = (
972 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
973 )
974 use_cache = use_cache if use_cache is not None else self.config.use_cache
975
976 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
977
978 # retrieve input_ids and inputs_embeds
979 if input_ids is not None and inputs_embeds is not None:
980 raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
981 elif input_ids is not None:
982 batch_size, seq_length = input_ids.shape
983 elif inputs_embeds is not None:
984 batch_size, seq_length, _ = inputs_embeds.shape
985 else:
986 raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
987
988 if self.gradient_checkpointing and self.training:
989 if use_cache:
990 logger.warning_once(
991 "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
992 )
993 use_cache = False
994
995 past_key_values_length = 0
996
997 if use_cache:
998 use_legacy_cache = not isinstance(past_key_values, Cache)
999 if use_legacy_cache:
1000 past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1001 past_key_values_length = past_key_values.get_usable_length(seq_length)
1002
1003 if position_ids is None:
1004 device = input_ids.device if input_ids is not None else inputs_embeds.device
1005 position_ids = torch.arange(
1006 past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1007 )
1008 position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1009 else:
1010 position_ids = position_ids.view(-1, seq_length).long()
1011
1012 if inputs_embeds is None:
1013 inputs_embeds = self.embed_tokens(input_ids)
1014
1015 if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1016 is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1017 if is_padding_right:
1018 raise ValueError(
1019 "You are attempting to perform batched generation with padding_side='right'"
1020 " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
1021 " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1022 )
1023
1024 if self._attn_implementation == "flash_attention_2":
1025 # 2d mask is passed through the layers
1026 attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1027 elif self._attn_implementation == "sdpa" and not output_attentions:
1028 # output_attentions=True can not be supported when using SDPA, and we fall back on
1029 # the manual implementation that requires a 4D causal mask in all cases.
1030 if is_causal:
1031 attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1032 attention_mask,
1033 (batch_size, seq_length),
1034 inputs_embeds,
1035 past_key_values_length,
1036 )
1037 else:
1038 attention_mask = _prepare_4d_attention_mask_for_sdpa(
1039 attention_mask, inputs_embeds.dtype
1040 )
1041 else:
1042 # 4d mask is passed through the layers
1043 if is_causal:
1044 # Causal mask with -3.3895e+38 where no attention should be
1045 attention_mask = _prepare_4d_causal_attention_mask(
1046 attention_mask,
1047 (batch_size, seq_length),
1048 inputs_embeds,
1049 past_key_values_length,
1050 sliding_window=self.config.sliding_window,
1051 )
1052 else:
1053 # Shape: batch_size, 1, query_length, key_value_length
1054 attention_mask = _prepare_4d_attention_mask(
1055 attention_mask, inputs_embeds.dtype
1056 )
1057
1058 hidden_states = inputs_embeds
1059
1060 # decoder layers
1061 all_hidden_states = () if output_hidden_states else None
1062 all_self_attns = () if output_attentions else None
1063 next_decoder_cache = None
1064
1065 for decoder_layer in self.layers:
1066 if output_hidden_states:
1067 all_hidden_states += (hidden_states,)
1068
1069 if self.gradient_checkpointing and self.training:
1070 layer_outputs = self._gradient_checkpointing_func(
1071 decoder_layer.__call__,
1072 hidden_states,
1073 attention_mask,
1074 position_ids,
1075 past_key_values,
1076 output_attentions,
1077 use_cache,
1078 is_causal,
1079 )
1080 else:
1081 layer_outputs = decoder_layer(
1082 hidden_states,
1083 attention_mask=attention_mask,
1084 position_ids=position_ids,
1085 past_key_value=past_key_values,
1086 output_attentions=output_attentions,
1087 use_cache=use_cache,
1088 is_causal=is_causal,
1089 )
1090
1091 hidden_states = layer_outputs[0]
1092
1093 if use_cache:
1094 next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1095
1096 if output_attentions:
1097 all_self_attns += (layer_outputs[1],)
1098
1099 hidden_states = self.norm(hidden_states)
1100
1101 # add hidden states from the last decoder layer
1102 if output_hidden_states:
1103 all_hidden_states += (hidden_states,)
1104
1105 next_cache = None
1106 if use_cache:
1107 next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1108
1109 if not return_dict:
1110 return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1111 return BaseModelOutputWithPast(
1112 last_hidden_state=hidden_states,
1113 past_key_values=next_cache,
1114 hidden_states=all_hidden_states,
1115 attentions=all_self_attns,
1116 )
1117
1118
1119 class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1120 _tied_weights_keys = ["lm_head.weight"]
1121
1122 def __init__(self, config):
1123 super().__init__(config)
1124 self.model = Qwen2Model(config)
1125 self.vocab_size = config.vocab_size
1126 self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1127
1128 # Initialize weights and apply final processing
1129 self.post_init()
1130
1131 def get_input_embeddings(self):
1132 return self.model.embed_tokens
1133
1134 def set_input_embeddings(self, value):
1135 self.model.embed_tokens = value
1136
1137 def get_output_embeddings(self):
1138 return self.lm_head
1139
1140 def set_output_embeddings(self, new_embeddings):
1141 self.lm_head = new_embeddings
1142
1143 def set_decoder(self, decoder):
1144 self.model = decoder
1145
1146 def get_decoder(self):
1147 return self.model
1148
1149 @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1150 @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1151 def forward(
1152 self,
1153 input_ids: torch.LongTensor = None,
1154 attention_mask: Optional[torch.Tensor] = None,
1155 position_ids: Optional[torch.LongTensor] = None,
1156 past_key_values: Optional[List[torch.FloatTensor]] = None,
1157 inputs_embeds: Optional[torch.FloatTensor] = None,
1158 labels: Optional[torch.LongTensor] = None,
1159 use_cache: Optional[bool] = None,
1160 output_attentions: Optional[bool] = None,
1161 output_hidden_states: Optional[bool] = None,
1162 return_dict: Optional[bool] = None,
1163 is_causal: Optional[bool] = False,
1164 ) -> Union[Tuple, CausalLMOutputWithPast]:
1165 r"""
1166 Args:
1167 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1168 Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1169 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1170 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1171
1172 Returns:
1173
1174 Example:
1175
1176 ```python
1177 >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
1178
1179 >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1180 >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1181
1182 >>> prompt = "Hey, are you conscious? Can you talk to me?"
1183 >>> inputs = tokenizer(prompt, return_tensors="pt")
1184
1185 >>> # Generate
1186 >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1187 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1188 "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1189 ```"""
1190
1191 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1192 output_hidden_states = (
1193 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1194 )
1195 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1196
1197 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1198 outputs = self.model(
1199 input_ids=input_ids,
1200 attention_mask=attention_mask,
1201 position_ids=position_ids,
1202 past_key_values=past_key_values,
1203 inputs_embeds=inputs_embeds,
1204 use_cache=use_cache,
1205 output_attentions=output_attentions,
1206 output_hidden_states=output_hidden_states,
1207 return_dict=return_dict,
1208 is_causal=is_causal,
1209 )
1210
1211 hidden_states = outputs[0]
1212 logits = self.lm_head(hidden_states)
1213 logits = logits.float()
1214
1215 loss = None
1216 if labels is not None:
1217 # Shift so that tokens < n predict n
1218 shift_logits = logits[..., :-1, :].contiguous()
1219 shift_labels = labels[..., 1:].contiguous()
1220 # Flatten the tokens
1221 loss_fct = CrossEntropyLoss()
1222 shift_logits = shift_logits.view(-1, self.config.vocab_size)
1223 shift_labels = shift_labels.view(-1)
1224 # Enable model parallelism
1225 shift_labels = shift_labels.to(shift_logits.device)
1226 loss = loss_fct(shift_logits, shift_labels)
1227
1228 if not return_dict:
1229 output = (logits,) + outputs[1:]
1230 return (loss,) + output if loss is not None else output
1231
1232 return CausalLMOutputWithPast(
1233 loss=loss,
1234 logits=logits,
1235 past_key_values=outputs.past_key_values,
1236 hidden_states=outputs.hidden_states,
1237 attentions=outputs.attentions,
1238 )
1239
1240 def prepare_inputs_for_generation(
1241 self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1242 ):
1243 # Omit tokens covered by past_key_values
1244 if past_key_values is not None:
1245 if isinstance(past_key_values, Cache):
1246 cache_length = past_key_values.get_seq_length()
1247 past_length = past_key_values.seen_tokens
1248 max_cache_length = past_key_values.get_max_length()
1249 else:
1250 cache_length = past_length = past_key_values[0][0].shape[2]
1251 max_cache_length = None
1252
1253 # Keep only the unprocessed tokens:
1254 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1255 # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1256 # input)
1257 if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1258 input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1259 # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1260 # input_ids based on the past_length.
1261 elif past_length < input_ids.shape[1]:
1262 input_ids = input_ids[:, past_length:]
1263 # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1264
1265 # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1266 if (
1267 max_cache_length is not None
1268 and attention_mask is not None
1269 and cache_length + input_ids.shape[1] > max_cache_length
1270 ):
1271 attention_mask = attention_mask[:, -max_cache_length:]
1272
1273 position_ids = kwargs.get("position_ids", None)
1274 if attention_mask is not None and position_ids is None:
1275 # create position_ids on the fly for batch generation
1276 position_ids = attention_mask.long().cumsum(-1) - 1
1277 position_ids.masked_fill_(attention_mask == 0, 1)
1278 if past_key_values:
1279 position_ids = position_ids[:, -input_ids.shape[1] :]
1280
1281 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1282 if inputs_embeds is not None and past_key_values is None:
1283 model_inputs = {"inputs_embeds": inputs_embeds}
1284 else:
1285 model_inputs = {"input_ids": input_ids}
1286
1287 model_inputs.update(
1288 {
1289 "position_ids": position_ids,
1290 "past_key_values": past_key_values,
1291 "use_cache": kwargs.get("use_cache"),
1292 "attention_mask": attention_mask,
1293 }
1294 )
1295 return model_inputs
1296
1297 @staticmethod
1298 def _reorder_cache(past_key_values, beam_idx):
1299 reordered_past = ()
1300 for layer_past in past_key_values:
1301 reordered_past += (
1302 tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1303 )
1304 return reordered_past
1305
1306
1307 @add_start_docstrings(
1308 """
1309 The Qwen2 Model transformer with a sequence classification head on top (linear layer).
1310
1311 [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1312 (e.g. GPT-2) do.
1313
1314 Since it does classification on the last token, it requires to know the position of the last token. If a
1315 `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1316 no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1317 padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1318 each row of the batch).
1319 """,
1320 QWEN2_START_DOCSTRING,
1321 )
1322 class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1323 def __init__(self, config):
1324 super().__init__(config)
1325 self.num_labels = config.num_labels
1326 self.model = Qwen2Model(config)
1327 self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1328
1329 # Initialize weights and apply final processing
1330 self.post_init()
1331
1332 def get_input_embeddings(self):
1333 return self.model.embed_tokens
1334
1335 def set_input_embeddings(self, value):
1336 self.model.embed_tokens = value
1337
1338 @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1339 def forward(
1340 self,
1341 input_ids: torch.LongTensor = None,
1342 attention_mask: Optional[torch.Tensor] = None,
1343 position_ids: Optional[torch.LongTensor] = None,
1344 past_key_values: Optional[List[torch.FloatTensor]] = None,
1345 inputs_embeds: Optional[torch.FloatTensor] = None,
1346 labels: Optional[torch.LongTensor] = None,
1347 use_cache: Optional[bool] = None,
1348 output_attentions: Optional[bool] = None,
1349 output_hidden_states: Optional[bool] = None,
1350 return_dict: Optional[bool] = None,
1351 is_causal: Optional[bool] = True,
1352 ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1353 r"""
1354 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1355 Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1356 config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1357 `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1358 """
1359 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1360
1361 transformer_outputs = self.model(
1362 input_ids,
1363 attention_mask=attention_mask,
1364 position_ids=position_ids,
1365 past_key_values=past_key_values,
1366 inputs_embeds=inputs_embeds,
1367 use_cache=use_cache,
1368 output_attentions=output_attentions,
1369 output_hidden_states=output_hidden_states,
1370 return_dict=return_dict,
1371 is_causal=is_causal,
1372 )
1373 hidden_states = transformer_outputs[0]
1374 logits = self.score(hidden_states)
1375
1376 if input_ids is not None:
1377 batch_size = input_ids.shape[0]
1378 else:
1379 batch_size = inputs_embeds.shape[0]
1380
1381 if self.config.pad_token_id is None and batch_size != 1:
1382 raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1383 if self.config.pad_token_id is None:
1384 sequence_lengths = -1
1385 else:
1386 if input_ids is not None:
1387 # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1388 sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1389 sequence_lengths = sequence_lengths % input_ids.shape[-1]
1390 sequence_lengths = sequence_lengths.to(logits.device)
1391 else:
1392 sequence_lengths = -1
1393
1394 pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1395
1396 loss = None
1397 if labels is not None:
1398 labels = labels.to(logits.device)
1399 if self.config.problem_type is None:
1400 if self.num_labels == 1:
1401 self.config.problem_type = "regression"
1402 elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1403 self.config.problem_type = "single_label_classification"
1404 else:
1405 self.config.problem_type = "multi_label_classification"
1406
1407 if self.config.problem_type == "regression":
1408 loss_fct = MSELoss()
1409 if self.num_labels == 1:
1410 loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1411 else:
1412 loss = loss_fct(pooled_logits, labels)
1413 elif self.config.problem_type == "single_label_classification":
1414 loss_fct = CrossEntropyLoss()
1415 loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1416 elif self.config.problem_type == "multi_label_classification":
1417 loss_fct = BCEWithLogitsLoss()
1418 loss = loss_fct(pooled_logits, labels)
1419 if not return_dict:
1420 output = (pooled_logits,) + transformer_outputs[1:]
1421 return ((loss,) + output) if loss is not None else output
1422
1423 return SequenceClassifierOutputWithPast(
1424 loss=loss,
1425 logits=pooled_logits,
1426 past_key_values=transformer_outputs.past_key_values,
1427 hidden_states=transformer_outputs.hidden_states,
1428 attentions=transformer_outputs.attentions,
1429 )
1430