mha.py
| 1 | # Copyright (c) 2023, Tri Dao. |
| 2 | # Adapted from https://github.com/Dao-AILab/flash-attention/pull/556 |
| 3 | |
| 4 | import math |
| 5 | from functools import partial |
| 6 | |
| 7 | import torch |
| 8 | import torch.nn as nn |
| 9 | from einops import rearrange, repeat |
| 10 | |
| 11 | try: |
| 12 | from flash_attn import ( |
| 13 | flash_attn_kvpacked_func, |
| 14 | flash_attn_qkvpacked_func, |
| 15 | flash_attn_varlen_kvpacked_func, |
| 16 | flash_attn_varlen_qkvpacked_func, |
| 17 | flash_attn_with_kvcache, |
| 18 | ) |
| 19 | except ImportError: |
| 20 | flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None |
| 21 | flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None |
| 22 | flash_attn_with_kvcache = None |
| 23 | |
| 24 | try: |
| 25 | from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear |
| 26 | except ImportError: |
| 27 | FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None |
| 28 | |
| 29 | |
| 30 | class FlashSelfAttention(nn.Module): |
| 31 | """Implement the scaled dot product attention with softmax. |
| 32 | Arguments |
| 33 | --------- |
| 34 | softmax_scale: The temperature to use for the softmax attention. |
| 35 | (default: 1/sqrt(d_keys) where d_keys is computed at |
| 36 | runtime) |
| 37 | attention_dropout: The dropout rate to apply to the attention |
| 38 | (default: 0.0) |
| 39 | """ |
| 40 | |
| 41 | def __init__( |
| 42 | self, |
| 43 | causal=False, |
| 44 | softmax_scale=None, |
| 45 | attention_dropout=0.0, |
| 46 | window_size=(-1, -1), |
| 47 | deterministic=False, |
| 48 | ): |
| 49 | super().__init__() |
| 50 | assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed" |
| 51 | assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed" |
| 52 | self.causal = causal |
| 53 | self.softmax_scale = softmax_scale |
| 54 | self.drop = nn.Dropout(attention_dropout) |
| 55 | self.window_size = window_size |
| 56 | self.deterministic = deterministic |
| 57 | |
| 58 | def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): |
| 59 | """Implements the multihead softmax attention. |
| 60 | Arguments |
| 61 | --------- |
| 62 | qkv: The tensor containing the query, key, and value. |
| 63 | If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). |
| 64 | If cu_seqlens is not None and max_seqlen is not None, then qkv has shape |
| 65 | (total, 3, H, D), where total is the sum of the sequence lengths in the batch. |
| 66 | causal: if passed, will override self.causal |
| 67 | cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
| 68 | of the sequences in the batch, used to index into qkv. |
| 69 | max_seqlen: int. Maximum sequence length in the batch. |
| 70 | Returns: |
| 71 | -------- |
| 72 | out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, |
| 73 | else (B, S, H, D). |
| 74 | """ |
| 75 | assert qkv.dtype in [torch.float16, torch.bfloat16] |
| 76 | assert qkv.is_cuda |
| 77 | causal = self.causal if causal is None else causal |
| 78 | unpadded = cu_seqlens is not None |
| 79 | |
| 80 | if unpadded: |
| 81 | assert cu_seqlens.dtype == torch.int32 |
| 82 | assert max_seqlen is not None |
| 83 | assert isinstance(max_seqlen, int) |
| 84 | return flash_attn_varlen_qkvpacked_func( |
| 85 | qkv, |
| 86 | cu_seqlens, |
| 87 | max_seqlen, |
| 88 | self.drop.p if self.training else 0.0, |
| 89 | softmax_scale=self.softmax_scale, |
| 90 | causal=causal, |
| 91 | alibi_slopes=None, |
| 92 | window_size=self.window_size, |
| 93 | deterministic=self.deterministic, |
| 94 | ) |
| 95 | else: |
| 96 | return flash_attn_qkvpacked_func( |
| 97 | qkv, |
| 98 | self.drop.p if self.training else 0.0, |
| 99 | softmax_scale=self.softmax_scale, |
| 100 | causal=causal, |
| 101 | alibi_slopes=None, |
| 102 | window_size=self.window_size, |
| 103 | deterministic=self.deterministic, |
| 104 | ) |
| 105 | |
| 106 | |
| 107 | class FlashCrossAttention(nn.Module): |
| 108 | """Implement the scaled dot product attention with softmax. |
| 109 | Arguments |
| 110 | --------- |
| 111 | softmax_scale: The temperature to use for the softmax attention. |
| 112 | (default: 1/sqrt(d_keys) where d_keys is computed at |
| 113 | runtime) |
| 114 | attention_dropout: The dropout rate to apply to the attention |
| 115 | (default: 0.0) |
| 116 | """ |
| 117 | |
| 118 | def __init__( |
| 119 | self, |
| 120 | causal=False, |
| 121 | softmax_scale=None, |
| 122 | attention_dropout=0.0, |
| 123 | window_size=(-1, -1), |
| 124 | deterministic=False, |
| 125 | ): |
| 126 | super().__init__() |
| 127 | assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed" |
| 128 | assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed" |
| 129 | self.causal = causal |
| 130 | self.softmax_scale = softmax_scale |
| 131 | self.drop = nn.Dropout(attention_dropout) |
| 132 | self.window_size = window_size |
| 133 | self.deterministic = deterministic |
| 134 | |
| 135 | def forward( |
| 136 | self, |
| 137 | q, |
| 138 | kv, |
| 139 | causal=None, |
| 140 | cu_seqlens=None, |
| 141 | max_seqlen=None, |
| 142 | cu_seqlens_k=None, |
| 143 | max_seqlen_k=None, |
| 144 | ): |
| 145 | """Implements the multihead softmax attention. |
| 146 | Arguments |
| 147 | --------- |
| 148 | q: The tensor containing the query. (B, Sq, H, D) |
| 149 | kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) |
| 150 | causal: if passed, will override self.causal |
| 151 | cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
| 152 | of the sequences in the batch, used to index into q. |
| 153 | max_seqlen: int. Maximum sequence length in the batch of q. |
| 154 | cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
| 155 | of the sequences in the batch, used to index into kv. |
| 156 | max_seqlen_k: int. Maximum sequence length in the batch of k and v. |
| 157 | """ |
| 158 | assert q.dtype in [torch.float16, torch.bfloat16] |
| 159 | assert q.is_cuda and kv.is_cuda |
| 160 | causal = self.causal if causal is None else causal |
| 161 | unpadded = cu_seqlens is not None |
| 162 | |
| 163 | if unpadded: |
| 164 | assert cu_seqlens.dtype == torch.int32 |
| 165 | assert max_seqlen is not None |
| 166 | assert isinstance(max_seqlen, int) |
| 167 | assert cu_seqlens_k is not None |
| 168 | assert cu_seqlens_k.dtype == torch.int32 |
| 169 | assert max_seqlen_k is not None |
| 170 | assert isinstance(max_seqlen, int) |
| 171 | return flash_attn_varlen_kvpacked_func( |
| 172 | q, |
| 173 | kv, |
| 174 | cu_seqlens, |
| 175 | cu_seqlens_k, |
| 176 | max_seqlen, |
| 177 | max_seqlen_k, |
| 178 | self.drop.p if self.training else 0.0, |
| 179 | softmax_scale=self.softmax_scale, |
| 180 | causal=causal, |
| 181 | alibi_slopes=None, |
| 182 | window_size=self.window_size, |
| 183 | deterministic=self.deterministic, |
| 184 | ) |
| 185 | else: |
| 186 | batch_size, seqlen_q = q.shape[0], q.shape[1] |
| 187 | seqlen_k = kv.shape[1] |
| 188 | assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] |
| 189 | return flash_attn_kvpacked_func( |
| 190 | q, |
| 191 | kv, |
| 192 | self.drop.p if self.training else 0.0, |
| 193 | causal=causal, |
| 194 | softmax_scale=self.softmax_scale, |
| 195 | alibi_slopes=None, |
| 196 | window_size=self.window_size, |
| 197 | deterministic=self.deterministic, |
| 198 | ) |
| 199 | |
| 200 | |
| 201 | class SelfAttention(nn.Module): |
| 202 | """Implement the scaled dot product attention with softmax. |
| 203 | Arguments |
| 204 | --------- |
| 205 | softmax_scale: The temperature to use for the softmax attention. |
| 206 | (default: 1/sqrt(d_keys) where d_keys is computed at |
| 207 | runtime) |
| 208 | attention_dropout: The dropout rate to apply to the attention |
| 209 | (default: 0.0) |
| 210 | """ |
| 211 | |
| 212 | def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): |
| 213 | super().__init__() |
| 214 | self.causal = causal |
| 215 | self.softmax_scale = softmax_scale |
| 216 | self.drop = nn.Dropout(attention_dropout) |
| 217 | |
| 218 | def forward(self, qkv, causal=None, key_padding_mask=None): |
| 219 | """Implements the multihead softmax attention. |
| 220 | Arguments |
| 221 | --------- |
| 222 | qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) |
| 223 | causal: if passed, will override self.causal |
| 224 | key_padding_mask: boolean mask to apply to the attention weights. True means to keep, |
| 225 | False means to mask out. (B, S) |
| 226 | """ |
| 227 | batch_size, seqlen = qkv.shape[0], qkv.shape[1] |
| 228 | causal = self.causal if causal is None else causal |
| 229 | q, k, v = qkv.unbind(dim=2) |
| 230 | softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) |
| 231 | scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) |
| 232 | if key_padding_mask is not None: |
| 233 | padding_mask = torch.full( |
| 234 | (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device |
| 235 | ) |
| 236 | padding_mask.masked_fill_(key_padding_mask, 0.0) |
| 237 | # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) |
| 238 | scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") |
| 239 | if causal: |
| 240 | # "triu_tril_cuda_template" not implemented for 'BFloat16' |
| 241 | # So we have to construct the mask in float |
| 242 | causal_mask = torch.triu( |
| 243 | torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1 |
| 244 | ) |
| 245 | # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) |
| 246 | scores = scores + causal_mask.to(dtype=scores.dtype) |
| 247 | attention = torch.softmax(scores, dim=-1, dtype=v.dtype) |
| 248 | attention_drop = self.drop(attention) |
| 249 | output = torch.einsum("bhts,bshd->bthd", attention_drop, v) |
| 250 | return output |
| 251 | |
| 252 | |
| 253 | class CrossAttention(nn.Module): |
| 254 | """Implement the scaled dot product attention with softmax. |
| 255 | Arguments |
| 256 | --------- |
| 257 | softmax_scale: The temperature to use for the softmax attention. |
| 258 | (default: 1/sqrt(d_keys) where d_keys is computed at |
| 259 | runtime) |
| 260 | attention_dropout: The dropout rate to apply to the attention |
| 261 | (default: 0.0) |
| 262 | """ |
| 263 | |
| 264 | def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): |
| 265 | super().__init__() |
| 266 | self.causal = causal |
| 267 | self.softmax_scale = softmax_scale |
| 268 | self.drop = nn.Dropout(attention_dropout) |
| 269 | |
| 270 | def forward(self, q, kv, causal=None, key_padding_mask=None): |
| 271 | """Implements the multihead softmax attention. |
| 272 | Arguments |
| 273 | --------- |
| 274 | q: The tensor containing the query. (B, Sq, H, D) |
| 275 | kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) |
| 276 | causal: if passed, will override self.causal |
| 277 | key_padding_mask: boolean mask to apply to the attention weights. True means to keep, |
| 278 | False means to mask out. (B, Sk) |
| 279 | """ |
| 280 | batch_size, seqlen_q = q.shape[0], q.shape[1] |
| 281 | causal = self.causal if causal is None else causal |
| 282 | seqlen_k = kv.shape[1] |
| 283 | assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] |
| 284 | if kv.shape[3] != q.shape[2]: # MQA/GQA |
| 285 | kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) |
| 286 | k, v = kv.unbind(dim=2) |
| 287 | softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) |
| 288 | scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) |
| 289 | if key_padding_mask is not None: |
| 290 | padding_mask = torch.full( |
| 291 | (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device |
| 292 | ) |
| 293 | padding_mask.masked_fill_(key_padding_mask, 0.0) |
| 294 | # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) |
| 295 | scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") |
| 296 | if causal: |
| 297 | # causal mask needs to take into account the difference between seqlen_q and seqlen_k |
| 298 | row_idx = rearrange( |
| 299 | torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1" |
| 300 | ) |
| 301 | col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long) |
| 302 | sk = ( |
| 303 | seqlen_k |
| 304 | if key_padding_mask is None |
| 305 | else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") |
| 306 | ) |
| 307 | causal_mask = col_idx > row_idx + sk - seqlen_q |
| 308 | scores = scores.masked_fill(causal_mask, -10000.0) |
| 309 | attention = torch.softmax(scores, dim=-1, dtype=v.dtype) |
| 310 | attention_drop = self.drop(attention) |
| 311 | output = torch.einsum("bhts,bshd->bthd", attention_drop, v) |
| 312 | return output |
| 313 | |
| 314 | |
| 315 | class LinearResidual(nn.Linear): |
| 316 | """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.""" |
| 317 | |
| 318 | def forward(self, input: torch.Tensor) -> torch.Tensor: |
| 319 | return super().forward(input), input |
| 320 | |
| 321 | |
| 322 | def _update_kv_cache(kv, inference_params, layer_idx): |
| 323 | """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" |
| 324 | # Pre-allocate memory for key-values for inference. |
| 325 | num_heads, head_dim = kv.shape[-2:] |
| 326 | if layer_idx not in inference_params.key_value_memory_dict: |
| 327 | kv_cache = torch.empty( |
| 328 | inference_params.max_batch_size, |
| 329 | inference_params.max_seqlen, |
| 330 | 2, |
| 331 | num_heads, |
| 332 | head_dim, |
| 333 | dtype=kv.dtype, |
| 334 | device=kv.device, |
| 335 | ) |
| 336 | inference_params.key_value_memory_dict[layer_idx] = kv_cache |
| 337 | else: |
| 338 | kv_cache = inference_params.key_value_memory_dict[layer_idx] |
| 339 | # Adjust key and value for inference |
| 340 | batch_start = inference_params.batch_size_offset |
| 341 | batch_end = batch_start + kv.shape[0] |
| 342 | sequence_start = inference_params.seqlen_offset |
| 343 | sequence_end = sequence_start + kv.shape[1] |
| 344 | assert batch_end <= kv_cache.shape[0] |
| 345 | assert sequence_end <= kv_cache.shape[1] |
| 346 | assert kv_cache is not None |
| 347 | kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv |
| 348 | return kv_cache[batch_start:batch_end, :sequence_end, ...] |
| 349 | |
| 350 | |
| 351 | class MHA(nn.Module): |
| 352 | """Multi-head self-attention and cross-attention""" |
| 353 | |
| 354 | def __init__( |
| 355 | self, |
| 356 | embed_dim, |
| 357 | num_heads, |
| 358 | num_heads_kv=None, |
| 359 | cross_attn=False, |
| 360 | qkv_proj_bias=True, |
| 361 | out_proj_bias=True, |
| 362 | dropout=0.0, |
| 363 | softmax_scale=None, |
| 364 | causal=False, |
| 365 | layer_idx=None, |
| 366 | dwconv=False, |
| 367 | window_size=(-1, -1), |
| 368 | fused_bias_fc=False, |
| 369 | use_flash_attn=False, |
| 370 | return_residual=False, |
| 371 | checkpointing=False, |
| 372 | device=None, |
| 373 | dtype=None, |
| 374 | ) -> None: |
| 375 | """ |
| 376 | num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. |
| 377 | return_residual: whether to return the input x along with the output. This is for |
| 378 | performance reason: for post-norm architecture, returning the input allows us |
| 379 | to fuse the backward of nn.Linear with the residual connection. |
| 380 | """ |
| 381 | factory_kwargs = {"device": device, "dtype": dtype} |
| 382 | super().__init__() |
| 383 | self.embed_dim = embed_dim |
| 384 | self.cross_attn = cross_attn |
| 385 | self.causal = causal |
| 386 | self.layer_idx = layer_idx |
| 387 | self.dwconv = dwconv |
| 388 | self.use_flash_attn = use_flash_attn |
| 389 | self.return_residual = return_residual |
| 390 | self.checkpointing = checkpointing |
| 391 | |
| 392 | if window_size != (-1, -1): |
| 393 | assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" |
| 394 | |
| 395 | self.num_heads = num_heads |
| 396 | self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads |
| 397 | assert ( |
| 398 | self.num_heads % self.num_heads_kv == 0 |
| 399 | ), "num_heads must be divisible by num_heads_kv" |
| 400 | assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" |
| 401 | self.head_dim = self.embed_dim // num_heads |
| 402 | qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) |
| 403 | kv_dim = 2 * self.head_dim * self.num_heads_kv |
| 404 | |
| 405 | if fused_bias_fc and FusedDense is None: |
| 406 | raise ImportError("fused_dense is not installed") |
| 407 | linear_cls = nn.Linear if not fused_bias_fc else FusedDense |
| 408 | linear_resid_cls = ( |
| 409 | LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True) |
| 410 | ) |
| 411 | wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls |
| 412 | inner_attn_cls = ( |
| 413 | partial(FlashSelfAttention, window_size=window_size) |
| 414 | if use_flash_attn |
| 415 | else SelfAttention |
| 416 | ) |
| 417 | inner_cross_attn_cls = ( |
| 418 | partial(FlashCrossAttention, window_size=window_size) |
| 419 | if use_flash_attn |
| 420 | else CrossAttention |
| 421 | ) |
| 422 | if not self.cross_attn: |
| 423 | self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) |
| 424 | else: |
| 425 | self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) |
| 426 | self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs) |
| 427 | if self.dwconv: |
| 428 | if self.num_heads_kv == self.num_heads: |
| 429 | self.dwconv_qkv = nn.Conv1d( |
| 430 | qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim |
| 431 | ) |
| 432 | else: |
| 433 | self.dwconv_q = nn.Conv1d( |
| 434 | embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim |
| 435 | ) |
| 436 | self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim) |
| 437 | self.inner_attn = inner_attn_cls( |
| 438 | causal=causal, |
| 439 | softmax_scale=softmax_scale, |
| 440 | attention_dropout=dropout, |
| 441 | ) |
| 442 | self.inner_cross_attn = inner_cross_attn_cls( |
| 443 | causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout |
| 444 | ) |
| 445 | self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) |
| 446 | |
| 447 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): |
| 448 | dtype = self.out_proj.weight.dtype if dtype is None else dtype |
| 449 | device = self.out_proj.weight.device |
| 450 | return torch.empty( |
| 451 | batch_size, |
| 452 | max_seqlen, |
| 453 | 2, |
| 454 | self.num_heads_kv, |
| 455 | self.head_dim, |
| 456 | dtype=dtype, |
| 457 | device=device, |
| 458 | ) |
| 459 | |
| 460 | def _update_kv_cache(self, kv, inference_params): |
| 461 | """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" |
| 462 | assert not self.dwconv, "Generation does not support dwconv yet" |
| 463 | assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" |
| 464 | return _update_kv_cache(kv, inference_params, self.layer_idx) |
| 465 | |
| 466 | def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): |
| 467 | """ |
| 468 | Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. |
| 469 | q: (batch_size, seqlen_q, nheads, head_dim) |
| 470 | kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) |
| 471 | """ |
| 472 | assert inference_params is not None and inference_params.seqlen_offset > 0 |
| 473 | assert self.use_flash_attn |
| 474 | batch = q.shape[0] |
| 475 | kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] |
| 476 | cache_seqlens = ( |
| 477 | inference_params.lengths_per_sample[:batch] |
| 478 | if inference_params.lengths_per_sample is not None |
| 479 | else inference_params.seqlen_offset |
| 480 | ) |
| 481 | context = flash_attn_with_kvcache( |
| 482 | q, |
| 483 | kv_cache[:, :, 0], |
| 484 | kv_cache[:, :, 1], |
| 485 | kv[:, :, 0], |
| 486 | kv[:, :, 1], |
| 487 | cache_seqlens=cache_seqlens, |
| 488 | softmax_scale=self.inner_cross_attn.softmax_scale, |
| 489 | causal=self.inner_cross_attn.causal, |
| 490 | rotary_interleaved=False, |
| 491 | alibi_slopes=None, |
| 492 | ) |
| 493 | return context |
| 494 | |
| 495 | def _update_kvcache_attention(self, q, kv, inference_params): |
| 496 | """Write kv to inference_params, then do attention""" |
| 497 | if ( |
| 498 | inference_params.seqlen_offset == 0 |
| 499 | or flash_attn_with_kvcache is None |
| 500 | or not self.use_flash_attn |
| 501 | ): |
| 502 | # TODO: this only uses seqlen_offset and not lengths_per_sample. |
| 503 | kv = self._update_kv_cache(kv, inference_params) |
| 504 | return self.inner_cross_attn(q, kv) |
| 505 | else: |
| 506 | batch = q.shape[0] |
| 507 | kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] |
| 508 | cache_seqlens = ( |
| 509 | inference_params.lengths_per_sample[:batch] |
| 510 | if inference_params.lengths_per_sample is not None |
| 511 | else inference_params.seqlen_offset |
| 512 | ) |
| 513 | return flash_attn_with_kvcache( |
| 514 | q, |
| 515 | kv_cache[:, :, 0], |
| 516 | kv_cache[:, :, 1], |
| 517 | kv[:, :, 0], |
| 518 | kv[:, :, 1], |
| 519 | cache_seqlens=cache_seqlens, |
| 520 | softmax_scale=self.inner_cross_attn.softmax_scale, |
| 521 | causal=self.inner_cross_attn.causal, |
| 522 | alibi_slopes=None, |
| 523 | ) |
| 524 | |
| 525 | def forward( |
| 526 | self, |
| 527 | x, |
| 528 | x_kv=None, |
| 529 | key_padding_mask=None, |
| 530 | cu_seqlens=None, |
| 531 | max_seqlen=None, |
| 532 | mixer_subset=None, |
| 533 | inference_params=None, |
| 534 | **kwargs, |
| 535 | ): |
| 536 | """ |
| 537 | Arguments: |
| 538 | x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if |
| 539 | cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total |
| 540 | is the is the sum of the sequence lengths in the batch. |
| 541 | x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. |
| 542 | cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
| 543 | of the sequences in the batch, used to index into x. Only applicable when using |
| 544 | FlashAttention. |
| 545 | max_seqlen: int. Maximum sequence length in the batch. |
| 546 | key_padding_mask: boolean mask, True means to keep, False means to mask out. |
| 547 | (batch, seqlen). Only applicable when not using FlashAttention. |
| 548 | mixer_subset: for cross-attention only. If not None, will take a subset of x |
| 549 | before applying the query projection. Useful for e.g., ViT where we only care |
| 550 | about the CLS token in the last layer. |
| 551 | inference_params: for generation. Adapted from Megatron-LM (and Apex) |
| 552 | https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 |
| 553 | """ |
| 554 | if cu_seqlens is not None: |
| 555 | assert max_seqlen is not None |
| 556 | assert key_padding_mask is None |
| 557 | assert self.use_flash_attn |
| 558 | assert not self.dwconv |
| 559 | if key_padding_mask is not None: |
| 560 | assert cu_seqlens is None |
| 561 | assert max_seqlen is None |
| 562 | assert not self.use_flash_attn |
| 563 | if inference_params is not None: |
| 564 | assert key_padding_mask is None |
| 565 | assert cu_seqlens is None and max_seqlen is None |
| 566 | assert not self.dwconv |
| 567 | |
| 568 | kwargs = ( |
| 569 | {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs} |
| 570 | if self.use_flash_attn |
| 571 | else {"key_padding_mask": key_padding_mask, **kwargs} |
| 572 | ) |
| 573 | seqlen_offset = ( |
| 574 | 0 |
| 575 | if inference_params is None |
| 576 | else ( |
| 577 | inference_params.lengths_per_sample |
| 578 | if inference_params.lengths_per_sample is not None |
| 579 | else inference_params.seqlen_offset |
| 580 | ) |
| 581 | ) |
| 582 | rotary_max_seqlen = ( |
| 583 | inference_params.max_sequence_len if inference_params is not None else max_seqlen |
| 584 | ) |
| 585 | batch, seqlen = x.shape[:2] |
| 586 | if not self.cross_attn and self.num_heads_kv == self.num_heads: |
| 587 | assert x_kv is None and mixer_subset is None |
| 588 | if not self.return_residual: |
| 589 | qkv = self.Wqkv(x) |
| 590 | else: |
| 591 | qkv, x = self.Wqkv(x) |
| 592 | if self.dwconv: |
| 593 | qkv = rearrange( |
| 594 | self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" |
| 595 | ).contiguous() |
| 596 | qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) |
| 597 | if ( |
| 598 | inference_params is None |
| 599 | or inference_params.seqlen_offset == 0 |
| 600 | or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) |
| 601 | or not self.use_flash_attn |
| 602 | ): |
| 603 | if inference_params is None: |
| 604 | if not self.checkpointing: |
| 605 | context = self.inner_attn(qkv, **kwargs) |
| 606 | else: |
| 607 | context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) |
| 608 | else: |
| 609 | context = self._update_kvcache_attention( |
| 610 | qkv[:, :, 0], qkv[:, :, 1:], inference_params |
| 611 | ) |
| 612 | else: |
| 613 | context = self._apply_rotary_update_kvcache_attention( |
| 614 | qkv[:, :, 0], qkv[:, :, 1:], inference_params |
| 615 | ) |
| 616 | else: |
| 617 | if self.cross_attn: |
| 618 | if not self.return_residual: |
| 619 | q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) |
| 620 | kv = self.Wkv(x_kv if x_kv is not None else x) |
| 621 | else: |
| 622 | if x_kv is not None: |
| 623 | kv, x_kv = self.Wkv(x_kv) |
| 624 | else: |
| 625 | kv, x = self.Wkv(x) |
| 626 | q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) |
| 627 | else: |
| 628 | assert self.num_heads_kv != self.num_heads |
| 629 | if not self.return_residual: |
| 630 | qkv = self.Wqkv(x) |
| 631 | else: |
| 632 | qkv, x = self.Wqkv(x) |
| 633 | q = qkv[..., : self.num_heads * self.head_dim] |
| 634 | kv = qkv[..., self.num_heads * self.head_dim :] |
| 635 | q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) |
| 636 | kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) |
| 637 | if self.dwconv: |
| 638 | q = rearrange( |
| 639 | self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d" |
| 640 | ).contiguous() |
| 641 | kv = rearrange( |
| 642 | self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" |
| 643 | ).contiguous() |
| 644 | if ( |
| 645 | inference_params is None |
| 646 | or inference_params.seqlen_offset == 0 |
| 647 | or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) |
| 648 | or not self.use_flash_attn |
| 649 | ): |
| 650 | if inference_params is None: |
| 651 | if not self.checkpointing: |
| 652 | context = self.inner_cross_attn(q, kv, **kwargs) |
| 653 | else: |
| 654 | context = torch.utils.checkpoint.checkpoint( |
| 655 | self.inner_cross_attn, q, kv, **kwargs |
| 656 | ) |
| 657 | else: |
| 658 | context = self._update_kvcache_attention(q, kv, inference_params) |
| 659 | else: |
| 660 | context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) |
| 661 | out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) |
| 662 | return out if not self.return_residual else (out, x) |
| 663 | |