processing_moss_tts.py
| 1 | # coding=utf-8 |
| 2 | # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved. |
| 3 | # |
| 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | # you may not use this file except in compliance with the License. |
| 6 | # You may obtain a copy of the License at |
| 7 | # |
| 8 | # http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | # |
| 10 | # Unless required by applicable law or agreed to in writing, software |
| 11 | # distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | # See the License for the specific language governing permissions and |
| 14 | # limitations under the License. |
| 15 | |
| 16 | import os |
| 17 | from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal, Final, cast |
| 18 | from dataclasses import dataclass |
| 19 | from pathlib import Path |
| 20 | import re |
| 21 | import torchaudio |
| 22 | |
| 23 | from transformers import processing_utils |
| 24 | |
| 25 | processing_utils.MODALITY_TO_BASE_CLASS_MAPPING["audio_tokenizer"] = "PreTrainedModel" |
| 26 | |
| 27 | import torch |
| 28 | from transformers import ( |
| 29 | PreTrainedTokenizerBase, |
| 30 | BatchFeature, |
| 31 | ProcessorMixin, |
| 32 | logging, |
| 33 | AutoConfig, |
| 34 | AutoModel, |
| 35 | AutoTokenizer, |
| 36 | ) |
| 37 | |
| 38 | from .configuration_moss_tts import MossTTSDelayConfig |
| 39 | |
| 40 | |
| 41 | logger = logging.get_logger(__name__) |
| 42 | |
| 43 | |
| 44 | AUDIO_PLACEHOLDER = "<|audio|>" |
| 45 | |
| 46 | |
| 47 | @dataclass |
| 48 | class Message: |
| 49 | def to_dict(self) -> Dict[str, Any]: |
| 50 | raise NotImplementedError |
| 51 | |
| 52 | |
| 53 | @dataclass |
| 54 | class UserMessage(Message): |
| 55 | text: Optional[str] = None |
| 56 | reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None |
| 57 | instruction: Optional[str] = None |
| 58 | tokens: Optional[int] = None |
| 59 | quality: Optional[str] = None |
| 60 | sound_event: Optional[str] = None |
| 61 | ambient_sound: Optional[str] = None |
| 62 | language: Optional[str] = None |
| 63 | |
| 64 | def __post_init__(self): |
| 65 | template = """<user_inst> |
| 66 | - Reference(s): |
| 67 | {reference} |
| 68 | - Instruction: |
| 69 | {instruction} |
| 70 | - Tokens: |
| 71 | {tokens} |
| 72 | - Quality: |
| 73 | {quality} |
| 74 | - Sound Event: |
| 75 | {sound_event} |
| 76 | - Ambient Sound: |
| 77 | {ambient_sound} |
| 78 | - Language: |
| 79 | {language} |
| 80 | - Text: |
| 81 | {text} |
| 82 | </user_inst>""" |
| 83 | |
| 84 | audio_codes_list = [] |
| 85 | if self.reference is None: |
| 86 | reference = "None" |
| 87 | elif isinstance(self.reference, List): |
| 88 | reference = [] |
| 89 | for speaker_idx, speaker_reference in enumerate(self.reference): |
| 90 | if speaker_reference is not None: |
| 91 | reference.append(f"[S{speaker_idx+1}]:\n{AUDIO_PLACEHOLDER}") |
| 92 | reference = "\n".join(reference) |
| 93 | audio_codes_list = [ |
| 94 | speaker_reference |
| 95 | for speaker_reference in self.reference |
| 96 | if speaker_reference is not None |
| 97 | ] |
| 98 | else: |
| 99 | raise TypeError("`reference` should be exactly a list when it is not None.") |
| 100 | |
| 101 | content = ( |
| 102 | template.replace("{reference}", str(reference)) |
| 103 | .replace("{instruction}", str(self.instruction)) |
| 104 | .replace("{tokens}", str(self.tokens)) |
| 105 | .replace("{quality}", str(self.quality)) |
| 106 | .replace("{sound_event}", str(self.sound_event)) |
| 107 | .replace("{ambient_sound}", str(self.ambient_sound)) |
| 108 | .replace("{language}", str(self.language)) |
| 109 | .replace("{text}", str(self.text)) |
| 110 | ) |
| 111 | |
| 112 | self._content = content |
| 113 | self._audio_codes_list = audio_codes_list |
| 114 | |
| 115 | def to_dict(self): |
| 116 | return { |
| 117 | "role": "user", |
| 118 | "content": self._content, |
| 119 | "audio_codes_list": self._audio_codes_list, |
| 120 | } |
| 121 | |
| 122 | |
| 123 | @dataclass |
| 124 | class AssistantMessage(Message): |
| 125 | audio_codes_list: List[Union[str, torch.Tensor]] |
| 126 | content: str = AUDIO_PLACEHOLDER |
| 127 | |
| 128 | def to_dict(self): |
| 129 | return { |
| 130 | "role": "assistant", |
| 131 | "content": self.content, |
| 132 | "audio_codes_list": self.audio_codes_list, |
| 133 | } |
| 134 | |
| 135 | |
| 136 | USER_MESSAGE_FIELDS = ( |
| 137 | "text", |
| 138 | "reference", |
| 139 | "instruction", |
| 140 | "tokens", |
| 141 | "quality", |
| 142 | "sound_event", |
| 143 | "ambient_sound", |
| 144 | "language", |
| 145 | ) |
| 146 | |
| 147 | |
| 148 | class MossTTSDelayProcessor(ProcessorMixin): |
| 149 | tokenizer_class = "AutoTokenizer" |
| 150 | audio_tokenizer_class = "AutoModel" |
| 151 | |
| 152 | tokenizer: PreTrainedTokenizerBase |
| 153 | audio_tokenizer: Any |
| 154 | |
| 155 | def __init__( |
| 156 | self, |
| 157 | tokenizer: PreTrainedTokenizerBase, |
| 158 | audio_tokenizer: Any = None, |
| 159 | model_config: Optional[MossTTSDelayConfig] = None, |
| 160 | **kwargs, |
| 161 | ): |
| 162 | super().__init__(tokenizer=tokenizer, audio_tokenizer=audio_tokenizer, **kwargs) |
| 163 | |
| 164 | # Explicit assignments for type-checkers; ProcessorMixin sets these too. |
| 165 | self.tokenizer = tokenizer |
| 166 | self.audio_tokenizer = audio_tokenizer |
| 167 | if model_config is None: |
| 168 | model_config = MossTTSDelayConfig() |
| 169 | self.model_config = model_config |
| 170 | |
| 171 | self.imstart_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>") |
| 172 | self.imend_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") |
| 173 | self.newline_token_id = 198 |
| 174 | |
| 175 | def _id_to_token(token_id: int) -> str: |
| 176 | tok = tokenizer.convert_ids_to_tokens(int(token_id)) |
| 177 | if isinstance(tok, list): |
| 178 | return tok[0] if len(tok) > 0 else "" |
| 179 | return cast(str, tok) |
| 180 | |
| 181 | self.audio_user_slot_token = _id_to_token( |
| 182 | self.model_config.audio_user_slot_token_id |
| 183 | ) |
| 184 | self.audio_assistant_gen_slot_token = _id_to_token( |
| 185 | self.model_config.audio_assistant_gen_slot_token_id |
| 186 | ) |
| 187 | self.audio_assistant_delay_slot_token = _id_to_token( |
| 188 | self.model_config.audio_assistant_delay_slot_token_id |
| 189 | ) |
| 190 | self.audio_start_token = _id_to_token(self.model_config.audio_start_token_id) |
| 191 | self.audio_end_token = _id_to_token(self.model_config.audio_end_token_id) |
| 192 | |
| 193 | @classmethod |
| 194 | def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| 195 | trust_remote_code = kwargs.pop("trust_remote_code", True) |
| 196 | kwargs.pop("_from_auto", None) |
| 197 | |
| 198 | audio_tokenizer_name_or_path = kwargs.pop( |
| 199 | "codec_path", "OpenMOSS-Team/MOSS-Audio-Tokenizer" |
| 200 | ) |
| 201 | |
| 202 | pretrained_model_name_or_path = Path(pretrained_model_name_or_path) |
| 203 | model_config = cast( |
| 204 | MossTTSDelayConfig, |
| 205 | AutoConfig.from_pretrained( |
| 206 | pretrained_model_name_or_path, |
| 207 | *args, |
| 208 | trust_remote_code=trust_remote_code, |
| 209 | **kwargs, |
| 210 | ), |
| 211 | ) |
| 212 | tokenizer = AutoTokenizer.from_pretrained( |
| 213 | pretrained_model_name_or_path, |
| 214 | *args, |
| 215 | trust_remote_code=trust_remote_code, |
| 216 | **kwargs, |
| 217 | ) |
| 218 | audio_tokenizer = AutoModel.from_pretrained( |
| 219 | audio_tokenizer_name_or_path, |
| 220 | trust_remote_code=trust_remote_code, |
| 221 | **kwargs, |
| 222 | ) |
| 223 | |
| 224 | return cls( |
| 225 | tokenizer=tokenizer, |
| 226 | audio_tokenizer=audio_tokenizer, |
| 227 | model_config=model_config, |
| 228 | **kwargs, |
| 229 | ) |
| 230 | |
| 231 | def __call__(self, *args, **kwargs) -> BatchFeature: |
| 232 | conversations = args[0] if len(args) > 0 else kwargs.pop("conversations") |
| 233 | mode: str = kwargs.pop("mode", "generation") |
| 234 | apply_chat_template: bool = kwargs.pop("apply_chat_template", True) |
| 235 | n_vq: Optional[int] = kwargs.pop("n_vq", None) |
| 236 | |
| 237 | # Common ProcessorMixin kwargs that we ignore because we always return torch tensors. |
| 238 | kwargs.pop("return_tensors", None) |
| 239 | kwargs.pop("padding", None) |
| 240 | kwargs.pop("truncation", None) |
| 241 | |
| 242 | """ |
| 243 | mode only works when a Message is converted to a dict. |
| 244 | """ |
| 245 | |
| 246 | if mode not in {"generation", "continuation"}: |
| 247 | raise RuntimeError |
| 248 | |
| 249 | if isinstance(conversations, (Message, Dict)): |
| 250 | conversations = [conversations] |
| 251 | |
| 252 | truncation = False |
| 253 | if mode == "continuation": |
| 254 | truncation = True |
| 255 | |
| 256 | input_ids_list = [] |
| 257 | for conversation in conversations: |
| 258 | if isinstance(conversation, (Message, Dict)): |
| 259 | conversation = [conversation] |
| 260 | |
| 261 | # Normalize early so downstream logic always deals with dict messages. |
| 262 | conversation = [self._normalize_message(m) for m in conversation] |
| 263 | |
| 264 | if (mode == "generation") ^ (len(conversation) % 2 != 0): |
| 265 | raise ValueError |
| 266 | |
| 267 | if (mode == "generation") ^ (conversation[-1]["role"] == "user"): |
| 268 | raise ValueError |
| 269 | |
| 270 | unified_codes = [] |
| 271 | for message_idx, message in enumerate(conversation): |
| 272 | if apply_chat_template: |
| 273 | add_generation_prompt = ( |
| 274 | mode == "generation" and message_idx == len(conversation) - 1 |
| 275 | ) |
| 276 | try: |
| 277 | content = self.tokenizer.apply_chat_template( |
| 278 | [{"role": message["role"], "content": message["content"]}], |
| 279 | add_generation_prompt=add_generation_prompt, |
| 280 | tokenize=False, |
| 281 | ) |
| 282 | except TypeError: |
| 283 | try: |
| 284 | content = self.tokenizer.apply_chat_template( |
| 285 | [ |
| 286 | { |
| 287 | "role": message["role"], |
| 288 | "content": message["content"], |
| 289 | } |
| 290 | ], |
| 291 | add_generation_prompt=add_generation_prompt, |
| 292 | ) |
| 293 | except Exception: |
| 294 | logger.warning( |
| 295 | "apply_chat_template failed; fallback to raw content." |
| 296 | ) |
| 297 | content = message["content"] |
| 298 | else: |
| 299 | content = message["content"] |
| 300 | |
| 301 | if not isinstance(content, str): |
| 302 | content = str(content) |
| 303 | |
| 304 | # Batch-encode all path-based references in one call when possible. |
| 305 | # This ensures we actually exercise audio_tokenizer.batch_encode for multi-reference prompts, |
| 306 | # instead of repeatedly calling it with batch=1. |
| 307 | raw_audio_items = message.get("audio_codes_list", []) |
| 308 | |
| 309 | audio_codes_list: List[torch.Tensor] = [] |
| 310 | if len(raw_audio_items) > 0: |
| 311 | encoded_items: List[Optional[torch.Tensor]] = [None] * len( |
| 312 | raw_audio_items |
| 313 | ) |
| 314 | paths: List[str] = [] |
| 315 | path_positions: List[int] = [] |
| 316 | |
| 317 | for idx, item in enumerate(raw_audio_items): |
| 318 | if isinstance(item, torch.Tensor): |
| 319 | if n_vq is not None and item.shape[1] != n_vq: |
| 320 | raise RuntimeError( |
| 321 | "audio_codes's n_vq is not equal to the parameter `n_vq`. Your can set the parameter `n_vq` as None if you have already tokenzied the wavs." |
| 322 | ) |
| 323 | encoded_items[idx] = item |
| 324 | continue |
| 325 | |
| 326 | if isinstance(item, (str, os.PathLike)): |
| 327 | paths.append(str(item)) |
| 328 | path_positions.append(idx) |
| 329 | continue |
| 330 | |
| 331 | raise TypeError( |
| 332 | "Each audio item must be a torch.Tensor of codes or a path-like string." |
| 333 | ) |
| 334 | |
| 335 | if len(paths) > 0: |
| 336 | encoded_from_paths = self.encode_audios_from_path(paths, n_vq) |
| 337 | if len(encoded_from_paths) != len(paths): |
| 338 | raise RuntimeError( |
| 339 | "encode_audios_from_path returned an unexpected number of items." |
| 340 | ) |
| 341 | for pos, codes in zip(path_positions, encoded_from_paths): |
| 342 | encoded_items[pos] = codes |
| 343 | |
| 344 | audio_codes_list = [cast(torch.Tensor, t) for t in encoded_items] |
| 345 | unified_codes.append( |
| 346 | self._get_unified_codes( |
| 347 | message["role"], content, audio_codes_list, truncation |
| 348 | ) |
| 349 | ) |
| 350 | |
| 351 | unified_codes = torch.cat(unified_codes) |
| 352 | input_ids_list.append(unified_codes) |
| 353 | |
| 354 | return BatchFeature(data=self._pad(input_ids_list)) |
| 355 | |
| 356 | @staticmethod |
| 357 | def build_user_message( |
| 358 | text: Optional[str] = None, |
| 359 | reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None, |
| 360 | instruction: Optional[str] = None, |
| 361 | tokens: Optional[int] = None, |
| 362 | quality: Optional[str] = None, |
| 363 | sound_event: Optional[str] = None, |
| 364 | ambient_sound: Optional[str] = None, |
| 365 | language: Optional[str] = None, |
| 366 | ) -> Dict: |
| 367 | if reference is not None and not isinstance(reference, list): |
| 368 | reference = [reference] |
| 369 | return UserMessage( |
| 370 | text=text, |
| 371 | reference=reference, |
| 372 | instruction=instruction, |
| 373 | tokens=tokens, |
| 374 | quality=quality, |
| 375 | sound_event=sound_event, |
| 376 | ambient_sound=ambient_sound, |
| 377 | language=language, |
| 378 | ).to_dict() |
| 379 | |
| 380 | @staticmethod |
| 381 | def build_assistant_message( |
| 382 | audio_codes_list: List[Union[str, torch.Tensor]], |
| 383 | content: str = AUDIO_PLACEHOLDER, |
| 384 | ) -> Dict: |
| 385 | return AssistantMessage( |
| 386 | audio_codes_list=audio_codes_list, |
| 387 | content=content, |
| 388 | ).to_dict() |
| 389 | |
| 390 | def _normalize_message(self, message: Union[Message, Dict]) -> Dict: |
| 391 | if isinstance(message, Message): |
| 392 | return message.to_dict() |
| 393 | if not isinstance(message, dict): |
| 394 | raise TypeError("Each message must be a Message or dict.") |
| 395 | if "role" not in message: |
| 396 | raise ValueError("Message dict must include a 'role' field.") |
| 397 | if "content" in message and "audio_codes_list" in message: |
| 398 | return message |
| 399 | role = message["role"] |
| 400 | if role == "user": |
| 401 | kwargs = {key: message.get(key) for key in USER_MESSAGE_FIELDS} |
| 402 | return self.build_user_message(**kwargs) |
| 403 | if role == "assistant": |
| 404 | return self.build_assistant_message( |
| 405 | audio_codes_list=message.get("audio_codes_list", []), |
| 406 | content=message.get("content", AUDIO_PLACEHOLDER), |
| 407 | ) |
| 408 | raise ValueError(f"Unsupported role: {role}") |
| 409 | |
| 410 | def _pad(self, input_ids_list: List[torch.Tensor]): |
| 411 | device = input_ids_list[0].device |
| 412 | lengths = torch.tensor([w.shape[0] for w in input_ids_list], device=device) |
| 413 | pad_input_ids = torch.nn.utils.rnn.pad_sequence( |
| 414 | input_ids_list, |
| 415 | batch_first=True, |
| 416 | padding_value=self.model_config.audio_pad_code, |
| 417 | padding_side="left", |
| 418 | ) |
| 419 | other_channel_mask = (pad_input_ids.shape[1] - lengths).unsqueeze( |
| 420 | 1 |
| 421 | ) > torch.arange(pad_input_ids.shape[1], device=device).unsqueeze(0) |
| 422 | pad_input_ids[..., 0][other_channel_mask] = self.model_config.pad_token_id |
| 423 | attention_mask = torch.zeros( |
| 424 | pad_input_ids.shape[0], pad_input_ids.shape[1], device=device |
| 425 | ) |
| 426 | attention_mask[~other_channel_mask] = 1 |
| 427 | attention_mask = attention_mask.bool() |
| 428 | return { |
| 429 | "input_ids": pad_input_ids, # [batch_size, seqlen, n_vq] |
| 430 | "attention_mask": attention_mask, |
| 431 | } |
| 432 | |
| 433 | @staticmethod |
| 434 | def _replace_audio_placeholders( |
| 435 | content: str, |
| 436 | lengths: List[int], |
| 437 | n_vq: int, |
| 438 | gen_slot_token: str, |
| 439 | delay_slot_token: str, |
| 440 | audio_start_token: str, |
| 441 | audio_end_token: str, |
| 442 | ) -> str: |
| 443 | if n_vq < 1: |
| 444 | raise ValueError(f"n_vq must be >= 1, got {n_vq}") |
| 445 | |
| 446 | num_placeholders = content.count(AUDIO_PLACEHOLDER) |
| 447 | if num_placeholders != len(lengths): |
| 448 | raise ValueError( |
| 449 | f"Number of {AUDIO_PLACEHOLDER} ({num_placeholders}) " |
| 450 | f"does not match lengths ({len(lengths)})" |
| 451 | ) |
| 452 | |
| 453 | def build_audio_block(length: int) -> str: |
| 454 | if length < 0: |
| 455 | raise ValueError(f"length must be >= 0, got {length}") |
| 456 | |
| 457 | if length == 0: |
| 458 | return f"{audio_start_token}{audio_end_token}" |
| 459 | |
| 460 | step_tokens = gen_slot_token * length + (delay_slot_token * (n_vq - 1)) |
| 461 | return f"{audio_start_token}{step_tokens}{audio_end_token}" |
| 462 | |
| 463 | lengths_iter = iter(lengths) |
| 464 | |
| 465 | def replacer(match: re.Match) -> str: |
| 466 | length = next(lengths_iter) |
| 467 | return build_audio_block(length) |
| 468 | |
| 469 | result = re.sub(re.escape(AUDIO_PLACEHOLDER), replacer, content) |
| 470 | |
| 471 | return result |
| 472 | |
| 473 | @staticmethod |
| 474 | def _merge_consecutive_audio_placeholders( |
| 475 | content: str, |
| 476 | audio_codes_list: List[torch.Tensor], |
| 477 | ) -> Tuple[str, List[torch.Tensor]]: |
| 478 | matches = list(re.finditer(re.escape(AUDIO_PLACEHOLDER), content)) |
| 479 | if len(matches) <= 1: |
| 480 | return content, audio_codes_list |
| 481 | |
| 482 | if len(matches) != len(audio_codes_list): |
| 483 | raise ValueError( |
| 484 | "Audio placeholders do not match the provided audio codes list." |
| 485 | ) |
| 486 | |
| 487 | new_audio_codes_list = [] |
| 488 | new_parts = [] |
| 489 | last_pos = 0 |
| 490 | i = 0 |
| 491 | while i < len(matches): |
| 492 | j = i |
| 493 | while ( |
| 494 | j + 1 < len(matches) |
| 495 | and content[matches[j].end() : matches[j + 1].start()].strip() == "" |
| 496 | ): |
| 497 | j += 1 |
| 498 | |
| 499 | new_parts.append(content[last_pos : matches[i].start()]) |
| 500 | new_parts.append(AUDIO_PLACEHOLDER) |
| 501 | last_pos = matches[j].end() |
| 502 | |
| 503 | if j == i: |
| 504 | new_audio_codes_list.append(audio_codes_list[i]) |
| 505 | else: |
| 506 | new_audio_codes_list.append( |
| 507 | torch.cat(audio_codes_list[i : j + 1], dim=0) |
| 508 | ) |
| 509 | |
| 510 | i = j + 1 |
| 511 | |
| 512 | new_parts.append(content[last_pos:]) |
| 513 | return "".join(new_parts), new_audio_codes_list |
| 514 | |
| 515 | @staticmethod |
| 516 | def apply_delay_pattern(codes: torch.Tensor, pad_code: int) -> torch.Tensor: |
| 517 | delayed_tokens = torch.full( |
| 518 | (codes.shape[0] + codes.shape[1] - 1, codes.shape[1]), |
| 519 | pad_code, |
| 520 | device=codes.device, |
| 521 | dtype=codes.dtype, |
| 522 | ) |
| 523 | for i in range(codes.shape[1]): |
| 524 | delayed_tokens[i : i + codes.shape[0], i] = codes[:, i] |
| 525 | return delayed_tokens |
| 526 | |
| 527 | @staticmethod |
| 528 | def apply_de_delay_pattern(delay_codes: torch.Tensor) -> torch.Tensor: |
| 529 | tokens = torch.full( |
| 530 | (delay_codes.shape[0] - delay_codes.shape[1] + 1, delay_codes.shape[1]), |
| 531 | 0, |
| 532 | device=delay_codes.device, |
| 533 | dtype=delay_codes.dtype, |
| 534 | ) |
| 535 | for i in range(delay_codes.shape[1]): |
| 536 | tokens[:, i] = delay_codes[i : i + tokens.shape[0], i] |
| 537 | return tokens |
| 538 | |
| 539 | def _get_unified_codes( |
| 540 | self, |
| 541 | role: str, |
| 542 | content: str, |
| 543 | audio_codes_list: List[torch.Tensor], |
| 544 | truncation: bool, |
| 545 | ) -> torch.Tensor: |
| 546 | """ |
| 547 | 此时的 content 已经是带上了对话格式 |
| 548 | """ |
| 549 | if role == "user": |
| 550 | audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token |
| 551 | truncation = False |
| 552 | else: |
| 553 | audio_gen_slot_token = self.audio_assistant_gen_slot_token |
| 554 | audio_delay_slot_token = self.audio_assistant_delay_slot_token |
| 555 | |
| 556 | if len(audio_codes_list): |
| 557 | n_vq = audio_codes_list[0].shape[1] |
| 558 | else: |
| 559 | n_vq = self.model_config.n_vq |
| 560 | |
| 561 | if len(audio_codes_list) > 1 and AUDIO_PLACEHOLDER in content: |
| 562 | content, audio_codes_list = self._merge_consecutive_audio_placeholders( |
| 563 | content, audio_codes_list |
| 564 | ) |
| 565 | content = self._replace_audio_placeholders( |
| 566 | content=content, |
| 567 | lengths=[len(audio_codes) for audio_codes in audio_codes_list], |
| 568 | n_vq=n_vq, |
| 569 | gen_slot_token=audio_gen_slot_token, |
| 570 | delay_slot_token=audio_delay_slot_token, |
| 571 | audio_start_token=self.audio_start_token, |
| 572 | audio_end_token=self.audio_end_token, |
| 573 | ) |
| 574 | text_codes = torch.tensor( |
| 575 | self.tokenizer.encode(content), |
| 576 | device=audio_codes_list[0].device if audio_codes_list else None, |
| 577 | ) |
| 578 | |
| 579 | audio_start_indices = torch.where( |
| 580 | text_codes == self.model_config.audio_start_token_id |
| 581 | )[0] |
| 582 | audio_end_indices = torch.where( |
| 583 | text_codes == self.model_config.audio_end_token_id |
| 584 | )[0] |
| 585 | if len(audio_start_indices) != len(audio_codes_list) or len( |
| 586 | audio_end_indices |
| 587 | ) != len(audio_codes_list): |
| 588 | raise ValueError( |
| 589 | "Audio placeholders do not match the provided audio codes list." |
| 590 | ) |
| 591 | |
| 592 | delay_audio_codes_list = [] |
| 593 | if len(audio_codes_list) == 0: |
| 594 | delay_audio_codes_list = torch.full( |
| 595 | (len(text_codes), n_vq), |
| 596 | self.model_config.audio_pad_code, |
| 597 | device=text_codes.device, |
| 598 | dtype=text_codes.dtype, |
| 599 | ) |
| 600 | else: |
| 601 | prefix_idx = 0 |
| 602 | for audio_start_idx_t, audio_end_idx_t, audio_codes in zip( |
| 603 | audio_start_indices, audio_end_indices, audio_codes_list |
| 604 | ): |
| 605 | audio_start_idx = int(audio_start_idx_t.item()) |
| 606 | audio_end_idx = int(audio_end_idx_t.item()) |
| 607 | delay_audio_codes = self.apply_delay_pattern( |
| 608 | audio_codes, self.model_config.audio_pad_code |
| 609 | ) |
| 610 | pad_codes = torch.full( |
| 611 | (audio_start_idx - prefix_idx + 1, n_vq), |
| 612 | self.model_config.audio_pad_code, |
| 613 | device=audio_codes.device, |
| 614 | dtype=audio_codes.dtype, |
| 615 | ) |
| 616 | delay_audio_codes_list.extend([pad_codes, delay_audio_codes]) |
| 617 | prefix_idx = audio_end_idx |
| 618 | |
| 619 | if truncation: |
| 620 | delay_audio_codes_list[-1] = delay_audio_codes_list[-1][ |
| 621 | : -(n_vq - 1), : |
| 622 | ] |
| 623 | else: |
| 624 | last_audio_end_idx = int(audio_end_indices[-1].item()) |
| 625 | pad_codes = torch.full( |
| 626 | (len(text_codes) - last_audio_end_idx, n_vq), |
| 627 | self.model_config.audio_pad_code, |
| 628 | device=audio_codes_list[0].device, |
| 629 | dtype=audio_codes_list[0].dtype, |
| 630 | ) |
| 631 | delay_audio_codes_list.append(pad_codes) |
| 632 | |
| 633 | delay_audio_codes_list = torch.cat(delay_audio_codes_list) |
| 634 | |
| 635 | if text_codes.shape[0] != delay_audio_codes_list.shape[0]: |
| 636 | text_codes = text_codes[: delay_audio_codes_list.shape[0]] |
| 637 | |
| 638 | unified_codes = torch.cat( |
| 639 | [text_codes.unsqueeze(1), delay_audio_codes_list], dim=1 |
| 640 | ) |
| 641 | return unified_codes |
| 642 | |
| 643 | def _parse_text_codes(self, start_length, text_codes): |
| 644 | text = cast(str, self.tokenizer.decode(text_codes)) |
| 645 | prefix = cast(str, self.tokenizer.decode(text_codes[:start_length])) |
| 646 | text = text[len(prefix) :] |
| 647 | |
| 648 | AUDIO_PATTERN = re.compile( |
| 649 | rf"(?:{self.audio_start_token})?" |
| 650 | rf"(?:{self.audio_assistant_gen_slot_token})*" |
| 651 | rf"(?:{self.audio_assistant_delay_slot_token})*" |
| 652 | rf"{self.audio_end_token}" |
| 653 | ) |
| 654 | |
| 655 | def normalize_audio_segments(text: str) -> str: |
| 656 | def repl(match: re.Match) -> str: |
| 657 | seg = match.group(0) |
| 658 | # Replace with <|audio|> if gen_slot is present in the segment; |
| 659 | if self.audio_assistant_gen_slot_token in seg: |
| 660 | return AUDIO_PLACEHOLDER |
| 661 | # Otherwise, remove it. |
| 662 | return "" |
| 663 | |
| 664 | return AUDIO_PATTERN.sub(repl, text) |
| 665 | |
| 666 | return normalize_audio_segments(text) |
| 667 | |
| 668 | def _parse_audio_codes(self, start_length, audio_codes): |
| 669 | # De-delay back to [T', n_vq] |
| 670 | audio_codes = self.apply_de_delay_pattern(audio_codes) |
| 671 | |
| 672 | # Rows that are all pad are separators between real audio segments. |
| 673 | is_pad = (audio_codes == self.model_config.audio_pad_code).all(dim=1) |
| 674 | non_pad = ~is_pad |
| 675 | if not non_pad.any(): |
| 676 | return [] |
| 677 | |
| 678 | idx = torch.nonzero(non_pad).squeeze(1) |
| 679 | breaks = torch.where(idx[1:] != idx[:-1] + 1)[0] + 1 |
| 680 | if breaks.numel() == 0: |
| 681 | segments_idx = [idx] |
| 682 | else: |
| 683 | segments_idx = torch.split(idx, breaks.tolist()) |
| 684 | |
| 685 | audio_codes_list = [audio_codes[s] for s in segments_idx] |
| 686 | |
| 687 | # Batch-decode all audio segments together. |
| 688 | decoded_audio_list = self.decode_audio_codes(audio_codes_list) |
| 689 | |
| 690 | # Keep codec causal context by decoding the whole first segment first, |
| 691 | # then trim at waveform level according to start_length ratio. |
| 692 | if ( |
| 693 | start_length > 0 |
| 694 | and len(audio_codes_list) > 0 |
| 695 | and len(decoded_audio_list) > 0 |
| 696 | ): |
| 697 | first_codes_length = audio_codes_list[0].shape[0] |
| 698 | if first_codes_length > 0: |
| 699 | trim_ratio = max( |
| 700 | 0.0, min(float(start_length) / float(first_codes_length), 1.0) |
| 701 | ) |
| 702 | first_audio = decoded_audio_list[0] |
| 703 | if trim_ratio >= 1.0: |
| 704 | decoded_audio_list = decoded_audio_list[1:] |
| 705 | elif trim_ratio > 0.0: |
| 706 | trim_samples = int(first_audio.shape[-1] * trim_ratio) |
| 707 | decoded_audio_list[0] = first_audio[..., trim_samples:] |
| 708 | |
| 709 | return decoded_audio_list |
| 710 | |
| 711 | def decode(self, output: List[Tuple[int, torch.Tensor]]): |
| 712 | """ |
| 713 | 1. 这里不管怎样,都需要一个完整的 assistant generation ids; |
| 714 | 2. 支持从任意位置进行截断; |
| 715 | """ |
| 716 | |
| 717 | genearted_messages = [] |
| 718 | for start_length, generation_ids in output: |
| 719 | content = self._parse_text_codes(start_length, generation_ids[:, 0]) |
| 720 | audio_codes_list = self._parse_audio_codes( |
| 721 | start_length, generation_ids[:, 1:] |
| 722 | ) |
| 723 | if content == "": |
| 724 | message = None |
| 725 | else: |
| 726 | message = AssistantMessage( |
| 727 | content=content, |
| 728 | audio_codes_list=cast( |
| 729 | List[Union[str, torch.Tensor]], audio_codes_list |
| 730 | ), |
| 731 | ) |
| 732 | genearted_messages.append(message) |
| 733 | return genearted_messages |
| 734 | |
| 735 | @staticmethod |
| 736 | def loudness_normalize( |
| 737 | wav: torch.Tensor, |
| 738 | target_dbfs: float = -20, |
| 739 | gain_range: tuple[float, float] = (-3.0, 3.0), |
| 740 | ) -> torch.Tensor: |
| 741 | wav = wav.to(torch.float32) |
| 742 | if wav.numel() == 0: |
| 743 | return wav |
| 744 | current_dbfs = 10.0 * torch.log10(torch.mean(wav**2) + 1e-9) |
| 745 | gain = float(target_dbfs - current_dbfs) |
| 746 | gain = max(gain_range[0], min(gain, gain_range[1])) |
| 747 | factor = 10.0 ** (gain / 20.0) |
| 748 | return wav * factor |
| 749 | |
| 750 | def _get_audio_tokenizer_device(self) -> torch.device: |
| 751 | """Best-effort device inference for `self.audio_tokenizer`. |
| 752 | |
| 753 | Notes: |
| 754 | - Old TAC wrapper exposed `.device`, but standard `torch.nn.Module` does not. |
| 755 | - New MossAudioTokenizerModel is a `PreTrainedModel`; parameters define its device. |
| 756 | """ |
| 757 | |
| 758 | audio_tokenizer = getattr(self, "audio_tokenizer", None) |
| 759 | if audio_tokenizer is None: |
| 760 | logger.warning( |
| 761 | "audio_tokenizer is not set on processor. Using CPU as default." |
| 762 | ) |
| 763 | return torch.device("cpu") |
| 764 | |
| 765 | device_attr = getattr(audio_tokenizer, "device", None) |
| 766 | if isinstance(device_attr, torch.device): |
| 767 | return device_attr |
| 768 | |
| 769 | try: |
| 770 | return next(audio_tokenizer.parameters()).device |
| 771 | except StopIteration: |
| 772 | # No parameters (shouldn't happen for real models); default to CPU. |
| 773 | logger.warning( |
| 774 | "No parameters found on audio_tokenizer. Using CPU as default." |
| 775 | ) |
| 776 | return torch.device("cpu") |
| 777 | |
| 778 | def encode_audios_from_wav( |
| 779 | self, |
| 780 | wav_list: List[torch.Tensor], |
| 781 | sampling_rate: int, |
| 782 | n_vq: Optional[int] = None, |
| 783 | ): |
| 784 | if self.audio_tokenizer is None: |
| 785 | raise RuntimeError("audio_tokenizer is not set on processor.") |
| 786 | audio_tokenizer = self.audio_tokenizer |
| 787 | |
| 788 | if isinstance(wav_list, torch.Tensor): |
| 789 | wav_list = [wav_list] |
| 790 | wav_list_ = [] |
| 791 | resample = False |
| 792 | if sampling_rate != self.model_config.sampling_rate: |
| 793 | resample = True |
| 794 | device = self._get_audio_tokenizer_device() |
| 795 | for wav in wav_list: |
| 796 | if wav.shape[0] > 1: |
| 797 | wav = torch.mean(wav, dim=0, keepdim=True) |
| 798 | if resample: |
| 799 | wav = torchaudio.functional.resample( |
| 800 | waveform=wav, |
| 801 | orig_freq=sampling_rate, |
| 802 | new_freq=self.model_config.sampling_rate, |
| 803 | ) |
| 804 | wav = wav.to(device) |
| 805 | wav_list_.append(self.loudness_normalize(wav.squeeze(0))) |
| 806 | |
| 807 | # New MossAudioTokenizerModel API: prefer batch_encode(list[wav]) |
| 808 | if hasattr(audio_tokenizer, "batch_encode"): |
| 809 | enc = audio_tokenizer.batch_encode(wav_list_, num_quantizers=n_vq) |
| 810 | audio_codes = enc.audio_codes # (NQ, B, T) |
| 811 | audio_codes_lengths = enc.audio_codes_lengths # (B,) |
| 812 | else: |
| 813 | # Fallback: use encode() with explicit padding. |
| 814 | max_len = max(int(wav.shape[-1]) for wav in wav_list_) |
| 815 | input_values = torch.zeros( |
| 816 | len(wav_list_), 1, max_len, device=device, dtype=torch.float32 |
| 817 | ) |
| 818 | padding_mask = torch.zeros( |
| 819 | len(wav_list_), max_len, device=device, dtype=torch.bool |
| 820 | ) |
| 821 | for i, wav in enumerate(wav_list_): |
| 822 | this_len = int(wav.shape[-1]) |
| 823 | input_values[i, 0, :this_len] = wav |
| 824 | padding_mask[i, :this_len] = True |
| 825 | enc = audio_tokenizer.encode( |
| 826 | input_values, |
| 827 | padding_mask=padding_mask, |
| 828 | num_quantizers=n_vq, |
| 829 | return_dict=True, |
| 830 | ) |
| 831 | audio_codes = enc.audio_codes |
| 832 | audio_codes_lengths = enc.audio_codes_lengths |
| 833 | |
| 834 | if audio_codes is None or audio_codes_lengths is None: |
| 835 | raise RuntimeError( |
| 836 | "audio_tokenizer.encode() returned empty outputs (audio_codes/audio_codes_lengths)." |
| 837 | ) |
| 838 | |
| 839 | # Keep processor's historical contract: list[Tensor] with shape (T, NQ) |
| 840 | # and on CPU (so downstream text/audio packing remains device-agnostic). |
| 841 | codes_list: List[torch.Tensor] = [] |
| 842 | for i in range(int(audio_codes.shape[1])): |
| 843 | length_i = int(audio_codes_lengths[i].item()) |
| 844 | codes_i = ( |
| 845 | audio_codes[:, i, :length_i] |
| 846 | .transpose(0, 1) |
| 847 | .contiguous() |
| 848 | .to(torch.long) |
| 849 | .cpu() |
| 850 | ) |
| 851 | codes_list.append(codes_i) |
| 852 | return codes_list |
| 853 | |
| 854 | def encode_audios_from_path( |
| 855 | self, wav_path_list: Union[str, List[str]], n_vq: Optional[int] = None |
| 856 | ): |
| 857 | if isinstance(wav_path_list, str): |
| 858 | wav_path_list = [wav_path_list] |
| 859 | |
| 860 | if len(wav_path_list) == 0: |
| 861 | raise ValueError("Empty wav_path_list") |
| 862 | |
| 863 | # Load + (if needed) resample each wav independently, so callers can |
| 864 | # pass a heterogeneous batch of files while still benefiting from |
| 865 | # audio_tokenizer.batch_encode. |
| 866 | target_sr = int(self.model_config.sampling_rate) |
| 867 | wav_list: List[torch.Tensor] = [] |
| 868 | for wav_path in wav_path_list: |
| 869 | wav, sr = torchaudio.load(wav_path) |
| 870 | if int(sr) != target_sr: |
| 871 | wav = torchaudio.functional.resample( |
| 872 | waveform=wav, |
| 873 | orig_freq=int(sr), |
| 874 | new_freq=target_sr, |
| 875 | ) |
| 876 | wav_list.append(wav) |
| 877 | |
| 878 | return self.encode_audios_from_wav(wav_list, target_sr, n_vq) |
| 879 | |
| 880 | def decode_audio_codes( |
| 881 | self, audio_tokens_list: Union[torch.Tensor, List[torch.Tensor]] |
| 882 | ): |
| 883 | if self.audio_tokenizer is None: |
| 884 | raise RuntimeError("audio_tokenizer is not set on processor.") |
| 885 | audio_tokenizer = self.audio_tokenizer |
| 886 | |
| 887 | if isinstance(audio_tokens_list, torch.Tensor): |
| 888 | audio_tokens_list = [audio_tokens_list] |
| 889 | if len(audio_tokens_list) == 0: |
| 890 | return [] |
| 891 | |
| 892 | device = self._get_audio_tokenizer_device() |
| 893 | |
| 894 | # Processor uses (T, NQ); MossAudioTokenizer expects (NQ, T) (or (NQ, B, T)). |
| 895 | codes_list = [ |
| 896 | codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long) |
| 897 | for codes in audio_tokens_list |
| 898 | ] |
| 899 | |
| 900 | # Fallback: pad to (NQ, B, T) + mask, then decode. |
| 901 | nq = int(codes_list[0].shape[0]) |
| 902 | max_t = max(int(c.shape[1]) for c in codes_list) |
| 903 | audio_codes = torch.zeros( |
| 904 | nq, len(codes_list), max_t, device=device, dtype=torch.long |
| 905 | ) |
| 906 | padding_mask = torch.zeros( |
| 907 | len(codes_list), max_t, device=device, dtype=torch.bool |
| 908 | ) |
| 909 | for i, c in enumerate(codes_list): |
| 910 | t = int(c.shape[1]) |
| 911 | audio_codes[:, i, :t] = c |
| 912 | padding_mask[i, :t] = True |
| 913 | dec = audio_tokenizer.decode( |
| 914 | audio_codes, padding_mask=padding_mask, return_dict=True, chunk_duration=8 |
| 915 | ) |
| 916 | audio = dec.audio |
| 917 | audio_lengths = dec.audio_lengths |
| 918 | |
| 919 | if audio is None or audio_lengths is None: |
| 920 | raise RuntimeError( |
| 921 | "audio_tokenizer.decode() returned empty outputs (audio/audio_lengths)." |
| 922 | ) |
| 923 | |
| 924 | # Return historical contract: list of 1D waveforms (T,) |
| 925 | wav_list: List[torch.Tensor] = [] |
| 926 | for i in range(int(audio.shape[0])): |
| 927 | length_i = int(audio_lengths[i].item()) |
| 928 | wav = audio[i, 0, :length_i].contiguous().to(torch.float32).cpu() |
| 929 | wav_list.append(wav) |
| 930 | return wav_list |
| 931 | |