modeling_qwen2.py
76.6 KB · 1739 lines · python Raw
1 # coding=utf-8
2 # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3 #
4 # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 # and OPT implementations in this library. It has been modified from its
6 # original forms to accommodate minor architectural differences compared
7 # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 #
9 # Licensed under the Apache License, Version 2.0 (the "License");
10 # you may not use this file except in compliance with the License.
11 # You may obtain a copy of the License at
12 #
13 # http://www.apache.org/licenses/LICENSE-2.0
14 #
15 # Unless required by applicable law or agreed to in writing, software
16 # distributed under the License is distributed on an "AS IS" BASIS,
17 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 # See the License for the specific language governing permissions and
19 # limitations under the License.
20 """ PyTorch Qwen2 model."""
21 import inspect
22 import math
23 import copy
24 import warnings
25 from functools import partial
26 from typing import List, Optional, Tuple, Union
27
28 import torch
29 import torch.nn.functional as F
30 import torch.utils.checkpoint
31 from torch import nn
32 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
34 from transformers.activations import ACT2FN
35 from transformers.cache_utils import Cache, DynamicCache
36 from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
37 from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
38 from transformers.modeling_utils import PreTrainedModel
39 from transformers.utils import (
40 add_start_docstrings,
41 add_start_docstrings_to_model_forward,
42 is_flash_attn_2_available,
43 is_flash_attn_greater_or_equal_2_10,
44 logging,
45 replace_return_docstrings,
46 )
47 from .configuration_qwen2 import Qwen2Config
48
49 if is_flash_attn_2_available():
50 from flash_attn import flash_attn_func, flash_attn_varlen_func
51 from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
52
53 _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
54
55
56 logger = logging.get_logger(__name__)
57
58 # Magi Attention Supported
59 _MAGI_AVAILABLE = False
60 try:
61 from magi_attention.functional.flex_flash_attn import flex_flash_attn_func
62 _MAGI_AVAILABLE = True
63 except ImportError:
64 flex_flash_attn_func = None
65
66
67 _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
68 _CONFIG_FOR_DOC = "Qwen2Config"
69
70 QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [
71 "Qwen/Qwen2-7B-beta",
72 # See all Qwen2 models at https://huggingface.co/models?filter=qwen2
73 ]
74
75 from .mask_sdpa_utils import (
76 find_prefix_seq_length_by_pe,
77 update_causal_mask_with_pad_non_visible_2d,
78 update_causal_mask_for_one_gen_window_2d,
79 create_block_diff_mask_by_pe_4d,
80 find_pred_pos_from_input_ids
81 )
82
83 from .mask_magi_utils import build_magi_ranges
84
85 # Copied from transformers.models.llama.modeling_llama._get_unpad_data
86 def _get_unpad_data(attention_mask):
87 seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
88 indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
89 max_seqlen_in_batch = seqlens_in_batch.max().item()
90 cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
91 return (
92 indices,
93 cu_seqlens,
94 max_seqlen_in_batch,
95 )
96
97
98 # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
99 class Qwen2RMSNorm(nn.Module):
100 def __init__(self, hidden_size, eps=1e-6):
101 """
102 Qwen2RMSNorm is equivalent to T5LayerNorm
103 """
104 super().__init__()
105 self.weight = nn.Parameter(torch.ones(hidden_size))
106 self.variance_epsilon = eps
107
108 def forward(self, hidden_states):
109 input_dtype = hidden_states.dtype
110 hidden_states = hidden_states.to(torch.float32)
111 variance = hidden_states.pow(2).mean(-1, keepdim=True)
112 hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
113 return self.weight * hidden_states.to(input_dtype)
114
115
116 # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
117 class Qwen2RotaryEmbedding(nn.Module):
118 def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
119 super().__init__()
120
121 self.dim = dim
122 self.max_position_embeddings = max_position_embeddings
123 self.base = base
124 inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
125 self.register_buffer("inv_freq", inv_freq, persistent=False)
126
127 # Build here to make `torch.jit.trace` work.
128 self._set_cos_sin_cache(
129 seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
130 )
131
132 def _set_cos_sin_cache(self, seq_len, device, dtype):
133 self.max_seq_len_cached = seq_len
134 t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
135
136 freqs = torch.outer(t, self.inv_freq)
137 # Different from paper, but it uses a different permutation in order to obtain the same calculation
138 emb = torch.cat((freqs, freqs), dim=-1)
139 self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
140 self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
141
142 def forward(self, x, seq_len=None):
143 # x: [bs, num_attention_heads, seq_len, head_size]
144 if seq_len > self.max_seq_len_cached:
145 self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
146
147 return (
148 self.cos_cached[:seq_len].to(dtype=x.dtype),
149 self.sin_cached[:seq_len].to(dtype=x.dtype),
150 )
151
152
153 # Copied from transformers.models.llama.modeling_llama.rotate_half
154 def rotate_half(x):
155 """Rotates half the hidden dims of the input."""
156 x1 = x[..., : x.shape[-1] // 2]
157 x2 = x[..., x.shape[-1] // 2 :]
158 return torch.cat((-x2, x1), dim=-1)
159
160
161 # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
162 def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
163 """Applies Rotary Position Embedding to the query and key tensors.
164
165 Args:
166 q (`torch.Tensor`): The query tensor.
167 k (`torch.Tensor`): The key tensor.
168 cos (`torch.Tensor`): The cosine part of the rotary embedding.
169 sin (`torch.Tensor`): The sine part of the rotary embedding.
170 position_ids (`torch.Tensor`):
171 The position indices of the tokens corresponding to the query and key tensors. For example, this can be
172 used to pass offsetted position ids when working with a KV-cache.
173 unsqueeze_dim (`int`, *optional*, defaults to 1):
174 The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
175 sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
176 that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
177 k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
178 cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
179 the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
180 Returns:
181 `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
182 """
183 cos = cos[position_ids].unsqueeze(unsqueeze_dim)
184 sin = sin[position_ids].unsqueeze(unsqueeze_dim)
185 q_embed = (q * cos) + (rotate_half(q) * sin)
186 k_embed = (k * cos) + (rotate_half(k) * sin)
187 return q_embed, k_embed
188
189
190 # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
191 class Qwen2MLP(nn.Module):
192 def __init__(self, config):
193 super().__init__()
194 self.config = config
195 self.hidden_size = config.hidden_size
196 self.intermediate_size = config.intermediate_size
197 self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
198 self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
199 self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
200 self.act_fn = ACT2FN[config.hidden_act]
201
202 def forward(self, x):
203 return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
204
205
206 # Copied from transformers.models.llama.modeling_llama.repeat_kv
207 def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
208 """
209 This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
210 num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
211 """
212 batch, num_key_value_heads, slen, head_dim = hidden_states.shape
213 if n_rep == 1:
214 return hidden_states
215 hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
216 return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
217
218
219 class Qwen2Attention(nn.Module):
220 """
221 Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
222 and "Generating Long Sequences with Sparse Transformers".
223 """
224
225 def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
226 super().__init__()
227 self.config = config
228 self.layer_idx = layer_idx
229 if layer_idx is None:
230 logger.warning_once(
231 f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
232 "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
233 "when creating this class."
234 )
235
236 self.hidden_size = config.hidden_size
237 self.num_heads = config.num_attention_heads
238 self.head_dim = self.hidden_size // self.num_heads
239 self.num_key_value_heads = config.num_key_value_heads
240 self.num_key_value_groups = self.num_heads // self.num_key_value_heads
241 self.max_position_embeddings = config.max_position_embeddings
242 self.rope_theta = config.rope_theta
243 self.is_causal = True
244 self.attention_dropout = config.attention_dropout
245
246 if (self.head_dim * self.num_heads) != self.hidden_size:
247 raise ValueError(
248 f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
249 f" and `num_heads`: {self.num_heads})."
250 )
251 self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
252 self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
253 self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
254 self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
255
256 self.rotary_emb = Qwen2RotaryEmbedding(
257 self.head_dim,
258 max_position_embeddings=self.max_position_embeddings,
259 base=self.rope_theta,
260 )
261
262 def forward(
263 self,
264 hidden_states: torch.Tensor,
265 attention_mask: Optional[torch.Tensor] = None,
266 position_ids: Optional[torch.LongTensor] = None,
267 past_key_value: Optional[Cache] = None,
268 output_attentions: bool = False,
269 use_cache: bool = False,
270 **kwargs,
271 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
272 if "padding_mask" in kwargs:
273 warnings.warn(
274 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
275 )
276 bsz, q_len, _ = hidden_states.size()
277
278 query_states = self.q_proj(hidden_states)
279 key_states = self.k_proj(hidden_states)
280 value_states = self.v_proj(hidden_states)
281
282 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
283 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
284 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
285
286 kv_seq_len = key_states.shape[-2]
287 if past_key_value is not None:
288 if self.layer_idx is None:
289 raise ValueError(
290 f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
291 "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
292 "with a layer index."
293 )
294 kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
295 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
296 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
297
298 if past_key_value is not None:
299 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
300 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
301
302 # repeat k/v heads if n_kv_heads < n_heads
303 key_states = repeat_kv(key_states, self.num_key_value_groups)
304 value_states = repeat_kv(value_states, self.num_key_value_groups)
305
306 attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
307
308 if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
309 raise ValueError(
310 f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
311 f" {attn_weights.size()}"
312 )
313
314 if attention_mask is not None:
315 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
316 raise ValueError(
317 f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
318 )
319
320 attn_weights = attn_weights + attention_mask
321
322 # upcast attention to fp32
323 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
324 attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
325 attn_output = torch.matmul(attn_weights, value_states)
326
327 if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
328 raise ValueError(
329 f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
330 f" {attn_output.size()}"
331 )
332
333 attn_output = attn_output.transpose(1, 2).contiguous()
334 attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
335
336 attn_output = self.o_proj(attn_output)
337
338 if not output_attentions:
339 attn_weights = None
340
341 return attn_output, attn_weights, past_key_value
342
343
344 class Qwen2FlashAttention2(Qwen2Attention):
345 """
346 Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
347 as the weights of the module stays untouched. The only required change would be on the forward pass
348 where it needs to correctly call the public API of flash attention and deal with padding tokens
349 in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
350 config.max_window_layers layers.
351 """
352
353 # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
354 def __init__(self, *args, **kwargs):
355 super().__init__(*args, **kwargs)
356
357 # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
358 # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
359 # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
360 self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
361
362 def forward(
363 self,
364 hidden_states: torch.Tensor,
365 attention_mask: Optional[torch.Tensor] = None,
366 position_ids: Optional[torch.LongTensor] = None,
367 past_key_value: Optional[Cache] = None,
368 output_attentions: bool = False,
369 use_cache: bool = False,
370 **kwargs,
371 ):
372 if "padding_mask" in kwargs:
373 warnings.warn(
374 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
375 )
376
377 # overwrite attention_mask with padding_mask
378 attention_mask = kwargs.pop("padding_mask")
379 bsz, q_len, _ = hidden_states.size()
380
381 query_states = self.q_proj(hidden_states)
382 key_states = self.k_proj(hidden_states)
383 value_states = self.v_proj(hidden_states)
384
385 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
386 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
387 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
388
389 kv_seq_len = key_states.shape[-2]
390 if past_key_value is not None:
391 if self.layer_idx is None:
392 raise ValueError(
393 f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
394 "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
395 "with a layer index."
396 )
397 kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
398
399 # Because the input can be padded, the absolute sequence length depends on the max position id.
400 rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
401 cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
402
403 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
404
405 use_sliding_windows = (
406 _flash_supports_window_size
407 and getattr(self.config, "sliding_window", None) is not None
408 and kv_seq_len > self.config.sliding_window
409 and self.config.use_sliding_window
410 )
411
412 if not _flash_supports_window_size:
413 logger.warning_once(
414 "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
415 " make sure to upgrade flash-attn library."
416 )
417
418 if past_key_value is not None:
419 # Activate slicing cache only if the config has a value `sliding_windows` attribute
420 cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
421 if (
422 getattr(self.config, "sliding_window", None) is not None
423 and kv_seq_len > self.config.sliding_window
424 and cache_has_contents
425 ):
426 slicing_tokens = 1 - self.config.sliding_window
427
428 past_key = past_key_value[self.layer_idx][0]
429 past_value = past_key_value[self.layer_idx][1]
430
431 past_key = past_key[:, :, slicing_tokens:, :].contiguous()
432 past_value = past_value[:, :, slicing_tokens:, :].contiguous()
433
434 if past_key.shape[-2] != self.config.sliding_window - 1:
435 raise ValueError(
436 f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
437 f" {past_key.shape}"
438 )
439
440 if attention_mask is not None:
441 attention_mask = attention_mask[:, slicing_tokens:]
442 attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
443
444 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
445 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
446
447 # repeat k/v heads if n_kv_heads < n_heads
448 key_states = repeat_kv(key_states, self.num_key_value_groups)
449 value_states = repeat_kv(value_states, self.num_key_value_groups)
450 dropout_rate = 0.0 if not self.training else self.attention_dropout
451
452 # In PEFT, usually we cast the layer norms in float32 for training stability reasons
453 # therefore the input hidden states gets silently casted in float32. Hence, we need
454 # cast them back in float16 just to be sure everything works as expected.
455 input_dtype = query_states.dtype
456 if input_dtype == torch.float32:
457 if torch.is_autocast_enabled():
458 target_dtype = torch.get_autocast_gpu_dtype()
459 # Handle the case where the model is quantized
460 elif hasattr(self.config, "_pre_quantization_dtype"):
461 target_dtype = self.config._pre_quantization_dtype
462 else:
463 target_dtype = self.q_proj.weight.dtype
464
465 logger.warning_once(
466 f"The input hidden states seems to be silently casted in float32, this might be related to"
467 f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
468 f" {target_dtype}."
469 )
470
471 query_states = query_states.to(target_dtype)
472 key_states = key_states.to(target_dtype)
473 value_states = value_states.to(target_dtype)
474
475 # Reashape to the expected shape for Flash Attention
476 query_states = query_states.transpose(1, 2)
477 key_states = key_states.transpose(1, 2)
478 value_states = value_states.transpose(1, 2)
479
480 attn_output = self._flash_attention_forward(
481 query_states,
482 key_states,
483 value_states,
484 attention_mask,
485 q_len,
486 dropout=dropout_rate,
487 use_sliding_windows=use_sliding_windows,
488 )
489
490 attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
491 attn_output = self.o_proj(attn_output)
492
493 if not output_attentions:
494 attn_weights = None
495
496 return attn_output, attn_weights, past_key_value
497
498 def _flash_attention_forward(
499 self,
500 query_states,
501 key_states,
502 value_states,
503 attention_mask,
504 query_length,
505 dropout=0.0,
506 softmax_scale=None,
507 use_sliding_windows=False,
508 ):
509 """
510 Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
511 first unpad the input, then computes the attention scores and pad the final attention scores.
512
513 Args:
514 query_states (`torch.Tensor`):
515 Input query states to be passed to Flash Attention API
516 key_states (`torch.Tensor`):
517 Input key states to be passed to Flash Attention API
518 value_states (`torch.Tensor`):
519 Input value states to be passed to Flash Attention API
520 attention_mask (`torch.Tensor`):
521 The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
522 position of padding tokens and 1 for the position of non-padding tokens.
523 dropout (`int`, *optional*):
524 Attention dropout
525 softmax_scale (`float`, *optional*):
526 The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
527 use_sliding_windows (`bool`, *optional*):
528 Whether to activate sliding window attention.
529 """
530 if not self._flash_attn_uses_top_left_mask:
531 causal = self.is_causal
532 else:
533 # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
534 causal = self.is_causal and query_length != 1
535
536 # Decide whether to use SWA or not by layer index.
537 if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
538 use_sliding_windows = False
539
540 # Contains at least one padding token in the sequence
541 if attention_mask is not None:
542 batch_size = query_states.shape[0]
543 query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
544 query_states, key_states, value_states, attention_mask, query_length
545 )
546
547 cu_seqlens_q, cu_seqlens_k = cu_seq_lens
548 max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
549
550 if not use_sliding_windows:
551 attn_output_unpad = flash_attn_varlen_func(
552 query_states,
553 key_states,
554 value_states,
555 cu_seqlens_q=cu_seqlens_q,
556 cu_seqlens_k=cu_seqlens_k,
557 max_seqlen_q=max_seqlen_in_batch_q,
558 max_seqlen_k=max_seqlen_in_batch_k,
559 dropout_p=dropout,
560 softmax_scale=softmax_scale,
561 causal=causal,
562 )
563 else:
564 attn_output_unpad = flash_attn_varlen_func(
565 query_states,
566 key_states,
567 value_states,
568 cu_seqlens_q=cu_seqlens_q,
569 cu_seqlens_k=cu_seqlens_k,
570 max_seqlen_q=max_seqlen_in_batch_q,
571 max_seqlen_k=max_seqlen_in_batch_k,
572 dropout_p=dropout,
573 softmax_scale=softmax_scale,
574 causal=causal,
575 window_size=(self.config.sliding_window, self.config.sliding_window),
576 )
577
578 attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
579 else:
580 if not use_sliding_windows:
581 attn_output = flash_attn_func(
582 query_states,
583 key_states,
584 value_states,
585 dropout,
586 softmax_scale=softmax_scale,
587 causal=causal,
588 )
589 else:
590 attn_output = flash_attn_func(
591 query_states,
592 key_states,
593 value_states,
594 dropout,
595 softmax_scale=softmax_scale,
596 causal=causal,
597 window_size=(self.config.sliding_window, self.config.sliding_window),
598 )
599
600 return attn_output
601
602 # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
603 def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
604 batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
605
606 # On the first iteration we need to properly re-create the padding mask
607 # by slicing it on the proper place
608 if kv_seq_len != attention_mask.shape[-1]:
609 attention_mask_num_tokens = attention_mask.shape[-1]
610 attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
611
612 indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
613
614 key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
615 value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
616
617 if query_length == kv_seq_len:
618 query_layer = index_first_axis(
619 query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
620 )
621 cu_seqlens_q = cu_seqlens_k
622 max_seqlen_in_batch_q = max_seqlen_in_batch_k
623 indices_q = indices_k
624 elif query_length == 1:
625 max_seqlen_in_batch_q = 1
626 cu_seqlens_q = torch.arange(
627 batch_size + 1, dtype=torch.int32, device=query_layer.device
628 ) # There is a memcpy here, that is very bad.
629 indices_q = cu_seqlens_q[:-1]
630 query_layer = query_layer.squeeze(1)
631 else:
632 # The -q_len: slice assumes left padding.
633 attention_mask = attention_mask[:, -query_length:]
634 query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
635
636 return (
637 query_layer,
638 key_layer,
639 value_layer,
640 indices_q,
641 (cu_seqlens_q, cu_seqlens_k),
642 (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
643 )
644
645
646 # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2
647 class Qwen2SdpaAttention(Qwen2Attention):
648 """
649 Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
650 `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
651 SDPA API.
652 """
653
654 # Adapted from Qwen2Attention.forward
655 def forward(
656 self,
657 hidden_states: torch.Tensor,
658 attention_mask: Optional[torch.Tensor] = None,
659 position_ids: Optional[torch.LongTensor] = None,
660 past_key_value: Optional[Cache] = None,
661 output_attentions: bool = False,
662 use_cache: bool = False,
663 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
664 if output_attentions:
665 # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
666 logger.warning_once(
667 "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
668 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
669 )
670 return super().forward(
671 hidden_states=hidden_states,
672 attention_mask=attention_mask,
673 position_ids=position_ids,
674 past_key_value=past_key_value,
675 output_attentions=output_attentions,
676 use_cache=use_cache,
677 )
678
679 bsz, q_len, _ = hidden_states.size()
680
681 query_states = self.q_proj(hidden_states)
682 key_states = self.k_proj(hidden_states)
683 value_states = self.v_proj(hidden_states)
684
685 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
686 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
687 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
688
689 kv_seq_len = key_states.shape[-2]
690 if past_key_value is not None:
691 kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
692 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
693
694 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
695
696 if past_key_value is not None:
697 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
698 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
699
700 key_states = repeat_kv(key_states, self.num_key_value_groups)
701 value_states = repeat_kv(value_states, self.num_key_value_groups)
702
703 if attention_mask is not None:
704 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
705 raise ValueError(
706 f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
707 )
708
709 # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
710 # Reference: https://github.com/pytorch/pytorch/issues/112577.
711 if query_states.device.type == "cuda" and attention_mask is not None:
712 query_states = query_states.contiguous()
713 key_states = key_states.contiguous()
714 value_states = value_states.contiguous()
715
716 attn_output = torch.nn.functional.scaled_dot_product_attention(
717 query_states,
718 key_states,
719 value_states,
720 attn_mask=attention_mask,
721 dropout_p=self.attention_dropout if self.training else 0.0,
722 is_causal=False,
723 )
724
725 attn_output = attn_output.transpose(1, 2).contiguous()
726 attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
727
728 attn_output = self.o_proj(attn_output)
729
730 return attn_output, None, past_key_value
731
732
733 class Qwen2SdpaAttentionGqa(Qwen2Attention):
734 """
735 Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
736 `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
737 SDPA API.
738 """
739
740 # Adapted from Qwen2Attention.forward
741 def forward(
742 self,
743 hidden_states: torch.Tensor,
744 attention_mask: Optional[torch.Tensor] = None,
745 position_ids: Optional[torch.LongTensor] = None,
746 past_key_value: Optional[Cache] = None,
747 output_attentions: bool = False,
748 use_cache: bool = False,
749 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
750 if output_attentions:
751 # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
752 logger.warning_once(
753 "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
754 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
755 )
756 return super().forward(
757 hidden_states=hidden_states,
758 attention_mask=attention_mask,
759 position_ids=position_ids,
760 past_key_value=past_key_value,
761 output_attentions=output_attentions,
762 use_cache=use_cache,
763 )
764
765 bsz, q_len, _ = hidden_states.size()
766
767 query_states = self.q_proj(hidden_states)
768 key_states = self.k_proj(hidden_states)
769 value_states = self.v_proj(hidden_states)
770
771 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
772 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
773 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
774
775 kv_seq_len = key_states.shape[-2]
776 if past_key_value is not None:
777 kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
778 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
779
780 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
781
782 if past_key_value is not None:
783 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
784 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
785
786 # key_states = repeat_kv(key_states, self.num_key_value_groups)
787 # value_states = repeat_kv(value_states, self.num_key_value_groups)
788
789 if attention_mask is not None:
790 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
791 raise ValueError(
792 f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
793 )
794
795 # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
796 # Reference: https://github.com/pytorch/pytorch/issues/112577.
797 if query_states.device.type == "cuda" and attention_mask is not None:
798 query_states = query_states.contiguous()
799 key_states = key_states.contiguous()
800 value_states = value_states.contiguous()
801
802 with torch.backends.cuda.sdp_kernel(enable_flash=True,
803 enable_math=True,
804 enable_mem_efficient=False):
805
806 attn_output = torch.nn.functional.scaled_dot_product_attention(
807 query_states,
808 key_states,
809 value_states,
810 attn_mask=attention_mask,
811 enable_gqa=True,
812 dropout_p=self.attention_dropout if self.training else 0.0,
813 is_causal=False,
814 )
815
816 attn_output = attn_output.transpose(1, 2).contiguous()
817 attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
818
819 attn_output = self.o_proj(attn_output)
820
821 return attn_output, None, past_key_value
822
823
824 class Qwen2MagiAttention(Qwen2Attention):
825 """
826 Qwen2 attention using MagiAttention for efficient training with MTP packing support.
827
828 MagiAttention uses range-based sparse attention patterns:
829 - q_ranges/k_ranges define which query/key ranges attend to each other
830 - attn_type_map specifies causal(1) or full(0) attention for each range pair
831 """
832
833 def __init__(self, *args, **kwargs):
834 super().__init__(*args, **kwargs)
835 if not _MAGI_AVAILABLE:
836 raise ImportError(
837 "magi_attention is not installed. Install with: pip install magi-attention"
838 )
839 self.softmax_scale = self.head_dim ** -0.5
840
841 def forward(
842 self,
843 hidden_states: torch.Tensor,
844 attention_mask: Optional[dict] = None, # magi_plan dict
845 position_ids: Optional[torch.LongTensor] = None,
846 past_key_value: Optional[Cache] = None,
847 output_attentions: bool = False,
848 use_cache: bool = False,
849 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
850 if output_attentions:
851 raise NotImplementedError('MagiAttention does not support output_attentions=True')
852
853 bsz, q_len, _ = hidden_states.size()
854 assert bsz == 1, "MagiAttention only supports batch_size=1 (use packing instead)"
855
856 query_states = self.q_proj(hidden_states)
857 key_states = self.k_proj(hidden_states)
858 value_states = self.v_proj(hidden_states)
859
860 # Magi expects [T, H, D] format (no batch dimension)
861 query_states = query_states.view(q_len, self.num_heads, self.head_dim)
862 key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
863 value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
864
865 kv_seq_len = q_len
866 if past_key_value is not None:
867 if self.layer_idx is None:
868 raise ValueError(
869 f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
870 "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
871 "with a layer index."
872 )
873 kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
874
875 cos, sin = self.rotary_emb(value_states.unsqueeze(0).transpose(1, 2), seq_len=kv_seq_len)
876
877 # Apply RoPE: need [B, H, L, D] format for apply_rotary_pos_emb
878 q_for_rope = query_states.unsqueeze(0).transpose(1, 2) # [1, H, L, D]
879 k_for_rope = key_states.unsqueeze(0).transpose(1, 2) # [1, Hkv, L, D]
880 q_for_rope, k_for_rope = apply_rotary_pos_emb(q_for_rope, k_for_rope, cos, sin, position_ids)
881
882 # Back to [T, H, D]
883 query_states = q_for_rope.squeeze(0).transpose(0, 1).contiguous() # [L, H, D]
884 key_states = k_for_rope.squeeze(0).transpose(0, 1).contiguous() # [L, Hkv, D]
885
886 if past_key_value is not None:
887 cache_kwargs = {"sin": sin, "cos": cos}
888 # Note: Magi doesn't support KV cache in training, this is for potential future use
889 key_states_4d = key_states.unsqueeze(0).transpose(1, 2)
890 value_states_4d = value_states.unsqueeze(0).transpose(1, 2)
891 key_states_4d, value_states_4d = past_key_value.update(
892 key_states_4d, value_states_4d, self.layer_idx, cache_kwargs
893 )
894 key_states = key_states_4d.squeeze(0).transpose(0, 1).contiguous()
895 value_states = value_states_4d.squeeze(0).transpose(0, 1).contiguous()
896
897 # Run Magi Attention
898 # attention_mask is a magi_plan dict with q_ranges, k_ranges, attn_type_map, etc.
899
900 attn_output, _ = flex_flash_attn_func(
901 query_states.contiguous(),
902 key_states.contiguous(),
903 value_states.contiguous(),
904 q_ranges=attention_mask["q_ranges"],
905 k_ranges=attention_mask["k_ranges"],
906 attn_type_map=attention_mask["attn_type_map"],
907 softmax_scale=self.softmax_scale,
908 softcap=0.0,
909 deterministic=False,
910 ) # [T, H, D]
911
912 # Reshape to [B, L, H*D]
913 attn_output = attn_output.view(1, q_len, self.hidden_size)
914 attn_output = self.o_proj(attn_output)
915
916 return attn_output, None, past_key_value
917
918
919 QWEN2_ATTENTION_CLASSES = {
920 "eager": Qwen2Attention,
921 "flash_attention_2": Qwen2FlashAttention2,
922 "sdpa": Qwen2SdpaAttention,
923 "magi": Qwen2MagiAttention,
924 }
925
926
927 class Qwen2DecoderLayer(nn.Module):
928 def __init__(self, config: Qwen2Config, layer_idx: int):
929 super().__init__()
930 self.hidden_size = config.hidden_size
931
932 if config._attn_implementation == 'magi' and not _MAGI_AVAILABLE:
933 if is_flash_attn_2_available():
934 logger.warning_once(
935 'magi_attention not available, falling back to flash_attention_2'
936 )
937 config._attn_implementation = 'flash_attention_2'
938 else:
939 logger.warning_once(
940 'magi_attention not available, falling back to sdpa'
941 )
942 config._attn_implementation = 'sdpa'
943 if config._attn_implementation == 'flash_attention_2' and not is_flash_attn_2_available():
944 logger.warning_once(
945 'flash_attn is not available, falling back to sdpa'
946 )
947 config._attn_implementation = 'sdpa'
948
949 self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
950
951 self.mlp = Qwen2MLP(config)
952 self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
953 self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
954
955 def forward(
956 self,
957 hidden_states: torch.Tensor,
958 attention_mask: Optional[torch.Tensor] = None,
959 position_ids: Optional[torch.LongTensor] = None,
960 past_key_value: Optional[Tuple[torch.Tensor]] = None,
961 output_attentions: Optional[bool] = False,
962 use_cache: Optional[bool] = False,
963 **kwargs,
964 ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
965 if "padding_mask" in kwargs:
966 warnings.warn(
967 "Passing `padding_mask` is deprecated and will be removed in v4.37. "
968 "Please make sure use `attention_mask` instead.`"
969 )
970 """
971 Args:
972 hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
973 attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
974 `(batch, sequence_length)` where padding elements are indicated by 0.
975 output_attentions (`bool`, *optional*):
976 Whether or not to return the attentions tensors of all attention layers. See `attentions` under
977 returned tensors for more detail.
978 use_cache (`bool`, *optional*):
979 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
980 (see `past_key_values`).
981 past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
982 """
983
984 residual = hidden_states
985
986 hidden_states = self.input_layernorm(hidden_states)
987
988 # Self Attention
989 hidden_states, self_attn_weights, present_key_value = self.self_attn(
990 hidden_states=hidden_states,
991 attention_mask=attention_mask,
992 position_ids=position_ids,
993 past_key_value=past_key_value,
994 output_attentions=output_attentions,
995 use_cache=use_cache,
996 )
997 hidden_states = residual + hidden_states
998
999 # Fully Connected
1000 residual = hidden_states
1001 hidden_states = self.post_attention_layernorm(hidden_states)
1002 hidden_states = self.mlp(hidden_states)
1003 hidden_states = residual + hidden_states
1004
1005 outputs = (hidden_states,)
1006
1007 if output_attentions:
1008 outputs += (self_attn_weights,)
1009
1010 if use_cache:
1011 outputs += (present_key_value,)
1012
1013 return outputs
1014
1015
1016 QWEN2_START_DOCSTRING = r"""
1017 This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1018 library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1019 etc.)
1020
1021 This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1022 Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1023 and behavior.
1024
1025 Parameters:
1026 config ([`Qwen2Config`]):
1027 Model configuration class with all the parameters of the model. Initializing with a config file does not
1028 load the weights associated with the model, only the configuration. Check out the
1029 [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1030 """
1031
1032
1033 @add_start_docstrings(
1034 "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
1035 QWEN2_START_DOCSTRING,
1036 )
1037 class Qwen2PreTrainedModel(PreTrainedModel):
1038 config_class = Qwen2Config
1039 base_model_prefix = "model"
1040 supports_gradient_checkpointing = True
1041 _no_split_modules = ["Qwen2DecoderLayer"]
1042 _skip_keys_device_placement = "past_key_values"
1043 _supports_flash_attn_2 = True
1044 _supports_sdpa = True
1045 _supports_cache_class = True
1046
1047 @classmethod
1048 def _autoset_attn_implementation(cls, config, *args, **kwargs):
1049 if getattr(config, '_attn_implementation', None) == 'magi':
1050 return config
1051 return super()._autoset_attn_implementation(config, *args, **kwargs)
1052
1053 def _check_and_adjust_attn_implementation(self, attn_implementation, is_init_check=False):
1054 if attn_implementation == "magi":
1055 return "magi"
1056 return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check)
1057
1058 def _init_weights(self, module):
1059 std = self.config.initializer_range
1060 if isinstance(module, nn.Linear):
1061 module.weight.data.normal_(mean=0.0, std=std)
1062 if module.bias is not None:
1063 module.bias.data.zero_()
1064 elif isinstance(module, nn.Embedding):
1065 module.weight.data.normal_(mean=0.0, std=std)
1066 if module.padding_idx is not None:
1067 module.weight.data[module.padding_idx].zero_()
1068
1069
1070 QWEN2_INPUTS_DOCSTRING = r"""
1071 Args:
1072 input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1073 Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1074 it.
1075
1076 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1077 [`PreTrainedTokenizer.__call__`] for details.
1078
1079 [What are input IDs?](../glossary#input-ids)
1080 attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1081 Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1082
1083 - 1 for tokens that are **not masked**,
1084 - 0 for tokens that are **masked**.
1085
1086 [What are attention masks?](../glossary#attention-mask)
1087
1088 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1089 [`PreTrainedTokenizer.__call__`] for details.
1090
1091 If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1092 `past_key_values`).
1093
1094 If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1095 and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1096 information on the default strategy.
1097
1098 - 1 indicates the head is **not masked**,
1099 - 0 indicates the head is **masked**.
1100 position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1101 Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1102 config.n_positions - 1]`.
1103
1104 [What are position IDs?](../glossary#position-ids)
1105 past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1106 Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1107 blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1108 returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1109
1110 Two formats are allowed:
1111 - a [`~cache_utils.Cache`] instance;
1112 - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1113 shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1114 cache format.
1115
1116 The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1117 legacy cache format will be returned.
1118
1119 If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1120 have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1121 of shape `(batch_size, sequence_length)`.
1122 inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1123 Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1124 is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1125 model's internal embedding lookup matrix.
1126 use_cache (`bool`, *optional*):
1127 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1128 `past_key_values`).
1129 output_attentions (`bool`, *optional*):
1130 Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1131 tensors for more detail.
1132 output_hidden_states (`bool`, *optional*):
1133 Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1134 more detail.
1135 return_dict (`bool`, *optional*):
1136 Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1137 """
1138
1139
1140 @add_start_docstrings(
1141 "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
1142 QWEN2_START_DOCSTRING,
1143 )
1144 class Qwen2Model(Qwen2PreTrainedModel):
1145 """
1146 Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
1147
1148 Args:
1149 config: Qwen2Config
1150 """
1151
1152 def __init__(self, config: Qwen2Config):
1153 super().__init__(config)
1154 self.padding_idx = config.pad_token_id
1155 self.vocab_size = config.vocab_size
1156
1157 self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1158 self.layers = nn.ModuleList(
1159 [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1160 )
1161 self._attn_implementation = config._attn_implementation
1162 self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1163
1164 self.gradient_checkpointing = False
1165 # Initialize weights and apply final processing
1166 self.post_init()
1167
1168 self.block_size = getattr(config, 'block_size', 6)
1169 self.causal_attn = getattr(config, 'causal_attn', False)
1170 self.text_mask_token_id = getattr(config, 'text_mask_token_id', 151676)
1171
1172
1173 def get_input_embeddings(self):
1174 return self.embed_tokens
1175
1176 def set_input_embeddings(self, value):
1177 self.embed_tokens = value
1178
1179 def image_processing(self, input_ids, visual_features, image_token_index):
1180 if visual_features is not None:
1181 input_embeds = self.get_input_embeddings()(input_ids)
1182 B, N, C = input_embeds.shape
1183 input_embeds = input_embeds.reshape(B * N, C)
1184
1185 input_ids = input_ids.reshape(B * N)
1186 selected = (input_ids == image_token_index)
1187 assert selected.sum() != 0
1188 input_embeds[selected] = visual_features.reshape(-1, C).to(input_embeds.device)
1189 input_embeds = input_embeds.reshape(B, N, C)
1190 else:
1191 input_embeds = self.get_input_embeddings()(input_ids)
1192 return input_embeds
1193
1194 @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1195 def forward(
1196 self,
1197 input_ids: torch.LongTensor = None,
1198 visual_features: Optional[torch.FloatTensor] = None,
1199 image_token_index: int = None,
1200 attention_mask: Optional[torch.Tensor] = None,
1201 position_ids: Optional[torch.LongTensor] = None,
1202 past_key_values: Optional[List[torch.FloatTensor]] = None,
1203 inputs_embeds: Optional[torch.FloatTensor] = None,
1204 use_cache: Optional[bool] = None,
1205 output_attentions: Optional[bool] = None,
1206 output_hidden_states: Optional[bool] = None,
1207 return_dict: Optional[bool] = None,
1208 ) -> Union[Tuple, BaseModelOutputWithPast]:
1209 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1210 output_hidden_states = (
1211 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1212 )
1213 use_cache = use_cache if use_cache is not None else self.config.use_cache
1214
1215 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1216
1217 # retrieve input_ids and inputs_embeds
1218 if input_ids is not None and inputs_embeds is not None:
1219 raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1220 elif input_ids is not None:
1221 batch_size, seq_length = input_ids.shape
1222 elif inputs_embeds is not None:
1223 batch_size, seq_length, _ = inputs_embeds.shape
1224 else:
1225 raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1226
1227 if self.gradient_checkpointing and self.training:
1228 if use_cache:
1229 logger.warning_once(
1230 "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1231 )
1232 use_cache = False
1233
1234 past_key_values_length = 0
1235
1236 if use_cache:
1237 use_legacy_cache = not isinstance(past_key_values, Cache)
1238 if use_legacy_cache:
1239 if past_key_values is None:
1240 past_key_values = DynamicCache()
1241 else:
1242 past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1243 past_key_values_length = past_key_values.get_seq_length()
1244
1245 if position_ids is None:
1246 device = input_ids.device if input_ids is not None else inputs_embeds.device
1247 position_ids = torch.arange(
1248 past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1249 )
1250 position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1251 else:
1252 position_ids = position_ids.view(-1, seq_length).long()
1253
1254 if inputs_embeds is None:
1255 inputs_embeds = self.image_processing(input_ids, visual_features, image_token_index)
1256
1257 if attention_mask is not None and self._attn_implementation == "magi" and use_cache:
1258 is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1259 if is_padding_right:
1260 raise ValueError(
1261 "You are attempting to perform batched generation with padding_side='right'"
1262 " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
1263 " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1264 )
1265
1266 device = input_ids.device if input_ids is not None else inputs_embeds.device
1267
1268 x0_len = find_prefix_seq_length_by_pe(position_ids).to(device=device)
1269
1270 def _prepare_block_mask_for_inference(attention_mask):
1271 attention_mask = _prepare_4d_causal_attention_mask(
1272 attention_mask,
1273 (batch_size, seq_length),
1274 inputs_embeds,
1275 past_key_values_length,
1276 sliding_window=self.config.sliding_window,
1277 )
1278 # switch to ar mode
1279 if seq_length == 1 or (input_ids is not None and input_ids[0][-1].item() != self.text_mask_token_id):
1280 return attention_mask
1281
1282
1283 if attention_mask is None or len(attention_mask.shape) != 4:
1284 return attention_mask
1285
1286 # For SDLM, the generation window should set to bidirectional attention
1287 if use_cache:
1288 update_mask_func = partial(
1289 update_causal_mask_for_one_gen_window_2d,
1290 block_size=self.block_size,
1291 use_cache=use_cache,
1292 causal_attn=self.causal_attn,
1293 )
1294 else:
1295 update_mask_func = partial(
1296 update_causal_mask_with_pad_non_visible_2d,
1297 block_size=self.block_size,
1298 text_mask_token_id=self.text_mask_token_id,
1299 causal_attn=self.causal_attn,
1300 )
1301
1302 new_attention_mask = []
1303 for b in range(attention_mask.shape[0]):
1304 new_attention_mask.append(
1305 update_mask_func(
1306 input_ids[b],
1307 attention_mask[b][0],
1308 ).unsqueeze(0)
1309 )
1310 return torch.stack(new_attention_mask, dim=0)
1311
1312 def _prepare_block_mask_for_training():
1313 block_mask, _ = create_block_diff_mask_by_pe_4d(
1314 block_size=self.block_size,
1315 x0_len_list=x0_len,
1316 position_ids=position_ids,
1317 causal_attn=self.causal_attn,
1318 )
1319 return block_mask
1320
1321 if self._attn_implementation == "magi":
1322 ar_decode = seq_length == 1 or (input_ids is not None and input_ids[0][-1].item() != self.text_mask_token_id)
1323 attention_mask = build_magi_ranges(
1324 kv_len=seq_length + past_key_values_length,
1325 q_len=seq_length,
1326 block_size=self.block_size,
1327 ar_decode=ar_decode,
1328 device=device
1329 )
1330
1331 elif self._attn_implementation == "sdpa":
1332 attention_mask = _prepare_block_mask_for_training() if self.training else _prepare_block_mask_for_inference(attention_mask)
1333
1334 else:
1335 raise NotImplementedError(f'{self._attn_implementation=}')
1336
1337
1338 hidden_states = inputs_embeds
1339
1340 # decoder layers
1341 all_hidden_states = () if output_hidden_states else None
1342 all_self_attns = () if output_attentions else None
1343 next_decoder_cache = None
1344
1345 for decoder_layer in self.layers:
1346 if output_hidden_states:
1347 all_hidden_states += (hidden_states,)
1348
1349 if self.gradient_checkpointing and self.training:
1350 layer_outputs = self._gradient_checkpointing_func(
1351 decoder_layer.__call__,
1352 hidden_states,
1353 attention_mask,
1354 position_ids,
1355 past_key_values,
1356 output_attentions,
1357 use_cache,
1358 )
1359 else:
1360 layer_outputs = decoder_layer(
1361 hidden_states,
1362 attention_mask=attention_mask,
1363 position_ids=position_ids,
1364 past_key_value=past_key_values,
1365 output_attentions=output_attentions,
1366 use_cache=use_cache,
1367 )
1368
1369 hidden_states = layer_outputs[0]
1370
1371 if use_cache:
1372 next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1373
1374 if output_attentions:
1375 all_self_attns += (layer_outputs[1],)
1376
1377 hidden_states = self.norm(hidden_states)
1378
1379 # add hidden states from the last decoder layer
1380 if output_hidden_states:
1381 all_hidden_states += (hidden_states,)
1382
1383 next_cache = None
1384 if use_cache:
1385 next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1386
1387 if not return_dict:
1388 return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1389 return BaseModelOutputWithPast(
1390 last_hidden_state=hidden_states,
1391 past_key_values=next_cache,
1392 hidden_states=all_hidden_states,
1393 attentions=all_self_attns,
1394 )
1395
1396
1397 class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1398 _tied_weights_keys = ["lm_head.weight"]
1399
1400 def __init__(self, config):
1401 super().__init__(config)
1402 self.model = Qwen2Model(config)
1403 self.vocab_size = config.vocab_size
1404 self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1405
1406 self.text_mask_token_id = getattr(config, 'text_mask_token_id', 151676)
1407
1408 # Initialize weights and apply final processing
1409 self.post_init()
1410
1411
1412 def get_input_embeddings(self):
1413 return self.model.embed_tokens
1414
1415 def set_input_embeddings(self, value):
1416 self.model.embed_tokens = value
1417
1418 def get_output_embeddings(self):
1419 return self.lm_head
1420
1421 def set_output_embeddings(self, new_embeddings):
1422 self.lm_head = new_embeddings
1423
1424 def set_decoder(self, decoder):
1425 self.model = decoder
1426
1427 def get_decoder(self):
1428 return self.model
1429
1430 @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1431 @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1432 def forward(
1433 self,
1434 input_ids: torch.LongTensor = None,
1435 visual_features: Optional[torch.FloatTensor] = None,
1436 image_token_index: int = None,
1437 attention_mask: Optional[torch.Tensor] = None,
1438 position_ids: Optional[torch.LongTensor] = None,
1439 past_key_values: Optional[List[torch.FloatTensor]] = None,
1440 inputs_embeds: Optional[torch.FloatTensor] = None,
1441 labels: Optional[torch.LongTensor] = None,
1442 use_cache: Optional[bool] = None,
1443 output_attentions: Optional[bool] = None,
1444 output_hidden_states: Optional[bool] = None,
1445 return_dict: Optional[bool] = None,
1446 ) -> Union[Tuple, CausalLMOutputWithPast]:
1447 r"""
1448 Args:
1449 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1450 Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1451 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1452 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1453
1454 Returns:
1455
1456 Example:
1457
1458 ```python
1459 >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
1460
1461 >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1462 >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1463
1464 >>> prompt = "Hey, are you conscious? Can you talk to me?"
1465 >>> inputs = tokenizer(prompt, return_tensors="pt")
1466
1467 >>> # Generate
1468 >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1469 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1470 "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1471 ```"""
1472
1473 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1474 output_hidden_states = (
1475 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1476 )
1477 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1478
1479 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1480 outputs = self.model(
1481 input_ids=input_ids,
1482 visual_features=visual_features,
1483 image_token_index=image_token_index,
1484 attention_mask=attention_mask,
1485 position_ids=position_ids,
1486 past_key_values=past_key_values,
1487 inputs_embeds=inputs_embeds,
1488 use_cache=use_cache,
1489 output_attentions=output_attentions,
1490 output_hidden_states=output_hidden_states,
1491 return_dict=return_dict,
1492 )
1493
1494 hidden_states = outputs[0]
1495 logits = self.lm_head(hidden_states)
1496 logits = logits.float()
1497
1498 loss = None
1499 if labels is not None:
1500
1501 # Shift so that tokens < n predict n
1502 shift_logits = logits[..., :-1, :].contiguous()
1503 shift_labels = labels[..., 1:].contiguous()
1504
1505 # Flatten the tokens
1506 loss_fct = CrossEntropyLoss()
1507 shift_logits = shift_logits.view(-1, self.config.vocab_size)
1508
1509 shift_labels = shift_labels.view(-1)
1510 shift_labels = shift_labels.to(shift_logits.device)
1511 loss = loss_fct(shift_logits, shift_labels)
1512
1513 pos_masks = find_pred_pos_from_input_ids(input_ids, text_mask_token_id=self.text_mask_token_id)
1514 shift_input_ids = input_ids[..., :-1].contiguous()
1515 shift_pos_masks = pos_masks[:, :-1]
1516 shift_input_ids = shift_input_ids.view(-1)
1517 max_n_future_tokens = min(4, self.model.block_size)
1518 pos_loss_list = torch.zeros(max_n_future_tokens, device=shift_logits.device)
1519 shift_pos_masks = shift_pos_masks.reshape(-1)
1520
1521 for ix in range(max_n_future_tokens):
1522 seg_loss = F.cross_entropy(
1523 shift_logits[shift_pos_masks == ix],
1524 shift_labels[shift_pos_masks == ix],
1525 reduction='mean'
1526 )
1527 pos_loss_list[ix] = seg_loss
1528
1529
1530 if not return_dict:
1531 output = (logits,) + outputs[1:]
1532 return (loss,) + output if loss is not None else output
1533
1534 if self.training:
1535 return CausalLMOutputWithPast(
1536 loss=loss,
1537 logits=logits,
1538 past_key_values=outputs.past_key_values,
1539 hidden_states=outputs.hidden_states,
1540 attentions=outputs.attentions,
1541 ), pos_loss_list
1542
1543 return CausalLMOutputWithPast(
1544 loss=loss,
1545 logits=logits,
1546 past_key_values=outputs.past_key_values,
1547 hidden_states=outputs.hidden_states,
1548 attentions=outputs.attentions,
1549 )
1550
1551 def prepare_inputs_for_generation(
1552 self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1553 ):
1554 # Omit tokens covered by past_key_values
1555 if past_key_values is not None:
1556 if isinstance(past_key_values, Cache):
1557 cache_length = past_key_values.get_seq_length()
1558 past_length = past_key_values.seen_tokens
1559 max_cache_length = past_key_values.get_max_length()
1560 else:
1561 cache_length = past_length = past_key_values[0][0].shape[2]
1562 max_cache_length = None
1563
1564 # Keep only the unprocessed tokens:
1565 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1566 # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1567 # input)
1568 if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1569 input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1570 # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1571 # input_ids based on the past_length.
1572 elif past_length < input_ids.shape[1]:
1573 input_ids = input_ids[:, past_length:]
1574 # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1575
1576 # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1577 if (
1578 max_cache_length is not None
1579 and attention_mask is not None
1580 and cache_length + input_ids.shape[1] > max_cache_length
1581 ):
1582 attention_mask = attention_mask[:, -max_cache_length:]
1583
1584 position_ids = kwargs.get("position_ids", None)
1585 if attention_mask is not None and position_ids is None:
1586 # create position_ids on the fly for batch generation
1587 position_ids = attention_mask.long().cumsum(-1) - 1
1588 position_ids.masked_fill_(attention_mask == 0, 1)
1589 if past_key_values:
1590 position_ids = position_ids[:, -input_ids.shape[1] :]
1591
1592 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1593 if inputs_embeds is not None and past_key_values is None:
1594 model_inputs = {"inputs_embeds": inputs_embeds}
1595 else:
1596 model_inputs = {"input_ids": input_ids}
1597
1598 model_inputs.update(
1599 {
1600 "position_ids": position_ids,
1601 "past_key_values": past_key_values,
1602 "use_cache": kwargs.get("use_cache"),
1603 "attention_mask": attention_mask,
1604 }
1605 )
1606 return model_inputs
1607
1608 @staticmethod
1609 def _reorder_cache(past_key_values, beam_idx):
1610 reordered_past = ()
1611 for layer_past in past_key_values:
1612 reordered_past += (
1613 tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1614 )
1615 return reordered_past
1616
1617
1618 @add_start_docstrings(
1619 """
1620 The Qwen2 Model transformer with a sequence classification head on top (linear layer).
1621
1622 [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1623 (e.g. GPT-2) do.
1624
1625 Since it does classification on the last token, it requires to know the position of the last token. If a
1626 `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1627 no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1628 padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1629 each row of the batch).
1630 """,
1631 QWEN2_START_DOCSTRING,
1632 )
1633 class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1634 def __init__(self, config):
1635 super().__init__(config)
1636 self.num_labels = config.num_labels
1637 self.model = Qwen2Model(config)
1638 self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1639
1640 # Initialize weights and apply final processing
1641 self.post_init()
1642
1643 def get_input_embeddings(self):
1644 return self.model.embed_tokens
1645
1646 def set_input_embeddings(self, value):
1647 self.model.embed_tokens = value
1648
1649 @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1650 def forward(
1651 self,
1652 input_ids: torch.LongTensor = None,
1653 attention_mask: Optional[torch.Tensor] = None,
1654 position_ids: Optional[torch.LongTensor] = None,
1655 past_key_values: Optional[List[torch.FloatTensor]] = None,
1656 inputs_embeds: Optional[torch.FloatTensor] = None,
1657 labels: Optional[torch.LongTensor] = None,
1658 use_cache: Optional[bool] = None,
1659 output_attentions: Optional[bool] = None,
1660 output_hidden_states: Optional[bool] = None,
1661 return_dict: Optional[bool] = None,
1662 ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1663 r"""
1664 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1665 Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1666 config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1667 `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1668 """
1669 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1670
1671 transformer_outputs = self.model(
1672 input_ids,
1673 attention_mask=attention_mask,
1674 position_ids=position_ids,
1675 past_key_values=past_key_values,
1676 inputs_embeds=inputs_embeds,
1677 use_cache=use_cache,
1678 output_attentions=output_attentions,
1679 output_hidden_states=output_hidden_states,
1680 return_dict=return_dict,
1681 )
1682 hidden_states = transformer_outputs[0]
1683 logits = self.score(hidden_states)
1684
1685 if input_ids is not None:
1686 batch_size = input_ids.shape[0]
1687 else:
1688 batch_size = inputs_embeds.shape[0]
1689
1690 if self.config.pad_token_id is None and batch_size != 1:
1691 raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1692 if self.config.pad_token_id is None:
1693 sequence_lengths = -1
1694 else:
1695 if input_ids is not None:
1696 # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1697 sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1698 sequence_lengths = sequence_lengths % input_ids.shape[-1]
1699 sequence_lengths = sequence_lengths.to(logits.device)
1700 else:
1701 sequence_lengths = -1
1702
1703 pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1704
1705 loss = None
1706 if labels is not None:
1707 labels = labels.to(logits.device)
1708 if self.config.problem_type is None:
1709 if self.num_labels == 1:
1710 self.config.problem_type = "regression"
1711 elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1712 self.config.problem_type = "single_label_classification"
1713 else:
1714 self.config.problem_type = "multi_label_classification"
1715
1716 if self.config.problem_type == "regression":
1717 loss_fct = MSELoss()
1718 if self.num_labels == 1:
1719 loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1720 else:
1721 loss = loss_fct(pooled_logits, labels)
1722 elif self.config.problem_type == "single_label_classification":
1723 loss_fct = CrossEntropyLoss()
1724 loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1725 elif self.config.problem_type == "multi_label_classification":
1726 loss_fct = BCEWithLogitsLoss()
1727 loss = loss_fct(pooled_logits, labels)
1728 if not return_dict:
1729 output = (pooled_logits,) + transformer_outputs[1:]
1730 return ((loss,) + output) if loss is not None else output
1731
1732 return SequenceClassifierOutputWithPast(
1733 loss=loss,
1734 logits=pooled_logits,
1735 past_key_values=transformer_outputs.past_key_values,
1736 hidden_states=transformer_outputs.hidden_states,
1737 attentions=transformer_outputs.attentions,
1738 )
1739