xlm_padding.py
| 1 | # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py |
| 2 | # Commit id: c94cd09744d20f0ac587a351ff6ff2e8ad11ae1b |
| 3 | |
| 4 | # Previously adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py |
| 5 | |
| 6 | import torch |
| 7 | import torch.nn.functional as F |
| 8 | from einops import rearrange, repeat |
| 9 | |
| 10 | |
| 11 | class IndexFirstAxis(torch.autograd.Function): |
| 12 | @staticmethod |
| 13 | def forward(ctx, input, indices): |
| 14 | ctx.save_for_backward(indices) |
| 15 | assert input.ndim >= 2 |
| 16 | ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] |
| 17 | second_dim = other_shape.numel() |
| 18 | # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. |
| 19 | # return input[indices] |
| 20 | return torch.gather( |
| 21 | rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) |
| 22 | ).reshape(-1, *other_shape) |
| 23 | |
| 24 | @staticmethod |
| 25 | def backward(ctx, grad_output): |
| 26 | (indices,) = ctx.saved_tensors |
| 27 | assert grad_output.ndim >= 2 |
| 28 | other_shape = grad_output.shape[1:] |
| 29 | grad_output = rearrange(grad_output, "b ... -> b (...)") |
| 30 | grad_input = torch.zeros( |
| 31 | [ctx.first_axis_dim, grad_output.shape[1]], |
| 32 | device=grad_output.device, |
| 33 | dtype=grad_output.dtype, |
| 34 | ) |
| 35 | # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. |
| 36 | # grad_input[indices] = grad_output |
| 37 | grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) |
| 38 | return grad_input.reshape(ctx.first_axis_dim, *other_shape), None |
| 39 | |
| 40 | |
| 41 | index_first_axis = IndexFirstAxis.apply |
| 42 | |
| 43 | |
| 44 | class IndexPutFirstAxis(torch.autograd.Function): |
| 45 | @staticmethod |
| 46 | def forward(ctx, values, indices, first_axis_dim): |
| 47 | ctx.save_for_backward(indices) |
| 48 | assert indices.ndim == 1 |
| 49 | assert values.ndim >= 2 |
| 50 | output = torch.zeros( |
| 51 | first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype |
| 52 | ) |
| 53 | # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. |
| 54 | output[indices] = values |
| 55 | # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) |
| 56 | return output |
| 57 | |
| 58 | @staticmethod |
| 59 | def backward(ctx, grad_output): |
| 60 | (indices,) = ctx.saved_tensors |
| 61 | # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. |
| 62 | grad_values = grad_output[indices] |
| 63 | # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) |
| 64 | return grad_values, None, None |
| 65 | |
| 66 | |
| 67 | index_put_first_axis = IndexPutFirstAxis.apply |
| 68 | |
| 69 | |
| 70 | class IndexFirstAxisResidual(torch.autograd.Function): |
| 71 | @staticmethod |
| 72 | def forward(ctx, input, indices): |
| 73 | ctx.save_for_backward(indices) |
| 74 | assert input.ndim >= 2 |
| 75 | ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] |
| 76 | second_dim = other_shape.numel() |
| 77 | # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. |
| 78 | output = input[indices] |
| 79 | # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last |
| 80 | # memory format to channel_first. In other words, input might not be contiguous. |
| 81 | # If we don't detach, Pytorch complains about output being a view and is being modified inplace |
| 82 | return output, input.detach() |
| 83 | |
| 84 | @staticmethod |
| 85 | def backward(ctx, grad_output, grad_residual): |
| 86 | (indices,) = ctx.saved_tensors |
| 87 | assert grad_output.ndim >= 2 |
| 88 | other_shape = grad_output.shape[1:] |
| 89 | assert grad_residual.shape[1:] == other_shape |
| 90 | grad_input = grad_residual |
| 91 | # grad_input[indices] += grad_output |
| 92 | indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) |
| 93 | indices = indices.expand_as(grad_output) |
| 94 | grad_input.scatter_add_(0, indices, grad_output) |
| 95 | return grad_input.reshape(ctx.first_axis_dim, *other_shape), None |
| 96 | |
| 97 | |
| 98 | index_first_axis_residual = IndexFirstAxisResidual.apply |
| 99 | |
| 100 | |
| 101 | def unpad_input(hidden_states, attention_mask): |
| 102 | """ |
| 103 | Arguments: |
| 104 | hidden_states: (batch, seqlen, ...) |
| 105 | attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. |
| 106 | Return: |
| 107 | hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. |
| 108 | indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. |
| 109 | cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. |
| 110 | max_seqlen_in_batch: int |
| 111 | """ |
| 112 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| 113 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| 114 | max_seqlen_in_batch = seqlens_in_batch.max().item() |
| 115 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) |
| 116 | # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the |
| 117 | # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim |
| 118 | # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to |
| 119 | # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, |
| 120 | # so we write custom forward and backward to make it a bit faster. |
| 121 | return ( |
| 122 | index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), |
| 123 | indices, |
| 124 | cu_seqlens, |
| 125 | max_seqlen_in_batch, |
| 126 | ) |
| 127 | |
| 128 | |
| 129 | def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): |
| 130 | """ |
| 131 | Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). |
| 132 | The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). |
| 133 | |
| 134 | For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: |
| 135 | ``` |
| 136 | [ |
| 137 | [2, 3, 0, 0, 0, 0], |
| 138 | [3, 2, 0, 0, 0, 0], |
| 139 | [6, 0, 0, 0, 0, 0] |
| 140 | ] |
| 141 | ``` |
| 142 | , which refers to the 3D-attention mask: |
| 143 | ``` |
| 144 | [ |
| 145 | [ |
| 146 | [1, 0, 0, 0, 0, 0], |
| 147 | [1, 1, 0, 0, 0, 0], |
| 148 | [0, 0, 1, 0, 0, 0], |
| 149 | [0, 0, 1, 1, 0, 0], |
| 150 | [0, 0, 1, 1, 1, 0], |
| 151 | [0, 0, 0, 0, 0, 1] |
| 152 | ], |
| 153 | [ |
| 154 | [1, 0, 0, 0, 0, 0], |
| 155 | [1, 1, 0, 0, 0, 0], |
| 156 | [1, 1, 1, 0, 0, 0], |
| 157 | [0, 0, 0, 1, 0, 0], |
| 158 | [0, 0, 0, 1, 1, 0], |
| 159 | [0, 0, 0, 0, 0, 1] |
| 160 | ], |
| 161 | [ |
| 162 | [1, 0, 0, 0, 0, 0], |
| 163 | [1, 1, 0, 0, 0, 0], |
| 164 | [1, 1, 1, 0, 0, 0], |
| 165 | [1, 1, 1, 1, 0, 0], |
| 166 | [1, 1, 1, 1, 1, 0], |
| 167 | [1, 1, 1, 1, 1, 1] |
| 168 | ] |
| 169 | ] |
| 170 | ```. |
| 171 | |
| 172 | Arguments: |
| 173 | hidden_states: (batch, seqlen, ...) |
| 174 | attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. |
| 175 | Return: |
| 176 | hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. |
| 177 | indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. |
| 178 | cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. |
| 179 | max_seqlen_in_batch: int |
| 180 | """ |
| 181 | length = attention_mask_in_length.sum(dim=-1) |
| 182 | seqlen = attention_mask_in_length.size(-1) |
| 183 | attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), |
| 184 | seqlen) < length.unsqueeze( |
| 185 | 1) |
| 186 | real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() |
| 187 | seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] |
| 188 | indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() |
| 189 | max_seqlen_in_batch = seqlens_in_batch.max().item() |
| 190 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) |
| 191 | # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the |
| 192 | # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim |
| 193 | # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to |
| 194 | # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, |
| 195 | # so we write custom forward and backward to make it a bit faster. |
| 196 | return ( |
| 197 | index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), |
| 198 | indices, |
| 199 | cu_seqlens, |
| 200 | max_seqlen_in_batch, |
| 201 | ) |
| 202 | |
| 203 | |
| 204 | def pad_input(hidden_states, indices, batch, seqlen): |
| 205 | """ |
| 206 | Arguments: |
| 207 | hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. |
| 208 | indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. |
| 209 | batch: int, batch size for the padded sequence. |
| 210 | seqlen: int, maximum sequence length for the padded sequence. |
| 211 | Return: |
| 212 | hidden_states: (batch, seqlen, ...) |
| 213 | """ |
| 214 | dim = hidden_states.shape[-1] |
| 215 | # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) |
| 216 | # output[indices] = hidden_states |
| 217 | output = index_put_first_axis(hidden_states, indices, batch * seqlen) |
| 218 | return rearrange(output, "(b s) ... -> b s ...", b=batch) |