modeling_openelm.py
38.4 KB · 1009 lines · python Raw
1 #
2 # For licensing see accompanying LICENSE file.
3 # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 #
5
6 from typing import List, Optional, Tuple, Union
7
8 import torch
9 import torch.utils.checkpoint
10 from torch import Tensor, nn
11 from torch.nn import CrossEntropyLoss
12 from torch.nn import functional as F
13 from transformers import PreTrainedModel
14 from transformers.activations import ACT2FN
15 from transformers.cache_utils import Cache, DynamicCache, StaticCache
16 from transformers.modeling_outputs import (
17 BaseModelOutputWithPast,
18 CausalLMOutputWithPast,
19 )
20 from transformers.utils import logging
21
22 logger = logging.get_logger(__name__)
23
24 # this import has to be relative, otherwise, when setting trust_remote_code=True
25 # huggingface transformers won't be able to load the module correctly
26 from .configuration_openelm import OpenELMConfig, make_divisible
27
28
29 class OpenELMRMSNorm(nn.Module):
30 def __init__(self, num_features: int, eps: float = 1e-6):
31 """
32 Initialize the OpenELMRMSNorm normalization layer.
33
34 Args:
35 dim (int): The dimension of the input tensor.
36 eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
37
38 Attributes:
39 eps (float): A small value added to the denominator for numerical stability.
40 weight (nn.Parameter): Learnable scaling parameter.
41
42 """
43 super().__init__()
44 self.eps = eps
45 self.weight = nn.Parameter(torch.ones(num_features))
46 self.num_features = num_features
47
48 def _norm(self, x: Tensor) -> Tensor:
49 """
50 Apply the OpenELMRMSNorm normalization to the input tensor.
51
52 Args:
53 x (torch.Tensor): The input tensor.
54
55 Returns:
56 torch.Tensor: The normalized tensor.
57
58 """
59 return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
60
61 def forward(self, x: Tensor) -> Tensor:
62 """
63 Forward pass through the OpenELMRMSNorm layer.
64
65 Args:
66 x (torch.Tensor): The input tensor.
67
68 Returns:
69 torch.Tensor: The output tensor after applying OpenELMRMSNorm.
70
71 """
72 output = self._norm(x.float()).type_as(x)
73 return output * self.weight
74
75 def extra_repr(self) -> str:
76 return (
77 super().extra_repr() + f"num_features={self.num_features}, eps={self.eps}"
78 )
79
80
81 class OpenELMPreTrainedModel(PreTrainedModel):
82 config_class = OpenELMConfig
83 base_model_prefix = "transformer"
84 supports_gradient_checkpointing = True
85 _no_split_modules = ["OpenELMDecoderLayer"]
86 _skip_keys_device_placement = "past_key_values"
87
88 def __init__(self, *inputs, **kwargs) -> None:
89 super().__init__(*inputs, **kwargs)
90
91 def _init_weights(self, module: nn.Module) -> None:
92 """Initialize the weights."""
93 if isinstance(module, nn.Linear):
94 # Slightly different from the TF version which uses truncated_normal for initialization
95 # cf https://github.com/pytorch/pytorch/pull/5617
96 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
97 if module.bias is not None:
98 module.bias.data.zero_()
99 elif isinstance(module, nn.Embedding):
100 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
101 if module.padding_idx is not None:
102 module.weight.data[module.padding_idx].zero_()
103 elif isinstance(module, OpenELMRMSNorm):
104 module.weight.data.fill_(1.0)
105
106
107 def _rotate_half(x: Tensor) -> Tensor:
108 x1, x2 = x.chunk(2, dim=-1)
109 return torch.cat((-x2, x1), dim=-1)
110
111
112 def _apply_rotary_pos_emb(x: Tensor, pos_sin: Tensor, pos_cos: Tensor) -> Tensor:
113 return (x * pos_cos) + (_rotate_half(x) * pos_sin)
114
115
116 class OpenELMRotaryEmbedding(torch.nn.Module):
117 """
118 The rotary position embeddings (aka RoPE) from `RoFormer <https://arxiv.org/abs/2104.09864>`_.
119
120 RoPE encodes the position information of tokens using a rotation matrix, and is able to capture
121 explicit relative positional dependencies.
122
123 Args:
124 model_dim: The dimensionality of the model's hidden state.
125 max_seq_length: Maximum sequence length.
126 freq_constant: A constant used for computing frequencies.
127 """
128
129 def __init__(
130 self, model_dim: int, max_seq_length: int, freq_constant: int = 10000
131 ) -> None:
132 inv_freq = 1.0 / (
133 freq_constant
134 ** (torch.arange(0, model_dim, 2, dtype=torch.float32) / model_dim)
135 )
136 super().__init__()
137
138 self.model_dim = model_dim
139 self.freq_constant = freq_constant
140 self.max_seq_length = max_seq_length
141
142 self.register_buffer("inv_freq", inv_freq, persistent=False)
143 self._cached_cos = None
144 self._cached_sin = None
145 self._cached_seq_length = max_seq_length
146 self._compute_sin_cos_embeddings(max_seq_length)
147
148 def extra_repr(self) -> str:
149 return f"\tmodel_dim={self.model_dim}, max_seq_length={self.max_seq_length}, freq_constant={self.freq_constant}"
150
151 def _compute_sin_cos_embeddings(
152 self,
153 key_len: int,
154 key_device: torch.device = torch.device("cpu"),
155 key_dtype: torch.dtype = torch.float32,
156 ) -> None:
157 """
158 Compute sine and cos embeddings.
159
160 Args:
161 key_len: Number of tokens in the key embeddings in the transformer model.
162 device: Device where the key embeddings are stored.
163 key_dtype: Data type of the key embeddings.
164
165 Returns:
166 None
167
168 ...note:
169 We recalculate the sine and cosine embeddings if any of the following conditions are met:
170 1. The number of tokens in key embeddings are greater than the cached sequence length.
171 2. Sine and cosine caches are empty.
172 3. The device and data type of sine and cosine embeddings does not match with the key embeddings.
173 """
174 if (
175 key_len > self._cached_seq_length
176 or self._cached_cos is None
177 or (self._cached_cos is not None and self._cached_cos.device != key_device)
178 or (self._cached_cos is not None and self._cached_cos.dtype != key_dtype)
179 or self._cached_sin is None
180 or (self._cached_sin is not None and self._cached_sin.device != key_device)
181 or (self._cached_sin is not None and self._cached_sin.dtype != key_dtype)
182 ):
183 self._cached_seq_length = max(key_len, self._cached_seq_length)
184
185 # The shape of 'pos_index' is [number of key tokens]
186 pos_index = torch.arange(
187 self._cached_seq_length,
188 dtype=torch.float32,
189 device=self.inv_freq.device,
190 )
191 # The shape of 'pos_index_theta' is [number of key tokens, model dimension]
192 pos_index_theta = torch.einsum("i,j->ij", pos_index, self.inv_freq)
193 # The shape of 'emb' is [number of key tokens, model dimension]
194 emb = torch.cat((pos_index_theta, pos_index_theta), dim=-1)
195
196 # the shape of cos and sin embeddings is [number of key tokens, model_dim]
197 cos_emb = emb.cos().to(dtype=key_dtype, device=key_device)
198 sin_emb = emb.sin().to(dtype=key_dtype, device=key_device)
199
200 # the shape of cached cos and sin embeddings is [1, 1, number of key tokens, model_dim]
201 self._cached_cos = cos_emb[None, None, :, :]
202 self._cached_sin = sin_emb[None, None, :, :]
203
204 def forward(
205 self,
206 query: torch.Tensor,
207 key: torch.Tensor,
208 ) -> Tuple[torch.Tensor, torch.Tensor]:
209 """
210 The forward function of RoPE embeddings.
211
212 Args:
213 query: Query embeddings in the transformer model. The shape of query embeddings is
214 [Batch, number of query heads, number of query tokens, model dimension].
215 key: Key embeddings in the transformer model. The shape of key embeddings is
216 [Batch, number of key heads, number of key tokens, model dimension].
217
218 Returns:
219 A tuple containing the query and key embeddings with positional information. The shape of the returned query
220 and key embeddings is the same as the input query and key embeddings respectively.
221
222 ...note:
223 The RoPE embedding computation is done in full-precision. After the computation, input query and key tensors
224 are casted to original input datatype.
225 """
226 dim = key.shape[-1]
227 key_len = key.shape[2]
228 query_len = query.shape[2]
229
230 assert dim == self.model_dim
231 assert key.device == query.device
232 assert key.dtype == query.dtype
233
234 # In the context of self-attention, the lengths of keys and queries are equal.
235 # However, in generation tasks, such as predicting the next token in a sequence, the lengths of keys and queries
236 # can differ. For instance, when employing key-value (KV) caching for sequence prediction, the keys
237 # represent embeddings of previous tokens and the current token, while the query corresponds
238 # to the embedding of the current token only.
239 assert (
240 key_len >= query_len
241 ), "Number of keys has to be greater than or equal to number of queries."
242
243 query_float = query.float()
244 key_float = key.float()
245
246 self._compute_sin_cos_embeddings(
247 key_len, key_device=key_float.device, key_dtype=key_float.dtype
248 )
249 query_float = _apply_rotary_pos_emb(
250 x=query_float,
251 pos_sin=self._cached_sin[..., key_len - query_len : key_len, :],
252 pos_cos=self._cached_cos[..., key_len - query_len : key_len, :],
253 )
254 key_float = _apply_rotary_pos_emb(
255 x=key_float,
256 pos_sin=self._cached_sin[..., :key_len, :],
257 pos_cos=self._cached_cos[..., :key_len, :],
258 )
259
260 return query_float.type_as(query), key_float.type_as(key)
261
262
263 class OpenELMMultiHeadCausalAttention(nn.Module):
264 def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
265 super().__init__()
266 self.layer_idx = layer_idx
267 head_dim = config.head_dim
268 q_heads = config.num_query_heads[layer_idx]
269 k_heads = config.num_kv_heads[layer_idx]
270 v_heads = config.num_kv_heads[layer_idx]
271
272 self.qkv_proj = nn.Linear(
273 in_features=config.model_dim,
274 out_features=(q_heads + k_heads + v_heads) * head_dim,
275 bias=False,
276 )
277
278 self.pos_embedding = OpenELMRotaryEmbedding(
279 model_dim=config.head_dim,
280 max_seq_length=config.rope_max_length,
281 freq_constant=config.rope_freq_constant,
282 )
283
284 if config.normalize_qk_projections:
285 self.q_norm = OpenELMRMSNorm(
286 num_features=config.head_dim,
287 )
288 self.k_norm = OpenELMRMSNorm(
289 num_features=config.head_dim,
290 )
291 else:
292 self.q_norm = None
293 self.k_norm = None
294
295 self.out_proj = nn.Linear(
296 in_features=q_heads * head_dim,
297 out_features=config.model_dim,
298 bias=False,
299 )
300
301 self.head_dim = config.head_dim
302 self.num_q_heads = q_heads
303 self.num_k_heads = k_heads
304 self.num_v_heads = v_heads
305 self.transformer_dim = config.model_dim
306 self.num_groups = self.num_q_heads // self.num_k_heads
307
308 def extra_repr(self) -> str:
309 return (
310 super().extra_repr()
311 + f"query_heads={self.num_q_heads}, key_heads={self.num_k_heads}, value_heads={self.num_v_heads}"
312 )
313
314 def forward(
315 self,
316 hidden_states: torch.Tensor,
317 attention_mask: Optional[torch.Tensor] = None,
318 past_key_value: Optional[Cache] = None,
319 output_attentions: bool = False,
320 use_cache: bool = False,
321 cache_position: Optional[torch.LongTensor] = None,
322 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
323 """
324 Forward pass of multi-head self-attention.
325
326 Args:
327 hidden_states: Input tensor of the shape [batch size, sequence length, model dimension].
328 past_key_value: Tensor storing the cached keys and values.
329 output_attentions: output attention weights.
330 use_cache: Specifies whether to use kv-cache for generation.
331 cache_position: used for updating the kv-cache.
332
333 Returns:
334 The output of the same shape as the input, optionally with a tensor containing cached keys and values.
335 """
336
337 # scaled_dot_product_attention does not return attention weights, set output_attentions to False
338 output_attentions = False
339 batch_size, seq_length, d_model = hidden_states.size()
340
341 # [B, S, d] --> [B, S, (q_h + k_h + v_h) * h]
342 qkv = self.qkv_proj(hidden_states)
343 # [B, S, (q_h + k_h + v_h) * h] --> [B, S, (q_h + k_h + v_h), h]
344 qkv = qkv.reshape(
345 batch_size,
346 seq_length,
347 self.num_q_heads + self.num_k_heads + self.num_v_heads,
348 self.head_dim,
349 )
350 # [B, S, (q_h + k_h + v_h), h] --> [B, (q_h + k_h + v_h), S, h]
351 qkv = qkv.transpose(1, 2)
352 # [B, (q_h + k_h + v_h), S, h] --> [B, q_h, S h], [B, k_h, S, h], [B, v_h, S, h]
353 queries, keys, values = qkv.split(
354 [self.num_q_heads, self.num_k_heads, self.num_v_heads], dim=1
355 )
356
357 if self.q_norm is not None:
358 queries = self.q_norm(queries)
359
360 if self.k_norm is not None:
361 keys = self.k_norm(keys)
362
363 past_key_value = getattr(self, "past_key_value", past_key_value)
364
365 if past_key_value is not None:
366 # sin and cos are specific to RoPE models; position_ids needed for the static cache
367 # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
368 cache_kwargs = {"cache_position": cache_position}
369 keys, values = past_key_value.update(
370 keys, values, self.layer_idx, cache_kwargs
371 )
372
373 # Add positional embedding
374 queries, keys = self.pos_embedding(queries, keys)
375
376 if self.num_groups != 1:
377 # GQA
378 # [B, k_h, S, h] --> [B, q_h, S, h]
379 keys = keys.repeat_interleave(self.num_groups, dim=1)
380 # [B, v_h, S, h] --> [B, q_h, S, h]
381 values = values.repeat_interleave(self.num_groups, dim=1)
382
383 causal_mask = attention_mask
384 if attention_mask is not None and cache_position is not None:
385 causal_mask = causal_mask[:, :, cache_position, : keys.shape[-2]]
386
387 attn_output = F.scaled_dot_product_attention(
388 queries,
389 keys,
390 values,
391 attn_mask=causal_mask,
392 dropout_p=0,
393 )
394
395 attn_output = attn_output.transpose(1, 2).contiguous()
396 attn_output = attn_output.reshape(
397 batch_size, seq_length, self.num_q_heads * self.head_dim
398 )
399 attn_output = self.out_proj(attn_output)
400 if not output_attentions:
401 attn_weights = None
402 return attn_output, attn_weights, past_key_value
403
404
405 class OpenELMFeedForwardNetwork(nn.Module):
406 def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
407 super().__init__()
408 ffn_multiplier = config.ffn_multipliers[layer_idx]
409 intermediate_dim = int(
410 make_divisible(
411 ffn_multiplier * config.model_dim,
412 divisor=config.ffn_dim_divisor,
413 )
414 )
415 if config.ffn_with_glu:
416 # FFN with Gated linear unit, as described in https://arxiv.org/abs/2002.05202v1.
417 self.proj_1 = nn.Linear(
418 in_features=config.model_dim,
419 out_features=2 * intermediate_dim,
420 bias=False,
421 )
422 self.proj_2 = nn.Linear(
423 in_features=intermediate_dim,
424 out_features=config.model_dim,
425 bias=False,
426 )
427 self.ffn_with_glu = True
428 else:
429 # Standard FFN, as described in https://arxiv.org/abs/1706.03762
430 self.proj_1 = nn.Linear(
431 in_features=config.model_dim,
432 out_features=intermediate_dim,
433 bias=False,
434 )
435 self.proj_2 = nn.Linear(
436 in_features=intermediate_dim,
437 out_features=config.model_dim,
438 bias=False,
439 )
440 self.ffn_with_glu = False
441
442 self.act = ACT2FN[config.activation_fn_name]
443
444 def extra_repr(self) -> str:
445 return super().extra_repr() + f"(ffn_with_glu) : {self.ffn_with_glu}"
446
447 def forward(self, x: Tensor) -> Tensor:
448 """Forward function of FFN layer.
449
450 Args:
451 x: Input tensor of the shape [batch size, sequence length, model dimension].
452
453 Returns:
454 A tensor of the same shape as the input.
455 """
456 if self.ffn_with_glu:
457 y_12 = self.proj_1(x)
458 y_1, y_2 = y_12.chunk(2, dim=-1)
459 y = self.act(y_1) * y_2
460 return self.proj_2(y)
461 else:
462 return self.proj_2(self.act(self.proj_1(x)))
463
464
465 class OpenELMDecoderLayer(nn.Module):
466 def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
467 super().__init__()
468 self.attn = OpenELMMultiHeadCausalAttention(config=config, layer_idx=layer_idx)
469 self.ffn = OpenELMFeedForwardNetwork(config=config, layer_idx=layer_idx)
470 self.ffn_norm = OpenELMRMSNorm(
471 num_features=config.model_dim,
472 )
473 self.attn_norm = OpenELMRMSNorm(
474 num_features=config.model_dim,
475 )
476
477 def forward(
478 self,
479 hidden_states: torch.Tensor,
480 attention_mask: Optional[torch.Tensor] = None,
481 position_ids: Optional[torch.LongTensor] = None,
482 past_key_value: Optional[Tuple[torch.Tensor]] = None,
483 output_attentions: Optional[bool] = False,
484 use_cache: Optional[bool] = False,
485 cache_position: Optional[torch.LongTensor] = None,
486 **kwargs,
487 ) -> Tuple[
488 torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
489 ]:
490 """
491 Args:
492 hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
493 attention_mask (`torch.FloatTensor`, *optional*):
494 attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
495 query_sequence_length, key_sequence_length)` if default attention is used.
496 output_attentions (`bool`, *optional*):
497 Whether or not to return the attentions tensors of all attention layers. See `attentions` under
498 returned tensors for more detail.
499 use_cache (`bool`, *optional*):
500 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
501 (see `past_key_values`).
502 past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
503 """
504 residual = hidden_states
505 hidden_states = self.attn_norm(hidden_states)
506
507 # Self Attention
508 hidden_states, self_attn_weights, present_key_value = self.attn(
509 hidden_states=hidden_states,
510 attention_mask=attention_mask,
511 past_key_value=past_key_value,
512 output_attentions=output_attentions,
513 use_cache=use_cache,
514 cache_position=cache_position,
515 **kwargs,
516 )
517 hidden_states = residual + hidden_states
518
519 # Fully Connected
520 residual = hidden_states
521 hidden_states = self.ffn_norm(hidden_states)
522 hidden_states = self.ffn(hidden_states)
523 hidden_states = residual + hidden_states
524
525 outputs = (hidden_states,)
526
527 if output_attentions:
528 outputs += (self_attn_weights,)
529
530 if use_cache:
531 outputs += (present_key_value,)
532
533 return outputs
534
535
536 class OpenELMModel(OpenELMPreTrainedModel):
537 config_class = OpenELMConfig
538
539 def __init__(self, config: OpenELMConfig):
540 super().__init__(config)
541 self.config = config
542
543 self.token_embeddings = nn.Embedding(
544 embedding_dim=config.model_dim,
545 num_embeddings=config.vocab_size,
546 )
547
548 self.layers = nn.ModuleList(
549 OpenELMDecoderLayer(config=config, layer_idx=layer_idx)
550 for layer_idx in range(config.num_transformer_layers)
551 )
552 self.norm = OpenELMRMSNorm(num_features=config.model_dim)
553 if config.share_input_output_layers:
554 self.classifier = None
555 else:
556 self.classifier = nn.Linear(
557 in_features=config.model_dim,
558 out_features=config.vocab_size,
559 bias=False,
560 )
561 self.num_transformer_layers = config.num_transformer_layers
562 self.gradient_checkpointing = False
563
564 # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
565 # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_context_length`.
566 causal_mask = torch.full(
567 (config.max_context_length, config.max_context_length),
568 fill_value=True,
569 dtype=torch.bool,
570 )
571 self.register_buffer(
572 "causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False
573 )
574
575 # Initialize weights and apply final processing
576 self.post_init()
577 self.reset_parameters(config=config)
578
579 def get_input_embeddings(self):
580 return self.token_embeddings
581
582 def set_input_embeddings(self, new_embeddings: torch.Tensor):
583 self.token_embeddings = new_embeddings
584
585 def reset_parameters(self, config: OpenELMConfig) -> None:
586 """Initialize the layers in Language Model
587
588 The initialization scheme is followed, following `OPT <https://arxiv.org/pdf/2205.01068.pdf>`_.
589
590 Args:
591 use_megatron_std: Use standard deviation as described in Megatron-LM.
592
593 Returns:
594 None
595 """
596 for module in self.modules():
597 if isinstance(module, nn.Linear):
598 std = module.in_features**-0.5
599 torch.nn.init.normal_(module.weight, mean=0.0, std=std)
600 if module.bias is not None:
601 torch.nn.init.zeros_(module.bias)
602 elif isinstance(module, nn.Embedding):
603 std = module.embedding_dim**-0.5
604 torch.nn.init.normal_(module.weight, mean=0.0, std=std)
605 elif isinstance(module, OpenELMRMSNorm):
606 if module.weight is not None:
607 torch.nn.init.ones_(module.weight)
608 if hasattr(module, "bias") and module.bias is not None:
609 torch.nn.init.zeros_(module.bias)
610
611 model_dim = config.model_dim
612 n_layers = config.num_transformer_layers
613 std = (model_dim**-0.5) * ((2 * n_layers) ** -0.5)
614 for param_name, param in self.named_parameters():
615 if param_name.endswith("out_proj.weight") or param_name.endswith(
616 "ffn.proj_2.weight"
617 ):
618 torch.nn.init.normal_(param, mean=0.0, std=std)
619
620 def forward(
621 self,
622 input_ids: torch.LongTensor = None,
623 attention_mask: Optional[torch.Tensor] = None,
624 position_ids: Optional[torch.LongTensor] = None,
625 past_key_values: Optional[List[torch.FloatTensor]] = None,
626 inputs_embeds: Optional[torch.FloatTensor] = None,
627 use_cache: Optional[bool] = None,
628 output_attentions: Optional[bool] = None,
629 output_hidden_states: Optional[bool] = None,
630 return_dict: Optional[bool] = None,
631 cache_position: Optional[torch.LongTensor] = None,
632 ) -> Union[Tuple, BaseModelOutputWithPast]:
633 output_attentions = (
634 output_attentions
635 if output_attentions is not None
636 else self.config.output_attentions
637 )
638 output_hidden_states = (
639 output_hidden_states
640 if output_hidden_states is not None
641 else self.config.output_hidden_states
642 )
643 use_cache = use_cache if use_cache is not None else self.config.use_cache
644 return_dict = (
645 return_dict if return_dict is not None else self.config.use_return_dict
646 )
647
648 if (input_ids is None) ^ (inputs_embeds is not None):
649 raise ValueError(
650 "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
651 )
652
653 if self.gradient_checkpointing and self.training and use_cache:
654 logger.warning_once(
655 "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
656 )
657 use_cache = False
658
659 if inputs_embeds is None:
660 inputs_embeds = self.token_embeddings(input_ids)
661
662 past_seen_tokens = 0
663 if use_cache: # kept for BC (cache positions)
664 if not isinstance(past_key_values, StaticCache):
665 past_key_values = DynamicCache.from_legacy_cache(past_key_values)
666 past_seen_tokens = past_key_values.get_seq_length()
667
668 if cache_position is None:
669 cache_position = torch.arange(
670 past_seen_tokens,
671 past_seen_tokens + inputs_embeds.shape[1],
672 device=inputs_embeds.device,
673 )
674
675 if position_ids is None:
676 position_ids = cache_position.unsqueeze(0)
677
678 causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
679
680 # embed positions
681 hidden_states = inputs_embeds
682
683 # decoder layers
684 all_hidden_states = () if output_hidden_states else None
685 all_self_attns = () if output_attentions else None
686 next_decoder_cache = None
687
688 for decoder_layer in self.layers:
689 if output_hidden_states:
690 all_hidden_states += (hidden_states,)
691
692 if self.gradient_checkpointing and self.training:
693 layer_outputs = self._gradient_checkpointing_func(
694 decoder_layer.__call__,
695 hidden_states,
696 causal_mask,
697 position_ids,
698 past_key_values,
699 output_attentions,
700 use_cache,
701 cache_position,
702 )
703 else:
704 layer_outputs = decoder_layer(
705 hidden_states,
706 attention_mask=causal_mask,
707 position_ids=position_ids,
708 past_key_value=past_key_values,
709 output_attentions=output_attentions,
710 use_cache=use_cache,
711 cache_position=cache_position,
712 )
713
714 hidden_states = layer_outputs[0]
715
716 if use_cache:
717 next_decoder_cache = layer_outputs[2 if output_attentions else 1]
718
719 if output_attentions:
720 all_self_attns += (layer_outputs[1],)
721
722 hidden_states = self.norm(hidden_states)
723
724 # add hidden states from the last decoder layer
725 if output_hidden_states:
726 all_hidden_states += (hidden_states,)
727
728 next_cache = None
729 if use_cache:
730 next_cache = (
731 next_decoder_cache.to_legacy_cache()
732 if isinstance(next_decoder_cache, Cache)
733 else next_decoder_cache
734 )
735 if not return_dict:
736 return tuple(
737 v
738 for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
739 if v is not None
740 )
741 return BaseModelOutputWithPast(
742 last_hidden_state=hidden_states,
743 past_key_values=next_cache,
744 hidden_states=all_hidden_states,
745 attentions=all_self_attns,
746 )
747
748 def _update_causal_mask(self, attention_mask, input_tensor):
749 if self.config._attn_implementation == "flash_attention_2":
750 if attention_mask is not None and 0.0 in attention_mask:
751 return attention_mask
752 return None
753
754 batch_size, seq_length = input_tensor.shape[:2]
755 dtype = input_tensor.dtype
756 device = input_tensor.device
757
758 # support going beyond cached `max_position_embedding`
759 if seq_length > self.causal_mask.shape[-1]:
760 causal_mask = torch.full(
761 (2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]),
762 fill_value=1,
763 )
764 self.register_buffer(
765 "causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False
766 )
767
768 # We use the current dtype to avoid any overflows
769 min_dtype = torch.finfo(dtype).min
770 causal_mask = (
771 self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype)
772 * min_dtype
773 )
774
775 causal_mask = causal_mask.to(dtype=dtype, device=device)
776 if attention_mask is not None and attention_mask.dim() == 2:
777 mask_length = attention_mask.shape[-1]
778 padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
779 :, None, None, :
780 ].eq(0.0)
781 causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
782 padding_mask, min_dtype
783 )
784
785 if self.config._attn_implementation == "sdpa" and attention_mask is not None:
786 # For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
787 is_tracing = (
788 torch.jit.is_tracing()
789 or isinstance(input_tensor, torch.fx.Proxy)
790 or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
791 )
792 if not is_tracing and torch.any(attention_mask != 1):
793 # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
794 # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
795 # Details: https://github.com/pytorch/pytorch/issues/110213
796 causal_mask = causal_mask.mul(
797 ~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)
798 ).to(dtype)
799
800 return causal_mask
801
802
803 class OpenELMForCausalLM(OpenELMPreTrainedModel):
804 _tied_weights_keys = ["lm_head.weight"]
805
806 def __init__(self, config: OpenELMConfig):
807 super().__init__(config)
808 self.transformer = OpenELMModel(config)
809 self.vocab_size = config.vocab_size
810 if config.share_input_output_layers:
811 self.lm_head = None
812 else:
813 self.lm_head = nn.Linear(config.model_dim, config.vocab_size, bias=False)
814
815 # Initialize weights and apply final processing
816 self.post_init()
817
818 def get_input_embeddings(self):
819 return self.transformer.token_embeddings
820
821 def set_input_embeddings(self, value):
822 self.transformer.token_embeddings = value
823
824 def get_output_embeddings(self):
825 return self.lm_head
826
827 def set_output_embeddings(self, new_embeddings):
828 self.lm_head = new_embeddings
829
830 def set_decoder(self, decoder):
831 self.transformer = decoder
832
833 def get_decoder(self):
834 return self.transformer
835
836 def forward(
837 self,
838 input_ids: torch.LongTensor = None,
839 attention_mask: Optional[torch.Tensor] = None,
840 position_ids: Optional[torch.LongTensor] = None,
841 past_key_values: Optional[List[torch.FloatTensor]] = None,
842 inputs_embeds: Optional[torch.FloatTensor] = None,
843 labels: Optional[torch.LongTensor] = None,
844 use_cache: Optional[bool] = None,
845 output_attentions: Optional[bool] = None,
846 output_hidden_states: Optional[bool] = None,
847 return_dict: Optional[bool] = None,
848 cache_position: Optional[torch.LongTensor] = None,
849 ) -> Union[Tuple, CausalLMOutputWithPast]:
850 output_attentions = (
851 output_attentions
852 if output_attentions is not None
853 else self.config.output_attentions
854 )
855 output_hidden_states = (
856 output_hidden_states
857 if output_hidden_states is not None
858 else self.config.output_hidden_states
859 )
860 return_dict = (
861 return_dict if return_dict is not None else self.config.use_return_dict
862 )
863 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
864 outputs = self.transformer(
865 input_ids=input_ids,
866 attention_mask=attention_mask,
867 position_ids=position_ids,
868 past_key_values=past_key_values,
869 inputs_embeds=inputs_embeds,
870 use_cache=use_cache,
871 output_attentions=output_attentions,
872 output_hidden_states=output_hidden_states,
873 return_dict=return_dict,
874 cache_position=cache_position,
875 )
876
877 hidden_states = outputs[0]
878 if self.lm_head is None:
879 # shared
880 logits = F.linear(
881 hidden_states, weight=self.transformer.token_embeddings.weight
882 )
883 else:
884 logits = self.lm_head(hidden_states)
885 logits = logits[:, : self.config.vocab_size]
886 loss = None
887 if labels is not None:
888 # Shift so that tokens < n predict n
889 shift_logits = logits[..., :-1, :].contiguous()
890 shift_labels = labels[..., 1:].contiguous()
891 # Flatten the tokens
892 loss_fct = CrossEntropyLoss()
893 shift_logits = shift_logits.view(-1, self.config.vocab_size)
894 shift_labels = shift_labels.view(-1)
895 # Enable model parallelism
896 shift_labels = shift_labels.to(shift_logits.device)
897 loss = loss_fct(shift_logits, shift_labels)
898
899 if not return_dict:
900 output = (logits,) + outputs[1:]
901 return (loss,) + output if loss is not None else output
902
903 return CausalLMOutputWithPast(
904 loss=loss,
905 logits=logits,
906 past_key_values=outputs.past_key_values,
907 hidden_states=outputs.hidden_states,
908 attentions=outputs.attentions,
909 )
910
911 def prepare_inputs_for_generation(
912 self,
913 input_ids,
914 past_key_values=None,
915 attention_mask=None,
916 inputs_embeds=None,
917 **kwargs,
918 ):
919 past_length = 0
920 if past_key_values is not None:
921 if isinstance(past_key_values, Cache):
922 cache_length = past_key_values.get_seq_length()
923 past_length = past_key_values.seen_tokens
924 max_cache_length = past_key_values.get_max_length()
925 else:
926 cache_length = past_length = past_key_values[0][0].shape[2]
927 max_cache_length = None
928
929 # Keep only the unprocessed tokens:
930 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
931 # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
932 # input)
933 if (
934 attention_mask is not None
935 and attention_mask.shape[1] > input_ids.shape[1]
936 ):
937 input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
938 # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
939 # input_ids based on the past_length.
940 elif past_length < input_ids.shape[1]:
941 input_ids = input_ids[:, past_length:]
942 # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
943
944 # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
945 if (
946 max_cache_length is not None
947 and attention_mask is not None
948 and cache_length + input_ids.shape[1] > max_cache_length
949 ):
950 attention_mask = attention_mask[:, -max_cache_length:]
951
952 position_ids = kwargs.get("position_ids", None)
953 if attention_mask is not None and position_ids is None:
954 # create position_ids on the fly for batch generation
955 position_ids = attention_mask.long().cumsum(-1) - 1
956 position_ids.masked_fill_(attention_mask == 0, 1)
957 if past_key_values:
958 position_ids = position_ids[:, -input_ids.shape[1] :]
959
960 if self.generation_config.cache_implementation == "static":
961 # generation with static cache
962 cache_position = kwargs.get("cache_position", None)
963 if cache_position is None:
964 past_length = 0
965 else:
966 past_length = cache_position[-1] + 1
967 input_ids = input_ids[:, past_length:]
968 position_ids = position_ids[:, past_length:]
969
970 # we should only keep a `cache_position` in generate, and do +=1.
971 # same goes for position ids. Could also help with continued generation.
972 cache_position = torch.arange(
973 past_length,
974 past_length + position_ids.shape[-1],
975 device=position_ids.device,
976 )
977
978 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
979 if inputs_embeds is not None and past_key_values is None:
980 model_inputs = {"inputs_embeds": inputs_embeds}
981 else:
982 # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
983 # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
984 # We could use `next_tokens` directly instead.
985 model_inputs = {"input_ids": input_ids.contiguous()}
986
987 model_inputs.update(
988 {
989 "position_ids": position_ids.contiguous(),
990 "cache_position": cache_position,
991 "past_key_values": past_key_values,
992 "use_cache": kwargs.get("use_cache"),
993 "attention_mask": attention_mask,
994 }
995 )
996 return model_inputs
997
998 @staticmethod
999 def _reorder_cache(past_key_values, beam_idx):
1000 reordered_past = ()
1001 for layer_past in past_key_values:
1002 reordered_past += (
1003 tuple(
1004 past_state.index_select(0, beam_idx.to(past_state.device))
1005 for past_state in layer_past
1006 ),
1007 )
1008 return reordered_past
1009