mask_sdpa_utils.py
| 1 | # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. |
| 2 | # |
| 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property |
| 4 | # and proprietary rights in and to this software, related documentation |
| 5 | # and any modifications thereto. Any use, reproduction, disclosure or |
| 6 | # distribution of this software and related documentation without an express |
| 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. |
| 8 | |
| 9 | import torch |
| 10 | |
| 11 | |
| 12 | def find_prefix_seq_length_by_pe( |
| 13 | pe: torch.Tensor |
| 14 | ) -> torch.Tensor: |
| 15 | """ |
| 16 | Find the sequence length where position encoding drops (indicating prefix boundary). |
| 17 | Args: |
| 18 | pe: Position encoding tensor of shape [Batch size, Sequence length ] |
| 19 | Contains position indices for each token in the sequence. |
| 20 | Returns: |
| 21 | torch.Tensor: A tensor of shape [B] containing: |
| 22 | - The index where position encoding drops for each sequence |
| 23 | - -1 if no drop occurs in the sequence |
| 24 | """ |
| 25 | batch_size, seq_len = pe.shape |
| 26 | prev = pe[:, :-1] |
| 27 | curr = pe[:, 1:] |
| 28 | drop_mask = curr < prev # [batch_size, seq_len-1] |
| 29 | |
| 30 | seq_len = torch.full((batch_size,), -1, dtype=torch.long) |
| 31 | |
| 32 | for b in range(batch_size): |
| 33 | drop_pos = torch.nonzero(drop_mask[b], as_tuple=False) |
| 34 | if drop_pos.numel() > 0: |
| 35 | i = drop_pos[0].item() + 1 # Take first drop position (+1 because we compared shifted sequences) |
| 36 | seq_len[b] = i |
| 37 | |
| 38 | return seq_len |
| 39 | |
| 40 | |
| 41 | |
| 42 | def update_causal_mask_with_pad_non_visible_2d( |
| 43 | input_ids: torch.Tensor, |
| 44 | attn_mask_2d: torch.Tensor, |
| 45 | text_mask_token_id: int, |
| 46 | block_size: int = 4, |
| 47 | causal_attn: bool = False |
| 48 | ) -> torch.Tensor: |
| 49 | """ |
| 50 | Updates a 2D attention mask for hole sequence through input_ids and text_mask_token_id |
| 51 | |
| 52 | Args: |
| 53 | input_ids: Input token IDs (unused in current implementation) |
| 54 | attn_mask_2d: 2D attention mask matrix of shape [seq_len, seq_len] where: |
| 55 | - 0.0 indicates allowed attention |
| 56 | - -inf indicates masked attention |
| 57 | text_mask_token_id: ID representing masked tokens |
| 58 | block_size: Size of the diffusion window |
| 59 | causal_attn: If True, maintains strict causal masking throughout |
| 60 | |
| 61 | Returns: |
| 62 | Modified attention mask with updated visibility patterns |
| 63 | """ |
| 64 | seq_len = input_ids.shape[0] |
| 65 | device = input_ids.device |
| 66 | |
| 67 | # Identify masked tokens and their preceding positions |
| 68 | input_mask = input_ids.eq(text_mask_token_id) |
| 69 | input_before_mask = torch.zeros_like(input_mask) |
| 70 | input_before_mask[:-1] = input_mask[1:] |
| 71 | mask_cols = (input_mask | input_before_mask) |
| 72 | non_mask = ~mask_cols |
| 73 | |
| 74 | rows = torch.arange(seq_len, device=device)[:, None] |
| 75 | cols = torch.arange(seq_len, device=device) |
| 76 | |
| 77 | indices = torch.arange(seq_len, device=device) |
| 78 | prev_non_mask = (indices * non_mask).cummax(dim=0).values |
| 79 | |
| 80 | max_value = torch.iinfo(indices.dtype).max |
| 81 | mask_indices = torch.where(non_mask, indices, torch.full_like(indices, max_value)) |
| 82 | reversed_mask_indices = torch.flip(mask_indices, dims=[0]) |
| 83 | reversed_cummin = reversed_mask_indices.cummin(dim=0).values |
| 84 | next_non_mask = torch.flip(reversed_cummin, dims=[0]) |
| 85 | |
| 86 | infra_mask = ( |
| 87 | (cols > prev_non_mask) & |
| 88 | (rows >= next_non_mask[None, :]) & |
| 89 | mask_cols[None, :] |
| 90 | ) |
| 91 | attn_mask_2d.masked_fill_(infra_mask, -float('inf')) |
| 92 | |
| 93 | if not causal_attn: |
| 94 | visible_mask = ( |
| 95 | (rows > prev_non_mask[None, :]) & |
| 96 | (rows < cols) & |
| 97 | mask_cols[None, :] |
| 98 | ) |
| 99 | attn_mask_2d.masked_fill_(visible_mask, 0.0) |
| 100 | |
| 101 | return attn_mask_2d |
| 102 | |
| 103 | |
| 104 | def update_causal_mask_for_one_gen_window_2d( |
| 105 | input_ids: torch.Tensor, |
| 106 | attn_mask_2d: torch.Tensor, |
| 107 | block_size: int = 4, |
| 108 | use_cache: bool = True, |
| 109 | causal_attn: bool = False |
| 110 | ) -> torch.Tensor: |
| 111 | """ |
| 112 | Updates a 2D attention mask for a diffusion window in transformer inference. |
| 113 | |
| 114 | Args: |
| 115 | input_ids: Input token IDs (unused in current implementation) |
| 116 | attn_mask_2d: 2D attention mask matrix of shape [seq_len, seq_len] where: |
| 117 | - 0.0 indicates allowed attention |
| 118 | - -inf indicates masked attention |
| 119 | block_size: Size of the diffusion window |
| 120 | use_cache: Whether key-value cache is being used |
| 121 | causal_attn: If True, maintains strict causal masking throughout |
| 122 | |
| 123 | Returns: |
| 124 | Modified attention mask with updated visibility patterns |
| 125 | """ |
| 126 | |
| 127 | if not causal_attn: |
| 128 | # Make the diffusion window (last block_size tokens) fully visible to itself |
| 129 | # This allows bidirectional attention within the diffusion window |
| 130 | attn_mask_2d[-block_size:, -block_size:] = 0.0 |
| 131 | if use_cache: |
| 132 | # Mask the last token from previous round to prevent recomputation and maintain generation consistency. |
| 133 | attn_mask_2d[-block_size:, -block_size-1] = -float('inf') |
| 134 | |
| 135 | return attn_mask_2d |
| 136 | |
| 137 | |
| 138 | def create_block_diff_mask_by_pe_4d( |
| 139 | block_size: int, |
| 140 | x0_len_list: torch.Tensor, |
| 141 | position_ids: torch.Tensor, |
| 142 | causal_attn: bool = False |
| 143 | ) -> tuple[torch.Tensor, torch.Tensor]: |
| 144 | """Generates a 4D attention mask for block-difference attention patterns. |
| 145 | |
| 146 | The mask consists of three regions: |
| 147 | 1. Causal block (top-left): Standard causal attention for `x0` tokens. |
| 148 | 2. Mutual block (bottom-right): Non-causal attention within the same block for non-`x0` tokens. |
| 149 | 3. Prefix block (bottom-left): Non-`x0` tokens can attend to a prefix of `x0` tokens. |
| 150 | |
| 151 | Args: |
| 152 | block_size (int): Size of processing blocks for non-`x0` tokens. |
| 153 | x0_len_list (torch.Tensor): Tensor of shape [B] containing lengths of `x0` segments per batch. |
| 154 | position_ids (torch.Tensor): Tensor of shape [B, seq_len] containing position IDs. |
| 155 | causal_attn (bool, optional): If True, enforces causal masking in mutual blocks. Defaults to False. |
| 156 | |
| 157 | Returns: |
| 158 | tuple[torch.Tensor, torch.Tensor]: |
| 159 | - A float mask of shape [batch_size, 1, seq_len, seq_len] with `-inf` for masked positions (non visiable). |
| 160 | - A boolean mask of shape [batch_size, 1, seq_len, seq_len] indicating allowed attention positions. |
| 161 | """ |
| 162 | batch_size, seq_len = position_ids.shape |
| 163 | device = position_ids.device |
| 164 | |
| 165 | # Create position indices [batch_size, seq_len, seq_len] |
| 166 | q_idx = torch.arange(seq_len, device=device).view(1, seq_len, 1) # [1, seq_len, 1] |
| 167 | kv_idx = torch.arange(seq_len, device=device).view(1, 1, seq_len) # [1, 1, seq_len] |
| 168 | |
| 169 | # Broadcast to [B, seq_len, seq_len] |
| 170 | x0_len = x0_len_list.view(batch_size, 1, 1) # [batch_size, 1, 1] |
| 171 | x0_flag_q = q_idx < x0_len # [batch_size, seq_len, seq_len] |
| 172 | x0_flag_kv = kv_idx < x0_len |
| 173 | |
| 174 | # Block indices calculation [batch_size, seq_len, seq_len] |
| 175 | q_block_idx = (q_idx - x0_len) // block_size |
| 176 | kv_block_idx = (kv_idx - x0_len) // block_size |
| 177 | |
| 178 | # causal block (top-left) |
| 179 | block_causal = x0_flag_q & x0_flag_kv & (q_idx >= kv_idx) |
| 180 | |
| 181 | mutual_condition = (q_idx >= kv_idx) if causal_attn else torch.ones_like(q_idx, dtype=torch.bool) |
| 182 | block_mutual = ( |
| 183 | ~x0_flag_q & ~x0_flag_kv & |
| 184 | (q_block_idx == kv_block_idx) & |
| 185 | mutual_condition |
| 186 | ) |
| 187 | |
| 188 | q_blk = torch.div(q_idx - x0_len, block_size, rounding_mode='floor') |
| 189 | q_blk_start = (x0_len_list.view(batch_size, 1) + q_blk[:, :, 0] * block_size).clamp(min=0, max=seq_len - 1) |
| 190 | prefix_len = position_ids.gather(1, q_blk_start) |
| 191 | prefix_len = prefix_len.unsqueeze(2) |
| 192 | block_prefix = (~x0_flag_q & x0_flag_kv) & (kv_idx < prefix_len) |
| 193 | |
| 194 | final_mask = (block_causal | block_mutual | block_prefix) |
| 195 | customized_mask = torch.full_like(final_mask, float('-inf'), dtype=torch.bfloat16) |
| 196 | customized_mask.masked_fill_(final_mask, 0.0) |
| 197 | |
| 198 | return customized_mask.unsqueeze(1).to(device=device), final_mask.unsqueeze(1).to(device=device) |
| 199 | |
| 200 | |
| 201 | def find_pred_pos_from_input_ids( |
| 202 | input_ids: torch.LongTensor = None, |
| 203 | text_mask_token_id: int = None, |
| 204 | ) -> torch.Tensor: |
| 205 | """Compute the relative prediction positions for masked tokens in a sequence. |
| 206 | |
| 207 | For non-masked positions, the output is 0. For masked positions, the value increments |
| 208 | by 1 for each consecutive mask token, indicating how many steps ahead the prediction is. |
| 209 | |
| 210 | Args: |
| 211 | input_ids (torch.LongTensor): Input token IDs of shape [batch_size, seq_len]. |
| 212 | text_mask_token_id (int, optional): Token ID representing masked positions. Defaults to 151666. |
| 213 | |
| 214 | Returns: |
| 215 | torch.Tensor: A tensor of shape [batch_size, seq_len] where: |
| 216 | - 0 indicates a non-masked token. |
| 217 | - n > 0 indicates the nth consecutive masked token (e.g., 1 = first mask, 2 = second mask, etc.). |
| 218 | """ |
| 219 | batch_size, seq_len = input_ids.shape |
| 220 | device = input_ids.device |
| 221 | |
| 222 | is_mask = (input_ids == text_mask_token_id) |
| 223 | |
| 224 | base_mask = torch.zeros((batch_size, seq_len), dtype=torch.int8, device=device) |
| 225 | |
| 226 | for b in range(batch_size): |
| 227 | for ix in range(1, seq_len): |
| 228 | if is_mask[b][ix] == True: |
| 229 | # Increment counter if current token is masked |
| 230 | base_mask[b][ix] = base_mask[b][ix-1] + 1 |
| 231 | |
| 232 | return base_mask |
| 233 | |