block.py
| 1 | # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py |
| 2 | # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48 |
| 3 | |
| 4 | # Copyright (c) 2024, Tri Dao. |
| 5 | |
| 6 | from functools import partial |
| 7 | from typing import Optional |
| 8 | |
| 9 | import torch |
| 10 | import torch.fx |
| 11 | import torch.nn as nn |
| 12 | import torch.nn.functional as F |
| 13 | from torch import Tensor |
| 14 | |
| 15 | from .mha import MHA |
| 16 | from .mlp import Mlp |
| 17 | |
| 18 | try: |
| 19 | from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm |
| 20 | except ImportError: |
| 21 | layer_norm_fn, RMSNorm = None, None |
| 22 | |
| 23 | |
| 24 | def stochastic_depth( |
| 25 | input: Tensor, p: float, mode: str, training: bool = True |
| 26 | ) -> Tensor: |
| 27 | """ |
| 28 | Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth" |
| 29 | <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual |
| 30 | branches of residual architectures. |
| 31 | Args: |
| 32 | input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one |
| 33 | being its batch i.e. a batch with ``N`` rows. |
| 34 | p (float): probability of the input to be zeroed. |
| 35 | mode (str): ``"batch"`` or ``"row"``. |
| 36 | ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes |
| 37 | randomly selected rows from the batch. |
| 38 | training: apply stochastic depth if is ``True``. Default: ``True`` |
| 39 | Returns: |
| 40 | Tensor[N, ...]: The randomly zeroed tensor. |
| 41 | """ |
| 42 | if p < 0.0 or p > 1.0: |
| 43 | raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") |
| 44 | if mode not in ["batch", "row"]: |
| 45 | raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}") |
| 46 | if not training or p == 0.0: |
| 47 | return input |
| 48 | |
| 49 | survival_rate = 1.0 - p |
| 50 | if mode == "row": |
| 51 | size = [input.shape[0]] + [1] * (input.ndim - 1) |
| 52 | else: |
| 53 | size = [1] * input.ndim |
| 54 | noise = torch.empty(size, dtype=input.dtype, device=input.device) |
| 55 | noise = noise.bernoulli_(survival_rate) |
| 56 | if survival_rate > 0.0: |
| 57 | noise.div_(survival_rate) |
| 58 | return input * noise |
| 59 | |
| 60 | |
| 61 | torch.fx.wrap("stochastic_depth") |
| 62 | |
| 63 | |
| 64 | class StochasticDepth(nn.Module): |
| 65 | """ |
| 66 | See :func:`stochastic_depth`. |
| 67 | """ |
| 68 | |
| 69 | def __init__(self, p: float, mode: str) -> None: |
| 70 | super().__init__() |
| 71 | self.p = p |
| 72 | self.mode = mode |
| 73 | |
| 74 | def forward(self, input: Tensor) -> Tensor: |
| 75 | return stochastic_depth(input, self.p, self.mode, self.training) |
| 76 | |
| 77 | def __repr__(self) -> str: |
| 78 | s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})" |
| 79 | return s |
| 80 | |
| 81 | |
| 82 | class Block(nn.Module): |
| 83 | def __init__( |
| 84 | self, |
| 85 | dim, |
| 86 | mixer_cls=None, |
| 87 | mlp_cls=None, |
| 88 | norm_cls=nn.LayerNorm, |
| 89 | dropout_cls=nn.Dropout, |
| 90 | prenorm=True, |
| 91 | resid_dropout1=0.0, |
| 92 | resid_dropout2=0.0, |
| 93 | drop_path1=0.0, |
| 94 | drop_path2=0.0, |
| 95 | fused_dropout_add_ln=False, |
| 96 | return_residual=False, |
| 97 | residual_in_fp32=False, |
| 98 | sequence_parallel=False, |
| 99 | mark_shared_params=False, |
| 100 | ): |
| 101 | """ |
| 102 | For prenorm=True, this Block has a slightly different structure compared to a regular |
| 103 | prenorm Transformer block. |
| 104 | The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. |
| 105 | [Ref: https://arxiv.org/abs/2002.04745] |
| 106 | Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both |
| 107 | the hidden_states (output of the MLP) and the residual. |
| 108 | This is for performance reasons, as we can fuse the dropout, add and LayerNorm. |
| 109 | The residual needs to be provided (except for the very first block). |
| 110 | |
| 111 | For prenorm=False, this Block has the same structure as a regular postnorm Transformer |
| 112 | block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN. |
| 113 | |
| 114 | return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. |
| 115 | This is for performance reason: for post-norm architecture, returning the input allows us |
| 116 | to fuse the backward of nn.Linear with the residual connection. |
| 117 | """ |
| 118 | super().__init__() |
| 119 | self.prenorm = prenorm |
| 120 | self.fused_dropout_add_ln = fused_dropout_add_ln |
| 121 | self.return_residual = return_residual |
| 122 | self.residual_in_fp32 = residual_in_fp32 |
| 123 | if self.residual_in_fp32: |
| 124 | assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True" |
| 125 | if mixer_cls is None: |
| 126 | mixer_cls = partial(MHA, num_heads=dim // 64) |
| 127 | if mlp_cls is None: |
| 128 | mlp_cls = partial(Mlp, hidden_features=4 * dim) |
| 129 | self.mixer = mixer_cls(dim) |
| 130 | self.dropout1 = dropout_cls(resid_dropout1) |
| 131 | self.drop_path1 = StochasticDepth(drop_path1, mode="row") |
| 132 | self.norm1 = norm_cls(dim) |
| 133 | self.mlp = mlp_cls(dim) |
| 134 | if not isinstance(self.mlp, nn.Identity): |
| 135 | self.dropout2 = dropout_cls(resid_dropout2) |
| 136 | self.drop_path2 = StochasticDepth(drop_path2, mode="row") |
| 137 | self.norm2 = norm_cls(dim) |
| 138 | |
| 139 | if self.fused_dropout_add_ln: |
| 140 | assert layer_norm_fn is not None, "Triton is not installed" |
| 141 | assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( |
| 142 | self.dropout1, nn.Dropout |
| 143 | ) |
| 144 | |
| 145 | # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, |
| 146 | # then the input to each worker in the tensor parallel group will be different. |
| 147 | # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. |
| 148 | # For now this is not an issue because we always use sequence_parallel=True during training |
| 149 | # and only use sequence_parallel=False during inference. |
| 150 | |
| 151 | # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. |
| 152 | if sequence_parallel: |
| 153 | for p in self.norm1.parameters(): |
| 154 | p._sequence_parallel = True |
| 155 | if hasattr(self, "norm2"): |
| 156 | for p in self.norm2.parameters(): |
| 157 | p._sequence_parallel = True |
| 158 | # Mark the norm parameters as "shared_params" so that we sync their values at init. |
| 159 | if mark_shared_params: |
| 160 | for p in self.norm1.parameters(): |
| 161 | p._shared_params = True |
| 162 | if hasattr(self, "norm2"): |
| 163 | for p in self.norm2.parameters(): |
| 164 | p._shared_params = True |
| 165 | |
| 166 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
| 167 | return self.mixer.allocate_inference_cache( |
| 168 | batch_size, max_seqlen, dtype=dtype, **kwargs |
| 169 | ) |
| 170 | |
| 171 | def forward( |
| 172 | self, |
| 173 | hidden_states: Tensor, |
| 174 | residual: Optional[Tensor] = None, |
| 175 | mixer_subset=None, |
| 176 | mixer_kwargs=None, |
| 177 | ): |
| 178 | r"""Pass the input through the encoder layer. |
| 179 | |
| 180 | Args: |
| 181 | hidden_states: the sequence to the encoder layer (required). |
| 182 | residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) |
| 183 | mixer_subset: for cross-attention only. If not None, will take a subset of x |
| 184 | before applying the query projection. Useful for e.g., ViT where we only care |
| 185 | about the CLS token in the last layer. |
| 186 | """ |
| 187 | if self.prenorm: |
| 188 | if not self.fused_dropout_add_ln: |
| 189 | dropped = self.drop_path1(self.dropout1(hidden_states)) |
| 190 | residual = (dropped + residual) if residual is not None else dropped |
| 191 | hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) |
| 192 | if self.residual_in_fp32: |
| 193 | residual = residual.to(torch.float32) |
| 194 | else: |
| 195 | if self.drop_path1.p == 0 or not self.training: |
| 196 | rowscale1 = None |
| 197 | else: |
| 198 | rowscale1 = self.drop_path1( |
| 199 | torch.ones( |
| 200 | hidden_states.shape[:-1], |
| 201 | device=hidden_states.device, |
| 202 | dtype=hidden_states.dtype, |
| 203 | ) |
| 204 | ) |
| 205 | hidden_states, residual = layer_norm_fn( |
| 206 | hidden_states, |
| 207 | self.norm1.weight, |
| 208 | self.norm1.bias, |
| 209 | residual=residual, |
| 210 | eps=self.norm1.eps, |
| 211 | dropout_p=self.dropout1.p if self.training else 0.0, |
| 212 | rowscale=rowscale1, |
| 213 | prenorm=True, |
| 214 | residual_in_fp32=self.residual_in_fp32, |
| 215 | is_rms_norm=isinstance(self.norm1, RMSNorm), |
| 216 | ) |
| 217 | if mixer_kwargs is None: |
| 218 | mixer_kwargs = {} |
| 219 | if mixer_subset is not None: |
| 220 | mixer_kwargs["mixer_subset"] = mixer_subset |
| 221 | hidden_states = self.mixer(hidden_states, **mixer_kwargs) |
| 222 | if mixer_subset is not None: |
| 223 | residual = residual[:, mixer_subset] |
| 224 | if not isinstance(self.mlp, nn.Identity): |
| 225 | if not self.fused_dropout_add_ln: |
| 226 | dropped = self.drop_path2(self.dropout2(hidden_states)) |
| 227 | residual = (dropped + residual) if residual is not None else dropped |
| 228 | hidden_states = self.norm2( |
| 229 | residual.to(dtype=self.norm2.weight.dtype) |
| 230 | ) |
| 231 | if self.residual_in_fp32: |
| 232 | residual = residual.to(torch.float32) |
| 233 | else: |
| 234 | if self.drop_path2.p == 0 or not self.training: |
| 235 | rowscale2 = None |
| 236 | else: |
| 237 | rowscale2 = self.drop_path2( |
| 238 | torch.ones( |
| 239 | hidden_states.shape[:-1], |
| 240 | device=hidden_states.device, |
| 241 | dtype=hidden_states.dtype, |
| 242 | ) |
| 243 | ) |
| 244 | hidden_states, residual = layer_norm_fn( |
| 245 | hidden_states, |
| 246 | self.norm2.weight, |
| 247 | self.norm2.bias, |
| 248 | residual=residual, |
| 249 | eps=self.norm2.eps, |
| 250 | dropout_p=self.dropout2.p if self.training else 0.0, |
| 251 | rowscale=rowscale2, |
| 252 | prenorm=True, |
| 253 | residual_in_fp32=self.residual_in_fp32, |
| 254 | is_rms_norm=isinstance(self.norm2, RMSNorm), |
| 255 | ) |
| 256 | hidden_states = self.mlp(hidden_states) |
| 257 | return hidden_states, residual |
| 258 | else: |
| 259 | assert residual is None |
| 260 | mixer_out = self.mixer( |
| 261 | hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {}) |
| 262 | ) |
| 263 | if self.return_residual: # mixer out is actually a pair here |
| 264 | mixer_out, hidden_states = mixer_out |
| 265 | if not self.fused_dropout_add_ln: |
| 266 | hidden_states = self.norm1( |
| 267 | (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to( |
| 268 | dtype=self.norm1.weight.dtype |
| 269 | ) |
| 270 | ) |
| 271 | else: |
| 272 | if self.drop_path1.p == 0 or not self.training: |
| 273 | rowscale1 = None |
| 274 | else: |
| 275 | rowscale1 = self.drop_path1( |
| 276 | torch.ones( |
| 277 | mixer_out.shape[:-1], |
| 278 | device=mixer_out.device, |
| 279 | dtype=mixer_out.dtype, |
| 280 | ) |
| 281 | ) |
| 282 | hidden_states = layer_norm_fn( |
| 283 | mixer_out, |
| 284 | self.norm1.weight, |
| 285 | self.norm1.bias, |
| 286 | residual=hidden_states, |
| 287 | eps=self.norm1.eps, |
| 288 | dropout_p=self.dropout1.p if self.training else 0.0, |
| 289 | rowscale=rowscale1, |
| 290 | prenorm=False, |
| 291 | is_rms_norm=isinstance(self.norm1, RMSNorm), |
| 292 | ) |
| 293 | if not isinstance(self.mlp, nn.Identity): |
| 294 | mlp_out = self.mlp(hidden_states) |
| 295 | if self.return_residual: # mlp out is actually a pair here |
| 296 | mlp_out, hidden_states = mlp_out |
| 297 | if not self.fused_dropout_add_ln: |
| 298 | hidden_states = self.norm2( |
| 299 | (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to( |
| 300 | dtype=self.norm2.weight.dtype |
| 301 | ) |
| 302 | ) |
| 303 | else: |
| 304 | if self.drop_path2.p == 0 or not self.training: |
| 305 | rowscale2 = None |
| 306 | else: |
| 307 | rowscale2 = self.drop_path2( |
| 308 | torch.ones( |
| 309 | mlp_out.shape[:-1], |
| 310 | device=mlp_out.device, |
| 311 | dtype=mlp_out.dtype, |
| 312 | ) |
| 313 | ) |
| 314 | hidden_states = layer_norm_fn( |
| 315 | mlp_out, |
| 316 | self.norm2.weight, |
| 317 | self.norm2.bias, |
| 318 | residual=hidden_states, |
| 319 | eps=self.norm2.eps, |
| 320 | dropout_p=self.dropout2.p if self.training else 0.0, |
| 321 | rowscale=rowscale2, |
| 322 | prenorm=False, |
| 323 | is_rms_norm=isinstance(self.norm2, RMSNorm), |
| 324 | ) |
| 325 | return hidden_states |
| 326 | |
| 327 | |
| 328 | class ParallelBlock(nn.Module): |
| 329 | """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX, |
| 330 | and PaLM. |
| 331 | """ |
| 332 | |
| 333 | def __init__( |
| 334 | self, |
| 335 | dim, |
| 336 | mixer_cls=None, |
| 337 | mlp_cls=None, |
| 338 | norm_cls=nn.LayerNorm, |
| 339 | dropout_cls=nn.Dropout, |
| 340 | resid_dropout1=0.0, |
| 341 | resid_dropout2=0.0, |
| 342 | tied_norm=False, |
| 343 | fused_dropout_add_ln=False, |
| 344 | residual_in_fp32=False, |
| 345 | sequence_parallel=False, |
| 346 | mark_shared_params=False, |
| 347 | ): |
| 348 | """ |
| 349 | This Block has a slightly different structure compared to a regular |
| 350 | prenorm Transformer block. |
| 351 | The standard block is: LN -> MHA / MLP -> Dropout -> Add. |
| 352 | [Ref: https://arxiv.org/abs/2002.04745] |
| 353 | Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both |
| 354 | the hidden_states (output1 of the MHA / MLP) and the residual. |
| 355 | This is for performance reasons, as we can fuse the dropout, add and LayerNorm. |
| 356 | The residual needs to be provided (except for the very first block). |
| 357 | """ |
| 358 | super().__init__() |
| 359 | self.tied_norm = tied_norm |
| 360 | self.fused_dropout_add_ln = fused_dropout_add_ln |
| 361 | self.residual_in_fp32 = residual_in_fp32 |
| 362 | if mixer_cls is None: |
| 363 | mixer_cls = partial(MHA, num_heads=dim // 64) |
| 364 | if mlp_cls is None: |
| 365 | mlp_cls = partial(Mlp, hidden_features=4 * dim) |
| 366 | self.mixer = mixer_cls(dim) |
| 367 | self.dropout1 = dropout_cls(resid_dropout1) |
| 368 | self.norm1 = norm_cls(dim) |
| 369 | self.mlp = mlp_cls(dim) |
| 370 | self.dropout2 = dropout_cls(resid_dropout2) |
| 371 | if not self.tied_norm: |
| 372 | self.norm2 = norm_cls(dim) |
| 373 | |
| 374 | if self.fused_dropout_add_ln: |
| 375 | assert layer_norm_fn is not None, "Triton is not installed" |
| 376 | assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( |
| 377 | self.dropout1, nn.Dropout |
| 378 | ) |
| 379 | |
| 380 | # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, |
| 381 | # then the input to each worker in the tensor parallel group will be different. |
| 382 | # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. |
| 383 | # For now this is not an issue because we always use sequence_parallel=True during training |
| 384 | # and only use sequence_parallel=False during inference. |
| 385 | |
| 386 | # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. |
| 387 | if sequence_parallel: |
| 388 | for p in self.norm1.parameters(): |
| 389 | p._sequence_parallel = True |
| 390 | if hasattr(self, "norm2"): |
| 391 | for p in self.norm2.parameters(): |
| 392 | p._sequence_parallel = True |
| 393 | # Mark the norm parameters as "shared_params" so that we sync their values at init. |
| 394 | if mark_shared_params: |
| 395 | for p in self.norm1.parameters(): |
| 396 | p._shared_params = True |
| 397 | if hasattr(self, "norm2"): |
| 398 | for p in self.norm2.parameters(): |
| 399 | p._shared_params = True |
| 400 | |
| 401 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
| 402 | return self.mixer.allocate_inference_cache( |
| 403 | batch_size, max_seqlen, dtype=dtype, **kwargs |
| 404 | ) |
| 405 | |
| 406 | def forward( |
| 407 | self, |
| 408 | hidden_states1: Tensor, |
| 409 | hidden_states2: Optional[Tensor] = None, |
| 410 | residual: Optional[Tensor] = None, |
| 411 | mixer_kwargs=None, |
| 412 | ): |
| 413 | r"""Pass the input through the encoder layer. |
| 414 | |
| 415 | Args: |
| 416 | hidden_states1: the output of the previous attention (mixer) or embedding layer. |
| 417 | hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1). |
| 418 | residual. |
| 419 | """ |
| 420 | # TODO: Ideally we should only do the allgather / allreduce once for |
| 421 | # the Linear to MLP & Attention |
| 422 | if not self.fused_dropout_add_ln: |
| 423 | dropped1 = self.dropout1(hidden_states1) |
| 424 | # For the very 1st block, we only want 1 dropout, not two different dropouts |
| 425 | if hidden_states2 is not None: |
| 426 | dropped2 = self.dropout2(hidden_states2) |
| 427 | residual = ( |
| 428 | (residual + dropped1 + dropped2) |
| 429 | if residual is not None |
| 430 | else dropped1 + dropped2 |
| 431 | ) |
| 432 | else: |
| 433 | residual = (residual + dropped1) if residual is not None else dropped1 |
| 434 | hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) |
| 435 | hidden_states2 = ( |
| 436 | self.norm2(residual.to(dtype=self.norm2.weight.dtype)) |
| 437 | if not self.tied_norm |
| 438 | else hidden_states1 |
| 439 | ) |
| 440 | if self.residual_in_fp32: |
| 441 | residual = residual.to(torch.float32) |
| 442 | else: |
| 443 | weight2, bias2 = ( |
| 444 | (self.norm2.weight, self.norm2.bias) |
| 445 | if not self.tied_norm |
| 446 | else (None, None) |
| 447 | ) |
| 448 | hidden_states1, *rest, residual = layer_norm_fn( |
| 449 | hidden_states1, |
| 450 | self.norm1.weight, |
| 451 | self.norm1.bias, |
| 452 | residual=residual, |
| 453 | x1=hidden_states2, |
| 454 | weight1=weight2, |
| 455 | bias1=bias2, |
| 456 | eps=self.norm1.eps, |
| 457 | dropout_p=self.dropout1.p if self.training else 0.0, |
| 458 | prenorm=True, |
| 459 | residual_in_fp32=self.residual_in_fp32, |
| 460 | is_rms_norm=isinstance(self.norm1, RMSNorm), |
| 461 | ) |
| 462 | if self.tied_norm: |
| 463 | hidden_states2 = hidden_states1 |
| 464 | else: |
| 465 | (hidden_states2,) = rest |
| 466 | if mixer_kwargs is None: |
| 467 | mixer_kwargs = {} |
| 468 | hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs) |
| 469 | hidden_states2 = self.mlp(hidden_states2) |
| 470 | return hidden_states1, hidden_states2, residual |
| 471 | |