configuration_nemotron_h.py
12.6 KB · 262 lines · python Raw
1 # coding=utf-8
2 # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
3 # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 # http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16 """NemotronH model configuration"""
17
18 import re
19
20 from transformers.configuration_utils import PretrainedConfig
21 from transformers.utils import logging
22
23
24 logger = logging.get_logger(__name__)
25
26
27 class NemotronHConfig(PretrainedConfig):
28 r"""
29 This is the configuration class to store the configuration of a [`NemotronHModel`]. It is used to instantiate a
30 NemotronH model according to the specified arguments, defining the model architecture. Instantiating a configuration
31 with the defaults will yield a similar configuration to that of the NemotronH-v0.1 model.
32
33 [todo](todo)
34
35 Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36 documentation from [`PretrainedConfig`] for more information.
37
38
39 Args:
40 vocab_size (`int`, *optional*, defaults to 131072):
41 Vocabulary size of the NemotronH model. Defines the number of different tokens that can be represented by the
42 `inputs_ids` passed when calling [`NemotronHModel`]
43 tie_word_embeddings (`bool`, *optional*, defaults to `False`):
44 Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
45 model has a output word embedding layer.
46 hidden_size (`int`, *optional*, defaults to 4096):
47 Dimension of the hidden representations.
48 intermediate_size (`int`, *optional*, defaults to 21504):
49 Dimension of the MLP representations.
50 num_hidden_layers (`int`, *optional*, defaults to 52):
51 Number of hidden layers in the Transformer encoder.
52 hybrid_override_pattern (`str`, *optional*, defaults to `"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`):
53 The pattern of the hybrid model. The pattern is a string of characters where each character represents M: Mamba2, *: Attention, -: MLP
54 num_attention_heads (`int`, *optional*, defaults to 32):
55 Number of attention heads for each attention layer in the Transformer encoder.
56 head_dim (`int`, *optional*, defaults to 128):
57 Dimension of each attention head.
58 num_key_value_heads (`int`, *optional*, defaults to 8):
59 This is the number of key_value heads that should be used to implement Grouped Query Attention. If
60 `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
61 `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used.
62 mlp_hidden_act (`str`, *optional*, defaults to "relu2"):
63 The non-linear activation function in the MLP layers.
64 attention_bias (`bool`, *optional*, defaults to `False`):
65 Whether to use bias in attention layers.
66 mlp_bias (`bool`, *optional*, defaults to `False`):
67 Whether to use bias in MLP layers.
68 use_bias (`bool`, *optional*, defaults to `False`):
69 Whether to use bias in the model.
70 initializer_range (`float`, *optional*, defaults to 0.02):
71 The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
72 layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
73 The epsilon used by the layer normalization layers.
74 residual_in_fp32 (`bool`, *optional*, defaults to `False`):
75 Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model.
76 use_cache (`bool`, *optional*, defaults to `True`):
77 Whether or not the model should return the last key/values attentions (not used by all models). Only
78 relevant if `config.is_decoder=True`.
79 num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
80 Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
81 integer value, only last `num_logits_to_keep` logits will be calculated.
82 pad_token_id (`int`, *optional*, defaults to 0):
83 The id of the padding token.
84 bos_token_id (`int`, *optional*, defaults to 1):
85 The id of the "beginning-of-sequence" token.
86 eos_token_id (`int`, *optional*, defaults to 2):
87 The id of the "end-of-sequence" token.
88 sliding_window (`int`, *optional*, defaults to None):
89 Sliding window attention window size.
90 max_position_embeddings (`int`, *optional*, defaults to 4096):
91 The maximum sequence length that this model might ever be used with.
92 attention_dropout (`float`, *optional*, defaults to 0.0):
93 The dropout ratio for the attention probabilities.
94 hidden_dropout (`float`, *optional*, defaults to 0.0):
95 The dropout ratio for the hidden states.
96 use_mamba_kernels (`bool`, *optional*, defaults to `True`):
97 Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
98 `causal-conv1d` are installed, and the mamba modules are running on a CUDA device.
99 ssm_state_size (`int`, *optional*, defaults to 128):
100 The dimension of the mamba state space latents.
101 mamba_num_heads (`int`, *optional*, defaults to 128):
102 Number of heads in Mamba layers.
103 mamba_n_groups (`int`, *optional*, defaults to 8):
104 Number of groups in Mamba layers.
105 mamba_head_dim (`int`, *optional*, defaults to 64):
106 Dimension of each Mamba head.
107 mamba_d_conv (`int`, *optional*, defaults to 4):
108 The size of the mamba convolution kernel.
109 mamba_expand (`int`, *optional*, defaults to 2):
110 Expanding factor used to determine the mamba intermediate size.
111 mamba_hidden_act (`str`, *optional*, defaults to "silu"):
112 The non-linear activation function in the Mamba layers.
113 mamba_dt_min (`float`, *optional*, defaults to 0.001):
114 Minimum value for the time step in Mamba.
115 mamba_dt_max (`float`, *optional*, defaults to 0.1):
116 Maximum value for the time step in Mamba.
117 mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))):
118 Limits for the time step in Mamba.
119 mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4):
120 Floor value for time step initialization in Mamba.
121 mamba_conv_bias (`bool`, *optional*, defaults to `True`):
122 Whether to use bias in the convolution layer of the mamba mixer block.
123 mamba_proj_bias (`bool`, *optional*, defaults to `False`):
124 Whether to use bias in the input and output projections of the mamba mixer block.
125 mamba_chunk_size (`int`, *optional*, defaults to 256):
126 Size of chunks for Mamba processing.
127 rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
128 Whether to rescale the pre-normalization residual connections.
129 """
130
131 model_type = "nemotron_h"
132 keys_to_ignore_at_inference = ["past_key_values"]
133
134 def __init__(
135 self,
136 vocab_size=131072,
137 tie_word_embeddings=False,
138 hidden_size=4096,
139 intermediate_size=21504,
140 num_hidden_layers=52,
141 hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
142 num_attention_heads=32,
143 head_dim=128,
144 num_key_value_heads=8, # nemo: num_query_groups
145 mlp_hidden_act="relu2",
146 attention_bias=False,
147 mlp_bias=False,
148 use_bias=False,
149 initializer_range=0.02, # nemo: init_method_std
150 layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon
151 residual_in_fp32=False, # Megatron Core default value
152 use_cache=True,
153 num_logits_to_keep=1,
154 pad_token_id=0,
155 bos_token_id=1,
156 eos_token_id=2,
157 sliding_window=None,
158 max_position_embeddings=4096,
159 attention_dropout=0.0,
160 hidden_dropout=0.0, # * ADDED
161 use_mamba_kernels=True,
162 ssm_state_size=128, # mamba_state_size
163 mamba_num_heads=128,
164 mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads
165 mamba_head_dim=64,
166 mamba_d_conv=4,
167 mamba_expand=2,
168 mamba_hidden_act="silu",
169 mamba_dt_min=0.001,
170 mamba_dt_max=0.1,
171 mamba_dt_limit=(0.0, float("inf")),
172 mamba_dt_init_floor=1e-4,
173 mamba_conv_bias=True,
174 mamba_proj_bias=False,
175 mamba_chunk_size=128,
176 rescale_prenorm_residual=True,
177 n_routed_experts=8,
178 n_shared_experts=1,
179 moe_intermediate_size=7688,
180 moe_shared_expert_intermediate_size=7688,
181 num_experts_per_tok=2,
182 routed_scaling_factor=1.0,
183 n_group=1,
184 topk_group=1,
185 norm_topk_prob=True,
186 **kwargs,
187 ):
188 self.vocab_size = vocab_size
189 self.tie_word_embeddings = tie_word_embeddings
190 self.hidden_size = hidden_size
191 self.intermediate_size = intermediate_size
192 self.num_hidden_layers = num_hidden_layers
193 self.hybrid_override_pattern = hybrid_override_pattern
194 self.num_attention_heads = num_attention_heads
195 self.head_dim = head_dim
196 self.sliding_window = sliding_window
197 self.max_position_embeddings = max_position_embeddings
198 self.attention_dropout = attention_dropout
199 self.hidden_dropout = hidden_dropout
200
201 # Validate hybrid_override_pattern
202 # M: Mamba2, *: Attention, -: MLP
203 assert len(self.hybrid_override_pattern) == self.num_hidden_layers, "hybrid_override_pattern must have the same length as num_hidden_layers"
204 assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), "hybrid_override_pattern must only contain characters 'M', '*', or '-'"
205
206 # for backward compatibility
207 if num_key_value_heads is None:
208 num_key_value_heads = num_attention_heads
209
210 self.num_key_value_heads = num_key_value_heads
211 self.mlp_hidden_act = mlp_hidden_act
212 self.attention_bias = attention_bias
213 self.mlp_bias = mlp_bias
214 self.use_bias = use_bias
215 self.initializer_range = initializer_range
216 self.layer_norm_epsilon = layer_norm_epsilon
217 self.residual_in_fp32 = residual_in_fp32
218
219 self.use_cache = use_cache
220 self.num_logits_to_keep = num_logits_to_keep
221
222 self.use_mamba_kernels = use_mamba_kernels
223 self.n_groups = mamba_n_groups
224 self.mamba_head_dim = mamba_head_dim
225 self.ssm_state_size = ssm_state_size
226 self.mamba_num_heads = mamba_num_heads
227 self.conv_kernel = mamba_d_conv
228 self.expand = mamba_expand
229 self.mamba_hidden_act = mamba_hidden_act
230 self.time_step_min = mamba_dt_min
231 self.time_step_max = mamba_dt_max
232 self.time_step_limit = mamba_dt_limit
233 self.time_step_floor = mamba_dt_init_floor
234 self.use_conv_bias = mamba_conv_bias
235 self.mamba_proj_bias = mamba_proj_bias
236 self.chunk_size = mamba_chunk_size
237 self.rescale_prenorm_residual = rescale_prenorm_residual
238 self.n_routed_experts = n_routed_experts
239 self.n_shared_experts = n_shared_experts
240 self.moe_intermediate_size = moe_intermediate_size
241 self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size
242 self.num_experts_per_tok = num_experts_per_tok
243 self.routed_scaling_factor = routed_scaling_factor
244 self.n_group = n_group
245 self.topk_group = topk_group
246 self.norm_topk_prob = norm_topk_prob
247
248 super().__init__(
249 pad_token_id=pad_token_id,
250 bos_token_id=bos_token_id,
251 eos_token_id=eos_token_id,
252 tie_word_embeddings=tie_word_embeddings,
253 **kwargs,
254 )
255
256 @property
257 def layers_block_type(self):
258 return [
259 "mamba" if self.hybrid_override_pattern[i] == "M" else
260 "attention" if self.hybrid_override_pattern[i] == "*" else
261 "mlp" if self.hybrid_override_pattern[i] == "-" else "moe"
262 for i in range(self.num_hidden_layers)]