xlm_padding.py
9.6 KB · 218 lines · python Raw
1 # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2 # Commit id: c94cd09744d20f0ac587a351ff6ff2e8ad11ae1b
3
4 # Previously adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
5
6 import torch
7 import torch.nn.functional as F
8 from einops import rearrange, repeat
9
10
11 class IndexFirstAxis(torch.autograd.Function):
12 @staticmethod
13 def forward(ctx, input, indices):
14 ctx.save_for_backward(indices)
15 assert input.ndim >= 2
16 ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
17 second_dim = other_shape.numel()
18 # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
19 # return input[indices]
20 return torch.gather(
21 rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
22 ).reshape(-1, *other_shape)
23
24 @staticmethod
25 def backward(ctx, grad_output):
26 (indices,) = ctx.saved_tensors
27 assert grad_output.ndim >= 2
28 other_shape = grad_output.shape[1:]
29 grad_output = rearrange(grad_output, "b ... -> b (...)")
30 grad_input = torch.zeros(
31 [ctx.first_axis_dim, grad_output.shape[1]],
32 device=grad_output.device,
33 dtype=grad_output.dtype,
34 )
35 # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
36 # grad_input[indices] = grad_output
37 grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
38 return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
39
40
41 index_first_axis = IndexFirstAxis.apply
42
43
44 class IndexPutFirstAxis(torch.autograd.Function):
45 @staticmethod
46 def forward(ctx, values, indices, first_axis_dim):
47 ctx.save_for_backward(indices)
48 assert indices.ndim == 1
49 assert values.ndim >= 2
50 output = torch.zeros(
51 first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
52 )
53 # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
54 output[indices] = values
55 # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
56 return output
57
58 @staticmethod
59 def backward(ctx, grad_output):
60 (indices,) = ctx.saved_tensors
61 # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
62 grad_values = grad_output[indices]
63 # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
64 return grad_values, None, None
65
66
67 index_put_first_axis = IndexPutFirstAxis.apply
68
69
70 class IndexFirstAxisResidual(torch.autograd.Function):
71 @staticmethod
72 def forward(ctx, input, indices):
73 ctx.save_for_backward(indices)
74 assert input.ndim >= 2
75 ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
76 second_dim = other_shape.numel()
77 # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
78 output = input[indices]
79 # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
80 # memory format to channel_first. In other words, input might not be contiguous.
81 # If we don't detach, Pytorch complains about output being a view and is being modified inplace
82 return output, input.detach()
83
84 @staticmethod
85 def backward(ctx, grad_output, grad_residual):
86 (indices,) = ctx.saved_tensors
87 assert grad_output.ndim >= 2
88 other_shape = grad_output.shape[1:]
89 assert grad_residual.shape[1:] == other_shape
90 grad_input = grad_residual
91 # grad_input[indices] += grad_output
92 indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
93 indices = indices.expand_as(grad_output)
94 grad_input.scatter_add_(0, indices, grad_output)
95 return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
96
97
98 index_first_axis_residual = IndexFirstAxisResidual.apply
99
100
101 def unpad_input(hidden_states, attention_mask):
102 """
103 Arguments:
104 hidden_states: (batch, seqlen, ...)
105 attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
106 Return:
107 hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
108 indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
109 cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
110 max_seqlen_in_batch: int
111 """
112 seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
113 indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
114 max_seqlen_in_batch = seqlens_in_batch.max().item()
115 cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
116 # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
117 # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
118 # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
119 # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
120 # so we write custom forward and backward to make it a bit faster.
121 return (
122 index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
123 indices,
124 cu_seqlens,
125 max_seqlen_in_batch,
126 )
127
128
129 def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
130 """
131 Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
132 The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
133
134 For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
135 ```
136 [
137 [2, 3, 0, 0, 0, 0],
138 [3, 2, 0, 0, 0, 0],
139 [6, 0, 0, 0, 0, 0]
140 ]
141 ```
142 , which refers to the 3D-attention mask:
143 ```
144 [
145 [
146 [1, 0, 0, 0, 0, 0],
147 [1, 1, 0, 0, 0, 0],
148 [0, 0, 1, 0, 0, 0],
149 [0, 0, 1, 1, 0, 0],
150 [0, 0, 1, 1, 1, 0],
151 [0, 0, 0, 0, 0, 1]
152 ],
153 [
154 [1, 0, 0, 0, 0, 0],
155 [1, 1, 0, 0, 0, 0],
156 [1, 1, 1, 0, 0, 0],
157 [0, 0, 0, 1, 0, 0],
158 [0, 0, 0, 1, 1, 0],
159 [0, 0, 0, 0, 0, 1]
160 ],
161 [
162 [1, 0, 0, 0, 0, 0],
163 [1, 1, 0, 0, 0, 0],
164 [1, 1, 1, 0, 0, 0],
165 [1, 1, 1, 1, 0, 0],
166 [1, 1, 1, 1, 1, 0],
167 [1, 1, 1, 1, 1, 1]
168 ]
169 ]
170 ```.
171
172 Arguments:
173 hidden_states: (batch, seqlen, ...)
174 attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
175 Return:
176 hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
177 indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
178 cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
179 max_seqlen_in_batch: int
180 """
181 length = attention_mask_in_length.sum(dim=-1)
182 seqlen = attention_mask_in_length.size(-1)
183 attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
184 seqlen) < length.unsqueeze(
185 1)
186 real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
187 seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
188 indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
189 max_seqlen_in_batch = seqlens_in_batch.max().item()
190 cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
191 # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
192 # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
193 # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
194 # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
195 # so we write custom forward and backward to make it a bit faster.
196 return (
197 index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
198 indices,
199 cu_seqlens,
200 max_seqlen_in_batch,
201 )
202
203
204 def pad_input(hidden_states, indices, batch, seqlen):
205 """
206 Arguments:
207 hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
208 indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
209 batch: int, batch size for the padded sequence.
210 seqlen: int, maximum sequence length for the padded sequence.
211 Return:
212 hidden_states: (batch, seqlen, ...)
213 """
214 dim = hidden_states.shape[-1]
215 # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
216 # output[indices] = hidden_states
217 output = index_put_first_axis(hidden_states, indices, batch * seqlen)
218 return rearrange(output, "(b s) ... -> b s ...", b=batch)