modeling_moss_tts.py
24.7 KB · 526 lines · python Raw
1 # coding=utf-8
2 # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 # http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15 """ Modeling classes for MossTTSDelay. """
16
17 from dataclasses import dataclass
18 from typing import List, Optional, Tuple, Union
19 from tqdm import tqdm
20
21 import torch
22 import torch.nn as nn
23 from torch.nn import CrossEntropyLoss
24
25 from transformers.modeling_utils import PreTrainedModel
26 from transformers.modeling_outputs import ModelOutput
27 from transformers.utils import (
28 add_start_docstrings,
29 add_start_docstrings_to_model_forward,
30 logging,
31 replace_return_docstrings,
32 )
33 from transformers.cache_utils import Cache
34 from transformers.models.qwen3 import Qwen3Model
35 from transformers import initialization as init
36
37 from .configuration_moss_tts import MossTTSDelayConfig
38 from .inference_utils import sample_token, find_last_equal_C
39
40 try:
41 from .processing_moss_tts import UserMessage, AssistantMessage, MossTTSDelayProcessor
42 except Exception:
43 UserMessage = None
44 AssistantMessage = None
45 MossTTSDelayProcessor = None
46
47 logger = logging.get_logger(__name__)
48
49 _CONFIG_FOR_DOC = "MossTTSDelayConfig"
50
51
52 @dataclass
53 class MossTTSDelayOutputWithPast(ModelOutput):
54 """
55 Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
56
57 Args:
58 loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
59 Weighted sum of channel losses.
60 all_sum_losses (`torch.FloatTensor` of shape `(batch_size, n_vq + 1)`, *optional*):
61 Sum of losses for each sample and each channel before averaging.
62 all_token_nums (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
63 Number of non-masked tokens per sample.
64 sample_losses (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
65 Loss per sample.
66 channel_losses (`torch.FloatTensor` of shape `(n_vq + 1,)`, *optional*):
67 Loss per channel (text head + vq heads).
68 logits (`List[torch.FloatTensor]`, *optional*):
69 List of prediction scores from each head.
70 past_key_values (`Cache`, *optional*):
71 Pre-computed hidden-states (key and values in the self-attention blocks).
72 hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
73 Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, +
74 one for the output of each layer).
75 attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed):
76 Tuple of torch.FloatTensor (one for each layer) of the attention weights.
77 """
78 loss: Optional[torch.FloatTensor] = None
79 all_sum_losses: Optional[torch.FloatTensor] = None
80 all_token_nums: Optional[torch.LongTensor] = None
81 sample_losses: Optional[torch.FloatTensor] = None
82 channel_losses: Optional[torch.FloatTensor] = None
83 logits: Optional[List[torch.FloatTensor]] = None
84 past_key_values: Optional[Cache] = None
85 hidden_states: Optional[Tuple[torch.FloatTensor]] = None
86 attentions: Optional[Tuple[torch.FloatTensor]] = None
87
88
89
90
91 class MossTTSDelayPreTrainedModel(PreTrainedModel):
92 config_class = MossTTSDelayConfig
93 base_model_prefix = "model"
94 supports_gradient_checkpointing = True
95 _no_split_modules = ["Qwen3DecoderLayer"]
96 _skip_keys_device_placement = "past_key_values"
97 _supports_flash_attn = True
98 _supports_flash_attn_2 = True
99 _supports_sdpa = True
100 _supports_flex_attn = True
101
102 def _init_weights(self, module):
103 """
104 Transformers 5.0+ safe init:
105 - MUST use transformers.initialization helpers
106 - MUST respect param._is_hf_initialized to avoid overwriting ckpt-loaded params
107 """
108 # Let HF handle its standard modules first (LayerNorm, Linear, Embedding, etc.)
109 super()._init_weights(module)
110
111 # Pick a std consistent with HF conventions
112 # Prefer model/text config initializer_range if present.
113 std = None
114 if hasattr(self.config, "initializer_range"):
115 std = self.config.initializer_range
116 elif hasattr(self.config, "language_config") and hasattr(self.config.language_config, "initializer_range"):
117 std = self.config.language_config.initializer_range
118 else:
119 std = 0.02
120
121 # Initialize extra audio embeddings
122 if isinstance(module, nn.Embedding):
123 # Only touch our extra embeddings (avoid double touching LM's embeddings if not desired)
124 # If you prefer, you can skip this check and rely on super()._init_weights for all embeddings.
125 if getattr(module, "num_embeddings", None) == self.config.audio_vocab_size + 1:
126 init.normal_(module.weight, mean=0.0, std=std)
127 # If you later set padding_idx, you must explicitly zero it (and respect _is_hf_initialized!)
128 # init.zeros_ will internally check param flags, but slicing needs manual care.
129
130 # Initialize multi-head projections you added
131 if isinstance(module, nn.Linear):
132 # For your lm_heads, super()._init_weights already covers typical Linear.
133 # This block is only needed if you have custom Linear variants later.
134 pass
135
136
137
138 MOSSTTS_START_DOCSTRING = r"""
139 This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
140 library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
141 etc.)
142
143 This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
144 Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
145 and behavior.
146
147 Parameters:
148 config ([`MossTTSDelayConfig`]):
149 Model configuration class with all the parameters of the model. Initializing with a config file does not
150 load the weights associated with the model, only the configuration. Check out the
151 [`~PreTrainedModel.from_pretrained`] method to load the model weights.
152 """
153
154
155 @add_start_docstrings(
156 "The MossTTSDelay Model architecture tailored for Text-to-Speech generation with multi-head VQ prediction.",
157 MOSSTTS_START_DOCSTRING,
158 )
159 class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
160 UserMessage = UserMessage
161 AssistantMessage = AssistantMessage
162 Processor = MossTTSDelayProcessor
163
164 def __init__(self, config: MossTTSDelayConfig):
165 super().__init__(config)
166 self.config = config
167
168 config.language_config.torch_dtype = config.torch_dtype
169
170 self.language_model = Qwen3Model(config.language_config)
171
172 # Audio VQ Embeddings (Extra channels)
173 # Note: input_ids[..., 0] uses Qwen's embedding.
174 # input_ids[..., 1:] use these extensions.
175 self.emb_ext = nn.ModuleList()
176 for vq_idx in range(self.config.n_vq):
177 # Add +1 for potential padding/special tokens logic if strictly required by upstream data prep
178 self.emb_ext.append(
179 nn.Embedding(self.config.audio_vocab_size + 1, config.language_config.hidden_size, padding_idx=None)
180 )
181
182 # Multi-Head Prediction Layers
183 # Head 0: Main language head
184 # Head 1..N: Audio VQ heads
185 self.lm_heads = nn.ModuleList([
186 nn.Linear(config.language_config.hidden_size, config.language_config.vocab_size, bias=False)
187 ])
188 for vq_idx in range(self.config.n_vq):
189 self.lm_heads.append(
190 nn.Linear(config.language_config.hidden_size, self.config.audio_vocab_size + 1, bias=False)
191 )
192
193 # Initialize weights and apply final processing
194 self.post_init()
195
196 def get_input_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
197 """
198 Computes the combined embeddings from text and multiple audio VQ channels.
199
200 Args:
201 input_ids: Shape (Batch, Seq_Len, 1 + n_vq)
202 """
203 # Base Text/Content Embedding
204 # input_ids[..., 0] is standard text or semantic tokens
205 inputs_embeds = self.language_model.get_input_embeddings()(input_ids[..., 0])
206
207 # Add VQ Embeddings
208 for i, embed_layer in enumerate(self.emb_ext):
209 # i corresponds to channel i+1 in input_ids
210 # We assume the data pipeline ensures indices are within range
211 inputs_embeds = inputs_embeds + embed_layer(input_ids[..., i + 1])
212
213 return inputs_embeds
214
215 def set_input_embeddings(self, value):
216 self.language_model.embed_tokens = value
217
218 def get_output_embeddings(self):
219 # Returning a list of heads might break some HF utilities expecting a single head.
220 # However, for custom models, this is acceptable.
221 return self.lm_heads
222
223 @add_start_docstrings_to_model_forward(MOSSTTS_START_DOCSTRING)
224 @replace_return_docstrings(output_type=MossTTSDelayOutputWithPast, config_class=_CONFIG_FOR_DOC)
225 def forward(
226 self,
227 input_ids: Optional[torch.LongTensor] = None,
228 attention_mask: Optional[torch.Tensor] = None,
229 position_ids: Optional[torch.LongTensor] = None,
230 past_key_values: Optional[List[torch.FloatTensor]] = None,
231 inputs_embeds: Optional[torch.FloatTensor] = None,
232 labels: Optional[torch.LongTensor] = None,
233 use_cache: Optional[bool] = None,
234 output_attentions: Optional[bool] = None,
235 cache_position: Optional[torch.LongTensor] = None,
236 hidden_out_layers: Optional[List[int]] = None,
237 channelwise_loss_weight: Optional[List[float]] = None,
238 **kwargs,
239 ) -> Union[Tuple, MossTTSDelayOutputWithPast]:
240 r"""
241 Args:
242 input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`):
243 Indices of input sequence tokens in the vocabulary.
244 Dimension 2 contains: [Text/Semantics, VQ_0, VQ_1, ..., VQ_N].
245 labels (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`, *optional*):
246 Labels for computing the masked language modeling loss.
247 channelwise_loss_weight (`List[float]`, *optional*):
248 Manual weights for summing losses across different heads (Text vs Audio channels).
249
250 Returns:
251 """
252
253 if len(input_ids.shape) != 3 or input_ids.shape[-1] != self.config.n_vq + 1:
254 raise ValueError("`Input_ids`'s shape should be exactly (batch_size, sequence_length, 1 + n_vq).")
255
256 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
257
258 # 1. Prepare Embeddings
259 if inputs_embeds is None:
260 inputs_embeds = self.get_input_embeddings(input_ids)
261
262 # 2. Backbone Forward
263 # Qwen3Model outputs standard CausalLMOutputWithPast or similar
264 outputs = self.language_model(
265 input_ids=None, # Passed via inputs_embeds
266 position_ids=position_ids,
267 attention_mask=attention_mask,
268 past_key_values=past_key_values,
269 inputs_embeds=inputs_embeds,
270 use_cache=use_cache,
271 output_attentions=output_attentions,
272 output_hidden_states=True, # Always need hidden states for multi-head projection
273 return_dict=True,
274 cache_position=cache_position,
275 **kwargs,
276 )
277
278 # 3. Handle specific layer outputs if requested (Delay Pattern often requires features from specific layers)
279 last_hidden_state = outputs.last_hidden_state
280 if hidden_out_layers is None:
281 # Default to using the last layer for all heads
282 # In some architectures (like MusicGen), different codebooks come from different transformer layers.
283 # Here we default to the final layer as per original code behavior [-1] * (n + 1).
284 hidden_states_for_heads = [last_hidden_state] * (len(self.lm_heads))
285 else:
286 # If hidden_out_layers is provided (e.g. [-1, -2, -3...]), fetch them from all_hidden_states
287 # Note: outputs.hidden_states includes embedding output at index 0 usually.
288 all_hs = outputs.hidden_states
289 hidden_states_for_heads = [all_hs[idx] for idx in hidden_out_layers]
290
291 # 4. Project to Logits (Multi-Head)
292 layer_logits = []
293 for i, (hs, head) in enumerate(zip(hidden_states_for_heads, self.lm_heads)):
294 logits = head(hs)
295 # Original code logic: Mask the last token index for audio heads (indices > 0)
296 # This implies the vocab size is (N+1) but the model shouldn't predict the (N+1)-th token
297 # (perhaps reserved for padding in the input but invalid for prediction).
298 if i > 0:
299 logits[..., -1] = float("-inf")
300 layer_logits.append(logits)
301
302 # 5. Loss Calculation
303 loss = None
304 all_sum_losses = None
305 all_token_nums = None
306 sample_losses = None
307 channel_losses = None
308
309 if labels is not None:
310 # Ensure labels match input shape rank (B, S, C)
311 if labels.dim() != 3:
312 raise ValueError(f"Labels must have rank 3 (B, S, C), got {labels.shape}")
313
314 batch_size = labels.size(0)
315 n_heads = len(layer_logits)
316
317 # Container for per-sample, per-channel losses
318 # Shape: [Batch, n_heads]
319 all_sum_losses_list = []
320
321 # Count valid tokens (not -100) per sample.
322 # Note: Assuming mask is consistent across channels or we take sum over dim 1 (seq)
323 # Usually strict masking means checking one channel or all.
324 # Original code: torch.sum(labels != -100, dim=1) -> [B, C]
325 all_token_nums = torch.sum(labels != -100, dim=1)
326
327 for i, logits in enumerate(layer_logits):
328 # logits: [B, S, V]
329 # cur_labels: [B, S]
330 cur_labels = labels[..., i]
331
332 # Flatten for CrossEntropy
333 # logits: [B*S, V], labels: [B*S]
334 loss_fct = CrossEntropyLoss(reduction='none')
335 vocab_size = logits.size(-1)
336
337 reshaped_logits = logits.view(-1, vocab_size)
338 reshaped_labels = cur_labels.contiguous().view(-1)
339
340 # Calculate loss per token
341 per_token_loss = loss_fct(reshaped_logits, reshaped_labels)
342
343 # Reshape back to [B, S] and sum over Sequence dimension to get per-sample loss
344 per_token_loss = per_token_loss.view(batch_size, -1)
345 per_sample_loss = torch.sum(per_token_loss, dim=-1) # [B]
346
347 all_sum_losses_list.append(per_sample_loss)
348
349 # Stack to [B, n_heads]
350 all_sum_losses = torch.stack(all_sum_losses_list, dim=1)
351
352 # Weighted Loss Aggregation
353 if channelwise_loss_weight is not None:
354 if len(channelwise_loss_weight) != n_heads:
355 raise ValueError(f"channelwise_loss_weight length {len(channelwise_loss_weight)} != {n_heads}")
356
357 w_tensor = torch.tensor(channelwise_loss_weight, device=all_sum_losses.device, dtype=all_sum_losses.dtype)
358
359 # Sample losses: Weighted sum over channels per sample / Total weight
360 # Normalize by token count per channel
361 # Avoid division by zero with epsilon or mask
362 token_counts_safe = all_token_nums.float().clamp(min=1.0)
363
364 normalized_losses = all_sum_losses / token_counts_safe
365 sample_losses = (normalized_losses * w_tensor).sum(dim=1) / w_tensor.sum()
366
367 # Channel losses: Sum over batch / Sum tokens over batch
368 total_loss_per_channel = all_sum_losses.sum(dim=0)
369 total_tokens_per_channel = all_token_nums.sum(dim=0).float().clamp(min=1.0)
370 channel_losses = total_loss_per_channel / total_tokens_per_channel
371
372 # Final scalar loss
373 loss = (channel_losses * w_tensor).sum() / w_tensor.sum()
374 else:
375 # Default average if no weights provided
376 total_tokens = all_token_nums.sum().float().clamp(min=1.0)
377 loss = all_sum_losses.sum() / total_tokens
378 channel_losses = all_sum_losses.sum(dim=0) / all_token_nums.sum(dim=0).clamp(min=1.0)
379
380 return MossTTSDelayOutputWithPast(
381 loss=loss,
382 all_sum_losses=all_sum_losses,
383 all_token_nums=all_token_nums,
384 sample_losses=sample_losses,
385 channel_losses=channel_losses,
386 logits=layer_logits,
387 past_key_values=outputs.past_key_values,
388 hidden_states=outputs.hidden_states,
389 attentions=outputs.attentions,
390 )
391
392 @torch.inference_mode()
393 def generate(
394 self,
395 input_ids: torch.LongTensor,
396 attention_mask: Optional[torch.Tensor] = None,
397 max_new_tokens: int = 1000,
398 text_temperature: float = 1.5,
399 text_top_p: float = 1.0,
400 text_top_k: int = 50,
401 audio_temperature: float = 1.7,
402 audio_top_p: float = 0.8,
403 audio_top_k: int = 25,
404 audio_repetition_penalty: float = 1.0,
405 ):
406 if text_temperature > 0:
407 text_do_sample = True
408 else:
409 text_temperature = 1
410 text_do_sample = False
411 if audio_temperature > 0:
412 audio_do_sample = True
413 else:
414 audio_temperature = 1
415 audio_do_sample = False
416
417 past_key_values = None
418 device = input_ids.device
419 current_input_ids = input_ids
420 current_attention_mask = attention_mask
421 batch_size, seq_len, n_vq = input_ids.shape
422 n_vq -= 1
423
424 generation_ids = input_ids[:]
425 is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
426
427 audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device)
428 torch_int64_max = torch.iinfo(torch.int64).max
429 delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device)
430
431 is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id)
432 audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id)
433 audio_start_mask = is_continuation & (audio_start_indices != -1)
434 audio_lengths[audio_start_mask] = seq_len - audio_start_indices[audio_start_mask]
435
436 is_audio = audio_start_mask.clone()
437
438 pre_exclude_mask0 = torch.tensor([self.config.pad_token_id, self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id, self.config.audio_end_token_id], device=device)
439 pre_exclude_mask1 = torch.ones(self.config.language_config.vocab_size, device=device).bool()
440 pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False
441
442 for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."):
443 outputs = self(
444 input_ids=current_input_ids,
445 attention_mask=current_attention_mask,
446 past_key_values=past_key_values,
447 use_cache=True,
448 )
449 past_key_values = outputs.past_key_values
450
451 next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] # List, len=n_vq+1, [batch_size, 1, vocab_size];
452 next_token_logits[0] = next_token_logits[0].clone()
453 next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device)
454 next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id
455 is_audio_eos = ~is_stopping & (delayed_lengths == n_vq)
456 next_text_token[is_audio_eos] = self.config.audio_end_token_id
457 is_audio[is_audio_eos] = False
458 sampling_text_mask = ~is_stopping & (delayed_lengths > n_vq)
459 next_token_logits[0][~is_audio] = next_token_logits[0][~is_audio].index_fill(-1, pre_exclude_mask0, float('-inf'))
460 next_token_logits[0][is_audio] = next_token_logits[0][is_audio].masked_fill(pre_exclude_mask1, float('-inf'))
461 if time_step == 0:
462 next_token_logits[0][..., 151662] = float('-inf')
463 if time_step <= n_vq:
464 next_token_logits[0][..., self.config.im_end_token_id] = float('-inf')
465
466 next_text_token[sampling_text_mask] = sample_token(
467 logits=next_token_logits[0][sampling_text_mask],
468 top_p=text_top_p,
469 top_k=text_top_k,
470 do_sample=text_do_sample
471 )
472 is_audio[next_text_token == self.config.audio_start_token_id] = True
473 is_stopping[next_text_token == self.config.im_end_token_id] = True
474
475 next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device)
476
477 pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq)
478 post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1
479 post_audio_mask[delayed_lengths == torch_int64_max] = True
480 sampling_audio_mask = pre_audio_mask & post_audio_mask
481 next_audio_tokens[~sampling_audio_mask] = self.config.audio_pad_code
482
483 if sampling_audio_mask.sum() > 0:
484 audio_ch0_logits = next_token_logits[1][sampling_audio_mask[:, 0]]
485 audio_logits = torch.stack(next_token_logits[2:], dim=1)[sampling_audio_mask[:, 1:]]
486 audio_ch0_logits[..., self.config.audio_pad_code] = float('-inf')
487 audio_logits[..., self.config.audio_pad_code] = float('-inf')
488 next_audio_tokens[:, 0][sampling_audio_mask[:, 0]] = sample_token(
489 logits=audio_ch0_logits,
490 prev_tokens=generation_ids[:, :, 1],
491 repetition_penalty=audio_repetition_penalty,
492 top_p=audio_top_p,
493 top_k=audio_top_k,
494 do_sample=audio_do_sample
495 )
496 next_audio_tokens[:, 1:][sampling_audio_mask[:, 1:]] = sample_token(
497 logits=audio_logits,
498 prev_tokens=generation_ids[:, :, 2:],
499 repetition_penalty=audio_repetition_penalty,
500 top_p=audio_top_p,
501 top_k=audio_top_k,
502 do_sample=audio_do_sample
503 )
504
505 audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1
506 audio_lengths[next_text_token == self.config.audio_end_token_id] = 0
507 delayed_lengths[(delayed_lengths == torch_int64_max) & (next_text_token == self.config.audio_assistant_delay_slot_token_id)] = 0
508 delayed_lengths[delayed_lengths != torch_int64_max] += 1
509 delayed_lengths[delayed_lengths > n_vq] = torch_int64_max
510
511 current_input_ids = torch.cat([next_text_token[:, None, None], next_audio_tokens[:, None, :]], dim=2)
512 current_attention_mask = torch.cat([current_attention_mask, (~is_stopping).unsqueeze(-1)], dim=-1)
513 generation_ids = torch.cat([generation_ids, current_input_ids], dim=1)
514
515 if is_stopping.sum() == batch_size:
516 break
517
518 start_indices = find_last_equal_C(input_ids[..., 0], self.config.im_start_token_id) + 3
519 start_lengths = seq_len - start_indices
520
521 output = []
522 for start_idx, start_length, cur_generation_ids in zip(start_indices, start_lengths, generation_ids):
523 output.append((start_length, cur_generation_ids[start_idx:]))
524
525 return output
526