mask_sdpa_utils.py
9.0 KB · 233 lines · python Raw
1 # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2 #
3 # NVIDIA CORPORATION and its licensors retain all intellectual property
4 # and proprietary rights in and to this software, related documentation
5 # and any modifications thereto. Any use, reproduction, disclosure or
6 # distribution of this software and related documentation without an express
7 # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
9 import torch
10
11
12 def find_prefix_seq_length_by_pe(
13 pe: torch.Tensor
14 ) -> torch.Tensor:
15 """
16 Find the sequence length where position encoding drops (indicating prefix boundary).
17 Args:
18 pe: Position encoding tensor of shape [Batch size, Sequence length ]
19 Contains position indices for each token in the sequence.
20 Returns:
21 torch.Tensor: A tensor of shape [B] containing:
22 - The index where position encoding drops for each sequence
23 - -1 if no drop occurs in the sequence
24 """
25 batch_size, seq_len = pe.shape
26 prev = pe[:, :-1]
27 curr = pe[:, 1:]
28 drop_mask = curr < prev # [batch_size, seq_len-1]
29
30 seq_len = torch.full((batch_size,), -1, dtype=torch.long)
31
32 for b in range(batch_size):
33 drop_pos = torch.nonzero(drop_mask[b], as_tuple=False)
34 if drop_pos.numel() > 0:
35 i = drop_pos[0].item() + 1 # Take first drop position (+1 because we compared shifted sequences)
36 seq_len[b] = i
37
38 return seq_len
39
40
41
42 def update_causal_mask_with_pad_non_visible_2d(
43 input_ids: torch.Tensor,
44 attn_mask_2d: torch.Tensor,
45 text_mask_token_id: int,
46 block_size: int = 4,
47 causal_attn: bool = False
48 ) -> torch.Tensor:
49 """
50 Updates a 2D attention mask for hole sequence through input_ids and text_mask_token_id
51
52 Args:
53 input_ids: Input token IDs (unused in current implementation)
54 attn_mask_2d: 2D attention mask matrix of shape [seq_len, seq_len] where:
55 - 0.0 indicates allowed attention
56 - -inf indicates masked attention
57 text_mask_token_id: ID representing masked tokens
58 block_size: Size of the diffusion window
59 causal_attn: If True, maintains strict causal masking throughout
60
61 Returns:
62 Modified attention mask with updated visibility patterns
63 """
64 seq_len = input_ids.shape[0]
65 device = input_ids.device
66
67 # Identify masked tokens and their preceding positions
68 input_mask = input_ids.eq(text_mask_token_id)
69 input_before_mask = torch.zeros_like(input_mask)
70 input_before_mask[:-1] = input_mask[1:]
71 mask_cols = (input_mask | input_before_mask)
72 non_mask = ~mask_cols
73
74 rows = torch.arange(seq_len, device=device)[:, None]
75 cols = torch.arange(seq_len, device=device)
76
77 indices = torch.arange(seq_len, device=device)
78 prev_non_mask = (indices * non_mask).cummax(dim=0).values
79
80 max_value = torch.iinfo(indices.dtype).max
81 mask_indices = torch.where(non_mask, indices, torch.full_like(indices, max_value))
82 reversed_mask_indices = torch.flip(mask_indices, dims=[0])
83 reversed_cummin = reversed_mask_indices.cummin(dim=0).values
84 next_non_mask = torch.flip(reversed_cummin, dims=[0])
85
86 infra_mask = (
87 (cols > prev_non_mask) &
88 (rows >= next_non_mask[None, :]) &
89 mask_cols[None, :]
90 )
91 attn_mask_2d.masked_fill_(infra_mask, -float('inf'))
92
93 if not causal_attn:
94 visible_mask = (
95 (rows > prev_non_mask[None, :]) &
96 (rows < cols) &
97 mask_cols[None, :]
98 )
99 attn_mask_2d.masked_fill_(visible_mask, 0.0)
100
101 return attn_mask_2d
102
103
104 def update_causal_mask_for_one_gen_window_2d(
105 input_ids: torch.Tensor,
106 attn_mask_2d: torch.Tensor,
107 block_size: int = 4,
108 use_cache: bool = True,
109 causal_attn: bool = False
110 ) -> torch.Tensor:
111 """
112 Updates a 2D attention mask for a diffusion window in transformer inference.
113
114 Args:
115 input_ids: Input token IDs (unused in current implementation)
116 attn_mask_2d: 2D attention mask matrix of shape [seq_len, seq_len] where:
117 - 0.0 indicates allowed attention
118 - -inf indicates masked attention
119 block_size: Size of the diffusion window
120 use_cache: Whether key-value cache is being used
121 causal_attn: If True, maintains strict causal masking throughout
122
123 Returns:
124 Modified attention mask with updated visibility patterns
125 """
126
127 if not causal_attn:
128 # Make the diffusion window (last block_size tokens) fully visible to itself
129 # This allows bidirectional attention within the diffusion window
130 attn_mask_2d[-block_size:, -block_size:] = 0.0
131 if use_cache:
132 # Mask the last token from previous round to prevent recomputation and maintain generation consistency.
133 attn_mask_2d[-block_size:, -block_size-1] = -float('inf')
134
135 return attn_mask_2d
136
137
138 def create_block_diff_mask_by_pe_4d(
139 block_size: int,
140 x0_len_list: torch.Tensor,
141 position_ids: torch.Tensor,
142 causal_attn: bool = False
143 ) -> tuple[torch.Tensor, torch.Tensor]:
144 """Generates a 4D attention mask for block-difference attention patterns.
145
146 The mask consists of three regions:
147 1. Causal block (top-left): Standard causal attention for `x0` tokens.
148 2. Mutual block (bottom-right): Non-causal attention within the same block for non-`x0` tokens.
149 3. Prefix block (bottom-left): Non-`x0` tokens can attend to a prefix of `x0` tokens.
150
151 Args:
152 block_size (int): Size of processing blocks for non-`x0` tokens.
153 x0_len_list (torch.Tensor): Tensor of shape [B] containing lengths of `x0` segments per batch.
154 position_ids (torch.Tensor): Tensor of shape [B, seq_len] containing position IDs.
155 causal_attn (bool, optional): If True, enforces causal masking in mutual blocks. Defaults to False.
156
157 Returns:
158 tuple[torch.Tensor, torch.Tensor]:
159 - A float mask of shape [batch_size, 1, seq_len, seq_len] with `-inf` for masked positions (non visiable).
160 - A boolean mask of shape [batch_size, 1, seq_len, seq_len] indicating allowed attention positions.
161 """
162 batch_size, seq_len = position_ids.shape
163 device = position_ids.device
164
165 # Create position indices [batch_size, seq_len, seq_len]
166 q_idx = torch.arange(seq_len, device=device).view(1, seq_len, 1) # [1, seq_len, 1]
167 kv_idx = torch.arange(seq_len, device=device).view(1, 1, seq_len) # [1, 1, seq_len]
168
169 # Broadcast to [B, seq_len, seq_len]
170 x0_len = x0_len_list.view(batch_size, 1, 1) # [batch_size, 1, 1]
171 x0_flag_q = q_idx < x0_len # [batch_size, seq_len, seq_len]
172 x0_flag_kv = kv_idx < x0_len
173
174 # Block indices calculation [batch_size, seq_len, seq_len]
175 q_block_idx = (q_idx - x0_len) // block_size
176 kv_block_idx = (kv_idx - x0_len) // block_size
177
178 # causal block (top-left)
179 block_causal = x0_flag_q & x0_flag_kv & (q_idx >= kv_idx)
180
181 mutual_condition = (q_idx >= kv_idx) if causal_attn else torch.ones_like(q_idx, dtype=torch.bool)
182 block_mutual = (
183 ~x0_flag_q & ~x0_flag_kv &
184 (q_block_idx == kv_block_idx) &
185 mutual_condition
186 )
187
188 q_blk = torch.div(q_idx - x0_len, block_size, rounding_mode='floor')
189 q_blk_start = (x0_len_list.view(batch_size, 1) + q_blk[:, :, 0] * block_size).clamp(min=0, max=seq_len - 1)
190 prefix_len = position_ids.gather(1, q_blk_start)
191 prefix_len = prefix_len.unsqueeze(2)
192 block_prefix = (~x0_flag_q & x0_flag_kv) & (kv_idx < prefix_len)
193
194 final_mask = (block_causal | block_mutual | block_prefix)
195 customized_mask = torch.full_like(final_mask, float('-inf'), dtype=torch.bfloat16)
196 customized_mask.masked_fill_(final_mask, 0.0)
197
198 return customized_mask.unsqueeze(1).to(device=device), final_mask.unsqueeze(1).to(device=device)
199
200
201 def find_pred_pos_from_input_ids(
202 input_ids: torch.LongTensor = None,
203 text_mask_token_id: int = None,
204 ) -> torch.Tensor:
205 """Compute the relative prediction positions for masked tokens in a sequence.
206
207 For non-masked positions, the output is 0. For masked positions, the value increments
208 by 1 for each consecutive mask token, indicating how many steps ahead the prediction is.
209
210 Args:
211 input_ids (torch.LongTensor): Input token IDs of shape [batch_size, seq_len].
212 text_mask_token_id (int, optional): Token ID representing masked positions. Defaults to 151666.
213
214 Returns:
215 torch.Tensor: A tensor of shape [batch_size, seq_len] where:
216 - 0 indicates a non-masked token.
217 - n > 0 indicates the nth consecutive masked token (e.g., 1 = first mask, 2 = second mask, etc.).
218 """
219 batch_size, seq_len = input_ids.shape
220 device = input_ids.device
221
222 is_mask = (input_ids == text_mask_token_id)
223
224 base_mask = torch.zeros((batch_size, seq_len), dtype=torch.int8, device=device)
225
226 for b in range(batch_size):
227 for ix in range(1, seq_len):
228 if is_mask[b][ix] == True:
229 # Increment counter if current token is masked
230 base_mask[b][ix] = base_mask[b][ix-1] + 1
231
232 return base_mask
233