inference_utils.py
5.0 KB · 155 lines · python Raw
1 import torch
2 import torchaudio
3 import torch.nn.functional as F
4 from typing import Optional, List, Tuple
5 from tqdm import tqdm
6
7
8 def apply_top_k(logits, top_k):
9 batch_size, vocab_size = logits.shape
10 top_k = min(top_k, vocab_size)
11 top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1)
12 filtered_logits = torch.full_like(logits, float("-inf"))
13 batch_indices = torch.arange(batch_size).unsqueeze(-1)
14 filtered_logits[batch_indices, top_k_indices] = top_k_values
15 return filtered_logits
16
17
18 def apply_top_p(logits, top_p):
19 probs = F.softmax(logits, dim=-1)
20 sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
21 cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
22 sorted_indices_to_remove = cumulative_probs > top_p
23 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
24 sorted_indices_to_remove[..., 0] = False
25 batch_size = logits.shape[0]
26 filtered_logits = logits.clone()
27 for i in range(batch_size):
28 indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
29 filtered_logits[i, indices_to_remove] = float("-inf")
30 return filtered_logits
31
32
33 def apply_top_p_optimized(logits, top_p):
34 probs = F.softmax(logits, dim=-1)
35 sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
36
37 cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
38
39 sorted_indices_to_remove = cumulative_probs > top_p
40 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
41 sorted_indices_to_remove[..., 0] = False
42
43 indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
44 dim=-1, index=sorted_indices, src=sorted_indices_to_remove
45 )
46
47 logits[indices_to_remove] = float("-inf")
48 return logits
49
50
51 def apply_repetition_penalty_delay_pattern(
52 logits: torch.Tensor,
53 prev_tokens: torch.LongTensor,
54 penalty: float,
55 ):
56 """
57 logits: [B, H, V] or [N, V]
58 prev_tokens: [B, T, H] or [N, T] or [B, H]
59
60 Apply the repetition penalty independently for each H (VQ head).
61 """
62 if penalty == 1.0 or prev_tokens is None:
63 return logits
64
65 vocab_size = logits.size(-1)
66
67 # Case 1: regular [N, V] (text layer)
68 if logits.dim() == 2:
69 prev_tokens_flat = prev_tokens.reshape(-1)
70 unique_tokens = torch.unique(prev_tokens_flat)
71
72 token_logits = logits[:, unique_tokens]
73 pos_mask = token_logits > 0
74 token_logits[pos_mask] /= penalty
75 token_logits[~pos_mask] *= penalty
76 logits[:, unique_tokens] = token_logits
77 return logits
78
79 # Case 2: Delay Pattern audio [B, H, V]
80 assert logits.dim() == 3, "Delay Pattern audio logits must be [B, H, V]"
81 B, H, V = logits.shape
82
83 for h in range(H):
84 # prev_tokens_h: [B, T] or [B]
85 prev_tokens_h = prev_tokens[..., h].reshape(-1)
86 unique_tokens = torch.unique(prev_tokens_h)
87
88 if unique_tokens.numel() == 0:
89 continue
90
91 token_logits = logits[:, h, unique_tokens]
92 pos_mask = token_logits > 0
93 token_logits[pos_mask] /= penalty
94 token_logits[~pos_mask] *= penalty
95 logits[:, h, unique_tokens] = token_logits
96
97 return logits
98
99
100 def sample_token(
101 logits,
102 prev_tokens: Optional[torch.LongTensor] = None,
103 repetition_penalty: float = 1.0,
104 top_p=None,
105 top_k=None,
106 do_sample=True,
107 ):
108 vocab_size = logits.size(-1)
109
110 # ===== Repetition Penalty (before reshaping!) =====
111 if prev_tokens is not None and repetition_penalty != 1.0:
112 logits = apply_repetition_penalty_delay_pattern(
113 logits,
114 prev_tokens,
115 repetition_penalty,
116 )
117
118 if not do_sample:
119 return torch.argmax(logits, dim=-1)
120
121 # ===== Only flatten after this, for top-k / top-p / multinomial =====
122 original_shape = logits.shape
123 reshaped_logits = logits.view(-1, vocab_size)
124
125 if top_k is not None and top_k > 0:
126 reshaped_logits = apply_top_k(reshaped_logits, top_k)
127
128 if top_p is not None and top_p < 1.0:
129 reshaped_logits = apply_top_p_optimized(reshaped_logits, top_p)
130
131 probs = F.softmax(reshaped_logits, dim=-1)
132 next_tokens = torch.multinomial(probs, num_samples=1)
133
134 return next_tokens.view(original_shape[:-1])
135
136
137 def find_last_equal_C(tensor, C):
138 """
139 tensor: torch.Tensor of shape [batch_size, seq_len]
140 C: scalar value to match
141 Returns: torch.Tensor of shape [batch_size] with last indices
142 """
143 mask = (tensor == C).int() # Shape: [batch_size, seq_len], bool tensor
144 flipped_mask = mask.flip(dims=[1]) # Flip along sequence dimension
145 flipped_indices = flipped_mask.argmax(dim=1) # First True in flipped
146 seq_len = tensor.shape[1]
147 last_indices = (seq_len - 1) - flipped_indices # Convert to original indices
148
149 # Optional: Handle cases with no C (set to -1), though problem assumes existence
150 actual_values = tensor[torch.arange(tensor.shape[0]), last_indices]
151 no_match = actual_values != C
152 last_indices[no_match] = -1
153
154 return last_indices
155