mha.py
27.4 KB · 663 lines · python Raw
1 # Copyright (c) 2023, Tri Dao.
2 # Adapted from https://github.com/Dao-AILab/flash-attention/pull/556
3
4 import math
5 from functools import partial
6
7 import torch
8 import torch.nn as nn
9 from einops import rearrange, repeat
10
11 try:
12 from flash_attn import (
13 flash_attn_kvpacked_func,
14 flash_attn_qkvpacked_func,
15 flash_attn_varlen_kvpacked_func,
16 flash_attn_varlen_qkvpacked_func,
17 flash_attn_with_kvcache,
18 )
19 except ImportError:
20 flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
21 flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
22 flash_attn_with_kvcache = None
23
24 try:
25 from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
26 except ImportError:
27 FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
28
29
30 class FlashSelfAttention(nn.Module):
31 """Implement the scaled dot product attention with softmax.
32 Arguments
33 ---------
34 softmax_scale: The temperature to use for the softmax attention.
35 (default: 1/sqrt(d_keys) where d_keys is computed at
36 runtime)
37 attention_dropout: The dropout rate to apply to the attention
38 (default: 0.0)
39 """
40
41 def __init__(
42 self,
43 causal=False,
44 softmax_scale=None,
45 attention_dropout=0.0,
46 window_size=(-1, -1),
47 deterministic=False,
48 ):
49 super().__init__()
50 assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
51 assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
52 self.causal = causal
53 self.softmax_scale = softmax_scale
54 self.drop = nn.Dropout(attention_dropout)
55 self.window_size = window_size
56 self.deterministic = deterministic
57
58 def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
59 """Implements the multihead softmax attention.
60 Arguments
61 ---------
62 qkv: The tensor containing the query, key, and value.
63 If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
64 If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
65 (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
66 causal: if passed, will override self.causal
67 cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
68 of the sequences in the batch, used to index into qkv.
69 max_seqlen: int. Maximum sequence length in the batch.
70 Returns:
71 --------
72 out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
73 else (B, S, H, D).
74 """
75 assert qkv.dtype in [torch.float16, torch.bfloat16]
76 assert qkv.is_cuda
77 causal = self.causal if causal is None else causal
78 unpadded = cu_seqlens is not None
79
80 if unpadded:
81 assert cu_seqlens.dtype == torch.int32
82 assert max_seqlen is not None
83 assert isinstance(max_seqlen, int)
84 return flash_attn_varlen_qkvpacked_func(
85 qkv,
86 cu_seqlens,
87 max_seqlen,
88 self.drop.p if self.training else 0.0,
89 softmax_scale=self.softmax_scale,
90 causal=causal,
91 alibi_slopes=None,
92 window_size=self.window_size,
93 deterministic=self.deterministic,
94 )
95 else:
96 return flash_attn_qkvpacked_func(
97 qkv,
98 self.drop.p if self.training else 0.0,
99 softmax_scale=self.softmax_scale,
100 causal=causal,
101 alibi_slopes=None,
102 window_size=self.window_size,
103 deterministic=self.deterministic,
104 )
105
106
107 class FlashCrossAttention(nn.Module):
108 """Implement the scaled dot product attention with softmax.
109 Arguments
110 ---------
111 softmax_scale: The temperature to use for the softmax attention.
112 (default: 1/sqrt(d_keys) where d_keys is computed at
113 runtime)
114 attention_dropout: The dropout rate to apply to the attention
115 (default: 0.0)
116 """
117
118 def __init__(
119 self,
120 causal=False,
121 softmax_scale=None,
122 attention_dropout=0.0,
123 window_size=(-1, -1),
124 deterministic=False,
125 ):
126 super().__init__()
127 assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
128 assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
129 self.causal = causal
130 self.softmax_scale = softmax_scale
131 self.drop = nn.Dropout(attention_dropout)
132 self.window_size = window_size
133 self.deterministic = deterministic
134
135 def forward(
136 self,
137 q,
138 kv,
139 causal=None,
140 cu_seqlens=None,
141 max_seqlen=None,
142 cu_seqlens_k=None,
143 max_seqlen_k=None,
144 ):
145 """Implements the multihead softmax attention.
146 Arguments
147 ---------
148 q: The tensor containing the query. (B, Sq, H, D)
149 kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
150 causal: if passed, will override self.causal
151 cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
152 of the sequences in the batch, used to index into q.
153 max_seqlen: int. Maximum sequence length in the batch of q.
154 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
155 of the sequences in the batch, used to index into kv.
156 max_seqlen_k: int. Maximum sequence length in the batch of k and v.
157 """
158 assert q.dtype in [torch.float16, torch.bfloat16]
159 assert q.is_cuda and kv.is_cuda
160 causal = self.causal if causal is None else causal
161 unpadded = cu_seqlens is not None
162
163 if unpadded:
164 assert cu_seqlens.dtype == torch.int32
165 assert max_seqlen is not None
166 assert isinstance(max_seqlen, int)
167 assert cu_seqlens_k is not None
168 assert cu_seqlens_k.dtype == torch.int32
169 assert max_seqlen_k is not None
170 assert isinstance(max_seqlen, int)
171 return flash_attn_varlen_kvpacked_func(
172 q,
173 kv,
174 cu_seqlens,
175 cu_seqlens_k,
176 max_seqlen,
177 max_seqlen_k,
178 self.drop.p if self.training else 0.0,
179 softmax_scale=self.softmax_scale,
180 causal=causal,
181 alibi_slopes=None,
182 window_size=self.window_size,
183 deterministic=self.deterministic,
184 )
185 else:
186 batch_size, seqlen_q = q.shape[0], q.shape[1]
187 seqlen_k = kv.shape[1]
188 assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
189 return flash_attn_kvpacked_func(
190 q,
191 kv,
192 self.drop.p if self.training else 0.0,
193 causal=causal,
194 softmax_scale=self.softmax_scale,
195 alibi_slopes=None,
196 window_size=self.window_size,
197 deterministic=self.deterministic,
198 )
199
200
201 class SelfAttention(nn.Module):
202 """Implement the scaled dot product attention with softmax.
203 Arguments
204 ---------
205 softmax_scale: The temperature to use for the softmax attention.
206 (default: 1/sqrt(d_keys) where d_keys is computed at
207 runtime)
208 attention_dropout: The dropout rate to apply to the attention
209 (default: 0.0)
210 """
211
212 def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
213 super().__init__()
214 self.causal = causal
215 self.softmax_scale = softmax_scale
216 self.drop = nn.Dropout(attention_dropout)
217
218 def forward(self, qkv, causal=None, key_padding_mask=None):
219 """Implements the multihead softmax attention.
220 Arguments
221 ---------
222 qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
223 causal: if passed, will override self.causal
224 key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
225 False means to mask out. (B, S)
226 """
227 batch_size, seqlen = qkv.shape[0], qkv.shape[1]
228 causal = self.causal if causal is None else causal
229 q, k, v = qkv.unbind(dim=2)
230 softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
231 scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
232 if key_padding_mask is not None:
233 padding_mask = torch.full(
234 (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
235 )
236 padding_mask.masked_fill_(key_padding_mask, 0.0)
237 # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
238 scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
239 if causal:
240 # "triu_tril_cuda_template" not implemented for 'BFloat16'
241 # So we have to construct the mask in float
242 causal_mask = torch.triu(
243 torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
244 )
245 # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
246 scores = scores + causal_mask.to(dtype=scores.dtype)
247 attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
248 attention_drop = self.drop(attention)
249 output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
250 return output
251
252
253 class CrossAttention(nn.Module):
254 """Implement the scaled dot product attention with softmax.
255 Arguments
256 ---------
257 softmax_scale: The temperature to use for the softmax attention.
258 (default: 1/sqrt(d_keys) where d_keys is computed at
259 runtime)
260 attention_dropout: The dropout rate to apply to the attention
261 (default: 0.0)
262 """
263
264 def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
265 super().__init__()
266 self.causal = causal
267 self.softmax_scale = softmax_scale
268 self.drop = nn.Dropout(attention_dropout)
269
270 def forward(self, q, kv, causal=None, key_padding_mask=None):
271 """Implements the multihead softmax attention.
272 Arguments
273 ---------
274 q: The tensor containing the query. (B, Sq, H, D)
275 kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
276 causal: if passed, will override self.causal
277 key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
278 False means to mask out. (B, Sk)
279 """
280 batch_size, seqlen_q = q.shape[0], q.shape[1]
281 causal = self.causal if causal is None else causal
282 seqlen_k = kv.shape[1]
283 assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
284 if kv.shape[3] != q.shape[2]: # MQA/GQA
285 kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
286 k, v = kv.unbind(dim=2)
287 softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
288 scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
289 if key_padding_mask is not None:
290 padding_mask = torch.full(
291 (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
292 )
293 padding_mask.masked_fill_(key_padding_mask, 0.0)
294 # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
295 scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
296 if causal:
297 # causal mask needs to take into account the difference between seqlen_q and seqlen_k
298 row_idx = rearrange(
299 torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
300 )
301 col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
302 sk = (
303 seqlen_k
304 if key_padding_mask is None
305 else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
306 )
307 causal_mask = col_idx > row_idx + sk - seqlen_q
308 scores = scores.masked_fill(causal_mask, -10000.0)
309 attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
310 attention_drop = self.drop(attention)
311 output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
312 return output
313
314
315 class LinearResidual(nn.Linear):
316 """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
317
318 def forward(self, input: torch.Tensor) -> torch.Tensor:
319 return super().forward(input), input
320
321
322 def _update_kv_cache(kv, inference_params, layer_idx):
323 """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
324 # Pre-allocate memory for key-values for inference.
325 num_heads, head_dim = kv.shape[-2:]
326 if layer_idx not in inference_params.key_value_memory_dict:
327 kv_cache = torch.empty(
328 inference_params.max_batch_size,
329 inference_params.max_seqlen,
330 2,
331 num_heads,
332 head_dim,
333 dtype=kv.dtype,
334 device=kv.device,
335 )
336 inference_params.key_value_memory_dict[layer_idx] = kv_cache
337 else:
338 kv_cache = inference_params.key_value_memory_dict[layer_idx]
339 # Adjust key and value for inference
340 batch_start = inference_params.batch_size_offset
341 batch_end = batch_start + kv.shape[0]
342 sequence_start = inference_params.seqlen_offset
343 sequence_end = sequence_start + kv.shape[1]
344 assert batch_end <= kv_cache.shape[0]
345 assert sequence_end <= kv_cache.shape[1]
346 assert kv_cache is not None
347 kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
348 return kv_cache[batch_start:batch_end, :sequence_end, ...]
349
350
351 class MHA(nn.Module):
352 """Multi-head self-attention and cross-attention"""
353
354 def __init__(
355 self,
356 embed_dim,
357 num_heads,
358 num_heads_kv=None,
359 cross_attn=False,
360 qkv_proj_bias=True,
361 out_proj_bias=True,
362 dropout=0.0,
363 softmax_scale=None,
364 causal=False,
365 layer_idx=None,
366 dwconv=False,
367 window_size=(-1, -1),
368 fused_bias_fc=False,
369 use_flash_attn=False,
370 return_residual=False,
371 checkpointing=False,
372 device=None,
373 dtype=None,
374 ) -> None:
375 """
376 num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
377 return_residual: whether to return the input x along with the output. This is for
378 performance reason: for post-norm architecture, returning the input allows us
379 to fuse the backward of nn.Linear with the residual connection.
380 """
381 factory_kwargs = {"device": device, "dtype": dtype}
382 super().__init__()
383 self.embed_dim = embed_dim
384 self.cross_attn = cross_attn
385 self.causal = causal
386 self.layer_idx = layer_idx
387 self.dwconv = dwconv
388 self.use_flash_attn = use_flash_attn
389 self.return_residual = return_residual
390 self.checkpointing = checkpointing
391
392 if window_size != (-1, -1):
393 assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
394
395 self.num_heads = num_heads
396 self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
397 assert (
398 self.num_heads % self.num_heads_kv == 0
399 ), "num_heads must be divisible by num_heads_kv"
400 assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
401 self.head_dim = self.embed_dim // num_heads
402 qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
403 kv_dim = 2 * self.head_dim * self.num_heads_kv
404
405 if fused_bias_fc and FusedDense is None:
406 raise ImportError("fused_dense is not installed")
407 linear_cls = nn.Linear if not fused_bias_fc else FusedDense
408 linear_resid_cls = (
409 LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
410 )
411 wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
412 inner_attn_cls = (
413 partial(FlashSelfAttention, window_size=window_size)
414 if use_flash_attn
415 else SelfAttention
416 )
417 inner_cross_attn_cls = (
418 partial(FlashCrossAttention, window_size=window_size)
419 if use_flash_attn
420 else CrossAttention
421 )
422 if not self.cross_attn:
423 self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
424 else:
425 self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
426 self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
427 if self.dwconv:
428 if self.num_heads_kv == self.num_heads:
429 self.dwconv_qkv = nn.Conv1d(
430 qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
431 )
432 else:
433 self.dwconv_q = nn.Conv1d(
434 embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
435 )
436 self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
437 self.inner_attn = inner_attn_cls(
438 causal=causal,
439 softmax_scale=softmax_scale,
440 attention_dropout=dropout,
441 )
442 self.inner_cross_attn = inner_cross_attn_cls(
443 causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
444 )
445 self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
446
447 def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
448 dtype = self.out_proj.weight.dtype if dtype is None else dtype
449 device = self.out_proj.weight.device
450 return torch.empty(
451 batch_size,
452 max_seqlen,
453 2,
454 self.num_heads_kv,
455 self.head_dim,
456 dtype=dtype,
457 device=device,
458 )
459
460 def _update_kv_cache(self, kv, inference_params):
461 """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
462 assert not self.dwconv, "Generation does not support dwconv yet"
463 assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
464 return _update_kv_cache(kv, inference_params, self.layer_idx)
465
466 def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
467 """
468 Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
469 q: (batch_size, seqlen_q, nheads, head_dim)
470 kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
471 """
472 assert inference_params is not None and inference_params.seqlen_offset > 0
473 assert self.use_flash_attn
474 batch = q.shape[0]
475 kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
476 cache_seqlens = (
477 inference_params.lengths_per_sample[:batch]
478 if inference_params.lengths_per_sample is not None
479 else inference_params.seqlen_offset
480 )
481 context = flash_attn_with_kvcache(
482 q,
483 kv_cache[:, :, 0],
484 kv_cache[:, :, 1],
485 kv[:, :, 0],
486 kv[:, :, 1],
487 cache_seqlens=cache_seqlens,
488 softmax_scale=self.inner_cross_attn.softmax_scale,
489 causal=self.inner_cross_attn.causal,
490 rotary_interleaved=False,
491 alibi_slopes=None,
492 )
493 return context
494
495 def _update_kvcache_attention(self, q, kv, inference_params):
496 """Write kv to inference_params, then do attention"""
497 if (
498 inference_params.seqlen_offset == 0
499 or flash_attn_with_kvcache is None
500 or not self.use_flash_attn
501 ):
502 # TODO: this only uses seqlen_offset and not lengths_per_sample.
503 kv = self._update_kv_cache(kv, inference_params)
504 return self.inner_cross_attn(q, kv)
505 else:
506 batch = q.shape[0]
507 kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
508 cache_seqlens = (
509 inference_params.lengths_per_sample[:batch]
510 if inference_params.lengths_per_sample is not None
511 else inference_params.seqlen_offset
512 )
513 return flash_attn_with_kvcache(
514 q,
515 kv_cache[:, :, 0],
516 kv_cache[:, :, 1],
517 kv[:, :, 0],
518 kv[:, :, 1],
519 cache_seqlens=cache_seqlens,
520 softmax_scale=self.inner_cross_attn.softmax_scale,
521 causal=self.inner_cross_attn.causal,
522 alibi_slopes=None,
523 )
524
525 def forward(
526 self,
527 x,
528 x_kv=None,
529 key_padding_mask=None,
530 cu_seqlens=None,
531 max_seqlen=None,
532 mixer_subset=None,
533 inference_params=None,
534 **kwargs,
535 ):
536 """
537 Arguments:
538 x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
539 cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
540 is the is the sum of the sequence lengths in the batch.
541 x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
542 cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
543 of the sequences in the batch, used to index into x. Only applicable when using
544 FlashAttention.
545 max_seqlen: int. Maximum sequence length in the batch.
546 key_padding_mask: boolean mask, True means to keep, False means to mask out.
547 (batch, seqlen). Only applicable when not using FlashAttention.
548 mixer_subset: for cross-attention only. If not None, will take a subset of x
549 before applying the query projection. Useful for e.g., ViT where we only care
550 about the CLS token in the last layer.
551 inference_params: for generation. Adapted from Megatron-LM (and Apex)
552 https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
553 """
554 if cu_seqlens is not None:
555 assert max_seqlen is not None
556 assert key_padding_mask is None
557 assert self.use_flash_attn
558 assert not self.dwconv
559 if key_padding_mask is not None:
560 assert cu_seqlens is None
561 assert max_seqlen is None
562 assert not self.use_flash_attn
563 if inference_params is not None:
564 assert key_padding_mask is None
565 assert cu_seqlens is None and max_seqlen is None
566 assert not self.dwconv
567
568 kwargs = (
569 {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
570 if self.use_flash_attn
571 else {"key_padding_mask": key_padding_mask, **kwargs}
572 )
573 seqlen_offset = (
574 0
575 if inference_params is None
576 else (
577 inference_params.lengths_per_sample
578 if inference_params.lengths_per_sample is not None
579 else inference_params.seqlen_offset
580 )
581 )
582 rotary_max_seqlen = (
583 inference_params.max_sequence_len if inference_params is not None else max_seqlen
584 )
585 batch, seqlen = x.shape[:2]
586 if not self.cross_attn and self.num_heads_kv == self.num_heads:
587 assert x_kv is None and mixer_subset is None
588 if not self.return_residual:
589 qkv = self.Wqkv(x)
590 else:
591 qkv, x = self.Wqkv(x)
592 if self.dwconv:
593 qkv = rearrange(
594 self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
595 ).contiguous()
596 qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
597 if (
598 inference_params is None
599 or inference_params.seqlen_offset == 0
600 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
601 or not self.use_flash_attn
602 ):
603 if inference_params is None:
604 if not self.checkpointing:
605 context = self.inner_attn(qkv, **kwargs)
606 else:
607 context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
608 else:
609 context = self._update_kvcache_attention(
610 qkv[:, :, 0], qkv[:, :, 1:], inference_params
611 )
612 else:
613 context = self._apply_rotary_update_kvcache_attention(
614 qkv[:, :, 0], qkv[:, :, 1:], inference_params
615 )
616 else:
617 if self.cross_attn:
618 if not self.return_residual:
619 q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
620 kv = self.Wkv(x_kv if x_kv is not None else x)
621 else:
622 if x_kv is not None:
623 kv, x_kv = self.Wkv(x_kv)
624 else:
625 kv, x = self.Wkv(x)
626 q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
627 else:
628 assert self.num_heads_kv != self.num_heads
629 if not self.return_residual:
630 qkv = self.Wqkv(x)
631 else:
632 qkv, x = self.Wqkv(x)
633 q = qkv[..., : self.num_heads * self.head_dim]
634 kv = qkv[..., self.num_heads * self.head_dim :]
635 q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
636 kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
637 if self.dwconv:
638 q = rearrange(
639 self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
640 ).contiguous()
641 kv = rearrange(
642 self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
643 ).contiguous()
644 if (
645 inference_params is None
646 or inference_params.seqlen_offset == 0
647 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
648 or not self.use_flash_attn
649 ):
650 if inference_params is None:
651 if not self.checkpointing:
652 context = self.inner_cross_attn(q, kv, **kwargs)
653 else:
654 context = torch.utils.checkpoint.checkpoint(
655 self.inner_cross_attn, q, kv, **kwargs
656 )
657 else:
658 context = self._update_kvcache_attention(q, kv, inference_params)
659 else:
660 context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
661 out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
662 return out if not self.return_residual else (out, x)
663