block.py
19.3 KB · 471 lines · python Raw
1 # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2 # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
4 # Copyright (c) 2024, Tri Dao.
5
6 from functools import partial
7 from typing import Optional
8
9 import torch
10 import torch.fx
11 import torch.nn as nn
12 import torch.nn.functional as F
13 from torch import Tensor
14
15 from .mha import MHA
16 from .mlp import Mlp
17
18 try:
19 from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
20 except ImportError:
21 layer_norm_fn, RMSNorm = None, None
22
23
24 def stochastic_depth(
25 input: Tensor, p: float, mode: str, training: bool = True
26 ) -> Tensor:
27 """
28 Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
29 <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
30 branches of residual architectures.
31 Args:
32 input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
33 being its batch i.e. a batch with ``N`` rows.
34 p (float): probability of the input to be zeroed.
35 mode (str): ``"batch"`` or ``"row"``.
36 ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
37 randomly selected rows from the batch.
38 training: apply stochastic depth if is ``True``. Default: ``True``
39 Returns:
40 Tensor[N, ...]: The randomly zeroed tensor.
41 """
42 if p < 0.0 or p > 1.0:
43 raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
44 if mode not in ["batch", "row"]:
45 raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
46 if not training or p == 0.0:
47 return input
48
49 survival_rate = 1.0 - p
50 if mode == "row":
51 size = [input.shape[0]] + [1] * (input.ndim - 1)
52 else:
53 size = [1] * input.ndim
54 noise = torch.empty(size, dtype=input.dtype, device=input.device)
55 noise = noise.bernoulli_(survival_rate)
56 if survival_rate > 0.0:
57 noise.div_(survival_rate)
58 return input * noise
59
60
61 torch.fx.wrap("stochastic_depth")
62
63
64 class StochasticDepth(nn.Module):
65 """
66 See :func:`stochastic_depth`.
67 """
68
69 def __init__(self, p: float, mode: str) -> None:
70 super().__init__()
71 self.p = p
72 self.mode = mode
73
74 def forward(self, input: Tensor) -> Tensor:
75 return stochastic_depth(input, self.p, self.mode, self.training)
76
77 def __repr__(self) -> str:
78 s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
79 return s
80
81
82 class Block(nn.Module):
83 def __init__(
84 self,
85 dim,
86 mixer_cls=None,
87 mlp_cls=None,
88 norm_cls=nn.LayerNorm,
89 dropout_cls=nn.Dropout,
90 prenorm=True,
91 resid_dropout1=0.0,
92 resid_dropout2=0.0,
93 drop_path1=0.0,
94 drop_path2=0.0,
95 fused_dropout_add_ln=False,
96 return_residual=False,
97 residual_in_fp32=False,
98 sequence_parallel=False,
99 mark_shared_params=False,
100 ):
101 """
102 For prenorm=True, this Block has a slightly different structure compared to a regular
103 prenorm Transformer block.
104 The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
105 [Ref: https://arxiv.org/abs/2002.04745]
106 Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
107 the hidden_states (output of the MLP) and the residual.
108 This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
109 The residual needs to be provided (except for the very first block).
110
111 For prenorm=False, this Block has the same structure as a regular postnorm Transformer
112 block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
113
114 return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
115 This is for performance reason: for post-norm architecture, returning the input allows us
116 to fuse the backward of nn.Linear with the residual connection.
117 """
118 super().__init__()
119 self.prenorm = prenorm
120 self.fused_dropout_add_ln = fused_dropout_add_ln
121 self.return_residual = return_residual
122 self.residual_in_fp32 = residual_in_fp32
123 if self.residual_in_fp32:
124 assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
125 if mixer_cls is None:
126 mixer_cls = partial(MHA, num_heads=dim // 64)
127 if mlp_cls is None:
128 mlp_cls = partial(Mlp, hidden_features=4 * dim)
129 self.mixer = mixer_cls(dim)
130 self.dropout1 = dropout_cls(resid_dropout1)
131 self.drop_path1 = StochasticDepth(drop_path1, mode="row")
132 self.norm1 = norm_cls(dim)
133 self.mlp = mlp_cls(dim)
134 if not isinstance(self.mlp, nn.Identity):
135 self.dropout2 = dropout_cls(resid_dropout2)
136 self.drop_path2 = StochasticDepth(drop_path2, mode="row")
137 self.norm2 = norm_cls(dim)
138
139 if self.fused_dropout_add_ln:
140 assert layer_norm_fn is not None, "Triton is not installed"
141 assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
142 self.dropout1, nn.Dropout
143 )
144
145 # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
146 # then the input to each worker in the tensor parallel group will be different.
147 # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
148 # For now this is not an issue because we always use sequence_parallel=True during training
149 # and only use sequence_parallel=False during inference.
150
151 # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
152 if sequence_parallel:
153 for p in self.norm1.parameters():
154 p._sequence_parallel = True
155 if hasattr(self, "norm2"):
156 for p in self.norm2.parameters():
157 p._sequence_parallel = True
158 # Mark the norm parameters as "shared_params" so that we sync their values at init.
159 if mark_shared_params:
160 for p in self.norm1.parameters():
161 p._shared_params = True
162 if hasattr(self, "norm2"):
163 for p in self.norm2.parameters():
164 p._shared_params = True
165
166 def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
167 return self.mixer.allocate_inference_cache(
168 batch_size, max_seqlen, dtype=dtype, **kwargs
169 )
170
171 def forward(
172 self,
173 hidden_states: Tensor,
174 residual: Optional[Tensor] = None,
175 mixer_subset=None,
176 mixer_kwargs=None,
177 ):
178 r"""Pass the input through the encoder layer.
179
180 Args:
181 hidden_states: the sequence to the encoder layer (required).
182 residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
183 mixer_subset: for cross-attention only. If not None, will take a subset of x
184 before applying the query projection. Useful for e.g., ViT where we only care
185 about the CLS token in the last layer.
186 """
187 if self.prenorm:
188 if not self.fused_dropout_add_ln:
189 dropped = self.drop_path1(self.dropout1(hidden_states))
190 residual = (dropped + residual) if residual is not None else dropped
191 hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
192 if self.residual_in_fp32:
193 residual = residual.to(torch.float32)
194 else:
195 if self.drop_path1.p == 0 or not self.training:
196 rowscale1 = None
197 else:
198 rowscale1 = self.drop_path1(
199 torch.ones(
200 hidden_states.shape[:-1],
201 device=hidden_states.device,
202 dtype=hidden_states.dtype,
203 )
204 )
205 hidden_states, residual = layer_norm_fn(
206 hidden_states,
207 self.norm1.weight,
208 self.norm1.bias,
209 residual=residual,
210 eps=self.norm1.eps,
211 dropout_p=self.dropout1.p if self.training else 0.0,
212 rowscale=rowscale1,
213 prenorm=True,
214 residual_in_fp32=self.residual_in_fp32,
215 is_rms_norm=isinstance(self.norm1, RMSNorm),
216 )
217 if mixer_kwargs is None:
218 mixer_kwargs = {}
219 if mixer_subset is not None:
220 mixer_kwargs["mixer_subset"] = mixer_subset
221 hidden_states = self.mixer(hidden_states, **mixer_kwargs)
222 if mixer_subset is not None:
223 residual = residual[:, mixer_subset]
224 if not isinstance(self.mlp, nn.Identity):
225 if not self.fused_dropout_add_ln:
226 dropped = self.drop_path2(self.dropout2(hidden_states))
227 residual = (dropped + residual) if residual is not None else dropped
228 hidden_states = self.norm2(
229 residual.to(dtype=self.norm2.weight.dtype)
230 )
231 if self.residual_in_fp32:
232 residual = residual.to(torch.float32)
233 else:
234 if self.drop_path2.p == 0 or not self.training:
235 rowscale2 = None
236 else:
237 rowscale2 = self.drop_path2(
238 torch.ones(
239 hidden_states.shape[:-1],
240 device=hidden_states.device,
241 dtype=hidden_states.dtype,
242 )
243 )
244 hidden_states, residual = layer_norm_fn(
245 hidden_states,
246 self.norm2.weight,
247 self.norm2.bias,
248 residual=residual,
249 eps=self.norm2.eps,
250 dropout_p=self.dropout2.p if self.training else 0.0,
251 rowscale=rowscale2,
252 prenorm=True,
253 residual_in_fp32=self.residual_in_fp32,
254 is_rms_norm=isinstance(self.norm2, RMSNorm),
255 )
256 hidden_states = self.mlp(hidden_states)
257 return hidden_states, residual
258 else:
259 assert residual is None
260 mixer_out = self.mixer(
261 hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
262 )
263 if self.return_residual: # mixer out is actually a pair here
264 mixer_out, hidden_states = mixer_out
265 if not self.fused_dropout_add_ln:
266 hidden_states = self.norm1(
267 (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
268 dtype=self.norm1.weight.dtype
269 )
270 )
271 else:
272 if self.drop_path1.p == 0 or not self.training:
273 rowscale1 = None
274 else:
275 rowscale1 = self.drop_path1(
276 torch.ones(
277 mixer_out.shape[:-1],
278 device=mixer_out.device,
279 dtype=mixer_out.dtype,
280 )
281 )
282 hidden_states = layer_norm_fn(
283 mixer_out,
284 self.norm1.weight,
285 self.norm1.bias,
286 residual=hidden_states,
287 eps=self.norm1.eps,
288 dropout_p=self.dropout1.p if self.training else 0.0,
289 rowscale=rowscale1,
290 prenorm=False,
291 is_rms_norm=isinstance(self.norm1, RMSNorm),
292 )
293 if not isinstance(self.mlp, nn.Identity):
294 mlp_out = self.mlp(hidden_states)
295 if self.return_residual: # mlp out is actually a pair here
296 mlp_out, hidden_states = mlp_out
297 if not self.fused_dropout_add_ln:
298 hidden_states = self.norm2(
299 (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
300 dtype=self.norm2.weight.dtype
301 )
302 )
303 else:
304 if self.drop_path2.p == 0 or not self.training:
305 rowscale2 = None
306 else:
307 rowscale2 = self.drop_path2(
308 torch.ones(
309 mlp_out.shape[:-1],
310 device=mlp_out.device,
311 dtype=mlp_out.dtype,
312 )
313 )
314 hidden_states = layer_norm_fn(
315 mlp_out,
316 self.norm2.weight,
317 self.norm2.bias,
318 residual=hidden_states,
319 eps=self.norm2.eps,
320 dropout_p=self.dropout2.p if self.training else 0.0,
321 rowscale=rowscale2,
322 prenorm=False,
323 is_rms_norm=isinstance(self.norm2, RMSNorm),
324 )
325 return hidden_states
326
327
328 class ParallelBlock(nn.Module):
329 """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
330 and PaLM.
331 """
332
333 def __init__(
334 self,
335 dim,
336 mixer_cls=None,
337 mlp_cls=None,
338 norm_cls=nn.LayerNorm,
339 dropout_cls=nn.Dropout,
340 resid_dropout1=0.0,
341 resid_dropout2=0.0,
342 tied_norm=False,
343 fused_dropout_add_ln=False,
344 residual_in_fp32=False,
345 sequence_parallel=False,
346 mark_shared_params=False,
347 ):
348 """
349 This Block has a slightly different structure compared to a regular
350 prenorm Transformer block.
351 The standard block is: LN -> MHA / MLP -> Dropout -> Add.
352 [Ref: https://arxiv.org/abs/2002.04745]
353 Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
354 the hidden_states (output1 of the MHA / MLP) and the residual.
355 This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
356 The residual needs to be provided (except for the very first block).
357 """
358 super().__init__()
359 self.tied_norm = tied_norm
360 self.fused_dropout_add_ln = fused_dropout_add_ln
361 self.residual_in_fp32 = residual_in_fp32
362 if mixer_cls is None:
363 mixer_cls = partial(MHA, num_heads=dim // 64)
364 if mlp_cls is None:
365 mlp_cls = partial(Mlp, hidden_features=4 * dim)
366 self.mixer = mixer_cls(dim)
367 self.dropout1 = dropout_cls(resid_dropout1)
368 self.norm1 = norm_cls(dim)
369 self.mlp = mlp_cls(dim)
370 self.dropout2 = dropout_cls(resid_dropout2)
371 if not self.tied_norm:
372 self.norm2 = norm_cls(dim)
373
374 if self.fused_dropout_add_ln:
375 assert layer_norm_fn is not None, "Triton is not installed"
376 assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
377 self.dropout1, nn.Dropout
378 )
379
380 # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
381 # then the input to each worker in the tensor parallel group will be different.
382 # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
383 # For now this is not an issue because we always use sequence_parallel=True during training
384 # and only use sequence_parallel=False during inference.
385
386 # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
387 if sequence_parallel:
388 for p in self.norm1.parameters():
389 p._sequence_parallel = True
390 if hasattr(self, "norm2"):
391 for p in self.norm2.parameters():
392 p._sequence_parallel = True
393 # Mark the norm parameters as "shared_params" so that we sync their values at init.
394 if mark_shared_params:
395 for p in self.norm1.parameters():
396 p._shared_params = True
397 if hasattr(self, "norm2"):
398 for p in self.norm2.parameters():
399 p._shared_params = True
400
401 def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
402 return self.mixer.allocate_inference_cache(
403 batch_size, max_seqlen, dtype=dtype, **kwargs
404 )
405
406 def forward(
407 self,
408 hidden_states1: Tensor,
409 hidden_states2: Optional[Tensor] = None,
410 residual: Optional[Tensor] = None,
411 mixer_kwargs=None,
412 ):
413 r"""Pass the input through the encoder layer.
414
415 Args:
416 hidden_states1: the output of the previous attention (mixer) or embedding layer.
417 hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
418 residual.
419 """
420 # TODO: Ideally we should only do the allgather / allreduce once for
421 # the Linear to MLP & Attention
422 if not self.fused_dropout_add_ln:
423 dropped1 = self.dropout1(hidden_states1)
424 # For the very 1st block, we only want 1 dropout, not two different dropouts
425 if hidden_states2 is not None:
426 dropped2 = self.dropout2(hidden_states2)
427 residual = (
428 (residual + dropped1 + dropped2)
429 if residual is not None
430 else dropped1 + dropped2
431 )
432 else:
433 residual = (residual + dropped1) if residual is not None else dropped1
434 hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
435 hidden_states2 = (
436 self.norm2(residual.to(dtype=self.norm2.weight.dtype))
437 if not self.tied_norm
438 else hidden_states1
439 )
440 if self.residual_in_fp32:
441 residual = residual.to(torch.float32)
442 else:
443 weight2, bias2 = (
444 (self.norm2.weight, self.norm2.bias)
445 if not self.tied_norm
446 else (None, None)
447 )
448 hidden_states1, *rest, residual = layer_norm_fn(
449 hidden_states1,
450 self.norm1.weight,
451 self.norm1.bias,
452 residual=residual,
453 x1=hidden_states2,
454 weight1=weight2,
455 bias1=bias2,
456 eps=self.norm1.eps,
457 dropout_p=self.dropout1.p if self.training else 0.0,
458 prenorm=True,
459 residual_in_fp32=self.residual_in_fp32,
460 is_rms_norm=isinstance(self.norm1, RMSNorm),
461 )
462 if self.tied_norm:
463 hidden_states2 = hidden_states1
464 else:
465 (hidden_states2,) = rest
466 if mixer_kwargs is None:
467 mixer_kwargs = {}
468 hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
469 hidden_states2 = self.mlp(hidden_states2)
470 return hidden_states1, hidden_states2, residual
471