modeling_phi3_v.py
86.8 KB · 1943 lines · python Raw
1 # coding=utf-8
2 # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 # http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 """ PyTorch Phi-3-V model."""
17
18 import inspect
19 import math
20 import warnings
21 from typing import List, Optional, Tuple, Union
22
23 import torch
24 import torch.nn.functional as F
25 import torch.utils.checkpoint
26 from torch import nn
27 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
29 from transformers.activations import ACT2FN
30 from transformers.cache_utils import Cache, DynamicCache
31 from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
32 from transformers.modeling_outputs import (
33 BaseModelOutputWithPast,
34 CausalLMOutputWithPast,
35 SequenceClassifierOutputWithPast,
36 TokenClassifierOutput,
37 )
38 from transformers.modeling_utils import PreTrainedModel
39 from transformers.utils import (
40 add_code_sample_docstrings,
41 add_start_docstrings,
42 add_start_docstrings_to_model_forward,
43 is_flash_attn_greater_or_equal_2_10,
44 logging,
45 replace_return_docstrings,
46 )
47 from .configuration_phi3_v import Phi3VConfig
48
49 try:
50 from flash_attn import flash_attn_func, flash_attn_varlen_func
51 from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
52
53 _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
54 except ImportError:
55 pass
56
57 import torch
58 from torch import nn
59 from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig
60 from transformers.models.clip.modeling_clip import CLIPAttention
61 from transformers.utils import logging
62
63 logger = logging.get_logger(__name__)
64
65
66 MAX_INPUT_ID = int(1e9)
67
68 CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
69 attention_dropout=0.0,
70 dropout=0.0,
71 hidden_act="quick_gelu",
72 hidden_size=1024,
73 image_size=336,
74 initializer_factor=1.0,
75 initializer_range=0.02,
76 intermediate_size=4096,
77 layer_norm_eps=1e-05,
78 num_attention_heads=16,
79 num_channels=3,
80 num_hidden_layers=24,
81 patch_size=14,
82 projection_dim=768
83 )
84
85 class CLIPAttentionFA2(CLIPAttention):
86 """Add flash attention 2 to CLIPAttention. (This is only used in the vision encoder)"""
87
88 def forward(self,
89 hidden_states,
90 attention_mask=None,
91 causal_attention_mask=None,
92 output_attentions=False,
93 ):
94 """Input shape: Batch x Time x Channel"""
95
96 assert attention_mask is None, "CLIPAttentionFA2 does not support attention_mask"
97 assert causal_attention_mask is None, "CLIPAttentionFA2 does not support causal_attention_mask"
98 assert output_attentions is False, "CLIPAttentionFA2 does not support output_attentions"
99
100 bsz, tgt_len, embed_dim = hidden_states.size()
101 query_states = self.q_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
102 key_states = self.k_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
103 value_states = self.v_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
104
105 attn_output = flash_attn_func(
106 query_states,
107 key_states,
108 value_states,
109 dropout_p=self.dropout if self.training else 0.0,
110 softmax_scale=self.scale,
111 causal=False,
112 ).reshape(bsz, tgt_len, embed_dim)
113
114 attn_output = self.out_proj(attn_output)
115 return attn_output, None
116
117
118 class Phi3ImageEmbedding(nn.Module):
119 """Phi3 Image embedding."""
120
121 def __init__(self, config: PretrainedConfig, wte=None, **kwargs) -> None:
122 super().__init__()
123
124 # n_embed or hidden_size
125 hidden_size = config.n_embd if hasattr(config, 'n_embd') else config.hidden_size
126 if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'):
127 embd_drop = config.embd_pdrop if hasattr(config, 'embd_pdrop') else config.embed_pdrop
128 self.drop = nn.Dropout(embd_drop)
129 else:
130 self.drop = None
131
132 self.wte = wte
133
134 if isinstance(config.img_processor, dict) and config.img_processor.get('name', None) == 'clip_vision_model':
135 assert 'model_name' in config.img_processor, 'model_name must be provided for CLIPVisionModel'
136 assert 'image_dim_out' in config.img_processor, 'image_dim_out must be provided for CLIPVisionModel'
137 assert 'num_img_tokens' in config.img_processor, 'num_img_tokens must be provided for CLIPVisionModel'
138 assert config.img_processor['model_name'] == 'openai/clip-vit-large-patch14-336'
139 clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
140 self.img_processor = CLIPVisionModel(clip_config)
141 image_dim_out = config.img_processor['image_dim_out']
142 self.num_img_tokens = config.img_processor['num_img_tokens']
143
144 # FA2 in CLIP
145 if config._attn_implementation == 'flash_attention_2':
146 for layer in self.img_processor.vision_model.encoder.layers:
147 clip_fa2 = CLIPAttentionFA2(clip_config)
148 del layer.self_attn
149 layer.self_attn = clip_fa2
150 else:
151 raise NotImplementedError(f'img_processor = {config.img_processor}, not implemented')
152
153 self.image_dim_out = image_dim_out
154 self.img_sizes = None
155
156 # global_gn and sub_gn for hd transform, serves as line separator
157 self.use_hd_transform = kwargs.get('use_hd_transform', False)
158 self.with_learnable_separator = kwargs.get('with_learnable_separator', False)
159 self.hd_transform_order = kwargs.get('hd_transform_order', 'glb_sub')
160 # with_hd_transform and with_learnable_separator should have same value
161 assert self.use_hd_transform == self.with_learnable_separator, 'use_hd_transform and with_learnable_separator should have same value'
162 if self.with_learnable_separator:
163 assert self.use_hd_transform, 'learnable separator is only for hd transform'
164 # 1024 * 4, merge spatial to channel dimension
165 self.glb_GN = nn.Parameter(torch.zeros([1, 1, self.image_dim_out * 4]))
166 self.sub_GN = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out * 4]))
167 logger.info(f'learnable separator enabled for hd transform, hd_transform_order = {self.hd_transform_order}')
168
169 projection_cls = kwargs.get('projection_cls', 'linear')
170 if projection_cls == 'linear':
171 self.img_projection = nn.Linear(image_dim_out, hidden_size)
172 elif projection_cls == 'mlp' and self.use_hd_transform:
173 dim_projection = hidden_size
174 depth = 2
175 layers = [nn.Linear(image_dim_out * 4, dim_projection)]
176 for _ in range(1, depth):
177 layers.extend([nn.GELU(),
178 nn.Linear(dim_projection, dim_projection)])
179 self.img_projection = nn.Sequential(*layers)
180 elif projection_cls == 'mlp':
181 dim_projection = hidden_size
182 depth = 2
183 layers = [nn.Linear(image_dim_out, dim_projection)]
184 for _ in range(1, depth):
185 layers.extend([nn.GELU(),
186 nn.Linear(dim_projection, dim_projection)])
187 self.img_projection = nn.Sequential(*layers)
188 else:
189 raise NotImplementedError(f'projection_cls = {projection_cls}, not implemented')
190
191 self.vocab_size = config.vocab_size
192 self.img_features = None
193
194 if isinstance(config.img_processor, dict):
195 self.layer_idx = config.img_processor.get('layer_idx', -2)
196 self.type_feature = config.img_processor.get('type_feature', 'patch')
197 else:
198 self.layer_idx = -2
199 self.type_feature = 'patch'
200
201
202 def set_img_features(self, img_features: torch.FloatTensor) -> None:
203 self.img_features = img_features
204
205 def set_img_sizes(self, img_sizes: torch.LongTensor) -> None:
206 self.img_sizes = img_sizes
207
208 def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
209 LAYER_IDX = self.layer_idx
210 TYPE_FEATURE = self.type_feature
211
212 img_processor_output = self.img_processor(img_embeds, output_hidden_states=True)
213 img_feature = img_processor_output.hidden_states[LAYER_IDX]
214
215 if TYPE_FEATURE == "patch":
216 patch_feature = img_feature[:, 1:]
217 return patch_feature
218
219 raise NotImplementedError
220
221 def forward(
222 self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, image_sizes=None
223 ) -> torch.FloatTensor:
224 input_shape = input_ids.size()
225 input_ids = input_ids.view(-1, input_shape[-1])
226
227 # positions for image tokens
228 positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=True)
229 has_image = len(positions[0].tolist()) > 0
230 input_ids = input_ids.clamp_min(0).clamp_max(self.vocab_size).detach()
231 hidden_states = self.wte(input_ids)
232
233 if has_image:
234 assert self.use_hd_transform
235 num_images, num_crops, c, h, w = pixel_values.shape
236 assert c == 3 and h == w == 336
237 img_features = self.get_img_features(pixel_values.flatten(0, 1)).reshape(
238 num_images, num_crops, -1, self.image_dim_out
239 )
240 image_features_proj = self.hd_feature_transform(img_features, image_sizes)
241 hidden_states = hidden_states.index_put(
242 positions, image_features_proj, accumulate=False
243 )
244
245 if self.drop is not None:
246 hidden_states = self.drop(hidden_states)
247
248 return hidden_states
249
250 def hd_feature_transform(self, image_features, image_sizes):
251 """
252 image_features: (num_images, num_crops+1, 24*24, 1024)
253 """
254 assert (
255 self.hd_transform_order == 'sub_glb'
256 ), f'hd_transform_order `{self.hd_transform_order}` not implemented'
257 if isinstance(self.img_projection, nn.Sequential):
258 target_device = self.img_projection[0].bias.device
259 target_dtype = self.img_projection[0].bias.dtype
260 else: # It's a single nn.Linear layer
261 target_device = self.img_projection.bias.device
262 target_dtype = self.img_projection.bias.dtype
263
264 global_image_features = image_features[:, 0] # (num_images, 24*24, 1024)
265 # global feature can be viewed as a special HD case with num_crops 1x1
266 global_image_features_hd = self.reshape_hd_patches_2x2merge(global_image_features, 1, 1)
267 global_image_features_hd_newline = self.add_image_newline(global_image_features_hd)
268
269 all_image_embeddings = []
270 # need a for loop to process each image because of different image sizes
271 # (patch arrangement is different for each image)
272 for i, img_size in enumerate(image_sizes):
273 h, w = img_size
274 h_crop = h // 336
275 w_crop = w // 336
276 num_crops = h_crop * w_crop
277
278 # NOTE: real num_crops is padded
279 # (num_crops, 24*24, 1024)
280 sub_image_features = image_features[i, 1 : 1 + num_crops]
281 sub_image_features_hd = self.reshape_hd_patches_2x2merge(
282 sub_image_features, h_crop, w_crop
283 )
284 sub_image_features_hd_newline = self.add_image_newline(sub_image_features_hd)
285
286 # [sub features, separator, global features]
287 all_image_embeddings.extend(
288 [
289 sub_image_features_hd_newline.squeeze(0), # (h_crop*12*(w_crop*12+1), 4096)
290 self.glb_GN.squeeze(0),
291 global_image_features_hd_newline[i],
292 ]
293 )
294
295 image_features_proj = self.img_projection(
296 torch.cat(all_image_embeddings, dim=0).to(target_device).to(target_dtype)
297 )
298
299 return image_features_proj
300
301 def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
302 """
303 image_features: (num_images*num_crops, 24*24, 1024)
304 output: (num_images, h_crop*12, w_crop*12, 4096), h_crop*w_crop == num_crops
305 """
306 N, L, C = image_features.shape
307 assert L == 24 * 24 and C == 1024 and N % (h_crop * w_crop) == 0
308 num_images = N // (h_crop * w_crop)
309 H = int(L**0.5)
310 image_features_hd = (
311 image_features.reshape(N, H, H, C) # N, 24, 24, 1024
312 .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
313 .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
314 .reshape(N, -1, 4 * C) # N, 144, 4096
315 .reshape(
316 num_images, h_crop, w_crop, H // 2, H // 2, -1
317 ) # n_img, h_crop, w_crop, 12, 12, 4096
318 .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
319 .reshape(
320 num_images, h_crop * H // 2, w_crop * H // 2, 4 * C
321 ) # n_img, h_crop*12, w_crop*12, 4096
322 )
323
324 # alternative implementation using einops
325 # from einops import rearrange
326 # image_features_nhwc = rearrange(
327 # image_features,
328 # 'N (H W) c -> N H W c',
329 # H=H,
330 # W=H,
331 # )
332 # image_features_2x2merge = rearrange(
333 # image_features_nhwc,
334 # 'N (h h_pool) (w w_pool) c -> N h w (h_pool w_pool c)',
335 # h_pool=2,
336 # w_pool=2,
337 # )
338 # image_features_hd = rearrange(
339 # image_features_2x2merge,
340 # '(n_img h_crop w_crop) h w C -> n_img (h_crop h) (w_crop w) C',
341 # h_crop=h_crop,
342 # w_crop=w_crop,
343 # )
344
345 return image_features_hd
346
347 def add_image_newline(self, image_features_hd):
348 """
349 image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
350 output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
351 """
352 num_images, h, w, hid_dim = image_features_hd.shape
353 # add the newline token to the HD image feature patches
354 newline_embeddings = self.sub_GN.expand(num_images, h, -1, -1) # (n_img, h, 1, hid_dim)
355 image_features_hd_newline = torch.cat(
356 [image_features_hd, newline_embeddings], dim=2
357 ).reshape(num_images, -1, hid_dim)
358 return image_features_hd_newline
359
360
361 logger = logging.get_logger(__name__)
362
363 _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-vision-128k-instruct"
364 _CONFIG_FOR_DOC = "Phi3VConfig"
365
366 PHI3V_PRETRAINED_MODEL_ARCHIVE_LIST = [
367 "microsoft/Phi-3-vision-128k-instruct",
368 # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3
369 ]
370
371
372 # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
373 class Phi3RMSNorm(nn.Module):
374 def __init__(self, hidden_size, eps=1e-6):
375 """
376 Phi3RMSNorm is equivalent to T5LayerNorm
377 """
378 super().__init__()
379 self.weight = nn.Parameter(torch.ones(hidden_size))
380 self.variance_epsilon = eps
381
382 def forward(self, hidden_states):
383 input_dtype = hidden_states.dtype
384 hidden_states = hidden_states.to(torch.float32)
385 variance = hidden_states.pow(2).mean(-1, keepdim=True)
386 hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
387 return self.weight * hidden_states.to(input_dtype)
388
389
390 # Copied from transformers.models.llama.modeling_llama._get_unpad_data
391 def _get_unpad_data(attention_mask):
392 seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
393 indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
394 max_seqlen_in_batch = seqlens_in_batch.max().item()
395 cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
396 return (
397 indices,
398 cu_seqlens,
399 max_seqlen_in_batch,
400 )
401
402
403 # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
404 class Phi3RotaryEmbedding(nn.Module):
405 def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
406 super().__init__()
407
408 self.dim = dim
409 self.max_position_embeddings = max_position_embeddings
410 self.base = base
411 self.register_buffer("inv_freq", None, persistent=False)
412
413 @torch.no_grad()
414 def forward(self, x, position_ids, seq_len=None):
415 # x: [bs, num_attention_heads, seq_len, head_size]
416 if self.inv_freq is None:
417 self.inv_freq = 1.0 / (
418 self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
419 )
420 inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
421 position_ids_expanded = position_ids[:, None, :].float()
422 # Force float32 since bfloat16 loses precision on long contexts
423 # See https://github.com/huggingface/transformers/pull/29285
424 device_type = x.device.type
425 device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
426 with torch.autocast(device_type=device_type, enabled=False):
427 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
428 emb = torch.cat((freqs, freqs), dim=-1)
429 cos = emb.cos()
430 sin = emb.sin()
431 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
432
433
434 class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
435 def __init__(self, dim, config, device=None):
436 super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
437
438 self.short_factor = config.rope_scaling["short_factor"]
439 self.long_factor = config.rope_scaling["long_factor"]
440 self.original_max_position_embeddings = config.original_max_position_embeddings
441
442 @torch.no_grad()
443 def forward(self, x, position_ids, seq_len=None):
444 seq_len = seq_len or torch.max(position_ids) + 1
445 if seq_len > self.original_max_position_embeddings:
446 ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
447 else:
448 ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
449
450 inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
451 self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
452
453 inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
454 position_ids_expanded = position_ids[:, None, :].float()
455
456 # Force float32 since bfloat16 loses precision on long contexts
457 # See https://github.com/huggingface/transformers/pull/29285
458 device_type = x.device.type
459 device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
460 with torch.autocast(device_type=device_type, enabled=False):
461 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
462 emb = torch.cat((freqs, freqs), dim=-1)
463
464 scale = self.max_position_embeddings / self.original_max_position_embeddings
465 if scale <= 1.0:
466 scaling_factor = 1.0
467 else:
468 scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
469
470 cos = emb.cos() * scaling_factor
471 sin = emb.sin() * scaling_factor
472 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
473
474
475 class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
476 def __init__(self, dim, config, device=None):
477 super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
478
479 self.short_factor = config.rope_scaling["short_factor"]
480 self.long_factor = config.rope_scaling["long_factor"]
481 self.original_max_position_embeddings = config.original_max_position_embeddings
482
483 @torch.no_grad()
484 def forward(self, x, position_ids, seq_len=None):
485 seq_len = torch.max(position_ids) + 1
486 if seq_len > self.original_max_position_embeddings:
487 ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
488 else:
489 ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
490
491 inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
492 self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
493
494 inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
495 position_ids_expanded = position_ids[:, None, :].float()
496
497 # Force float32 since bfloat16 loses precision on long contexts
498 # See https://github.com/huggingface/transformers/pull/29285
499 device_type = x.device.type
500 device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
501 with torch.autocast(device_type=device_type, enabled=False):
502 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
503 emb = torch.cat((freqs, freqs), dim=-1)
504
505 scale = self.max_position_embeddings / self.original_max_position_embeddings
506 if scale <= 1.0:
507 scaling_factor = 1.0
508 else:
509 scaling_factor = 0.1 * math.log(scale) + 1.0
510
511 cos = emb.cos() * scaling_factor
512 sin = emb.sin() * scaling_factor
513 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
514
515
516 # Copied from transformers.models.llama.modeling_llama.rotate_half
517 def rotate_half(x):
518 """Rotates half the hidden dims of the input."""
519 x1 = x[..., : x.shape[-1] // 2]
520 x2 = x[..., x.shape[-1] // 2 :]
521 return torch.cat((-x2, x1), dim=-1)
522
523
524 # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
525 def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
526 """Applies Rotary Position Embedding to the query and key tensors.
527
528 Args:
529 q (`torch.Tensor`): The query tensor.
530 k (`torch.Tensor`): The key tensor.
531 cos (`torch.Tensor`): The cosine part of the rotary embedding.
532 sin (`torch.Tensor`): The sine part of the rotary embedding.
533 position_ids (`torch.Tensor`, *optional*):
534 Deprecated and unused.
535 unsqueeze_dim (`int`, *optional*, defaults to 1):
536 The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
537 sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
538 that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
539 k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
540 cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
541 the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
542 Returns:
543 `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
544 """
545 cos = cos.unsqueeze(unsqueeze_dim)
546 sin = sin.unsqueeze(unsqueeze_dim)
547 q_embed = (q * cos) + (rotate_half(q) * sin)
548 k_embed = (k * cos) + (rotate_half(k) * sin)
549 return q_embed, k_embed
550
551
552 class Phi3MLP(nn.Module):
553 def __init__(self, config):
554 super().__init__()
555
556 self.config = config
557 self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
558 self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
559
560 self.activation_fn = ACT2FN[config.hidden_act]
561
562 def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
563 up_states = self.gate_up_proj(hidden_states)
564
565 gate, up_states = up_states.chunk(2, dim=-1)
566 up_states = up_states * self.activation_fn(gate)
567
568 return self.down_proj(up_states)
569
570
571 # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
572 def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
573 """
574 This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
575 num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
576 """
577 batch, num_key_value_heads, slen, head_dim = hidden_states.shape
578 if n_rep == 1:
579 return hidden_states
580 hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
581 return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
582
583
584 class Phi3Attention(nn.Module):
585 """Multi-headed attention from 'Attention Is All You Need' paper"""
586
587 def __init__(self, config: Phi3VConfig, layer_idx: Optional[int] = None):
588 super().__init__()
589 self.config = config
590 self.layer_idx = layer_idx
591 if layer_idx is None:
592 logger.warning_once(
593 f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
594 "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
595 "when creating this class."
596 )
597
598 self.attention_dropout = config.attention_dropout
599 self.hidden_size = config.hidden_size
600 self.num_heads = config.num_attention_heads
601 self.head_dim = self.hidden_size // self.num_heads
602 self.num_key_value_heads = config.num_key_value_heads
603 self.num_key_value_groups = self.num_heads // self.num_key_value_heads
604 self.max_position_embeddings = config.max_position_embeddings
605 self.original_max_position_embeddings = config.original_max_position_embeddings
606 self.rope_theta = config.rope_theta
607 self.rope_scaling = config.rope_scaling
608 self.is_causal = True
609
610 if (self.head_dim * self.num_heads) != self.hidden_size:
611 raise ValueError(
612 f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
613 f" and `num_heads`: {self.num_heads})."
614 )
615
616 op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
617 self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
618 self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
619 self._init_rope()
620
621 def _init_rope(self):
622 if self.rope_scaling is None:
623 self.rotary_emb = Phi3RotaryEmbedding(
624 self.head_dim,
625 max_position_embeddings=self.max_position_embeddings,
626 base=self.rope_theta,
627 )
628 else:
629 scaling_type = self.config.rope_scaling["type"]
630 if scaling_type == "su":
631 self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
632 elif scaling_type == "yarn":
633 self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
634 else:
635 raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
636
637 def forward(
638 self,
639 hidden_states: torch.Tensor,
640 attention_mask: Optional[torch.Tensor] = None,
641 position_ids: Optional[torch.LongTensor] = None,
642 past_key_value: Optional[Cache] = None,
643 output_attentions: bool = False,
644 use_cache: bool = False,
645 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
646 logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.")
647
648 bsz, q_len, _ = hidden_states.size()
649
650 qkv = self.qkv_proj(hidden_states)
651 query_pos = self.num_heads * self.head_dim
652 query_states = qkv[..., :query_pos]
653 key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
654 value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
655
656 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
657 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
658 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
659
660 kv_seq_len = key_states.shape[-2]
661 if past_key_value is not None:
662 if self.layer_idx is None:
663 raise ValueError(
664 f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
665 "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
666 "with a layer index."
667 )
668 kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
669 cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
670
671 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
672
673 if past_key_value is not None:
674 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
675 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
676
677 # repeat k/v heads if n_kv_heads < n_heads
678 key_states = repeat_kv(key_states, self.num_key_value_groups)
679 value_states = repeat_kv(value_states, self.num_key_value_groups)
680
681 attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
682
683 if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
684 raise ValueError(
685 f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
686 f" {attn_weights.size()}"
687 )
688
689 if attention_mask is not None:
690 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
691 raise ValueError(
692 f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
693 )
694 attn_weights = attn_weights + attention_mask
695
696 # upcast attention to fp32
697 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
698 attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
699
700 attn_output = torch.matmul(attn_weights, value_states)
701
702 if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
703 raise ValueError(
704 f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
705 f" {attn_output.size()}"
706 )
707
708 attn_output = attn_output.transpose(1, 2).contiguous()
709 attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
710
711 attn_output = self.o_proj(attn_output)
712
713 if not output_attentions:
714 attn_weights = None
715
716 return attn_output, attn_weights, past_key_value
717
718
719 class Phi3FlashAttention2(Phi3Attention):
720 """
721 Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
722 untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
723 flash attention and deal with padding tokens in case the input contains any of them.
724 """
725
726 # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
727 def __init__(self, *args, **kwargs):
728 super().__init__(*args, **kwargs)
729
730 # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
731 # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
732 # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
733 self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
734
735 def forward(
736 self,
737 hidden_states: torch.Tensor,
738 attention_mask: Optional[torch.LongTensor] = None,
739 position_ids: Optional[torch.LongTensor] = None,
740 past_key_value: Optional[Cache] = None,
741 output_attentions: bool = False,
742 use_cache: bool = False,
743 **kwargs,
744 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
745 # Phi3FlashAttention2 attention does not support output_attentions
746
747 if not _flash_supports_window_size:
748 logger.warning_once(
749 "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library."
750 )
751 raise ValueError("The current flash attention version does not support sliding window attention.")
752
753 output_attentions = False
754
755 if "padding_mask" in kwargs:
756 warnings.warn(
757 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
758 )
759
760 # overwrite attention_mask with padding_mask
761 attention_mask = kwargs.pop("padding_mask")
762
763 bsz, q_len, _ = hidden_states.size()
764
765 qkv = self.qkv_proj(hidden_states)
766 query_pos = self.num_heads * self.head_dim
767 query_states = qkv[..., :query_pos]
768 key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
769 value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
770
771 # Flash attention requires the input to have the shape
772 # batch_size x seq_length x head_dim x hidden_dim
773 # therefore we just need to keep the original shape
774 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
775 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
776 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
777
778 kv_seq_len = key_states.shape[-2]
779 if past_key_value is not None:
780 if self.layer_idx is None:
781 raise ValueError(
782 f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
783 "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
784 "with a layer index."
785 )
786 kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
787
788 # Because the input can be padded, the absolute sequence length depends on the max position id.
789 rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
790 cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
791
792 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
793
794 use_sliding_windows = (
795 _flash_supports_window_size
796 and getattr(self.config, "sliding_window", None) is not None
797 and kv_seq_len > self.config.sliding_window
798 )
799
800 if past_key_value is not None:
801 # Activate slicing cache only if the config has a value `sliding_windows` attribute
802 cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
803 if (
804 getattr(self.config, "sliding_window", None) is not None
805 and kv_seq_len > self.config.sliding_window
806 and cache_has_contents
807 ):
808 slicing_tokens = 1 - self.config.sliding_window
809
810 past_key = past_key_value[self.layer_idx][0]
811 past_value = past_key_value[self.layer_idx][1]
812
813 past_key = past_key[:, :, slicing_tokens:, :].contiguous()
814 past_value = past_value[:, :, slicing_tokens:, :].contiguous()
815
816 if past_key.shape[-2] != self.config.sliding_window - 1:
817 raise ValueError(
818 f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
819 f" {past_key.shape}"
820 )
821
822 if attention_mask is not None:
823 attention_mask = attention_mask[:, slicing_tokens:]
824 attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
825
826 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
827 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
828
829 # repeat k/v heads if n_kv_heads < n_heads
830 key_states = repeat_kv(key_states, self.num_key_value_groups)
831 value_states = repeat_kv(value_states, self.num_key_value_groups)
832
833 attn_dropout = self.attention_dropout if self.training else 0.0
834
835 # In PEFT, usually we cast the layer norms in float32 for training stability reasons
836 # therefore the input hidden states gets silently casted in float32. Hence, we need
837 # cast them back in the correct dtype just to be sure everything works as expected.
838 # This might slowdown training & inference so it is recommended to not cast the LayerNorms
839 # in fp32.
840
841 if query_states.dtype == torch.float32:
842 if torch.is_autocast_enabled():
843 target_dtype = torch.get_autocast_gpu_dtype()
844 # Handle the case where the model is quantized
845 elif hasattr(self.config, "_pre_quantization_dtype"):
846 target_dtype = self.config._pre_quantization_dtype
847 else:
848 target_dtype = self.qkv_proj.weight.dtype
849
850 logger.warning_once(
851 f"The input hidden states seems to be silently casted in float32, this might be related to"
852 f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
853 f" {target_dtype}."
854 )
855
856 query_states = query_states.to(target_dtype)
857 key_states = key_states.to(target_dtype)
858 value_states = value_states.to(target_dtype)
859
860 # Reashape to the expected shape for Flash Attention
861 query_states = query_states.transpose(1, 2)
862 key_states = key_states.transpose(1, 2)
863 value_states = value_states.transpose(1, 2)
864
865 attn_output = self._flash_attention_forward(
866 query_states,
867 key_states,
868 value_states,
869 attention_mask,
870 q_len,
871 dropout=attn_dropout,
872 use_sliding_windows=use_sliding_windows,
873 )
874
875 attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
876 attn_output = self.o_proj(attn_output)
877
878 if not output_attentions:
879 attn_weights = None
880
881 return attn_output, attn_weights, past_key_value
882
883 # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
884 def _flash_attention_forward(
885 self,
886 query_states,
887 key_states,
888 value_states,
889 attention_mask,
890 query_length,
891 dropout=0.0,
892 softmax_scale=None,
893 use_sliding_windows=False,
894 ):
895 """
896 Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
897 first unpad the input, then computes the attention scores and pad the final attention scores.
898
899 Args:
900 query_states (`torch.Tensor`):
901 Input query states to be passed to Flash Attention API
902 key_states (`torch.Tensor`):
903 Input key states to be passed to Flash Attention API
904 value_states (`torch.Tensor`):
905 Input value states to be passed to Flash Attention API
906 attention_mask (`torch.Tensor`):
907 The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
908 position of padding tokens and 1 for the position of non-padding tokens.
909 dropout (`float`):
910 Attention dropout
911 softmax_scale (`float`, *optional*):
912 The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
913 use_sliding_windows (`bool`, *optional*):
914 Whether to activate sliding window attention.
915 """
916 if not self._flash_attn_uses_top_left_mask:
917 causal = self.is_causal
918 else:
919 # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
920 causal = self.is_causal and query_length != 1
921
922 # Contains at least one padding token in the sequence
923 if attention_mask is not None:
924 batch_size = query_states.shape[0]
925 query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
926 query_states, key_states, value_states, attention_mask, query_length
927 )
928
929 cu_seqlens_q, cu_seqlens_k = cu_seq_lens
930 max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
931
932 if not use_sliding_windows:
933 attn_output_unpad = flash_attn_varlen_func(
934 query_states,
935 key_states,
936 value_states,
937 cu_seqlens_q=cu_seqlens_q,
938 cu_seqlens_k=cu_seqlens_k,
939 max_seqlen_q=max_seqlen_in_batch_q,
940 max_seqlen_k=max_seqlen_in_batch_k,
941 dropout_p=dropout,
942 softmax_scale=softmax_scale,
943 causal=causal,
944 )
945 else:
946 attn_output_unpad = flash_attn_varlen_func(
947 query_states,
948 key_states,
949 value_states,
950 cu_seqlens_q=cu_seqlens_q,
951 cu_seqlens_k=cu_seqlens_k,
952 max_seqlen_q=max_seqlen_in_batch_q,
953 max_seqlen_k=max_seqlen_in_batch_k,
954 dropout_p=dropout,
955 softmax_scale=softmax_scale,
956 causal=causal,
957 window_size=(self.config.sliding_window, self.config.sliding_window),
958 )
959
960 attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
961 else:
962 if not use_sliding_windows:
963 attn_output = flash_attn_func(
964 query_states,
965 key_states,
966 value_states,
967 dropout,
968 softmax_scale=softmax_scale,
969 causal=causal,
970 )
971 else:
972 attn_output = flash_attn_func(
973 query_states,
974 key_states,
975 value_states,
976 dropout,
977 softmax_scale=softmax_scale,
978 causal=causal,
979 window_size=(self.config.sliding_window, self.config.sliding_window),
980 )
981
982 return attn_output
983
984 # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
985 def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
986 batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
987
988 # On the first iteration we need to properly re-create the padding mask
989 # by slicing it on the proper place
990 if kv_seq_len != attention_mask.shape[-1]:
991 attention_mask_num_tokens = attention_mask.shape[-1]
992 attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
993
994 indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
995
996 key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
997 value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
998
999 if query_length == kv_seq_len:
1000 query_layer = index_first_axis(
1001 query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
1002 )
1003 cu_seqlens_q = cu_seqlens_k
1004 max_seqlen_in_batch_q = max_seqlen_in_batch_k
1005 indices_q = indices_k
1006 elif query_length == 1:
1007 max_seqlen_in_batch_q = 1
1008 cu_seqlens_q = torch.arange(
1009 batch_size + 1, dtype=torch.int32, device=query_layer.device
1010 ) # There is a memcpy here, that is very bad.
1011 indices_q = cu_seqlens_q[:-1]
1012 query_layer = query_layer.squeeze(1)
1013 else:
1014 # The -q_len: slice assumes left padding.
1015 attention_mask = attention_mask[:, -query_length:]
1016 query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
1017
1018 return (
1019 query_layer,
1020 key_layer,
1021 value_layer,
1022 indices_q,
1023 (cu_seqlens_q, cu_seqlens_k),
1024 (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1025 )
1026
1027
1028 # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
1029 # TODO @Arthur no longer copied from LLama after static cache
1030 class Phi3SdpaAttention(Phi3Attention):
1031 """
1032 Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
1033 `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
1034 SDPA API.
1035 """
1036
1037 # Adapted from Phi3Attention.forward
1038 def forward(
1039 self,
1040 hidden_states: torch.Tensor,
1041 attention_mask: Optional[torch.Tensor] = None,
1042 position_ids: Optional[torch.LongTensor] = None,
1043 past_key_value: Optional[Cache] = None,
1044 output_attentions: bool = False,
1045 use_cache: bool = False,
1046 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1047 if output_attentions:
1048 # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
1049 logger.warning_once(
1050 "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
1051 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
1052 )
1053 return super().forward(
1054 hidden_states=hidden_states,
1055 attention_mask=attention_mask,
1056 position_ids=position_ids,
1057 past_key_value=past_key_value,
1058 output_attentions=output_attentions,
1059 use_cache=use_cache,
1060 )
1061
1062 bsz, q_len, _ = hidden_states.size()
1063
1064 qkv = self.qkv_proj(hidden_states)
1065 query_pos = self.num_heads * self.head_dim
1066 query_states = qkv[..., :query_pos]
1067 key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
1068 value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
1069
1070 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1071 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1072 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1073
1074 kv_seq_len = key_states.shape[-2]
1075 if past_key_value is not None:
1076 kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1077 cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
1078
1079 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1080
1081 if past_key_value is not None:
1082 cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
1083 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1084
1085 key_states = repeat_kv(key_states, self.num_key_value_groups)
1086 value_states = repeat_kv(value_states, self.num_key_value_groups)
1087
1088 if attention_mask is not None:
1089 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
1090 raise ValueError(
1091 f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
1092 )
1093
1094 # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1095 # Reference: https://github.com/pytorch/pytorch/issues/112577.
1096 if query_states.device.type == "cuda" and attention_mask is not None:
1097 query_states = query_states.contiguous()
1098 key_states = key_states.contiguous()
1099 value_states = value_states.contiguous()
1100
1101 attn_output = torch.nn.functional.scaled_dot_product_attention(
1102 query_states,
1103 key_states,
1104 value_states,
1105 attn_mask=attention_mask,
1106 dropout_p=self.attention_dropout if self.training else 0.0,
1107 # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
1108 is_causal=self.is_causal and attention_mask is None and q_len > 1,
1109 )
1110
1111 attn_output = attn_output.transpose(1, 2).contiguous()
1112 attn_output = attn_output.view(bsz, q_len, self.hidden_size)
1113
1114 attn_output = self.o_proj(attn_output)
1115
1116 return attn_output, None, past_key_value
1117
1118
1119 PHI3_ATTENTION_CLASSES = {
1120 "eager": Phi3Attention,
1121 "flash_attention_2": Phi3FlashAttention2,
1122 "sdpa": Phi3SdpaAttention,
1123 }
1124
1125
1126 class Phi3DecoderLayer(nn.Module):
1127 def __init__(self, config: Phi3VConfig, layer_idx: int):
1128 super().__init__()
1129
1130 self.config = config
1131 self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
1132
1133 self.mlp = Phi3MLP(config)
1134 self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1135
1136 self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
1137 self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
1138 self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1139
1140 def forward(
1141 self,
1142 hidden_states: torch.Tensor,
1143 attention_mask: Optional[torch.Tensor] = None,
1144 position_ids: Optional[torch.LongTensor] = None,
1145 past_key_value: Optional[Tuple[torch.Tensor]] = None,
1146 output_attentions: Optional[bool] = False,
1147 use_cache: Optional[bool] = False,
1148 **kwargs,
1149 ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1150 if "padding_mask" in kwargs:
1151 warnings.warn(
1152 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1153 )
1154 """
1155 Args:
1156 hidden_states (`torch.FloatTensor`):
1157 input to the layer of shape `(batch, seq_len, embed_dim)`
1158 attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1159 `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
1160 position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
1161 Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
1162 `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
1163 output_attentions (`bool`, *optional*):
1164 Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1165 returned tensors for more detail.
1166 use_cache (`bool`, *optional*):
1167 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1168 (see `past_key_values`).
1169 past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1170 """
1171
1172 residual = hidden_states
1173
1174 hidden_states = self.input_layernorm(hidden_states)
1175
1176 # Self Attention
1177 attn_outputs, self_attn_weights, present_key_value = self.self_attn(
1178 hidden_states=hidden_states,
1179 attention_mask=attention_mask,
1180 position_ids=position_ids,
1181 past_key_value=past_key_value,
1182 output_attentions=output_attentions,
1183 use_cache=use_cache,
1184 )
1185
1186 hidden_states = residual + self.resid_attn_dropout(attn_outputs)
1187
1188 residual = hidden_states
1189 hidden_states = self.post_attention_layernorm(hidden_states)
1190 hidden_states = self.mlp(hidden_states)
1191 hidden_states = residual + self.resid_mlp_dropout(hidden_states)
1192
1193 outputs = (hidden_states,)
1194
1195 if output_attentions:
1196 outputs += (self_attn_weights,)
1197
1198 if use_cache:
1199 outputs += (present_key_value,)
1200
1201 return outputs
1202
1203
1204 PHI3V_START_DOCSTRING = r"""
1205 This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1206 library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1207 etc.)
1208
1209 This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1210 Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1211 and behavior.
1212
1213 Parameters:
1214 config ([`Phi3VConfig`]):
1215 Model configuration class with all the parameters of the model. Initializing with a config file does not
1216 load the weights associated with the model, only the configuration. Check out the
1217 [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1218 """
1219
1220
1221 @add_start_docstrings(
1222 "The bare Phi-3-V model outputting raw hidden-states without any specific head on top.",
1223 PHI3V_START_DOCSTRING,
1224 )
1225 class Phi3VPreTrainedModel(PreTrainedModel):
1226 config_class = Phi3VConfig
1227 base_model_prefix = "model"
1228 supports_gradient_checkpointing = True
1229 _no_split_modules = ["Phi3DecoderLayer"]
1230 _skip_keys_device_placement = "past_key_values"
1231 _supports_flash_attn_2 = True
1232 _supports_sdpa = False
1233 _supports_cache_class = True
1234
1235 _version = "0.0.5"
1236
1237 def _init_weights(self, module):
1238 std = self.config.initializer_range
1239 if isinstance(module, nn.Linear):
1240 module.weight.data.normal_(mean=0.0, std=std)
1241 if module.bias is not None:
1242 module.bias.data.zero_()
1243 elif isinstance(module, nn.Embedding):
1244 module.weight.data.normal_(mean=0.0, std=std)
1245 if module.padding_idx is not None:
1246 module.weight.data[module.padding_idx].zero_()
1247
1248
1249 PHI3V_INPUTS_DOCSTRING = r"""
1250 Args:
1251 input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1252 Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1253 it.
1254
1255 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1256 [`PreTrainedTokenizer.__call__`] for details.
1257
1258 [What are input IDs?](../glossary#input-ids)
1259 attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1260 Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1261
1262 - 1 for tokens that are **not masked**,
1263 - 0 for tokens that are **masked**.
1264
1265 [What are attention masks?](../glossary#attention-mask)
1266
1267 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1268 [`PreTrainedTokenizer.__call__`] for details.
1269
1270 If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1271 `past_key_values`).
1272
1273 If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1274 and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1275 information on the default strategy.
1276
1277 - 1 indicates the head is **not masked**,
1278 - 0 indicates the head is **masked**.
1279 position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1280 Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1281 config.n_positions - 1]`.
1282
1283 [What are position IDs?](../glossary#position-ids)
1284 past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1285 Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1286 blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1287 returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1288
1289 Two formats are allowed:
1290 - a [`~cache_utils.Cache`] instance;
1291 - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1292 shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1293 cache format.
1294
1295 The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1296 legacy cache format will be returned.
1297
1298 If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1299 have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1300 of shape `(batch_size, sequence_length)`.
1301 inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1302 Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1303 is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1304 model's internal embedding lookup matrix.
1305 pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
1306 The tensors corresponding to the input images. Pixel values can be obtained using [`AutoImageProcessor`].
1307 See [`Phi3ImageProcessor.__call__`] for details.
1308 image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
1309 The sizes of the images in the batch, being (height, width) for each image.
1310 use_cache (`bool`, *optional*):
1311 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1312 `past_key_values`).
1313 output_attentions (`bool`, *optional*):
1314 Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1315 tensors for more detail.
1316 output_hidden_states (`bool`, *optional*):
1317 Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1318 more detail.
1319 return_dict (`bool`, *optional*):
1320 Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1321 """
1322
1323
1324 @add_start_docstrings(
1325 "The bare Phi-3-V model outputting raw hidden-states without any specific head on top.",
1326 PHI3V_START_DOCSTRING,
1327 )
1328 class Phi3VModel(Phi3VPreTrainedModel):
1329 """
1330 Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
1331
1332 Args:
1333 config: Phi3Config
1334 """
1335
1336 def __init__(self, config: Phi3VConfig):
1337 super().__init__(config)
1338 self.padding_idx = config.pad_token_id
1339 self.vocab_size = config.vocab_size
1340
1341 self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1342 self.embed_dropout = nn.Dropout(config.embd_pdrop)
1343
1344 self.vision_embed_tokens = None
1345 if isinstance(config.embd_layer, dict):
1346 # vision embedding layer
1347 embedding_config = {
1348 'embedding_cls': config.embd_layer['embedding_cls'],
1349 **config.embd_layer
1350 }
1351 self.vision_embed_tokens = Phi3ImageEmbedding(config, wte=self.embed_tokens, **embedding_config)
1352 # # set wte the same for vision embedding
1353 # self.vision_embed_tokens.wte.weight = self.embed_tokens.weight
1354
1355 self.layers = nn.ModuleList(
1356 [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1357 )
1358 self._attn_implementation = config._attn_implementation
1359 self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1360
1361 self.gradient_checkpointing = False
1362 # Initialize weights and apply final processing
1363 self.post_init()
1364
1365 def get_input_embeddings(self):
1366 return self.embed_tokens
1367
1368 def set_input_embeddings(self, value):
1369 self.embed_tokens = value
1370
1371 @add_start_docstrings_to_model_forward(PHI3V_INPUTS_DOCSTRING)
1372 def forward(
1373 self,
1374 input_ids: torch.LongTensor = None,
1375 attention_mask: Optional[torch.Tensor] = None,
1376 position_ids: Optional[torch.LongTensor] = None,
1377 past_key_values: Optional[List[torch.FloatTensor]] = None,
1378 inputs_embeds: Optional[torch.FloatTensor] = None,
1379 pixel_values: Optional[torch.FloatTensor] = None,
1380 image_sizes: Optional[torch.LongTensor] = None,
1381 use_cache: Optional[bool] = None,
1382 output_attentions: Optional[bool] = None,
1383 output_hidden_states: Optional[bool] = None,
1384 return_dict: Optional[bool] = None,
1385 ) -> Union[Tuple, BaseModelOutputWithPast]:
1386 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1387 output_hidden_states = (
1388 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1389 )
1390 use_cache = use_cache if use_cache is not None else self.config.use_cache
1391
1392 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1393
1394 # retrieve input_ids and inputs_embeds
1395 if input_ids is not None and inputs_embeds is not None:
1396 raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1397 elif input_ids is not None:
1398 batch_size, seq_length = input_ids.shape[:2]
1399 elif inputs_embeds is not None:
1400 batch_size, seq_length = inputs_embeds.shape[:2]
1401 else:
1402 raise ValueError("You have to specify either input_ids or inputs_embeds")
1403
1404 past_key_values_length = 0
1405
1406 if self.gradient_checkpointing and self.training:
1407 if use_cache:
1408 logger.warning_once(
1409 "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1410 )
1411 use_cache = False
1412
1413 if use_cache:
1414 use_legacy_cache = not isinstance(past_key_values, Cache)
1415 if use_legacy_cache:
1416 past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1417 past_key_values_length = past_key_values.get_usable_length(seq_length)
1418
1419 if position_ids is None:
1420 device = input_ids.device if input_ids is not None else inputs_embeds.device
1421 position_ids = torch.arange(
1422 past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1423 )
1424 position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1425 else:
1426 position_ids = position_ids.view(-1, seq_length).long()
1427
1428 if inputs_embeds is None:
1429 if pixel_values is not None and image_sizes is not None:
1430 assert self.vision_embed_tokens is not None, "Vision embedding layer is not defined"
1431 inputs_embeds = self.vision_embed_tokens(input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
1432 else:
1433 inputs_embeds = self.embed_tokens(input_ids)
1434
1435 if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1436 is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1437 if is_padding_right:
1438 raise ValueError(
1439 "You are attempting to perform batched generation with padding_side='right'"
1440 " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to "
1441 " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1442 )
1443
1444 if self._attn_implementation == "flash_attention_2":
1445 # 2d mask is passed through the layers
1446 attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1447 else:
1448 # 4d mask is passed through the layers
1449 attention_mask = _prepare_4d_causal_attention_mask(
1450 attention_mask,
1451 (batch_size, seq_length),
1452 inputs_embeds,
1453 past_key_values_length,
1454 sliding_window=self.config.sliding_window,
1455 )
1456
1457 hidden_states = inputs_embeds
1458
1459 # decoder layers
1460 all_hidden_states = () if output_hidden_states else None
1461 all_self_attns = () if output_attentions else None
1462 next_decoder_cache = None
1463
1464 for decoder_layer in self.layers:
1465 if output_hidden_states:
1466 all_hidden_states += (hidden_states,)
1467
1468 if self.gradient_checkpointing and self.training:
1469 layer_outputs = self._gradient_checkpointing_func(
1470 decoder_layer.__call__,
1471 hidden_states,
1472 attention_mask,
1473 position_ids,
1474 past_key_values,
1475 output_attentions,
1476 use_cache,
1477 )
1478 else:
1479 layer_outputs = decoder_layer(
1480 hidden_states,
1481 attention_mask=attention_mask,
1482 position_ids=position_ids,
1483 past_key_value=past_key_values,
1484 output_attentions=output_attentions,
1485 use_cache=use_cache,
1486 )
1487
1488 hidden_states = layer_outputs[0]
1489
1490 if use_cache:
1491 next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1492
1493 if output_attentions:
1494 all_self_attns += (layer_outputs[1],)
1495
1496 hidden_states = self.norm(hidden_states)
1497
1498 # add hidden states from the last decoder layer
1499 if output_hidden_states:
1500 all_hidden_states += (hidden_states,)
1501
1502 next_cache = None
1503 if use_cache:
1504 next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1505 if not return_dict:
1506 return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1507 return BaseModelOutputWithPast(
1508 last_hidden_state=hidden_states,
1509 past_key_values=next_cache,
1510 hidden_states=all_hidden_states,
1511 attentions=all_self_attns,
1512 )
1513
1514
1515 class Phi3VForCausalLM(Phi3VPreTrainedModel):
1516 _tied_weights_keys = ["lm_head.weight"]
1517
1518 # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
1519 def __init__(self, config):
1520 super().__init__(config)
1521 self.model = Phi3VModel(config)
1522 self.vocab_size = config.vocab_size
1523 self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1524
1525 # Initialize weights and apply final processing
1526 self.post_init()
1527
1528 # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1529 def get_input_embeddings(self):
1530 return self.model.embed_tokens
1531
1532 # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1533 def set_input_embeddings(self, value):
1534 self.model.embed_tokens = value
1535
1536 # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1537 def get_output_embeddings(self):
1538 return self.lm_head
1539
1540 # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1541 def set_output_embeddings(self, new_embeddings):
1542 self.lm_head = new_embeddings
1543
1544 # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1545 def set_decoder(self, decoder):
1546 self.model = decoder
1547
1548 # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1549 def get_decoder(self):
1550 return self.model
1551
1552 # Ignore copy
1553 @add_start_docstrings_to_model_forward(PHI3V_INPUTS_DOCSTRING)
1554 @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1555 def forward(
1556 self,
1557 input_ids: torch.LongTensor = None,
1558 attention_mask: Optional[torch.Tensor] = None,
1559 position_ids: Optional[torch.LongTensor] = None,
1560 past_key_values: Optional[List[torch.FloatTensor]] = None,
1561 inputs_embeds: Optional[torch.FloatTensor] = None,
1562 pixel_values: Optional[torch.FloatTensor] = None,
1563 image_sizes: Optional[torch.LongTensor] = None,
1564 labels: Optional[torch.LongTensor] = None,
1565 use_cache: Optional[bool] = None,
1566 output_attentions: Optional[bool] = None,
1567 output_hidden_states: Optional[bool] = None,
1568 return_dict: Optional[bool] = None,
1569 ) -> Union[Tuple, CausalLMOutputWithPast]:
1570 r"""
1571 Args:
1572 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1573 Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1574 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1575 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1576
1577 Returns:
1578
1579 Example:
1580
1581 ```python
1582 >>> from transformers import AutoTokenizer, Phi3ForCausalLM
1583
1584 >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1585 >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1586
1587 >>> prompt = "This is an example script ."
1588 >>> inputs = tokenizer(prompt, return_tensors="pt")
1589
1590 >>> # Generate
1591 >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1592 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1593 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
1594 ```"""
1595
1596 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1597 output_hidden_states = (
1598 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1599 )
1600 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1601
1602 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1603 outputs = self.model(
1604 input_ids=input_ids,
1605 attention_mask=attention_mask,
1606 position_ids=position_ids,
1607 past_key_values=past_key_values,
1608 inputs_embeds=inputs_embeds,
1609 pixel_values=pixel_values,
1610 image_sizes=image_sizes,
1611 use_cache=use_cache,
1612 output_attentions=output_attentions,
1613 output_hidden_states=output_hidden_states,
1614 return_dict=return_dict,
1615 )
1616
1617 hidden_states = outputs[0]
1618 logits = self.lm_head(hidden_states)
1619 logits = logits.float()
1620
1621 loss = None
1622 if labels is not None:
1623 # Shift so that tokens < n predict n
1624 shift_logits = logits[..., :-1, :].contiguous()
1625 shift_labels = labels[..., 1:].contiguous()
1626 # Flatten the tokens
1627 loss_fct = CrossEntropyLoss()
1628 shift_logits = shift_logits.view(-1, self.config.vocab_size)
1629 shift_labels = shift_labels.view(-1)
1630 # Enable model parallelism
1631 shift_labels = shift_labels.to(shift_logits.device)
1632 loss = loss_fct(shift_logits, shift_labels)
1633
1634 if not return_dict:
1635 output = (logits,) + outputs[1:]
1636 return (loss,) + output if loss is not None else output
1637
1638 return CausalLMOutputWithPast(
1639 loss=loss,
1640 logits=logits,
1641 past_key_values=outputs.past_key_values,
1642 hidden_states=outputs.hidden_states,
1643 attentions=outputs.attentions,
1644 )
1645
1646 # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
1647 def prepare_inputs_for_generation(
1648 self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, pixel_values=None, image_sizes=None, **kwargs
1649 ):
1650 # When the first time input length reached long and short factor switching point, enforce re-compute cache
1651 # It will cause downside of slower at this single token position, however, better than current failure.
1652 if past_key_values and self.config.rope_scaling and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1:
1653 past_length = past_key_values.seen_tokens if isinstance(past_key_values, Cache) else past_key_values[0][0].shape[2]
1654 if past_length <= self.config.original_max_position_embeddings:
1655 past_key_values = None
1656
1657 if past_key_values is not None:
1658 if isinstance(past_key_values, Cache):
1659 cache_length = past_key_values.get_seq_length()
1660 past_length = past_key_values.seen_tokens
1661 max_cache_length = past_key_values.get_max_length()
1662 else:
1663 cache_length = past_length = past_key_values[0][0].shape[2]
1664 max_cache_length = None
1665
1666 # Keep only the unprocessed tokens:
1667 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1668 # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1669 # input)
1670 if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1671 input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1672 # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1673 # input_ids based on the past_length.
1674 elif past_length < input_ids.shape[1]:
1675 input_ids = input_ids[:, past_length:]
1676 # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1677
1678 # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1679 if (
1680 max_cache_length is not None
1681 and attention_mask is not None
1682 and cache_length + input_ids.shape[1] > max_cache_length
1683 ):
1684 attention_mask = attention_mask[:, -max_cache_length:]
1685
1686 position_ids = kwargs.get("position_ids", None)
1687 if attention_mask is not None and position_ids is None:
1688 # create position_ids on the fly for batch generation
1689 position_ids = attention_mask.long().cumsum(-1) - 1
1690 position_ids.masked_fill_(attention_mask == 0, 1)
1691 if past_key_values:
1692 position_ids = position_ids[:, -input_ids.shape[1] :]
1693
1694 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1695 if inputs_embeds is not None and past_key_values is None:
1696 model_inputs = {"inputs_embeds": inputs_embeds}
1697 else:
1698 model_inputs = {"input_ids": input_ids}
1699
1700 model_inputs.update(
1701 {
1702 "position_ids": position_ids,
1703 "past_key_values": past_key_values,
1704 "use_cache": kwargs.get("use_cache"),
1705 "attention_mask": attention_mask,
1706 "pixel_values": pixel_values,
1707 "image_sizes": image_sizes,
1708 }
1709 )
1710 return model_inputs
1711
1712 @staticmethod
1713 # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1714 def _reorder_cache(past_key_values, beam_idx):
1715 reordered_past = ()
1716 for layer_past in past_key_values:
1717 reordered_past += (
1718 tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1719 )
1720 return reordered_past
1721
1722
1723 @add_start_docstrings(
1724 """
1725 The [`Phi3VModel`] with a sequence classification head on top (linear layer).
1726
1727 [`Phi3VForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1728 (e.g. GPT-2) do.
1729
1730 Since it does classification on the last token, it requires to know the position of the last token. If a
1731 `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1732 no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1733 padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1734 each row of the batch).
1735 """,
1736 PHI3V_START_DOCSTRING,
1737 )
1738 # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
1739 class Phi3VForSequenceClassification(Phi3VPreTrainedModel):
1740 def __init__(self, config):
1741 super().__init__(config)
1742 self.num_labels = config.num_labels
1743 self.model = Phi3VModel(config)
1744 self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1745
1746 # Initialize weights and apply final processing
1747 self.post_init()
1748
1749 def get_input_embeddings(self):
1750 return self.model.embed_tokens
1751
1752 def set_input_embeddings(self, value):
1753 self.model.embed_tokens = value
1754
1755 @add_start_docstrings_to_model_forward(PHI3V_INPUTS_DOCSTRING)
1756 def forward(
1757 self,
1758 input_ids: torch.LongTensor = None,
1759 attention_mask: Optional[torch.Tensor] = None,
1760 position_ids: Optional[torch.LongTensor] = None,
1761 past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1762 inputs_embeds: Optional[torch.FloatTensor] = None,
1763 pixel_values: Optional[torch.FloatTensor] = None,
1764 image_sizes: Optional[torch.LongTensor] = None,
1765 labels: Optional[torch.LongTensor] = None,
1766 use_cache: Optional[bool] = None,
1767 output_attentions: Optional[bool] = None,
1768 output_hidden_states: Optional[bool] = None,
1769 return_dict: Optional[bool] = None,
1770 ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1771 r"""
1772 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1773 Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1774 config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1775 `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1776 """
1777 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1778
1779 model_outputs = self.model(
1780 input_ids,
1781 attention_mask=attention_mask,
1782 position_ids=position_ids,
1783 past_key_values=past_key_values,
1784 inputs_embeds=inputs_embeds,
1785 pixel_values=pixel_values,
1786 image_sizes=image_sizes,
1787 use_cache=use_cache,
1788 output_attentions=output_attentions,
1789 output_hidden_states=output_hidden_states,
1790 return_dict=return_dict,
1791 )
1792 hidden_states = model_outputs[0]
1793 logits = self.score(hidden_states)
1794
1795 if input_ids is not None:
1796 batch_size = input_ids.shape[0]
1797 else:
1798 batch_size = inputs_embeds.shape[0]
1799
1800 if self.config.pad_token_id is None and batch_size != 1:
1801 raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1802 if self.config.pad_token_id is None:
1803 sequence_lengths = -1
1804 else:
1805 if input_ids is not None:
1806 # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1807 sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1808 sequence_lengths = sequence_lengths % input_ids.shape[-1]
1809 sequence_lengths = sequence_lengths.to(logits.device)
1810 else:
1811 sequence_lengths = -1
1812
1813 pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1814
1815 loss = None
1816 if labels is not None:
1817 labels = labels.to(logits.device)
1818 if self.config.problem_type is None:
1819 if self.num_labels == 1:
1820 self.config.problem_type = "regression"
1821 elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1822 self.config.problem_type = "single_label_classification"
1823 else:
1824 self.config.problem_type = "multi_label_classification"
1825
1826 if self.config.problem_type == "regression":
1827 loss_fct = MSELoss()
1828 if self.num_labels == 1:
1829 loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1830 else:
1831 loss = loss_fct(pooled_logits, labels)
1832 elif self.config.problem_type == "single_label_classification":
1833 loss_fct = CrossEntropyLoss()
1834 loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1835 elif self.config.problem_type == "multi_label_classification":
1836 loss_fct = BCEWithLogitsLoss()
1837 loss = loss_fct(pooled_logits, labels)
1838 if not return_dict:
1839 output = (pooled_logits,) + model_outputs[1:]
1840 return ((loss,) + output) if loss is not None else output
1841
1842 return SequenceClassifierOutputWithPast(
1843 loss=loss,
1844 logits=pooled_logits,
1845 past_key_values=model_outputs.past_key_values,
1846 hidden_states=model_outputs.hidden_states,
1847 attentions=model_outputs.attentions,
1848 )
1849
1850
1851 @add_start_docstrings(
1852 """
1853 [`Phi3VModel`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1854 Named-Entity-Recognition (NER) tasks.
1855 """,
1856 PHI3V_START_DOCSTRING,
1857 )
1858 # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
1859 class Phi3VForTokenClassification(Phi3VPreTrainedModel):
1860 def __init__(self, config: Phi3VConfig):
1861 super().__init__(config)
1862 self.num_labels = config.num_labels
1863
1864 self.model = Phi3VModel(config)
1865 if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1866 classifier_dropout = config.classifier_dropout
1867 elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1868 classifier_dropout = config.hidden_dropout
1869 else:
1870 classifier_dropout = 0.1
1871 self.dropout = nn.Dropout(classifier_dropout)
1872 self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1873
1874 # Initialize weights and apply final processing
1875 self.post_init()
1876
1877 @add_start_docstrings_to_model_forward(PHI3V_INPUTS_DOCSTRING)
1878 @add_code_sample_docstrings(
1879 checkpoint=_CHECKPOINT_FOR_DOC,
1880 output_type=TokenClassifierOutput,
1881 config_class=_CONFIG_FOR_DOC,
1882 )
1883 def forward(
1884 self,
1885 input_ids: Optional[torch.LongTensor] = None,
1886 past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1887 attention_mask: Optional[torch.Tensor] = None,
1888 inputs_embeds: Optional[torch.Tensor] = None,
1889 pixel_values: Optional[torch.FloatTensor] = None,
1890 image_sizes: Optional[torch.LongTensor] = None,
1891 labels: Optional[torch.Tensor] = None,
1892 use_cache: Optional[bool] = None,
1893 output_attentions: Optional[bool] = None,
1894 output_hidden_states: Optional[bool] = None,
1895 return_dict: Optional[bool] = None,
1896 **deprecated_arguments,
1897 ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1898 r"""
1899 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1900 Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1901 config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1902 `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1903 """
1904 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1905
1906 model_outputs = self.model(
1907 input_ids,
1908 past_key_values=past_key_values,
1909 attention_mask=attention_mask,
1910 inputs_embeds=inputs_embeds,
1911 pixel_values=pixel_values,
1912 image_sizes=image_sizes,
1913 use_cache=use_cache,
1914 output_attentions=output_attentions,
1915 output_hidden_states=output_hidden_states,
1916 return_dict=return_dict,
1917 )
1918
1919 hidden_states = model_outputs[0]
1920 hidden_states = self.dropout(hidden_states)
1921 logits = self.classifier(hidden_states)
1922
1923 loss = None
1924 if labels is not None:
1925 # move labels to correct device to enable model parallelism
1926 labels = labels.to(logits.device)
1927 batch_size, seq_length = labels.shape
1928 loss_fct = CrossEntropyLoss()
1929 loss = loss_fct(
1930 logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1931 )
1932
1933 if not return_dict:
1934 output = (logits,) + model_outputs[2:]
1935 return ((loss,) + output) if loss is not None else output
1936
1937 return TokenClassifierOutput(
1938 loss=loss,
1939 logits=logits,
1940 hidden_states=model_outputs.hidden_states,
1941 attentions=model_outputs.attentions,
1942 )
1943