modeling_moss_tts.py
| 1 | # coding=utf-8 |
| 2 | # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved. |
| 3 | # |
| 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | # you may not use this file except in compliance with the License. |
| 6 | # You may obtain a copy of the License at |
| 7 | # |
| 8 | # http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | # |
| 10 | # Unless required by applicable law or agreed to in writing, software |
| 11 | # distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | # See the License for the specific language governing permissions and |
| 14 | # limitations under the License. |
| 15 | """ Modeling classes for MossTTSDelay. """ |
| 16 | |
| 17 | from dataclasses import dataclass |
| 18 | from typing import List, Optional, Tuple, Union |
| 19 | from tqdm import tqdm |
| 20 | |
| 21 | import torch |
| 22 | import torch.nn as nn |
| 23 | from torch.nn import CrossEntropyLoss |
| 24 | |
| 25 | from transformers.modeling_utils import PreTrainedModel |
| 26 | from transformers.modeling_outputs import ModelOutput |
| 27 | from transformers.utils import ( |
| 28 | add_start_docstrings, |
| 29 | add_start_docstrings_to_model_forward, |
| 30 | logging, |
| 31 | replace_return_docstrings, |
| 32 | ) |
| 33 | from transformers.cache_utils import Cache |
| 34 | from transformers.models.qwen3 import Qwen3Model |
| 35 | from transformers import initialization as init |
| 36 | |
| 37 | from .configuration_moss_tts import MossTTSDelayConfig |
| 38 | from .inference_utils import sample_token, find_last_equal_C |
| 39 | |
| 40 | try: |
| 41 | from .processing_moss_tts import UserMessage, AssistantMessage, MossTTSDelayProcessor |
| 42 | except Exception: |
| 43 | UserMessage = None |
| 44 | AssistantMessage = None |
| 45 | MossTTSDelayProcessor = None |
| 46 | |
| 47 | logger = logging.get_logger(__name__) |
| 48 | |
| 49 | _CONFIG_FOR_DOC = "MossTTSDelayConfig" |
| 50 | |
| 51 | |
| 52 | @dataclass |
| 53 | class MossTTSDelayOutputWithPast(ModelOutput): |
| 54 | """ |
| 55 | Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). |
| 56 | |
| 57 | Args: |
| 58 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| 59 | Weighted sum of channel losses. |
| 60 | all_sum_losses (`torch.FloatTensor` of shape `(batch_size, n_vq + 1)`, *optional*): |
| 61 | Sum of losses for each sample and each channel before averaging. |
| 62 | all_token_nums (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| 63 | Number of non-masked tokens per sample. |
| 64 | sample_losses (`torch.FloatTensor` of shape `(batch_size,)`, *optional*): |
| 65 | Loss per sample. |
| 66 | channel_losses (`torch.FloatTensor` of shape `(n_vq + 1,)`, *optional*): |
| 67 | Loss per channel (text head + vq heads). |
| 68 | logits (`List[torch.FloatTensor]`, *optional*): |
| 69 | List of prediction scores from each head. |
| 70 | past_key_values (`Cache`, *optional*): |
| 71 | Pre-computed hidden-states (key and values in the self-attention blocks). |
| 72 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed): |
| 73 | Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + |
| 74 | one for the output of each layer). |
| 75 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed): |
| 76 | Tuple of torch.FloatTensor (one for each layer) of the attention weights. |
| 77 | """ |
| 78 | loss: Optional[torch.FloatTensor] = None |
| 79 | all_sum_losses: Optional[torch.FloatTensor] = None |
| 80 | all_token_nums: Optional[torch.LongTensor] = None |
| 81 | sample_losses: Optional[torch.FloatTensor] = None |
| 82 | channel_losses: Optional[torch.FloatTensor] = None |
| 83 | logits: Optional[List[torch.FloatTensor]] = None |
| 84 | past_key_values: Optional[Cache] = None |
| 85 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| 86 | attentions: Optional[Tuple[torch.FloatTensor]] = None |
| 87 | |
| 88 | |
| 89 | |
| 90 | |
| 91 | class MossTTSDelayPreTrainedModel(PreTrainedModel): |
| 92 | config_class = MossTTSDelayConfig |
| 93 | base_model_prefix = "model" |
| 94 | supports_gradient_checkpointing = True |
| 95 | _no_split_modules = ["Qwen3DecoderLayer"] |
| 96 | _skip_keys_device_placement = "past_key_values" |
| 97 | _supports_flash_attn = True |
| 98 | _supports_flash_attn_2 = True |
| 99 | _supports_sdpa = True |
| 100 | _supports_flex_attn = True |
| 101 | |
| 102 | def _init_weights(self, module): |
| 103 | """ |
| 104 | Transformers 5.0+ safe init: |
| 105 | - MUST use transformers.initialization helpers |
| 106 | - MUST respect param._is_hf_initialized to avoid overwriting ckpt-loaded params |
| 107 | """ |
| 108 | # Let HF handle its standard modules first (LayerNorm, Linear, Embedding, etc.) |
| 109 | super()._init_weights(module) |
| 110 | |
| 111 | # Pick a std consistent with HF conventions |
| 112 | # Prefer model/text config initializer_range if present. |
| 113 | std = None |
| 114 | if hasattr(self.config, "initializer_range"): |
| 115 | std = self.config.initializer_range |
| 116 | elif hasattr(self.config, "language_config") and hasattr(self.config.language_config, "initializer_range"): |
| 117 | std = self.config.language_config.initializer_range |
| 118 | else: |
| 119 | std = 0.02 |
| 120 | |
| 121 | # Initialize extra audio embeddings |
| 122 | if isinstance(module, nn.Embedding): |
| 123 | # Only touch our extra embeddings (avoid double touching LM's embeddings if not desired) |
| 124 | # If you prefer, you can skip this check and rely on super()._init_weights for all embeddings. |
| 125 | if getattr(module, "num_embeddings", None) == self.config.audio_vocab_size + 1: |
| 126 | init.normal_(module.weight, mean=0.0, std=std) |
| 127 | # If you later set padding_idx, you must explicitly zero it (and respect _is_hf_initialized!) |
| 128 | # init.zeros_ will internally check param flags, but slicing needs manual care. |
| 129 | |
| 130 | # Initialize multi-head projections you added |
| 131 | if isinstance(module, nn.Linear): |
| 132 | # For your lm_heads, super()._init_weights already covers typical Linear. |
| 133 | # This block is only needed if you have custom Linear variants later. |
| 134 | pass |
| 135 | |
| 136 | |
| 137 | |
| 138 | MOSSTTS_START_DOCSTRING = r""" |
| 139 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| 140 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| 141 | etc.) |
| 142 | |
| 143 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| 144 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| 145 | and behavior. |
| 146 | |
| 147 | Parameters: |
| 148 | config ([`MossTTSDelayConfig`]): |
| 149 | Model configuration class with all the parameters of the model. Initializing with a config file does not |
| 150 | load the weights associated with the model, only the configuration. Check out the |
| 151 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| 152 | """ |
| 153 | |
| 154 | |
| 155 | @add_start_docstrings( |
| 156 | "The MossTTSDelay Model architecture tailored for Text-to-Speech generation with multi-head VQ prediction.", |
| 157 | MOSSTTS_START_DOCSTRING, |
| 158 | ) |
| 159 | class MossTTSDelayModel(MossTTSDelayPreTrainedModel): |
| 160 | UserMessage = UserMessage |
| 161 | AssistantMessage = AssistantMessage |
| 162 | Processor = MossTTSDelayProcessor |
| 163 | |
| 164 | def __init__(self, config: MossTTSDelayConfig): |
| 165 | super().__init__(config) |
| 166 | self.config = config |
| 167 | |
| 168 | config.language_config.torch_dtype = config.torch_dtype |
| 169 | |
| 170 | self.language_model = Qwen3Model(config.language_config) |
| 171 | |
| 172 | # Audio VQ Embeddings (Extra channels) |
| 173 | # Note: input_ids[..., 0] uses Qwen's embedding. |
| 174 | # input_ids[..., 1:] use these extensions. |
| 175 | self.emb_ext = nn.ModuleList() |
| 176 | for vq_idx in range(self.config.n_vq): |
| 177 | # Add +1 for potential padding/special tokens logic if strictly required by upstream data prep |
| 178 | self.emb_ext.append( |
| 179 | nn.Embedding(self.config.audio_vocab_size + 1, config.language_config.hidden_size, padding_idx=None) |
| 180 | ) |
| 181 | |
| 182 | # Multi-Head Prediction Layers |
| 183 | # Head 0: Main language head |
| 184 | # Head 1..N: Audio VQ heads |
| 185 | self.lm_heads = nn.ModuleList([ |
| 186 | nn.Linear(config.language_config.hidden_size, config.language_config.vocab_size, bias=False) |
| 187 | ]) |
| 188 | for vq_idx in range(self.config.n_vq): |
| 189 | self.lm_heads.append( |
| 190 | nn.Linear(config.language_config.hidden_size, self.config.audio_vocab_size + 1, bias=False) |
| 191 | ) |
| 192 | |
| 193 | # Initialize weights and apply final processing |
| 194 | self.post_init() |
| 195 | |
| 196 | def get_input_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: |
| 197 | """ |
| 198 | Computes the combined embeddings from text and multiple audio VQ channels. |
| 199 | |
| 200 | Args: |
| 201 | input_ids: Shape (Batch, Seq_Len, 1 + n_vq) |
| 202 | """ |
| 203 | # Base Text/Content Embedding |
| 204 | # input_ids[..., 0] is standard text or semantic tokens |
| 205 | inputs_embeds = self.language_model.get_input_embeddings()(input_ids[..., 0]) |
| 206 | |
| 207 | # Add VQ Embeddings |
| 208 | for i, embed_layer in enumerate(self.emb_ext): |
| 209 | # i corresponds to channel i+1 in input_ids |
| 210 | # We assume the data pipeline ensures indices are within range |
| 211 | inputs_embeds = inputs_embeds + embed_layer(input_ids[..., i + 1]) |
| 212 | |
| 213 | return inputs_embeds |
| 214 | |
| 215 | def set_input_embeddings(self, value): |
| 216 | self.language_model.embed_tokens = value |
| 217 | |
| 218 | def get_output_embeddings(self): |
| 219 | # Returning a list of heads might break some HF utilities expecting a single head. |
| 220 | # However, for custom models, this is acceptable. |
| 221 | return self.lm_heads |
| 222 | |
| 223 | @add_start_docstrings_to_model_forward(MOSSTTS_START_DOCSTRING) |
| 224 | @replace_return_docstrings(output_type=MossTTSDelayOutputWithPast, config_class=_CONFIG_FOR_DOC) |
| 225 | def forward( |
| 226 | self, |
| 227 | input_ids: Optional[torch.LongTensor] = None, |
| 228 | attention_mask: Optional[torch.Tensor] = None, |
| 229 | position_ids: Optional[torch.LongTensor] = None, |
| 230 | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 231 | inputs_embeds: Optional[torch.FloatTensor] = None, |
| 232 | labels: Optional[torch.LongTensor] = None, |
| 233 | use_cache: Optional[bool] = None, |
| 234 | output_attentions: Optional[bool] = None, |
| 235 | cache_position: Optional[torch.LongTensor] = None, |
| 236 | hidden_out_layers: Optional[List[int]] = None, |
| 237 | channelwise_loss_weight: Optional[List[float]] = None, |
| 238 | **kwargs, |
| 239 | ) -> Union[Tuple, MossTTSDelayOutputWithPast]: |
| 240 | r""" |
| 241 | Args: |
| 242 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`): |
| 243 | Indices of input sequence tokens in the vocabulary. |
| 244 | Dimension 2 contains: [Text/Semantics, VQ_0, VQ_1, ..., VQ_N]. |
| 245 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length, 1 + n_vq)`, *optional*): |
| 246 | Labels for computing the masked language modeling loss. |
| 247 | channelwise_loss_weight (`List[float]`, *optional*): |
| 248 | Manual weights for summing losses across different heads (Text vs Audio channels). |
| 249 | |
| 250 | Returns: |
| 251 | """ |
| 252 | |
| 253 | if len(input_ids.shape) != 3 or input_ids.shape[-1] != self.config.n_vq + 1: |
| 254 | raise ValueError("`Input_ids`'s shape should be exactly (batch_size, sequence_length, 1 + n_vq).") |
| 255 | |
| 256 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 257 | |
| 258 | # 1. Prepare Embeddings |
| 259 | if inputs_embeds is None: |
| 260 | inputs_embeds = self.get_input_embeddings(input_ids) |
| 261 | |
| 262 | # 2. Backbone Forward |
| 263 | # Qwen3Model outputs standard CausalLMOutputWithPast or similar |
| 264 | outputs = self.language_model( |
| 265 | input_ids=None, # Passed via inputs_embeds |
| 266 | position_ids=position_ids, |
| 267 | attention_mask=attention_mask, |
| 268 | past_key_values=past_key_values, |
| 269 | inputs_embeds=inputs_embeds, |
| 270 | use_cache=use_cache, |
| 271 | output_attentions=output_attentions, |
| 272 | output_hidden_states=True, # Always need hidden states for multi-head projection |
| 273 | return_dict=True, |
| 274 | cache_position=cache_position, |
| 275 | **kwargs, |
| 276 | ) |
| 277 | |
| 278 | # 3. Handle specific layer outputs if requested (Delay Pattern often requires features from specific layers) |
| 279 | last_hidden_state = outputs.last_hidden_state |
| 280 | if hidden_out_layers is None: |
| 281 | # Default to using the last layer for all heads |
| 282 | # In some architectures (like MusicGen), different codebooks come from different transformer layers. |
| 283 | # Here we default to the final layer as per original code behavior [-1] * (n + 1). |
| 284 | hidden_states_for_heads = [last_hidden_state] * (len(self.lm_heads)) |
| 285 | else: |
| 286 | # If hidden_out_layers is provided (e.g. [-1, -2, -3...]), fetch them from all_hidden_states |
| 287 | # Note: outputs.hidden_states includes embedding output at index 0 usually. |
| 288 | all_hs = outputs.hidden_states |
| 289 | hidden_states_for_heads = [all_hs[idx] for idx in hidden_out_layers] |
| 290 | |
| 291 | # 4. Project to Logits (Multi-Head) |
| 292 | layer_logits = [] |
| 293 | for i, (hs, head) in enumerate(zip(hidden_states_for_heads, self.lm_heads)): |
| 294 | logits = head(hs) |
| 295 | # Original code logic: Mask the last token index for audio heads (indices > 0) |
| 296 | # This implies the vocab size is (N+1) but the model shouldn't predict the (N+1)-th token |
| 297 | # (perhaps reserved for padding in the input but invalid for prediction). |
| 298 | if i > 0: |
| 299 | logits[..., -1] = float("-inf") |
| 300 | layer_logits.append(logits) |
| 301 | |
| 302 | # 5. Loss Calculation |
| 303 | loss = None |
| 304 | all_sum_losses = None |
| 305 | all_token_nums = None |
| 306 | sample_losses = None |
| 307 | channel_losses = None |
| 308 | |
| 309 | if labels is not None: |
| 310 | # Ensure labels match input shape rank (B, S, C) |
| 311 | if labels.dim() != 3: |
| 312 | raise ValueError(f"Labels must have rank 3 (B, S, C), got {labels.shape}") |
| 313 | |
| 314 | batch_size = labels.size(0) |
| 315 | n_heads = len(layer_logits) |
| 316 | |
| 317 | # Container for per-sample, per-channel losses |
| 318 | # Shape: [Batch, n_heads] |
| 319 | all_sum_losses_list = [] |
| 320 | |
| 321 | # Count valid tokens (not -100) per sample. |
| 322 | # Note: Assuming mask is consistent across channels or we take sum over dim 1 (seq) |
| 323 | # Usually strict masking means checking one channel or all. |
| 324 | # Original code: torch.sum(labels != -100, dim=1) -> [B, C] |
| 325 | all_token_nums = torch.sum(labels != -100, dim=1) |
| 326 | |
| 327 | for i, logits in enumerate(layer_logits): |
| 328 | # logits: [B, S, V] |
| 329 | # cur_labels: [B, S] |
| 330 | cur_labels = labels[..., i] |
| 331 | |
| 332 | # Flatten for CrossEntropy |
| 333 | # logits: [B*S, V], labels: [B*S] |
| 334 | loss_fct = CrossEntropyLoss(reduction='none') |
| 335 | vocab_size = logits.size(-1) |
| 336 | |
| 337 | reshaped_logits = logits.view(-1, vocab_size) |
| 338 | reshaped_labels = cur_labels.contiguous().view(-1) |
| 339 | |
| 340 | # Calculate loss per token |
| 341 | per_token_loss = loss_fct(reshaped_logits, reshaped_labels) |
| 342 | |
| 343 | # Reshape back to [B, S] and sum over Sequence dimension to get per-sample loss |
| 344 | per_token_loss = per_token_loss.view(batch_size, -1) |
| 345 | per_sample_loss = torch.sum(per_token_loss, dim=-1) # [B] |
| 346 | |
| 347 | all_sum_losses_list.append(per_sample_loss) |
| 348 | |
| 349 | # Stack to [B, n_heads] |
| 350 | all_sum_losses = torch.stack(all_sum_losses_list, dim=1) |
| 351 | |
| 352 | # Weighted Loss Aggregation |
| 353 | if channelwise_loss_weight is not None: |
| 354 | if len(channelwise_loss_weight) != n_heads: |
| 355 | raise ValueError(f"channelwise_loss_weight length {len(channelwise_loss_weight)} != {n_heads}") |
| 356 | |
| 357 | w_tensor = torch.tensor(channelwise_loss_weight, device=all_sum_losses.device, dtype=all_sum_losses.dtype) |
| 358 | |
| 359 | # Sample losses: Weighted sum over channels per sample / Total weight |
| 360 | # Normalize by token count per channel |
| 361 | # Avoid division by zero with epsilon or mask |
| 362 | token_counts_safe = all_token_nums.float().clamp(min=1.0) |
| 363 | |
| 364 | normalized_losses = all_sum_losses / token_counts_safe |
| 365 | sample_losses = (normalized_losses * w_tensor).sum(dim=1) / w_tensor.sum() |
| 366 | |
| 367 | # Channel losses: Sum over batch / Sum tokens over batch |
| 368 | total_loss_per_channel = all_sum_losses.sum(dim=0) |
| 369 | total_tokens_per_channel = all_token_nums.sum(dim=0).float().clamp(min=1.0) |
| 370 | channel_losses = total_loss_per_channel / total_tokens_per_channel |
| 371 | |
| 372 | # Final scalar loss |
| 373 | loss = (channel_losses * w_tensor).sum() / w_tensor.sum() |
| 374 | else: |
| 375 | # Default average if no weights provided |
| 376 | total_tokens = all_token_nums.sum().float().clamp(min=1.0) |
| 377 | loss = all_sum_losses.sum() / total_tokens |
| 378 | channel_losses = all_sum_losses.sum(dim=0) / all_token_nums.sum(dim=0).clamp(min=1.0) |
| 379 | |
| 380 | return MossTTSDelayOutputWithPast( |
| 381 | loss=loss, |
| 382 | all_sum_losses=all_sum_losses, |
| 383 | all_token_nums=all_token_nums, |
| 384 | sample_losses=sample_losses, |
| 385 | channel_losses=channel_losses, |
| 386 | logits=layer_logits, |
| 387 | past_key_values=outputs.past_key_values, |
| 388 | hidden_states=outputs.hidden_states, |
| 389 | attentions=outputs.attentions, |
| 390 | ) |
| 391 | |
| 392 | @torch.inference_mode() |
| 393 | def generate( |
| 394 | self, |
| 395 | input_ids: torch.LongTensor, |
| 396 | attention_mask: Optional[torch.Tensor] = None, |
| 397 | max_new_tokens: int = 1000, |
| 398 | text_temperature: float = 1.5, |
| 399 | text_top_p: float = 1.0, |
| 400 | text_top_k: int = 50, |
| 401 | audio_temperature: float = 1.7, |
| 402 | audio_top_p: float = 0.8, |
| 403 | audio_top_k: int = 25, |
| 404 | audio_repetition_penalty: float = 1.0, |
| 405 | ): |
| 406 | if text_temperature > 0: |
| 407 | text_do_sample = True |
| 408 | else: |
| 409 | text_temperature = 1 |
| 410 | text_do_sample = False |
| 411 | if audio_temperature > 0: |
| 412 | audio_do_sample = True |
| 413 | else: |
| 414 | audio_temperature = 1 |
| 415 | audio_do_sample = False |
| 416 | |
| 417 | past_key_values = None |
| 418 | device = input_ids.device |
| 419 | current_input_ids = input_ids |
| 420 | current_attention_mask = attention_mask |
| 421 | batch_size, seq_len, n_vq = input_ids.shape |
| 422 | n_vq -= 1 |
| 423 | |
| 424 | generation_ids = input_ids[:] |
| 425 | is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device) |
| 426 | |
| 427 | audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device) |
| 428 | torch_int64_max = torch.iinfo(torch.int64).max |
| 429 | delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device) |
| 430 | |
| 431 | is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id) |
| 432 | audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id) |
| 433 | audio_start_mask = is_continuation & (audio_start_indices != -1) |
| 434 | audio_lengths[audio_start_mask] = seq_len - audio_start_indices[audio_start_mask] |
| 435 | |
| 436 | is_audio = audio_start_mask.clone() |
| 437 | |
| 438 | pre_exclude_mask0 = torch.tensor([self.config.pad_token_id, self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id, self.config.audio_end_token_id], device=device) |
| 439 | pre_exclude_mask1 = torch.ones(self.config.language_config.vocab_size, device=device).bool() |
| 440 | pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False |
| 441 | |
| 442 | for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."): |
| 443 | outputs = self( |
| 444 | input_ids=current_input_ids, |
| 445 | attention_mask=current_attention_mask, |
| 446 | past_key_values=past_key_values, |
| 447 | use_cache=True, |
| 448 | ) |
| 449 | past_key_values = outputs.past_key_values |
| 450 | |
| 451 | next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] # List, len=n_vq+1, [batch_size, 1, vocab_size]; |
| 452 | next_token_logits[0] = next_token_logits[0].clone() |
| 453 | next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device) |
| 454 | next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id |
| 455 | is_audio_eos = ~is_stopping & (delayed_lengths == n_vq) |
| 456 | next_text_token[is_audio_eos] = self.config.audio_end_token_id |
| 457 | is_audio[is_audio_eos] = False |
| 458 | sampling_text_mask = ~is_stopping & (delayed_lengths > n_vq) |
| 459 | next_token_logits[0][~is_audio] = next_token_logits[0][~is_audio].index_fill(-1, pre_exclude_mask0, float('-inf')) |
| 460 | next_token_logits[0][is_audio] = next_token_logits[0][is_audio].masked_fill(pre_exclude_mask1, float('-inf')) |
| 461 | if time_step == 0: |
| 462 | next_token_logits[0][..., 151662] = float('-inf') |
| 463 | if time_step <= n_vq: |
| 464 | next_token_logits[0][..., self.config.im_end_token_id] = float('-inf') |
| 465 | |
| 466 | next_text_token[sampling_text_mask] = sample_token( |
| 467 | logits=next_token_logits[0][sampling_text_mask], |
| 468 | top_p=text_top_p, |
| 469 | top_k=text_top_k, |
| 470 | do_sample=text_do_sample |
| 471 | ) |
| 472 | is_audio[next_text_token == self.config.audio_start_token_id] = True |
| 473 | is_stopping[next_text_token == self.config.im_end_token_id] = True |
| 474 | |
| 475 | next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device) |
| 476 | |
| 477 | pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) |
| 478 | post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1 |
| 479 | post_audio_mask[delayed_lengths == torch_int64_max] = True |
| 480 | sampling_audio_mask = pre_audio_mask & post_audio_mask |
| 481 | next_audio_tokens[~sampling_audio_mask] = self.config.audio_pad_code |
| 482 | |
| 483 | if sampling_audio_mask.sum() > 0: |
| 484 | audio_ch0_logits = next_token_logits[1][sampling_audio_mask[:, 0]] |
| 485 | audio_logits = torch.stack(next_token_logits[2:], dim=1)[sampling_audio_mask[:, 1:]] |
| 486 | audio_ch0_logits[..., self.config.audio_pad_code] = float('-inf') |
| 487 | audio_logits[..., self.config.audio_pad_code] = float('-inf') |
| 488 | next_audio_tokens[:, 0][sampling_audio_mask[:, 0]] = sample_token( |
| 489 | logits=audio_ch0_logits, |
| 490 | prev_tokens=generation_ids[:, :, 1], |
| 491 | repetition_penalty=audio_repetition_penalty, |
| 492 | top_p=audio_top_p, |
| 493 | top_k=audio_top_k, |
| 494 | do_sample=audio_do_sample |
| 495 | ) |
| 496 | next_audio_tokens[:, 1:][sampling_audio_mask[:, 1:]] = sample_token( |
| 497 | logits=audio_logits, |
| 498 | prev_tokens=generation_ids[:, :, 2:], |
| 499 | repetition_penalty=audio_repetition_penalty, |
| 500 | top_p=audio_top_p, |
| 501 | top_k=audio_top_k, |
| 502 | do_sample=audio_do_sample |
| 503 | ) |
| 504 | |
| 505 | audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1 |
| 506 | audio_lengths[next_text_token == self.config.audio_end_token_id] = 0 |
| 507 | delayed_lengths[(delayed_lengths == torch_int64_max) & (next_text_token == self.config.audio_assistant_delay_slot_token_id)] = 0 |
| 508 | delayed_lengths[delayed_lengths != torch_int64_max] += 1 |
| 509 | delayed_lengths[delayed_lengths > n_vq] = torch_int64_max |
| 510 | |
| 511 | current_input_ids = torch.cat([next_text_token[:, None, None], next_audio_tokens[:, None, :]], dim=2) |
| 512 | current_attention_mask = torch.cat([current_attention_mask, (~is_stopping).unsqueeze(-1)], dim=-1) |
| 513 | generation_ids = torch.cat([generation_ids, current_input_ids], dim=1) |
| 514 | |
| 515 | if is_stopping.sum() == batch_size: |
| 516 | break |
| 517 | |
| 518 | start_indices = find_last_equal_C(input_ids[..., 0], self.config.im_start_token_id) + 3 |
| 519 | start_lengths = seq_len - start_indices |
| 520 | |
| 521 | output = [] |
| 522 | for start_idx, start_length, cur_generation_ids in zip(start_indices, start_lengths, generation_ids): |
| 523 | output.append((start_length, cur_generation_ids[start_idx:])) |
| 524 | |
| 525 | return output |
| 526 | |