modeling_nemotron_h.py
81.8 KB · 1741 lines · python Raw
1 # coding=utf-8
2 # Copyright 2024 HuggingFace Inc. team.
3 # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 # http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16 """PyTorch NemotronH model."""
17
18 import math
19 from dataclasses import dataclass
20 from typing import Any, Dict, Optional, Tuple, Union
21
22 import torch
23 import torch.utils.checkpoint
24 from torch import nn
25 from torch.nn import CrossEntropyLoss
26 import torch.nn.functional as F
27
28 from transformers.activations import ACT2FN
29 from transformers.cache_utils import DynamicCache # we need __iter__ and __len__ of pkv
30 from transformers.generation import GenerationMixin
31 from transformers.modeling_attn_mask_utils import (
32 AttentionMaskConverter,
33 )
34 from transformers.modeling_utils import PreTrainedModel
35 from transformers.utils import (
36 ModelOutput,
37 add_code_sample_docstrings,
38 add_start_docstrings,
39 add_start_docstrings_to_model_forward,
40 logging,
41 )
42 from transformers.utils.import_utils import (
43 is_causal_conv1d_available,
44 is_flash_attn_2_available,
45 is_flash_attn_greater_or_equal_2_10,
46 is_mamba_2_ssm_available,
47 )
48 from .configuration_nemotron_h import NemotronHConfig
49
50
51 logger = logging.get_logger(__name__)
52
53
54 # Copied from transformers.models.mamba.modeling_mamba2.modeling_mamba2.py with MAMBA2->NEMOTRONH,Mamba2->NemotronH
55 # For Mamba2 components Mamba2->NemotronHMamba2
56 if is_mamba_2_ssm_available():
57 from mamba_ssm.ops.triton.selective_state_update import selective_state_update
58 from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
59 else:
60 mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None
61
62 try:
63 #from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
64 from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn
65 except ImportError:
66 raise ImportError("mamba-ssm is required by the Mamba model but cannot be imported")
67
68 if is_causal_conv1d_available():
69 from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
70 else:
71 causal_conv1d_update, causal_conv1d_fn = None, None
72
73 if is_flash_attn_2_available():
74 from transformers.modeling_flash_attention_utils import _flash_attention_forward
75
76 is_fast_path_available = all(
77 (
78 selective_state_update,
79 mamba_chunk_scan_combined,
80 mamba_split_conv1d_scan_combined,
81 causal_conv1d_fn,
82 causal_conv1d_update,
83 )
84 )
85
86
87 _CHECKPOINT_FOR_DOC = "nvidia/Nemotron-H-56B-Base-8K"
88 _CONFIG_FOR_DOC = "NemotronHConfig"
89
90
91 # Helper methods for segment sum computation
92
93
94 def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
95 """
96 Padding x tensor with `pad_size` on the seq_len dim (dim=1)
97
98 Assumes that we only have tensors of either size 4 or 3
99 """
100 pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
101
102 return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
103
104
105 def reshape_into_chunks(input_tensor, pad_size, chunk_size):
106 """
107 Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
108 simultaneously splitting it into chunk sequences.
109
110 Assumes that we only have tensors of either size 4 or 3
111 """
112 # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
113 input_tensor = pad_tensor_by_size(input_tensor, pad_size)
114
115 if len(input_tensor.shape) == 3:
116 # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
117 return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
118 else:
119 # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
120 return input_tensor.reshape(
121 input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
122 )
123
124
125 def segment_sum(input_tensor):
126 """
127 More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
128 """
129 chunk_size = input_tensor.size(-1)
130 # 1. expand input tensor to have an additional dimension and repeat along that dimension
131 # [..., chunk_size] -> [..., chunk_size, chunk_size]
132 input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
133 # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
134 mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
135 input_tensor = input_tensor.masked_fill(~mask, 0)
136 # 3. compute actual cumsum
137 tensor_segsum = torch.cumsum(input_tensor, dim=-2)
138
139 # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
140 mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
141 tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
142 return tensor_segsum
143
144
145 def apply_mask_to_padding_states(hidden_states, attention_mask):
146 """
147 Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
148 """
149 if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
150 dtype = hidden_states.dtype
151 hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
152
153 return hidden_states
154
155 # Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py
156 class HybridMambaAttentionDynamicCache(DynamicCache):
157 """
158 A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
159 (which has a constant shape regardless of seq_len).
160
161 This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
162 and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
163 For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
164 while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
165 For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
166 while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
167 and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
168 """
169
170 def __init__(self, config, batch_size, dtype=torch.float16, device=None):
171 super().__init__()
172 self.dtype = dtype
173 self.hybrid_override_pattern = config.hybrid_override_pattern
174 self.has_previous_state = False # only used by mamba
175 intermediate_size = config.mamba_num_heads * config.mamba_head_dim
176 ssm_state_size = config.ssm_state_size
177 conv_kernel_size = config.conv_kernel
178 self.conv_states = []
179 self.ssm_states = []
180 self.transformer_layers = []
181 for i in range(config.num_hidden_layers):
182 if self.hybrid_override_pattern[i] == "M":
183 # Mamba layer
184 self.conv_states += [
185 torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
186 ]
187 self.ssm_states += [
188 torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
189 ]
190 else:
191 # Attention or MLP layer
192 self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
193 self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
194 self.transformer_layers.append(i)
195
196 self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
197 self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
198
199 def update(
200 self,
201 key_states: torch.Tensor,
202 value_states: torch.Tensor,
203 layer_idx: int,
204 cache_kwargs: Optional[Dict[str, Any]] = None,
205 ) -> Tuple[torch.Tensor, torch.Tensor]:
206 # Update the cache
207 if self.key_cache[layer_idx].shape[-1] == 0:
208 self.key_cache[layer_idx] = key_states
209 self.value_cache[layer_idx] = value_states
210 else:
211 self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
212 self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
213
214 return self.key_cache[layer_idx], self.value_cache[layer_idx]
215
216 def reorder_cache(self, beam_idx: torch.LongTensor):
217 """Reorders the cache for beam search, given the selected beam indices."""
218 for layer_idx in range(len(self.key_cache)):
219 device = self.key_cache[layer_idx].device
220 self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
221 device = self.value_cache[layer_idx].device
222 self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
223
224 device = self.conv_states[layer_idx].device
225 self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
226 device = self.ssm_states[layer_idx].device
227 self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
228
229 def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
230 """Returns the sequence length of the cached states. A layer index can be optionally passed."""
231 # take any layer that contains cache and not empty tensor
232 layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
233 if len(self.key_cache) <= layer_idx:
234 return 0
235 return self.key_cache[layer_idx].shape[-2]
236
237 def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
238 raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
239
240 @classmethod
241 def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
242 raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
243
244 # Copied from modeling_mamba2.py
245 def update_conv_state(
246 self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
247 ) -> torch.Tensor:
248 if cache_init:
249 self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device)
250 else:
251 self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
252 self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
253 return self.conv_states[layer_idx]
254
255 def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
256 self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
257 return self.ssm_states[layer_idx]
258
259 def reset(self):
260 self.conv_states.zero_()
261 self.ssm_states.zero_()
262
263 class MambaRMSNormGated(torch.nn.Module):
264 def __init__(self, hidden_size, group_size, eps=1e-5):
265 super().__init__()
266 self.weight = nn.Parameter(torch.ones(hidden_size))
267 self.variance_epsilon = eps
268 self.group_size = group_size
269
270 # jan28b version
271 def forward(self, hidden_states, gate=None):
272 return rmsnorm_fn(x=hidden_states,
273 weight=self.weight,
274 bias=None, # No bias
275 z=gate,
276 eps=self.variance_epsilon,
277 group_size=self.group_size,
278 norm_before_gate=False
279 )
280
281 class NemotronHMamba2Mixer(nn.Module):
282 """
283 Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
284 A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
285 ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
286 and is why Mamba is called **selective** state spaces)
287 """
288
289 def __init__(self, config: NemotronHConfig, layer_idx: int):
290 super().__init__()
291 self.num_heads = config.mamba_num_heads
292 self.hidden_size = config.hidden_size
293 self.ssm_state_size = config.ssm_state_size
294 self.conv_kernel_size = config.conv_kernel
295 self.intermediate_size = config.mamba_num_heads * config.mamba_head_dim
296 self.layer_idx = layer_idx
297 self.use_conv_bias = config.use_conv_bias
298 self.activation = config.mamba_hidden_act
299 self.act = ACT2FN[config.mamba_hidden_act]
300
301 self.layer_norm_epsilon = config.layer_norm_epsilon
302
303 self.n_groups = config.n_groups
304 self.head_dim = config.mamba_head_dim
305 self.chunk_size = config.chunk_size
306
307 self.time_step_limit = config.time_step_limit
308 self.time_step_min = config.time_step_min
309 self.time_step_max = config.time_step_max
310
311 self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
312 self.conv1d = nn.Conv1d(
313 in_channels=self.conv_dim,
314 out_channels=self.conv_dim,
315 bias=config.use_conv_bias,
316 kernel_size=config.conv_kernel,
317 groups=self.conv_dim,
318 padding=config.conv_kernel - 1,
319 )
320
321 # projection of the input hidden states
322 projection_size = self.intermediate_size + self.conv_dim + self.num_heads
323 self.in_proj = nn.Linear(
324 self.hidden_size,
325 projection_size,
326 bias=config.use_bias,
327 )
328 # selective projection used to make dt, B and C input dependant
329
330 # time step projection (discretization)
331 # instantiate once and copy inv_dt in init_weights of PretrainedModel
332 self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
333
334 # S4D real initialization. These are not discretized!
335 # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
336 A = torch.arange(1, self.num_heads + 1)
337 self.A_log = nn.Parameter(torch.log(A))
338 self.A_log._no_weight_decay = True
339 self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon, group_size=self.intermediate_size // self.n_groups)
340 self.D = nn.Parameter(torch.ones(self.num_heads))
341 self.D._no_weight_decay = True
342
343 self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
344 self.use_bias = config.use_bias
345
346 if not is_fast_path_available:
347 logger.warning_once(
348 "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
349 " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
350 " https://github.com/Dao-AILab/causal-conv1d"
351 )
352
353 def cuda_kernels_forward(
354 self,
355 hidden_states: torch.Tensor,
356 cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
357 cache_position: Optional[torch.LongTensor] = None,
358 attention_mask: Optional[torch.Tensor] = None,
359 ):
360 # 1. Gated MLP's linear projection
361 hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
362 projected_states = self.in_proj(hidden_states)
363
364 # Set up dimensions for reshapes later
365 batch_size, seq_len, _ = hidden_states.shape
366 groups_time_state_size = self.n_groups * self.ssm_state_size
367 d_mlp = (
368 projected_states.shape[-1]
369 - 2 * self.intermediate_size
370 - 2 * self.n_groups * self.ssm_state_size
371 - self.num_heads
372 ) // 2
373
374 # Single step calculations via cache
375 if cache_params is not None and cache_position is not None and cache_position[0] > 0:
376 _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
377 [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
378 )
379
380 # 2. Convolution sequence transformation
381 hidden_states_B_C = causal_conv1d_update(
382 hidden_states_B_C,
383 cache_params.conv_states[self.layer_idx],
384 self.conv1d.weight.squeeze(1),
385 self.conv1d.bias,
386 self.activation,
387 )
388
389 hidden_states, B, C = torch.split(
390 hidden_states_B_C,
391 [self.intermediate_size, groups_time_state_size, groups_time_state_size],
392 dim=-1,
393 )
394
395 # 3. SSM transformation
396 A = -torch.exp(self.A_log.float()) # (nheads,)
397 A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
398 dt = dt[:, :, None].expand(-1, -1, self.head_dim)
399 dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
400 D = self.D[:, None, ...].expand(-1, self.head_dim)
401 B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
402 C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
403 hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
404 hidden_states = selective_state_update(
405 cache_params.ssm_states[self.layer_idx],
406 hidden_states_reshaped,
407 dt,
408 A,
409 B,
410 C,
411 D,
412 z=None,
413 dt_bias=dt_bias,
414 dt_softplus=True,
415 )
416 hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
417 hidden_states = self.norm(hidden_states, gate)
418
419 # 4. Final linear projection
420 out = self.out_proj(hidden_states)[:, None, ...]
421
422 # Fused calculations or step by step if no initialized cache is found
423 else:
424 A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
425 dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
426
427 # 2-4. Fused kernel for conv1d, SSM, and the final projection
428 if self.training and cache_params is None:
429 out = mamba_split_conv1d_scan_combined(
430 projected_states,
431 self.conv1d.weight.squeeze(1),
432 self.conv1d.bias,
433 self.dt_bias,
434 A,
435 D=self.D,
436 chunk_size=self.chunk_size,
437 seq_idx=None, # was seq_idx
438 activation=self.activation,
439 rmsnorm_weight=self.norm.weight,
440 rmsnorm_eps=self.norm.variance_epsilon,
441 outproj_weight=self.out_proj.weight,
442 outproj_bias=self.out_proj.bias,
443 headdim=self.head_dim,
444 ngroups=self.n_groups,
445 norm_before_gate=False,
446 return_final_states=False,
447 **dt_limit_kwargs,
448 )
449
450 else:
451 _, _, gate, hidden_states_B_C, dt = projected_states.split(
452 [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
453 )
454
455 # 2. Convolution sequence transformation
456 # Init cache
457 if cache_params is not None:
458 hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
459 conv_states = nn.functional.pad(
460 hidden_states_B_C_transposed,
461 (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
462 )
463 cache_params.update_conv_state(
464 layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True
465 )
466
467 if self.activation not in ["silu", "swish"]:
468 hidden_states_B_C = self.act(
469 self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
470 )
471 else:
472 hidden_states_B_C = causal_conv1d_fn(
473 x=hidden_states_B_C.transpose(1, 2),
474 weight=self.conv1d.weight.squeeze(1),
475 bias=self.conv1d.bias,
476 activation=self.activation,
477 ).transpose(1, 2)
478 hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
479 hidden_states, B, C = torch.split(
480 hidden_states_B_C,
481 [self.intermediate_size, groups_time_state_size, groups_time_state_size],
482 dim=-1,
483 )
484
485 # 3. SSM transformation
486 scan_output, ssm_state = mamba_chunk_scan_combined(
487 hidden_states.view(batch_size, seq_len, -1, self.head_dim),
488 dt,
489 A,
490 B.view(batch_size, seq_len, self.n_groups, -1),
491 C.view(batch_size, seq_len, self.n_groups, -1),
492 chunk_size=self.chunk_size,
493 D=self.D,
494 z=None,
495 seq_idx=None,
496 return_final_states=True,
497 dt_bias=self.dt_bias,
498 dt_softplus=True,
499 **dt_limit_kwargs,
500 )
501
502 # Init cache
503 if ssm_state is not None and cache_params is not None:
504 cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
505
506 scan_output = scan_output.view(batch_size, seq_len, -1)
507
508 # Multiply "gate" branch and apply extra normalization layer
509 scan_output = self.norm(scan_output, gate)
510
511 # 4. Final linear projection
512 out = self.out_proj(scan_output)
513 return out
514
515 # fmt: off
516 def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
517 batch_size, seq_len, _ = input_states.shape
518 dtype = input_states.dtype
519
520 # 1. Gated MLP's linear projection
521 input_states = apply_mask_to_padding_states(input_states, attention_mask)
522 projected_states = self.in_proj(input_states)
523 d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2
524 _, _, gate, hidden_states_B_C, dt = projected_states.split(
525 [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
526 )
527
528 # 2. Convolution sequence transformation
529 if cache_params is not None and cache_position is not None and cache_position[0] > 0:
530 cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False)
531
532 # We need to guarantee that anything regarding the cache is on the same device
533 conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device)
534
535 hidden_states_B_C = torch.sum(
536 conv_states * self.conv1d.weight.squeeze(1), dim=-1
537 )
538 if self.use_conv_bias:
539 hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
540 hidden_states_B_C = self.act(hidden_states_B_C)
541 else:
542 # Init cache
543 if cache_params is not None:
544 hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
545 conv_states = nn.functional.pad(
546 hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0)
547 )
548 cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True)
549
550 hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2))
551
552 hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
553 hidden_states, B, C = torch.split(
554 hidden_states_B_C,
555 [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size],
556 dim=-1
557 )
558
559 # 3. SSM transformation
560 A = -torch.exp(self.A_log.float()) # [num_heads]
561 if cache_params is not None and cache_position is not None and cache_position[0] > 0:
562 # We need to guarantee that anything regarding the cache is on the same device
563 cache_device = cache_params.ssm_states.device
564
565 # Note: there is no need to pad parameter matrices here, as there is just one new token
566 # for batched generation
567 dt = dt[:, 0, :][:, None, ...]
568 dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
569 # [num_heads] -> [num_heads, head_dim]
570 dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
571
572 dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
573 dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
574 A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
575 # [bsz, num_heads, head_dim, state_size]
576 dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)
577
578 # Discretize B
579 # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
580 # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
581 B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
582 B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
583 B = B.reshape(batch_size, -1, B.shape[-1])
584 # [bsz, num_heads, head_dim, state_size]
585 dB = dt[..., None] * B[..., None, :]
586
587 # Discretize x into dB
588 # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
589 hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
590 dBx = (dB * hidden_states[..., None]).to(device=cache_device)
591
592 # State calculation
593 cache_params.update_ssm_state(
594 layer_idx=self.layer_idx,
595 new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx
596 )
597
598 # Subsequent output
599 # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
600 C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
601 C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
602 C = C.reshape(batch_size, -1, C.shape[-1])
603 # [bsz, num_heads, head_dim]
604
605 ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n]
606 # Reshape ssm_states to merge the first two dimensions
607 ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
608 C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
609 y = torch.bmm(ssm_states_reshaped, C_reshaped)
610 y = y.view(batch_size, self.num_heads, self.head_dim)
611
612 # D skip connection
613 # [num_heads] -> [num_heads, head_dim]
614 D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
615 y = (y + hidden_states * D).to(y.dtype)
616
617 # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
618 y = y.reshape(batch_size, -1)[:, None, ...]
619 else:
620 # begin ssd naive implementation without einsums
621 dt = nn.functional.softplus(dt + self.dt_bias)
622 dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
623 hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
624 B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
625 C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
626 B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
627 C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
628 pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
629
630 D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
631
632 # Discretize x and A
633 hidden_states = hidden_states * dt[..., None]
634 A = A.to(hidden_states.dtype) * dt
635
636 # Rearrange into blocks/chunks
637 hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
638
639 # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
640 A = A.permute(0, 3, 1, 2)
641 A_cumsum = torch.cumsum(A, dim=-1)
642
643 # 1. Compute the output for each intra-chunk (diagonal blocks)
644 # This is the analog of a causal mask
645 L = torch.exp(segment_sum(A))
646
647 # Contraction of C and B to get G (attention-weights like)
648 G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n)
649 G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
650
651 # Compute M, equivalent to applying attention mask to weights
652 M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
653 M = M_intermediate.sum(dim=-1)
654
655 # Compute Y_diag (apply to values)
656 Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)
657
658 # 2. Compute the state for each intra-chunk
659 # (right term of low-rank factorization of off-diagonal blocks; B terms)
660 decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
661 B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
662 states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)
663
664 # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
665 # (middle term of factorization of off-diag blocks; A terms)
666 if cache_params is not None and cache_position is not None and cache_position[0] > 0:
667 previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device)
668 else:
669 previous_states = torch.zeros_like(states[:, :1])
670 states = torch.cat([previous_states, states], dim=1)
671 decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
672 decay_chunk = decay_chunk.transpose(1, 3)
673 new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
674 states, ssm_state = new_states[:, :-1], new_states[:, -1]
675
676 # 4. Compute state -> output conversion per chunk
677 # (left term of low-rank factorization of off-diagonal blocks; C terms)
678 state_decay_out = torch.exp(A_cumsum)
679 C_times_states = (C[..., None, :] * states[:, :, None, ...])
680 state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
681 Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
682
683 # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
684 y = Y_diag + Y_off
685 # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
686 y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
687
688 y = y + D_residual
689 # Cutting off padded chunks
690 if pad_size > 0:
691 y = y[:, :seq_len, :, :]
692 y = y.reshape(batch_size, seq_len, -1)
693
694 # Init cache
695 if ssm_state is not None and cache_params is not None:
696 cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
697
698 scan_output = self.norm(y, gate)
699
700 # end ssd naive
701
702 # 4. Final linear projection
703 contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
704 return contextualized_states
705 # fmt: on
706
707 def forward(
708 self,
709 hidden_states,
710 cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
711 cache_position: Optional[torch.LongTensor] = None,
712 attention_mask: Optional[torch.Tensor] = None,
713 ):
714 if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
715 return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
716 dtype = hidden_states.dtype
717 if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
718 # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
719 hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
720
721 return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
722
723
724 class NemotronHRMSNorm(nn.Module):
725 def __init__(self, hidden_size, eps=1e-6):
726 """
727 NemotronHRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
728 """
729 super().__init__()
730 self.weight = nn.Parameter(torch.ones(hidden_size))
731 self.variance_epsilon = eps
732
733 def forward(self, hidden_states):
734 input_dtype = hidden_states.dtype
735 hidden_states = hidden_states.to(torch.float32)
736 variance = hidden_states.pow(2).mean(-1, keepdim=True)
737 hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
738 # Weights are in float32
739 return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
740
741 class NemotronHBlock(nn.Module):
742 def __init__(self, config, layer_idx):
743 super().__init__()
744 self.config = config
745 self.layer_idx = layer_idx
746 self.residual_in_fp32 = config.residual_in_fp32
747 self.norm = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
748
749 # M: Mamba2, *: Attention, -: MLP
750 self.block_type = config.layers_block_type[layer_idx]
751 if self.block_type == "mamba":
752 self.mixer = NemotronHMamba2Mixer(config, layer_idx=layer_idx)
753 elif self.block_type == "attention":
754 self.mixer = NEMOTRONH_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
755 elif self.block_type == "mlp":
756 self.mixer = NemotronHMLP(config, layer_idx=layer_idx)
757 elif self.block_type == "moe":
758 self.mixer = NemotronHMOE(config, layer_idx=layer_idx)
759 else:
760 raise ValueError(f"Invalid layer pattern {config.hybrid_override_pattern[layer_idx]}")
761
762 def forward(
763 self,
764 hidden_states,
765 cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
766 cache_position: Optional[torch.LongTensor] = None,
767 attention_mask: Optional[torch.Tensor] = None,
768 ):
769 with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)):
770 # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
771 residual = hidden_states
772 hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
773 if self.residual_in_fp32:
774 residual = residual.to(torch.float32)
775
776 if self.block_type == "mamba":
777 hidden_states = self.mixer(
778 hidden_states, cache_params=cache_params, cache_position=cache_position
779 )
780 elif self.block_type == "attention":
781 hidden_states = self.mixer(
782 hidden_states, cache_position=cache_position
783 )
784 hidden_states = hidden_states[0]
785 elif self.block_type in ["mlp", "moe"]:
786 hidden_states = self.mixer(
787 hidden_states
788 )
789 else:
790 raise ValueError(f"Invalid block_type: {self.block_type}")
791
792 hidden_states = residual + hidden_states
793 return hidden_states
794
795
796 # Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH
797 class NemotronHMLP(nn.Module):
798 def __init__(self, config, intermediate_size=None, layer_idx: Optional[int] = None):
799 super().__init__()
800 self.config = config
801 self.layer_idx = layer_idx
802 if layer_idx is None:
803 logger.warning_once(
804 f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
805 "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
806 "when creating this class."
807 )
808 self.hidden_size = config.hidden_size
809 self.intermediate_size = intermediate_size or config.intermediate_size
810 self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
811 self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
812 self.act_fn = ACT2FN[config.mlp_hidden_act]
813
814 def forward(self, x):
815 return self.down_proj(self.act_fn(self.up_proj(x)))
816
817
818 class NemotronHMOE(nn.Module):
819 def __init__(self, config, layer_idx: Optional[int] = None):
820 super().__init__()
821 self.config = config
822 self.experts = nn.ModuleList(
823 [
824 NemotronHMLP(config, intermediate_size=config.moe_intermediate_size, layer_idx=layer_idx)
825 for _ in range(config.n_routed_experts)
826 ]
827 )
828 self.gate = NemotronHTopkRouter(config)
829 self.shared_experts = NemotronHMLP(
830 config=config, intermediate_size=config.moe_shared_expert_intermediate_size, layer_idx=layer_idx
831 )
832
833 def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
834 r"""
835 CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
836 to not have to do a loop here (deepseek has 256 experts soooo yeah).
837 """
838 final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
839 expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
840 expert_mask = expert_mask.permute(2, 0, 1)
841
842 for expert_idx in range(len(self.experts)):
843 expert = self.experts[expert_idx]
844 mask = expert_mask[expert_idx]
845 token_indices, weight_indices = torch.where(mask)
846
847 if token_indices.numel() > 0:
848 expert_weights = topk_weights[token_indices, weight_indices]
849 expert_input = hidden_states[token_indices]
850 expert_output = expert(expert_input)
851 weighted_output = expert_output * expert_weights.unsqueeze(-1)
852 final_hidden_states.index_add_(0, token_indices, weighted_output)
853 else:
854 # Local empty expert: no-op compute that still marks params as used.
855 expert_dtype = expert.down_proj.weight.dtype
856 dummy_out = expert(torch.zeros_like(hidden_states[0]).unsqueeze(0).to(expert_dtype))
857 final_hidden_states = final_hidden_states + dummy_out
858
859 # in original deepseek, the output of the experts are gathered once we leave this module
860 # thus the moe module is itelsf an IsolatedParallel module
861 # and all expert are "local" meaning we shard but we don't gather
862 return final_hidden_states.type(hidden_states.dtype)
863
864 def forward(self, hidden_states):
865 residuals = hidden_states
866 orig_shape = hidden_states.shape
867 topk_indices, topk_weights = self.gate(hidden_states)
868 hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
869 hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
870 hidden_states = hidden_states + self.shared_experts(residuals)
871 return hidden_states
872
873
874 class NemotronHTopkRouter(nn.Module):
875 def __init__(self, config):
876 super().__init__()
877 self.config = config
878 self.top_k = config.num_experts_per_tok
879 self.n_routed_experts = config.n_routed_experts
880 self.routed_scaling_factor = config.routed_scaling_factor
881 self.n_group = config.n_group
882 self.topk_group = config.topk_group
883 self.norm_topk_prob = config.norm_topk_prob
884
885 self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size), dtype=torch.float32))
886 self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts, dtype=torch.float32))
887
888 @torch.no_grad()
889 def get_topk_indices(self, scores):
890 scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
891 group_scores = (
892 scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
893 .topk(2, dim=-1)[0]
894 .sum(dim=-1)
895 )
896 group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
897 group_mask = torch.zeros_like(group_scores)
898 group_mask.scatter_(1, group_idx, 1)
899 score_mask = (
900 group_mask.unsqueeze(-1)
901 .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
902 .reshape(-1, self.n_routed_experts)
903 )
904 scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
905 topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
906 return topk_indices
907
908 def forward(self, hidden_states):
909 hidden_states = hidden_states.view(-1, self.config.hidden_size)
910 router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
911 scores = router_logits.sigmoid()
912 topk_indices = self.get_topk_indices(scores)
913 topk_weights = scores.gather(1, topk_indices)
914 if self.norm_topk_prob:
915 denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
916 topk_weights /= denominator
917 topk_weights = topk_weights * self.routed_scaling_factor
918 return topk_indices, topk_weights
919
920 # Copied from transformers.models.llama.modeling_llama.repeat_kv
921 def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
922 """
923 This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
924 num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
925 """
926 batch, num_key_value_heads, slen, head_dim = hidden_states.shape
927 if n_rep == 1:
928 return hidden_states
929 hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
930 return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
931
932
933 class NemotronHAttention(nn.Module):
934 """Multi-headed attention from 'Attention Is All You Need' paper"""
935
936 def __init__(self, config: NemotronHConfig, layer_idx: Optional[int] = None):
937 super().__init__()
938 self.config = config
939 self.layer_idx = layer_idx
940 if layer_idx is None:
941 logger.warning_once(
942 f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
943 "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
944 "when creating this class."
945 )
946
947 self.attention_dropout = config.attention_dropout
948 self.hidden_size = config.hidden_size
949 self.num_heads = config.num_attention_heads
950 if hasattr(config, "head_dim") and config.head_dim is not None:
951 self.head_dim = config.head_dim
952 else:
953 self.head_dim = config.hidden_size // self.num_attention_heads
954 self.num_key_value_heads = config.num_key_value_heads
955 self.num_key_value_groups = self.num_heads // self.num_key_value_heads
956 self.max_position_embeddings = config.max_position_embeddings
957 self.is_causal = True
958
959 self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
960 self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
961 self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
962 self.o_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=config.attention_bias)
963
964 def forward(
965 self,
966 hidden_states: torch.Tensor,
967 # position_embeddings: Tuple[torch.Tensor, torch.Tensor], #TODO
968 attention_mask: Optional[torch.Tensor] = None,
969 position_ids: Optional[torch.LongTensor] = None,
970 past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
971 output_attentions: bool = False,
972 use_cache: bool = False,
973 cache_position: Optional[torch.LongTensor] = None,
974 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
975 bsz, q_len, _ = hidden_states.size()
976
977 query_states = self.q_proj(hidden_states)
978 key_states = self.k_proj(hidden_states)
979 value_states = self.v_proj(hidden_states)
980
981 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
982 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
983 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
984
985 if past_key_value is not None:
986 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
987
988 key_states = repeat_kv(key_states, self.num_key_value_groups)
989 value_states = repeat_kv(value_states, self.num_key_value_groups)
990
991 causal_mask = attention_mask
992 if attention_mask is not None: # no matter the length, we just slice it
993 causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
994
995 if query_states.device.type == "cuda" and attention_mask is not None:
996 query_states = query_states.contiguous()
997 key_states = key_states.contiguous()
998 value_states = value_states.contiguous()
999
1000 is_causal = True if causal_mask is None and q_len > 1 else False
1001
1002 attn_output = torch.nn.functional.scaled_dot_product_attention(
1003 query_states,
1004 key_states,
1005 value_states,
1006 attn_mask=causal_mask,
1007 dropout_p=self.attention_dropout if self.training else 0.0,
1008 is_causal=is_causal,
1009 )
1010 attn_output = attn_output.transpose(1, 2).contiguous()
1011 #attn_output = attn_output.view(bsz, q_len, self.hidden_size)
1012 attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim)
1013
1014 attn_output = self.o_proj(attn_output)
1015
1016 return attn_output, None, past_key_value
1017
1018
1019 # Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba
1020 #class JambaFlashAttention2(JambaAttention):
1021 class NemotronHFlashAttention2(NemotronHAttention):
1022 """
1023 Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays
1024 untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
1025 flash attention and deal with padding tokens in case the input contains any of them.
1026 """
1027 def __init__(self, *args, **kwargs):
1028 super().__init__(*args, **kwargs)
1029
1030 # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
1031 # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
1032 # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
1033 self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
1034
1035 def forward(
1036 self,
1037 hidden_states: torch.Tensor,
1038 attention_mask: Optional[torch.Tensor] = None,
1039 position_ids: Optional[torch.LongTensor] = None,
1040 past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
1041 output_attentions: bool = False,
1042 use_cache: bool = False,
1043 cache_position: Optional[torch.LongTensor] = None,
1044 **kwargs,
1045 ):
1046 bsz, q_len, _ = hidden_states.size()
1047
1048 query_states = self.q_proj(hidden_states)
1049 key_states = self.k_proj(hidden_states)
1050 value_states = self.v_proj(hidden_states)
1051
1052 # Flash attention requires the input to have the shape
1053 # batch_size x seq_length x head_dim x hidden_dim
1054 # therefore we just need to keep the original shape
1055 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
1056 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1057 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1058
1059 if past_key_value is not None:
1060 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
1061
1062 # repeat k/v heads if n_kv_heads < n_heads
1063 key_states = repeat_kv(key_states, self.num_key_value_groups)
1064 value_states = repeat_kv(value_states, self.num_key_value_groups)
1065 dropout_rate = 0.0 if not self.training else self.attention_dropout
1066
1067 # In PEFT, usually we cast the layer norms in float32 for training stability reasons
1068 # therefore the input hidden states gets silently casted in float32. Hence, we need
1069 # cast them back in float16 just to be sure everything works as expected.
1070 input_dtype = query_states.dtype
1071 if input_dtype == torch.float32:
1072 if torch.is_autocast_enabled():
1073 target_dtype = torch.get_autocast_gpu_dtype()
1074 # Handle the case where the model is quantized
1075 elif hasattr(self.config, "_pre_quantization_dtype"):
1076 target_dtype = self.config._pre_quantization_dtype
1077 else:
1078 target_dtype = self.q_proj.weight.dtype
1079
1080 logger.warning_once(
1081 f"The input hidden states seems to be silently casted in float32, this might be related to"
1082 f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
1083 f" {target_dtype}."
1084 )
1085
1086 query_states = query_states.to(target_dtype)
1087 key_states = key_states.to(target_dtype)
1088 value_states = value_states.to(target_dtype)
1089
1090 # Reashape to the expected shape for Flash Attention
1091 key_states = key_states.transpose(1, 2)
1092 value_states = value_states.transpose(1, 2)
1093
1094 attn_output = _flash_attention_forward(
1095 query_states,
1096 key_states,
1097 value_states,
1098 attention_mask,
1099 q_len,
1100 dropout=dropout_rate,
1101 sliding_window=getattr(self.config, "sliding_window", None),
1102 is_causal=self.is_causal,
1103 use_top_left_mask=self._flash_attn_uses_top_left_mask,
1104 )
1105
1106 #attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
1107 attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
1108 attn_output = self.o_proj(attn_output)
1109
1110 if not output_attentions:
1111 attn_weights = None
1112
1113 return attn_output, attn_weights, past_key_value
1114
1115
1116 # Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba
1117 #class JambaSdpaAttention(JambaAttention):
1118 class NemotronHSdpaAttention(NemotronHAttention):
1119 """
1120 Jamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
1121 `JambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
1122 SDPA API.
1123 """
1124
1125 # Adapted from NemotronHAttention.forward
1126 def forward(
1127 self,
1128 hidden_states: torch.Tensor,
1129 attention_mask: Optional[torch.Tensor] = None,
1130 position_ids: Optional[torch.LongTensor] = None,
1131 past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
1132 output_attentions: bool = False,
1133 use_cache: bool = False,
1134 cache_position: Optional[torch.LongTensor] = None,
1135 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1136 if output_attentions:
1137 # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
1138 logger.warning_once(
1139 "NemotronHModel is using NemotronHSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
1140 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
1141 )
1142 return super().forward(
1143 hidden_states=hidden_states,
1144 attention_mask=attention_mask,
1145 position_ids=position_ids,
1146 past_key_value=past_key_value,
1147 output_attentions=output_attentions,
1148 use_cache=use_cache,
1149 )
1150
1151 bsz, q_len, _ = hidden_states.size()
1152
1153 query_states = self.q_proj(hidden_states)
1154 key_states = self.k_proj(hidden_states)
1155 value_states = self.v_proj(hidden_states)
1156
1157 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1158 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1159 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1160
1161 if past_key_value is not None:
1162 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
1163
1164 key_states = repeat_kv(key_states, self.num_key_value_groups)
1165 value_states = repeat_kv(value_states, self.num_key_value_groups)
1166
1167 causal_mask = attention_mask
1168 if attention_mask is not None:
1169 causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
1170
1171 # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1172 # Reference: https://github.com/pytorch/pytorch/issues/112577.
1173 if query_states.device.type == "cuda" and attention_mask is not None:
1174 query_states = query_states.contiguous()
1175 key_states = key_states.contiguous()
1176 value_states = value_states.contiguous()
1177
1178 # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
1179 # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
1180 # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
1181 is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False
1182
1183 attn_output = torch.nn.functional.scaled_dot_product_attention(
1184 query_states,
1185 key_states,
1186 value_states,
1187 attn_mask=causal_mask,
1188 dropout_p=self.attention_dropout if self.training else 0.0,
1189 is_causal=is_causal,
1190 )
1191
1192 attn_output = attn_output.transpose(1, 2).contiguous()
1193 attn_output = attn_output.view(bsz, q_len, self.hidden_size)
1194
1195 attn_output = self.o_proj(attn_output)
1196
1197 return attn_output, None, past_key_value
1198
1199
1200 NEMOTRONH_ATTENTION_CLASSES = {
1201 "eager": NemotronHAttention,
1202 "flash_attention_2": NemotronHFlashAttention2,
1203 "sdpa": NemotronHSdpaAttention,
1204 }
1205
1206 # Copied from transformers.models.mamba.modeling_mamba2.Mamba2PreTrainedModel
1207 class NemotronHPreTrainedModel(PreTrainedModel):
1208 """
1209 An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1210 models.
1211 """
1212
1213 config_class = NemotronHConfig
1214 base_model_prefix = "backbone"
1215 _no_split_modules = ["NemotronHBlock"]
1216 supports_gradient_checkpointing = True
1217 _is_stateful = True
1218
1219 def _init_weights(self, module):
1220 """Initialize the weights."""
1221 if isinstance(module, NemotronHMamba2Mixer):
1222 if getattr(module.dt_bias, "_is_hf_initialized", False):
1223 return
1224 module.A_log._no_weight_decay = True
1225 module.D._no_weight_decay = True
1226
1227 dt = torch.exp(
1228 torch.rand(self.config.mamba_num_heads)
1229 * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
1230 + math.log(self.config.time_step_min)
1231 ).clamp(min=self.config.time_step_floor)
1232
1233 # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
1234 inv_dt = dt + torch.log(-torch.expm1(-dt))
1235 with torch.no_grad():
1236 module.dt_bias.copy_(inv_dt)
1237 module.dt_bias._no_reinit = True
1238
1239 if isinstance(module, nn.Linear):
1240 if module.bias is not None:
1241 if not getattr(module.bias, "_no_reinit", False):
1242 nn.init.zeros_(module.bias)
1243 elif isinstance(module, nn.Embedding):
1244 nn.init.normal_(module.weight, std=self.config.initializer_range)
1245
1246 # TODO: Check
1247 if self.config.rescale_prenorm_residual:
1248 # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
1249 # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
1250 # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
1251 # > -- GPT-2 :: https://openai.com/blog/better-language-models/
1252 #
1253 # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
1254 for name, p in module.named_parameters():
1255 if getattr(p, "_is_hf_initialized", False):
1256 continue
1257 if name in ["out_proj.weight"]:
1258 # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
1259 # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
1260 # We need to reinit p since this code could be called multiple times
1261 # Having just p *= scale would repeatedly scale it down
1262 nn.init.kaiming_uniform_(p, a=math.sqrt(5))
1263 with torch.no_grad():
1264 p /= math.sqrt(self.config.num_hidden_layers)
1265
1266
1267 @dataclass
1268 # Copied from transformers.models.mamba.modeling_mamba2.Mamba2Output with MAMBA2->NemotronH,Mamba2->NemotronH
1269 class NemotronHOutput(ModelOutput):
1270 """
1271 Class for the NemotronH model outputs.
1272
1273 Args:
1274 last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1275 Sequence of hidden-states at the output of the last layer of the model.
1276 cache_params (`HybridMambaAttentionDynamicCache`):
1277 The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
1278 avoid providing the old `input_ids`.
1279
1280 Includes both the State space model state matrices after the selective scan, and the Convolutional states
1281 hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1282 Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
1283 one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
1284
1285 Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
1286 """
1287
1288 last_hidden_state: Optional[torch.FloatTensor] = None
1289 cache_params: Optional[HybridMambaAttentionDynamicCache] = None
1290 hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1291 attentions: Optional[Tuple[torch.FloatTensor]] = None
1292
1293
1294 @dataclass
1295 # Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH
1296 class NemotronHCausalLMOutput(ModelOutput):
1297 """
1298 Base class for causal language model (or autoregressive) outputs.
1299
1300 Args:
1301 loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1302 Language modeling loss (for next-token prediction).
1303 logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1304 Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1305 cache_params (`HybridMambaAttentionDynamicCache`):
1306 The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
1307 avoid providing the old `input_ids`.
1308
1309 Includes both the State space model state matrices after the selective scan, and the Convolutional states
1310 hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1311 Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
1312 one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
1313
1314 Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
1315 """
1316
1317 loss: Optional[torch.FloatTensor] = None
1318 logits: Optional[torch.FloatTensor] = None
1319 cache_params: Optional[HybridMambaAttentionDynamicCache] = None
1320 hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1321 attentions: Optional[Tuple[torch.FloatTensor]] = None
1322
1323
1324 NEMOTRONH_START_DOCSTRING = r"""
1325
1326 This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1327 library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1328 etc.)
1329
1330 This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1331 Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1332 and behavior.
1333
1334 Parameters:
1335 config ([`NemotronHConfig`]): Model configuration class with all the parameters of the model.
1336 Initializing with a config file does not load the weights associated with the model, only the
1337 configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1338 """
1339
1340 NEMOTRONH_INPUTS_DOCSTRING = r"""
1341 Args:
1342 input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
1343 Indices of input sequence tokens in the vocabulary.
1344
1345 If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as
1346 `input_ids`.
1347
1348 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1349 [`PreTrainedTokenizer.__call__`] for details.
1350
1351 [What are input IDs?](../glossary#input-ids)
1352 inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1353 Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1354 is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1355 model's internal embedding lookup matrix.
1356 position_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1357 Indices of positions of each input sequence tokens in the position embeddings.
1358 cache_params (`HybridMambaAttentionDynamicCache`, *optional*):
1359 If passed along, the model uses the previous state in all the blocks (which will give the output for the
1360 `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
1361 use_cache (`bool`, *optional*):
1362 If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
1363 output_attentions (`bool`, *optional*):
1364 Whether or not to return the attentions tensors of all attention layers.
1365 output_hidden_states (`bool`, *optional*):
1366 Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1367 more detail.
1368 return_dict (`bool`, *optional*):
1369 Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1370 cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1371 The position of the current input in the cache. This is used to ensure that the cache is correctly updated.
1372 If `cache_params` is passed, `cache_position` should also be passed.
1373 attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1374 Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1375
1376 - 1 for tokens that are **not masked**,
1377 - 0 for tokens that are **masked**.
1378
1379 [What are attention masks?](../glossary#attention-mask)
1380 """
1381
1382
1383 @add_start_docstrings(
1384 "The bare NemotronH Model transformer outputting raw hidden-states without any specific head on top.",
1385 NEMOTRONH_START_DOCSTRING,
1386 )
1387 class NemotronHModel(NemotronHPreTrainedModel):
1388 def __init__(self, config):
1389 super().__init__(config)
1390
1391 self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
1392 self.layers = nn.ModuleList([NemotronHBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
1393
1394 self.gradient_checkpointing = False
1395 self.norm_f = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
1396 # Initialize weights and apply final processing
1397 self._register_load_state_dict_pre_hook(self.load_hook)
1398 self.post_init()
1399
1400 def load_hook(self, state_dict, prefix, *args):
1401 for k in state_dict:
1402 if "embedding." in k:
1403 state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
1404 break
1405
1406 def get_input_embeddings(self):
1407 return self.embeddings
1408
1409 def set_input_embeddings(self, new_embeddings):
1410 self.embeddings = new_embeddings
1411
1412 @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING)
1413 @add_code_sample_docstrings(
1414 checkpoint=_CHECKPOINT_FOR_DOC,
1415 output_type=NemotronHOutput,
1416 config_class=_CONFIG_FOR_DOC,
1417 )
1418 def forward(
1419 self,
1420 input_ids: Optional[torch.LongTensor] = None,
1421 inputs_embeds: Optional[torch.LongTensor] = None,
1422 position_ids: Optional[torch.LongTensor] = None,
1423 cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
1424 use_cache: Optional[bool] = None,
1425 output_attentions: Optional[bool] = None,
1426 output_hidden_states: Optional[bool] = None,
1427 return_dict: Optional[bool] = None,
1428 cache_position: Optional[torch.LongTensor] = None,
1429 attention_mask: Optional[torch.Tensor] = None,
1430 **kwargs,
1431 ) -> Union[Tuple, NemotronHOutput]:
1432 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1433 output_hidden_states = (
1434 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1435 )
1436 # use_cache = use_cache if use_cache is not None else self.config.use_cache
1437 use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
1438
1439 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1440
1441 if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
1442 raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1443
1444 if inputs_embeds is None:
1445 inputs_embeds = self.embeddings(input_ids)
1446
1447 if self.gradient_checkpointing and self.training and use_cache:
1448 logger.warning_once(
1449 "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1450 )
1451 use_cache = False
1452
1453 # From zamba_modeling.py
1454 if use_cache and cache_params is None:
1455 logger.warning_once(
1456 "NemotronH requires an initialized `NemotronHHybridDynamicCache` to return a cache. None was "
1457 "provided, so no cache will be returned."
1458 )
1459
1460 hidden_states = inputs_embeds
1461
1462 if cache_position is None:
1463 cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
1464 if position_ids is None:
1465 position_ids = cache_position.unsqueeze(0)
1466
1467 causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
1468 mamba_mask = self._update_mamba_mask(attention_mask, cache_position)
1469
1470 all_hidden_states = () if output_hidden_states else None
1471 all_self_attns = () if output_attentions else None
1472 # Until HERE
1473
1474 for layer_idx, mixer_block in enumerate(self.layers):
1475 # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
1476 if mixer_block.block_type == "mamba":
1477 layer_mask = mamba_mask
1478 elif mixer_block.block_type == "attention":
1479 layer_mask = causal_mask
1480 elif mixer_block.block_type in ["mlp", "moe"]:
1481 layer_mask = None
1482 else:
1483 raise ValueError(f"Invalid block_type: {self.block_type}")
1484
1485 if output_hidden_states:
1486 all_hidden_states += (hidden_states,)
1487
1488 if self.gradient_checkpointing and self.training:
1489 hidden_states = self._gradient_checkpointing_func(
1490 mixer_block.__call__, hidden_states, cache_params, cache_position, layer_mask
1491 )
1492 else:
1493 hidden_states = mixer_block(
1494 hidden_states,
1495 cache_params=cache_params,
1496 cache_position=cache_position,
1497 attention_mask=layer_mask,
1498 )
1499
1500 # TODO: Store attentions
1501 # if output_attentions:
1502 # if layer_outputs[1] is not None:
1503 # # append attentions only of attention layers. Mamba layers return `None` as the attention weights
1504 # all_self_attns += (layer_outputs[1],)
1505
1506 # TODO (Check): should it happen before the forward pass?
1507 # if output_hidden_states:
1508 # all_hidden_states = all_hidden_states + (hidden_states,)
1509
1510 hidden_states = self.norm_f(hidden_states)
1511
1512 if output_hidden_states:
1513 all_hidden_states = all_hidden_states + (hidden_states,)
1514
1515 if not return_dict:
1516 return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
1517
1518 return NemotronHOutput(
1519 last_hidden_state=hidden_states,
1520 cache_params=cache_params if use_cache else None,
1521 hidden_states=all_hidden_states,
1522 attentions=all_self_attns,
1523 )
1524
1525 # Copied from transformers.models.jamba.modeling_jamba.JambaModel._update_causal_mask
1526 def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
1527 if self.config._attn_implementation == "flash_attention_2":
1528 if attention_mask is not None and 0.0 in attention_mask:
1529 return attention_mask
1530 return None
1531
1532 dtype, device = input_tensor.dtype, input_tensor.device
1533 min_dtype = torch.finfo(dtype).min
1534 sequence_length = input_tensor.shape[1]
1535 target_length = cache_position[-1] + 1
1536
1537 causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1538 if sequence_length != 1:
1539 causal_mask = torch.triu(causal_mask, diagonal=1)
1540 causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1541 causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1542 if attention_mask is not None:
1543 causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1544 if attention_mask.dim() == 2:
1545 mask_length = attention_mask.shape[-1]
1546 padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1547 causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1548
1549 if (
1550 self.config._attn_implementation == "sdpa"
1551 and attention_mask is not None
1552 and attention_mask.device.type == "cuda"
1553 ):
1554 # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1555 # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1556 # Details: https://github.com/pytorch/pytorch/issues/110213
1557 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1558
1559 return causal_mask
1560
1561 def _update_mamba_mask(self, attention_mask, cache_position):
1562 """
1563 No need for zeroing states when
1564 1. Cached forward
1565 2. Attending to all inputs
1566 """
1567 mamba_mask = attention_mask
1568 if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
1569 mamba_mask = None
1570 return mamba_mask
1571
1572
1573 @add_start_docstrings(
1574 """
1575 The NEMOTRONH Model transformer with a language modeling head on top (linear layer with weights not tied to the input
1576 embeddings).
1577 """,
1578 NEMOTRONH_START_DOCSTRING,
1579 )
1580 class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
1581 _tied_weights_keys = ["lm_head.weight"]
1582
1583 def __init__(self, config):
1584 super().__init__(config)
1585 self.backbone = NemotronHModel(config)
1586 self.vocab_size = config.vocab_size
1587 self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1588
1589 # Initialize weights and apply final processing
1590 self.post_init()
1591
1592 def get_input_embeddings(self):
1593 return self.backbone.get_input_embeddings()
1594
1595 def set_input_embeddings(self, new_embeddings):
1596 return self.backbone.set_input_embeddings(new_embeddings)
1597
1598 def get_output_embeddings(self):
1599 return self.lm_head
1600
1601 def set_output_embeddings(self, new_embeddings):
1602 self.lm_head = new_embeddings
1603
1604 def get_decoder(self):
1605 return self.model
1606
1607 def set_decoder(self, decoder):
1608 self.model = decoder
1609
1610 def prepare_inputs_for_generation(
1611 self,
1612 input_ids,
1613 past_key_values=None,
1614 attention_mask=None,
1615 inputs_embeds=None,
1616 cache_position=None,
1617 position_ids=None,
1618 use_cache=True,
1619 **kwargs,
1620 ):
1621 # Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py
1622 # Overwitten -- uses `cache_params` as opposed to `past_key_values`
1623 empty_past_kv = past_key_values is None
1624
1625 # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1626 # Exception 1: when passing input_embeds, input_ids may be missing entries
1627 # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1628 # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
1629 # (we can't check exception 3 while compiling)
1630 if not empty_past_kv:
1631 if (
1632 inputs_embeds is not None # Exception 1
1633 or cache_position[-1] >= input_ids.shape[1] # Exception 3
1634 ):
1635 input_ids = input_ids[:, -cache_position.shape[0] :]
1636 elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1637 input_ids = input_ids[:, cache_position]
1638 else:
1639 past_key_values = HybridMambaAttentionDynamicCache(
1640 self.config, input_ids.shape[0], self.dtype, device=self.device
1641 )
1642
1643 if attention_mask is not None and position_ids is None:
1644 # create position_ids on the fly for batch generation
1645 position_ids = attention_mask.long().cumsum(-1) - 1
1646 position_ids.masked_fill_(attention_mask == 0, 1)
1647 if not empty_past_kv:
1648 position_ids = position_ids[:, -input_ids.shape[1] :]
1649
1650 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1651 if inputs_embeds is not None and empty_past_kv:
1652 model_inputs = {"inputs_embeds": inputs_embeds}
1653 else:
1654 model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
1655
1656 model_inputs.update(
1657 {
1658 "position_ids": position_ids,
1659 "past_key_values": past_key_values,
1660 "use_cache": use_cache,
1661 "attention_mask": attention_mask,
1662 "logits_to_keep": self.config.num_logits_to_keep,
1663 "cache_position": cache_position,
1664 }
1665 )
1666 return model_inputs
1667
1668 @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING)
1669 @add_code_sample_docstrings(
1670 checkpoint=_CHECKPOINT_FOR_DOC,
1671 output_type=NemotronHCausalLMOutput,
1672 config_class=_CONFIG_FOR_DOC,
1673 )
1674 def forward(
1675 self,
1676 input_ids: Optional[torch.LongTensor] = None,
1677 inputs_embeds: Optional[torch.FloatTensor] = None,
1678 position_ids: Optional[torch.LongTensor] = None,
1679 cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
1680 labels: Optional[torch.LongTensor] = None,
1681 output_attentions: Optional[bool] = None,
1682 output_hidden_states: Optional[bool] = None,
1683 return_dict: Optional[bool] = None,
1684 use_cache: Optional[bool] = None,
1685 cache_position: Optional[torch.Tensor] = None,
1686 attention_mask: Optional[torch.Tensor] = None,
1687 **kwargs, # for now we need this for generation
1688 ) -> Union[Tuple, NemotronHCausalLMOutput]:
1689 r"""
1690 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1691 Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1692 `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1693 are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1694 """
1695 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1696
1697 output_hidden_states = (
1698 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1699 )
1700 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1701
1702 nemotron_h_outputs = self.backbone(
1703 input_ids,
1704 cache_params=cache_params,
1705 inputs_embeds=inputs_embeds,
1706 output_attentions=output_attentions,
1707 output_hidden_states=output_hidden_states,
1708 return_dict=return_dict,
1709 use_cache=use_cache,
1710 cache_position=cache_position,
1711 attention_mask=attention_mask,
1712 )
1713 hidden_states = nemotron_h_outputs[0]
1714
1715 # TODO: Check zamba_modeling.py: https://github.com/huggingface/transformers/blob/d7188ba600e36d3fd191b12e19f1b3bb81a8404f/src/transformers/models/zamba/modeling_zamba.py#L1284C1-L1286C2
1716 #logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
1717 logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
1718
1719 loss = None
1720 if labels is not None:
1721 # move labels to correct device to enable model parallelism
1722 labels = labels.to(logits.device)
1723 # Shift so that tokens < n predict n
1724 shift_logits = logits[..., :-1, :].contiguous()
1725 shift_labels = labels[..., 1:].contiguous()
1726 # Flatten the tokens
1727 loss_fct = CrossEntropyLoss()
1728 loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1729
1730 if not return_dict:
1731 output = (logits,) + nemotron_h_outputs[1:]
1732 return ((loss,) + output) if loss is not None else output
1733
1734 return NemotronHCausalLMOutput(
1735 loss=loss,
1736 logits=logits,
1737 cache_params=nemotron_h_outputs.cache_params,
1738 hidden_states=nemotron_h_outputs.hidden_states,
1739 attentions=nemotron_h_outputs.attentions,
1740 )
1741