configuration_openelm.py
13.9 KB · 319 lines · python Raw
1 #
2 # For licensing see accompanying LICENSE file.
3 # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4 #
5
6 """Implements HF OpenELMConfig based on PretrainedConfig"""
7 from numbers import Number
8 from typing import List, Optional, Union
9
10 import numpy as np
11 from transformers import PretrainedConfig
12
13
14 def make_divisible(
15 v: Union[float, int],
16 divisor: Optional[int] = 8,
17 min_value: Optional[Union[float, int]] = None,
18 ) -> Union[float, int]:
19 """
20 This function is taken from the original tf repo.
21 It ensures that all layers have a channel number that is divisible by the divisor
22 It can be seen at:
23 https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
24
25 Args:
26 v: input value
27 divisor: default to 8
28 min_value: minimum divisor value
29 Returns:
30 new_v: new divisible value
31 """
32 if min_value is None:
33 min_value = divisor
34 new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
35 # Make sure that round down does not go down by more than 10%.
36 if new_v < 0.9 * v:
37 new_v += divisor
38 return new_v
39
40
41 def compute_heads(model_dim: int, head_dim: int) -> int:
42 """Compute the number of heads.
43
44 Args:
45 model_dim: Model dimension.
46 head_dim: Head dimension.
47
48 Returns:
49 An integer denoting number of heads in multi-head attention is returned.
50
51 Raises:
52 ValueError: if model dimension is not divisible by head dimension.
53 """
54 if model_dim % head_dim == 0:
55 return model_dim // head_dim
56 else:
57 raise ValueError(
58 f"Model dimension should be divisible by head dimension. Got: {model_dim} and {head_dim}."
59 )
60
61
62 OpenELM_CONFIGS = {
63 "OpenELM-270M": dict(
64 num_transformer_layers=16,
65 model_dim=1280,
66 head_dim=64,
67 num_gqa_groups=4,
68 normalize_qk_projections=True,
69 share_input_output_layers=True,
70 # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
71 ffn_multipliers=(0.5, 4.0),
72 qkv_multipliers=(0.5, 1.0),
73 ),
74 "OpenELM-450M": dict(
75 num_transformer_layers=20,
76 model_dim=1536,
77 head_dim=64,
78 num_gqa_groups=4,
79 normalize_qk_projections=True,
80 share_input_output_layers=True,
81 # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
82 ffn_multipliers=(0.5, 4.0),
83 qkv_multipliers=(0.5, 1.0),
84 ),
85 "OpenELM-1_1B": dict(
86 num_transformer_layers=28,
87 model_dim=2048,
88 head_dim=64,
89 num_gqa_groups=4,
90 normalize_qk_projections=True,
91 share_input_output_layers=True,
92 # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
93 ffn_multipliers=(0.5, 4.0),
94 qkv_multipliers=(0.5, 1.0),
95 ),
96 "OpenELM-3B": dict(
97 num_transformer_layers=36,
98 model_dim=3072,
99 head_dim=128,
100 num_gqa_groups=4,
101 normalize_qk_projections=True,
102 share_input_output_layers=True,
103 # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
104 ffn_multipliers=(0.5, 4.0),
105 qkv_multipliers=(0.5, 1.0),
106 ),
107 }
108
109
110 class OpenELMConfig(PretrainedConfig):
111 r"""
112 This is the configuration class to store the configuration of a [`OpenELMModel`]. It is used to instantiate an OpenELM model according to the specified arguments, defining the model architecture.
113
114 Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
115 documentation from [`PretrainedConfig`] for more information.
116
117 Args:
118 vocab_size (`int`, *optional*, defaults to 32000):
119 Vocabulary size of the OpenELM model.
120 max_context_length (`int`, *optional*, defaults to 2048):
121 Maximum number of input tokens.
122 num_transformer_layers (`int`, *optional*, defaults to 12):
123 Number of hidden layers in the Transformer decoder.
124 model_dim (`int`, *optional*, defaults to 2048):
125 Dimension of the hidden representations.
126 head_dim (`int`, *optional*, defaults to 128):
127 The attention head dimension.
128 qkv_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 1.0):
129 If the qkv_multipliers is a Number, then all attention layers have the same latent dimensions,
130 resulting in uniform allocation of parameters.
131 If the qkv_multipliers is a List of Number, then each attention layer have different latent dimensions
132 assuming qkv_multipliers[0] != qkv_multipliers[1]. This results in variable allocation of parameters in attention layer.
133 This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
134 num_query_heads (`Union[int, None]`, *optional*, defaults to None):
135 The number of query heads, computed from `compute_heads(model_dim=model_dim, head_dim=head_dim)`.
136 num_gqa_groups (`int`, *optional*, defaults to 1):
137 This variable allows to switch between multi-head attention, group query attention, and multi-query attention.
138 When num_gqa_groups == 1, then it is multi-head attention.
139 When 1 < num_gqa_groups < num_heads and num_heads is divisible by num_gqa_groups, then it is group query attention
140 When num_gqa_groups == num_heads, then it is multi-query attention
141 ffn_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 4.0):
142 Feed-forward network (FFN) multipliers.
143 If the ffn_multipliers is a Number, then all FFN layers have the same latent dimensions,
144 resulting in uniform allocation of parameters.
145 If the ffn_multipliers is a List of Number, then each FFN layer have different latent dimensions
146 assuming ffn_multipliers[0] != ffn_multipliers[1]. This results in variable allocation of parameters in FFN layer.
147 This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
148 ffn_with_glu (`bool`, *optional*, defaults to True):
149 Whether to use FFN with Gated Linear Unit (GLU)
150 ffn_dim_divisor (`int`, *optional*, defaults to 256):
151 The ffn layer dimension divisor.
152 activation_fn_name (`str` or `function`, *optional*, defaults to `"swish"`):
153 The non-linear activation function (function or string) in the decoder.
154 normalization_layer_name (`str` or `function`, *optional*, defaults to `"rms_norm"`):
155 Type of normalization layer.
156 normalize_qk_projections (`bool`, *optional*, defaults to False):
157 Whether to normalize queries and keys after projections
158 share_input_output_layers (`bool`, *optional*, defaults to False):
159 Whether to share the embedding between input and output linear layer
160 rope_freq_constant (`int`, *optional*, defaults to 10000):
161 The base period of the RoPE embeddings.
162 rope_max_length (`int`, *optional*, defaults to 4096):
163 That rope_max_length is set to twice of max_context_length.
164 This allows flexibility in token lengths during training or fine-tuning.
165 initializer_range (`float`, *optional*, defaults to 0.02):
166 The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
167 use_cache (`bool`, *optional*, defaults to `True`):
168 Whether or not the model should return the last key/values attentions (not used by all models). Only
169 relevant if `config.is_decoder=True`.
170 bos_token_id (`int`, *optional*, defaults to 2):
171 Beginning of stream token id.
172 eos_token_id (`int`, *optional*, defaults to 1):
173 End of stream token id.
174 """
175
176 model_type = "openelm"
177
178 def __init__(
179 self,
180 vocab_size: int = 32000,
181 max_context_length: int = 2048,
182 num_transformer_layers: int = 12,
183 model_dim: int = 2048,
184 head_dim: int = 128,
185 qkv_multipliers: Union[Number, List[Number]] = 1.0,
186 num_query_heads: Union[int, None] = None,
187 num_gqa_groups: int = 1,
188 ffn_multipliers: Union[Number, List[Number]] = 4.0,
189 ffn_with_glu: bool = True,
190 ffn_dim_divisor: int = 256,
191 activation_fn_name: str = "swish",
192 normalization_layer_name: str = "rms_norm",
193 normalize_qk_projections: bool = False,
194 share_input_output_layers: bool = False,
195 rope_freq_constant: int = 10000,
196 rope_max_length: int = 4096,
197 initializer_range: float = 0.02,
198 use_cache: bool = True,
199 bos_token_id: int = 1,
200 eos_token_id: int = 2,
201 **kwargs,
202 ) -> None:
203 self.vocab_size = vocab_size
204 self.max_context_length = max_context_length
205 self.num_transformer_layers = num_transformer_layers
206 self.model_dim = model_dim
207 self.head_dim = head_dim
208 self.qkv_multipliers = qkv_multipliers
209 self.num_query_heads = num_query_heads
210 self.num_gqa_groups = num_gqa_groups
211 self.ffn_multipliers = ffn_multipliers
212 self.ffn_with_glu = ffn_with_glu
213 self.ffn_dim_divisor = ffn_dim_divisor
214 self.activation_fn_name = activation_fn_name
215 self.normalization_layer_name = normalization_layer_name
216 self.normalize_qk_projections = normalize_qk_projections
217 self.share_input_output_layers = share_input_output_layers
218 self.rope_freq_constant = rope_freq_constant
219 self.rope_max_length = rope_max_length
220 self.num_query_heads = (
221 compute_heads(model_dim=model_dim, head_dim=head_dim)
222 if num_query_heads is None
223 else num_query_heads
224 )
225 self.initializer_range = initializer_range
226
227 self.__post_init__()
228 super().__init__(
229 use_cache=use_cache,
230 bos_token_id=bos_token_id,
231 eos_token_id=eos_token_id,
232 **kwargs,
233 )
234
235 def __post_init__(self) -> None:
236 if self.num_gqa_groups is not None:
237 head_multiple_of = self.num_gqa_groups
238 else:
239 head_multiple_of = 2
240
241 if isinstance(self.qkv_multipliers, Number):
242 # All attention layers have the same latent dimensions, resulting in uniform allocation of parameters.
243 qkv_dim = make_divisible(
244 self.model_dim * self.qkv_multipliers,
245 divisor=self.head_dim * head_multiple_of,
246 )
247 query_dims = [int(qkv_dim)] * self.num_transformer_layers
248
249 elif (
250 isinstance(self.qkv_multipliers, (tuple, list))
251 and len(self.qkv_multipliers) == 2
252 ):
253 # Each attention layer have different latent dimensions assuming qkv_multipliers[0] != qkv_multipliers[1].
254 # This results in variable allocation of parameters in attention layer.
255 # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
256 qkv_multipliers = [
257 round(v, 2)
258 for v in np.linspace(
259 self.qkv_multipliers[0],
260 self.qkv_multipliers[1],
261 num=self.num_transformer_layers,
262 dtype=float,
263 )
264 ]
265 # Make sure that scaled model dimension is divisible by scaled head dimension.
266 query_dims = [
267 int(
268 make_divisible(
269 self.model_dim * m, divisor=self.head_dim * head_multiple_of
270 )
271 )
272 for m in qkv_multipliers
273 ]
274 else:
275 raise NotImplementedError(
276 f"QKV multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}."
277 )
278
279 # compute the number of query, key, and value heads
280 # For multi-head and multi-query attention, the number of heads for query, key, and value are the same.
281 # For group query attention, the number of key and value heads are the same.
282 self.num_query_heads = [
283 int(compute_heads(q_dim, self.head_dim)) for q_dim in query_dims
284 ]
285 self.num_kv_heads = [
286 q_heads // self.num_gqa_groups for q_heads in self.num_query_heads
287 ]
288
289 # Feed-forward network (FFN) multipliers
290 if isinstance(self.ffn_multipliers, Number):
291 # All FFN layers have the same latent dimensions, resulting in uniform allocation of parameters.
292 self.ffn_multipliers = [self.ffn_multipliers] * self.num_transformer_layers
293 elif isinstance(self.ffn_multipliers, (tuple, list)):
294 # Each FFN layer have different latent dimensions assuming ffn_multipliers[0] != ffn_multipliers[1].
295 # This results in variable allocation of parameters in FFN layer.
296 # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
297 if len(self.ffn_multipliers) == 2:
298 self.ffn_multipliers = [
299 round(v, 2)
300 for v in np.linspace(
301 self.ffn_multipliers[0],
302 self.ffn_multipliers[1],
303 num=self.num_transformer_layers,
304 dtype=float,
305 )
306 ]
307 else:
308 assert (
309 len(self.ffn_multipliers) == self.num_transformer_layers
310 ), f"{len(self.ffn_multipliers)=}!={self.num_transformer_layers=}"
311 else:
312 raise NotImplementedError(
313 f"FFN multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}."
314 )
315
316 # check num_query_heads divisible by num_kv_heads for every layer
317 for layer_idx in range(len(query_dims)):
318 assert self.num_query_heads[layer_idx] % self.num_kv_heads[layer_idx] == 0
319