inference_utils.py
| 1 | import torch |
| 2 | import torchaudio |
| 3 | import torch.nn.functional as F |
| 4 | from typing import Optional, List, Tuple |
| 5 | from tqdm import tqdm |
| 6 | |
| 7 | |
| 8 | def apply_top_k(logits, top_k): |
| 9 | batch_size, vocab_size = logits.shape |
| 10 | top_k = min(top_k, vocab_size) |
| 11 | top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) |
| 12 | filtered_logits = torch.full_like(logits, float("-inf")) |
| 13 | batch_indices = torch.arange(batch_size).unsqueeze(-1) |
| 14 | filtered_logits[batch_indices, top_k_indices] = top_k_values |
| 15 | return filtered_logits |
| 16 | |
| 17 | |
| 18 | def apply_top_p(logits, top_p): |
| 19 | probs = F.softmax(logits, dim=-1) |
| 20 | sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) |
| 21 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
| 22 | sorted_indices_to_remove = cumulative_probs > top_p |
| 23 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| 24 | sorted_indices_to_remove[..., 0] = False |
| 25 | batch_size = logits.shape[0] |
| 26 | filtered_logits = logits.clone() |
| 27 | for i in range(batch_size): |
| 28 | indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] |
| 29 | filtered_logits[i, indices_to_remove] = float("-inf") |
| 30 | return filtered_logits |
| 31 | |
| 32 | |
| 33 | def apply_top_p_optimized(logits, top_p): |
| 34 | probs = F.softmax(logits, dim=-1) |
| 35 | sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) |
| 36 | |
| 37 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
| 38 | |
| 39 | sorted_indices_to_remove = cumulative_probs > top_p |
| 40 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| 41 | sorted_indices_to_remove[..., 0] = False |
| 42 | |
| 43 | indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_( |
| 44 | dim=-1, index=sorted_indices, src=sorted_indices_to_remove |
| 45 | ) |
| 46 | |
| 47 | logits[indices_to_remove] = float("-inf") |
| 48 | return logits |
| 49 | |
| 50 | |
| 51 | def apply_repetition_penalty_delay_pattern( |
| 52 | logits: torch.Tensor, |
| 53 | prev_tokens: torch.LongTensor, |
| 54 | penalty: float, |
| 55 | ): |
| 56 | """ |
| 57 | logits: [B, H, V] or [N, V] |
| 58 | prev_tokens: [B, T, H] or [N, T] or [B, H] |
| 59 | |
| 60 | Apply the repetition penalty independently for each H (VQ head). |
| 61 | """ |
| 62 | if penalty == 1.0 or prev_tokens is None: |
| 63 | return logits |
| 64 | |
| 65 | vocab_size = logits.size(-1) |
| 66 | |
| 67 | # Case 1: regular [N, V] (text layer) |
| 68 | if logits.dim() == 2: |
| 69 | prev_tokens_flat = prev_tokens.reshape(-1) |
| 70 | unique_tokens = torch.unique(prev_tokens_flat) |
| 71 | |
| 72 | token_logits = logits[:, unique_tokens] |
| 73 | pos_mask = token_logits > 0 |
| 74 | token_logits[pos_mask] /= penalty |
| 75 | token_logits[~pos_mask] *= penalty |
| 76 | logits[:, unique_tokens] = token_logits |
| 77 | return logits |
| 78 | |
| 79 | # Case 2: Delay Pattern audio [B, H, V] |
| 80 | assert logits.dim() == 3, "Delay Pattern audio logits must be [B, H, V]" |
| 81 | B, H, V = logits.shape |
| 82 | |
| 83 | for h in range(H): |
| 84 | # prev_tokens_h: [B, T] or [B] |
| 85 | prev_tokens_h = prev_tokens[..., h].reshape(-1) |
| 86 | unique_tokens = torch.unique(prev_tokens_h) |
| 87 | |
| 88 | if unique_tokens.numel() == 0: |
| 89 | continue |
| 90 | |
| 91 | token_logits = logits[:, h, unique_tokens] |
| 92 | pos_mask = token_logits > 0 |
| 93 | token_logits[pos_mask] /= penalty |
| 94 | token_logits[~pos_mask] *= penalty |
| 95 | logits[:, h, unique_tokens] = token_logits |
| 96 | |
| 97 | return logits |
| 98 | |
| 99 | |
| 100 | def sample_token( |
| 101 | logits, |
| 102 | prev_tokens: Optional[torch.LongTensor] = None, |
| 103 | repetition_penalty: float = 1.0, |
| 104 | top_p=None, |
| 105 | top_k=None, |
| 106 | do_sample=True, |
| 107 | ): |
| 108 | vocab_size = logits.size(-1) |
| 109 | |
| 110 | # ===== Repetition Penalty (before reshaping!) ===== |
| 111 | if prev_tokens is not None and repetition_penalty != 1.0: |
| 112 | logits = apply_repetition_penalty_delay_pattern( |
| 113 | logits, |
| 114 | prev_tokens, |
| 115 | repetition_penalty, |
| 116 | ) |
| 117 | |
| 118 | if not do_sample: |
| 119 | return torch.argmax(logits, dim=-1) |
| 120 | |
| 121 | # ===== Only flatten after this, for top-k / top-p / multinomial ===== |
| 122 | original_shape = logits.shape |
| 123 | reshaped_logits = logits.view(-1, vocab_size) |
| 124 | |
| 125 | if top_k is not None and top_k > 0: |
| 126 | reshaped_logits = apply_top_k(reshaped_logits, top_k) |
| 127 | |
| 128 | if top_p is not None and top_p < 1.0: |
| 129 | reshaped_logits = apply_top_p_optimized(reshaped_logits, top_p) |
| 130 | |
| 131 | probs = F.softmax(reshaped_logits, dim=-1) |
| 132 | next_tokens = torch.multinomial(probs, num_samples=1) |
| 133 | |
| 134 | return next_tokens.view(original_shape[:-1]) |
| 135 | |
| 136 | |
| 137 | def find_last_equal_C(tensor, C): |
| 138 | """ |
| 139 | tensor: torch.Tensor of shape [batch_size, seq_len] |
| 140 | C: scalar value to match |
| 141 | Returns: torch.Tensor of shape [batch_size] with last indices |
| 142 | """ |
| 143 | mask = (tensor == C).int() # Shape: [batch_size, seq_len], bool tensor |
| 144 | flipped_mask = mask.flip(dims=[1]) # Flip along sequence dimension |
| 145 | flipped_indices = flipped_mask.argmax(dim=1) # First True in flipped |
| 146 | seq_len = tensor.shape[1] |
| 147 | last_indices = (seq_len - 1) - flipped_indices # Convert to original indices |
| 148 | |
| 149 | # Optional: Handle cases with no C (set to -1), though problem assumes existence |
| 150 | actual_values = tensor[torch.arange(tensor.shape[0]), last_indices] |
| 151 | no_match = actual_values != C |
| 152 | last_indices[no_match] = -1 |
| 153 | |
| 154 | return last_indices |
| 155 | |