processing_moss_tts.py
33.7 KB · 931 lines · python Raw
1 # coding=utf-8
2 # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 # http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 import os
17 from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal, Final, cast
18 from dataclasses import dataclass
19 from pathlib import Path
20 import re
21 import torchaudio
22
23 from transformers import processing_utils
24
25 processing_utils.MODALITY_TO_BASE_CLASS_MAPPING["audio_tokenizer"] = "PreTrainedModel"
26
27 import torch
28 from transformers import (
29 PreTrainedTokenizerBase,
30 BatchFeature,
31 ProcessorMixin,
32 logging,
33 AutoConfig,
34 AutoModel,
35 AutoTokenizer,
36 )
37
38 from .configuration_moss_tts import MossTTSDelayConfig
39
40
41 logger = logging.get_logger(__name__)
42
43
44 AUDIO_PLACEHOLDER = "<|audio|>"
45
46
47 @dataclass
48 class Message:
49 def to_dict(self) -> Dict[str, Any]:
50 raise NotImplementedError
51
52
53 @dataclass
54 class UserMessage(Message):
55 text: Optional[str] = None
56 reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None
57 instruction: Optional[str] = None
58 tokens: Optional[int] = None
59 quality: Optional[str] = None
60 sound_event: Optional[str] = None
61 ambient_sound: Optional[str] = None
62 language: Optional[str] = None
63
64 def __post_init__(self):
65 template = """<user_inst>
66 - Reference(s):
67 {reference}
68 - Instruction:
69 {instruction}
70 - Tokens:
71 {tokens}
72 - Quality:
73 {quality}
74 - Sound Event:
75 {sound_event}
76 - Ambient Sound:
77 {ambient_sound}
78 - Language:
79 {language}
80 - Text:
81 {text}
82 </user_inst>"""
83
84 audio_codes_list = []
85 if self.reference is None:
86 reference = "None"
87 elif isinstance(self.reference, List):
88 reference = []
89 for speaker_idx, speaker_reference in enumerate(self.reference):
90 if speaker_reference is not None:
91 reference.append(f"[S{speaker_idx+1}]:\n{AUDIO_PLACEHOLDER}")
92 reference = "\n".join(reference)
93 audio_codes_list = [
94 speaker_reference
95 for speaker_reference in self.reference
96 if speaker_reference is not None
97 ]
98 else:
99 raise TypeError("`reference` should be exactly a list when it is not None.")
100
101 content = (
102 template.replace("{reference}", str(reference))
103 .replace("{instruction}", str(self.instruction))
104 .replace("{tokens}", str(self.tokens))
105 .replace("{quality}", str(self.quality))
106 .replace("{sound_event}", str(self.sound_event))
107 .replace("{ambient_sound}", str(self.ambient_sound))
108 .replace("{language}", str(self.language))
109 .replace("{text}", str(self.text))
110 )
111
112 self._content = content
113 self._audio_codes_list = audio_codes_list
114
115 def to_dict(self):
116 return {
117 "role": "user",
118 "content": self._content,
119 "audio_codes_list": self._audio_codes_list,
120 }
121
122
123 @dataclass
124 class AssistantMessage(Message):
125 audio_codes_list: List[Union[str, torch.Tensor]]
126 content: str = AUDIO_PLACEHOLDER
127
128 def to_dict(self):
129 return {
130 "role": "assistant",
131 "content": self.content,
132 "audio_codes_list": self.audio_codes_list,
133 }
134
135
136 USER_MESSAGE_FIELDS = (
137 "text",
138 "reference",
139 "instruction",
140 "tokens",
141 "quality",
142 "sound_event",
143 "ambient_sound",
144 "language",
145 )
146
147
148 class MossTTSDelayProcessor(ProcessorMixin):
149 tokenizer_class = "AutoTokenizer"
150 audio_tokenizer_class = "AutoModel"
151
152 tokenizer: PreTrainedTokenizerBase
153 audio_tokenizer: Any
154
155 def __init__(
156 self,
157 tokenizer: PreTrainedTokenizerBase,
158 audio_tokenizer: Any = None,
159 model_config: Optional[MossTTSDelayConfig] = None,
160 **kwargs,
161 ):
162 super().__init__(tokenizer=tokenizer, audio_tokenizer=audio_tokenizer, **kwargs)
163
164 # Explicit assignments for type-checkers; ProcessorMixin sets these too.
165 self.tokenizer = tokenizer
166 self.audio_tokenizer = audio_tokenizer
167 if model_config is None:
168 model_config = MossTTSDelayConfig()
169 self.model_config = model_config
170
171 self.imstart_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
172 self.imend_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
173 self.newline_token_id = 198
174
175 def _id_to_token(token_id: int) -> str:
176 tok = tokenizer.convert_ids_to_tokens(int(token_id))
177 if isinstance(tok, list):
178 return tok[0] if len(tok) > 0 else ""
179 return cast(str, tok)
180
181 self.audio_user_slot_token = _id_to_token(
182 self.model_config.audio_user_slot_token_id
183 )
184 self.audio_assistant_gen_slot_token = _id_to_token(
185 self.model_config.audio_assistant_gen_slot_token_id
186 )
187 self.audio_assistant_delay_slot_token = _id_to_token(
188 self.model_config.audio_assistant_delay_slot_token_id
189 )
190 self.audio_start_token = _id_to_token(self.model_config.audio_start_token_id)
191 self.audio_end_token = _id_to_token(self.model_config.audio_end_token_id)
192
193 @classmethod
194 def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
195 trust_remote_code = kwargs.pop("trust_remote_code", True)
196 kwargs.pop("_from_auto", None)
197
198 audio_tokenizer_name_or_path = kwargs.pop(
199 "codec_path", "OpenMOSS-Team/MOSS-Audio-Tokenizer"
200 )
201
202 pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
203 model_config = cast(
204 MossTTSDelayConfig,
205 AutoConfig.from_pretrained(
206 pretrained_model_name_or_path,
207 *args,
208 trust_remote_code=trust_remote_code,
209 **kwargs,
210 ),
211 )
212 tokenizer = AutoTokenizer.from_pretrained(
213 pretrained_model_name_or_path,
214 *args,
215 trust_remote_code=trust_remote_code,
216 **kwargs,
217 )
218 audio_tokenizer = AutoModel.from_pretrained(
219 audio_tokenizer_name_or_path,
220 trust_remote_code=trust_remote_code,
221 **kwargs,
222 )
223
224 return cls(
225 tokenizer=tokenizer,
226 audio_tokenizer=audio_tokenizer,
227 model_config=model_config,
228 **kwargs,
229 )
230
231 def __call__(self, *args, **kwargs) -> BatchFeature:
232 conversations = args[0] if len(args) > 0 else kwargs.pop("conversations")
233 mode: str = kwargs.pop("mode", "generation")
234 apply_chat_template: bool = kwargs.pop("apply_chat_template", True)
235 n_vq: Optional[int] = kwargs.pop("n_vq", None)
236
237 # Common ProcessorMixin kwargs that we ignore because we always return torch tensors.
238 kwargs.pop("return_tensors", None)
239 kwargs.pop("padding", None)
240 kwargs.pop("truncation", None)
241
242 """
243 mode only works when a Message is converted to a dict.
244 """
245
246 if mode not in {"generation", "continuation"}:
247 raise RuntimeError
248
249 if isinstance(conversations, (Message, Dict)):
250 conversations = [conversations]
251
252 truncation = False
253 if mode == "continuation":
254 truncation = True
255
256 input_ids_list = []
257 for conversation in conversations:
258 if isinstance(conversation, (Message, Dict)):
259 conversation = [conversation]
260
261 # Normalize early so downstream logic always deals with dict messages.
262 conversation = [self._normalize_message(m) for m in conversation]
263
264 if (mode == "generation") ^ (len(conversation) % 2 != 0):
265 raise ValueError
266
267 if (mode == "generation") ^ (conversation[-1]["role"] == "user"):
268 raise ValueError
269
270 unified_codes = []
271 for message_idx, message in enumerate(conversation):
272 if apply_chat_template:
273 add_generation_prompt = (
274 mode == "generation" and message_idx == len(conversation) - 1
275 )
276 try:
277 content = self.tokenizer.apply_chat_template(
278 [{"role": message["role"], "content": message["content"]}],
279 add_generation_prompt=add_generation_prompt,
280 tokenize=False,
281 )
282 except TypeError:
283 try:
284 content = self.tokenizer.apply_chat_template(
285 [
286 {
287 "role": message["role"],
288 "content": message["content"],
289 }
290 ],
291 add_generation_prompt=add_generation_prompt,
292 )
293 except Exception:
294 logger.warning(
295 "apply_chat_template failed; fallback to raw content."
296 )
297 content = message["content"]
298 else:
299 content = message["content"]
300
301 if not isinstance(content, str):
302 content = str(content)
303
304 # Batch-encode all path-based references in one call when possible.
305 # This ensures we actually exercise audio_tokenizer.batch_encode for multi-reference prompts,
306 # instead of repeatedly calling it with batch=1.
307 raw_audio_items = message.get("audio_codes_list", [])
308
309 audio_codes_list: List[torch.Tensor] = []
310 if len(raw_audio_items) > 0:
311 encoded_items: List[Optional[torch.Tensor]] = [None] * len(
312 raw_audio_items
313 )
314 paths: List[str] = []
315 path_positions: List[int] = []
316
317 for idx, item in enumerate(raw_audio_items):
318 if isinstance(item, torch.Tensor):
319 if n_vq is not None and item.shape[1] != n_vq:
320 raise RuntimeError(
321 "audio_codes's n_vq is not equal to the parameter `n_vq`. Your can set the parameter `n_vq` as None if you have already tokenzied the wavs."
322 )
323 encoded_items[idx] = item
324 continue
325
326 if isinstance(item, (str, os.PathLike)):
327 paths.append(str(item))
328 path_positions.append(idx)
329 continue
330
331 raise TypeError(
332 "Each audio item must be a torch.Tensor of codes or a path-like string."
333 )
334
335 if len(paths) > 0:
336 encoded_from_paths = self.encode_audios_from_path(paths, n_vq)
337 if len(encoded_from_paths) != len(paths):
338 raise RuntimeError(
339 "encode_audios_from_path returned an unexpected number of items."
340 )
341 for pos, codes in zip(path_positions, encoded_from_paths):
342 encoded_items[pos] = codes
343
344 audio_codes_list = [cast(torch.Tensor, t) for t in encoded_items]
345 unified_codes.append(
346 self._get_unified_codes(
347 message["role"], content, audio_codes_list, truncation
348 )
349 )
350
351 unified_codes = torch.cat(unified_codes)
352 input_ids_list.append(unified_codes)
353
354 return BatchFeature(data=self._pad(input_ids_list))
355
356 @staticmethod
357 def build_user_message(
358 text: Optional[str] = None,
359 reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None,
360 instruction: Optional[str] = None,
361 tokens: Optional[int] = None,
362 quality: Optional[str] = None,
363 sound_event: Optional[str] = None,
364 ambient_sound: Optional[str] = None,
365 language: Optional[str] = None,
366 ) -> Dict:
367 if reference is not None and not isinstance(reference, list):
368 reference = [reference]
369 return UserMessage(
370 text=text,
371 reference=reference,
372 instruction=instruction,
373 tokens=tokens,
374 quality=quality,
375 sound_event=sound_event,
376 ambient_sound=ambient_sound,
377 language=language,
378 ).to_dict()
379
380 @staticmethod
381 def build_assistant_message(
382 audio_codes_list: List[Union[str, torch.Tensor]],
383 content: str = AUDIO_PLACEHOLDER,
384 ) -> Dict:
385 return AssistantMessage(
386 audio_codes_list=audio_codes_list,
387 content=content,
388 ).to_dict()
389
390 def _normalize_message(self, message: Union[Message, Dict]) -> Dict:
391 if isinstance(message, Message):
392 return message.to_dict()
393 if not isinstance(message, dict):
394 raise TypeError("Each message must be a Message or dict.")
395 if "role" not in message:
396 raise ValueError("Message dict must include a 'role' field.")
397 if "content" in message and "audio_codes_list" in message:
398 return message
399 role = message["role"]
400 if role == "user":
401 kwargs = {key: message.get(key) for key in USER_MESSAGE_FIELDS}
402 return self.build_user_message(**kwargs)
403 if role == "assistant":
404 return self.build_assistant_message(
405 audio_codes_list=message.get("audio_codes_list", []),
406 content=message.get("content", AUDIO_PLACEHOLDER),
407 )
408 raise ValueError(f"Unsupported role: {role}")
409
410 def _pad(self, input_ids_list: List[torch.Tensor]):
411 device = input_ids_list[0].device
412 lengths = torch.tensor([w.shape[0] for w in input_ids_list], device=device)
413 pad_input_ids = torch.nn.utils.rnn.pad_sequence(
414 input_ids_list,
415 batch_first=True,
416 padding_value=self.model_config.audio_pad_code,
417 padding_side="left",
418 )
419 other_channel_mask = (pad_input_ids.shape[1] - lengths).unsqueeze(
420 1
421 ) > torch.arange(pad_input_ids.shape[1], device=device).unsqueeze(0)
422 pad_input_ids[..., 0][other_channel_mask] = self.model_config.pad_token_id
423 attention_mask = torch.zeros(
424 pad_input_ids.shape[0], pad_input_ids.shape[1], device=device
425 )
426 attention_mask[~other_channel_mask] = 1
427 attention_mask = attention_mask.bool()
428 return {
429 "input_ids": pad_input_ids, # [batch_size, seqlen, n_vq]
430 "attention_mask": attention_mask,
431 }
432
433 @staticmethod
434 def _replace_audio_placeholders(
435 content: str,
436 lengths: List[int],
437 n_vq: int,
438 gen_slot_token: str,
439 delay_slot_token: str,
440 audio_start_token: str,
441 audio_end_token: str,
442 ) -> str:
443 if n_vq < 1:
444 raise ValueError(f"n_vq must be >= 1, got {n_vq}")
445
446 num_placeholders = content.count(AUDIO_PLACEHOLDER)
447 if num_placeholders != len(lengths):
448 raise ValueError(
449 f"Number of {AUDIO_PLACEHOLDER} ({num_placeholders}) "
450 f"does not match lengths ({len(lengths)})"
451 )
452
453 def build_audio_block(length: int) -> str:
454 if length < 0:
455 raise ValueError(f"length must be >= 0, got {length}")
456
457 if length == 0:
458 return f"{audio_start_token}{audio_end_token}"
459
460 step_tokens = gen_slot_token * length + (delay_slot_token * (n_vq - 1))
461 return f"{audio_start_token}{step_tokens}{audio_end_token}"
462
463 lengths_iter = iter(lengths)
464
465 def replacer(match: re.Match) -> str:
466 length = next(lengths_iter)
467 return build_audio_block(length)
468
469 result = re.sub(re.escape(AUDIO_PLACEHOLDER), replacer, content)
470
471 return result
472
473 @staticmethod
474 def _merge_consecutive_audio_placeholders(
475 content: str,
476 audio_codes_list: List[torch.Tensor],
477 ) -> Tuple[str, List[torch.Tensor]]:
478 matches = list(re.finditer(re.escape(AUDIO_PLACEHOLDER), content))
479 if len(matches) <= 1:
480 return content, audio_codes_list
481
482 if len(matches) != len(audio_codes_list):
483 raise ValueError(
484 "Audio placeholders do not match the provided audio codes list."
485 )
486
487 new_audio_codes_list = []
488 new_parts = []
489 last_pos = 0
490 i = 0
491 while i < len(matches):
492 j = i
493 while (
494 j + 1 < len(matches)
495 and content[matches[j].end() : matches[j + 1].start()].strip() == ""
496 ):
497 j += 1
498
499 new_parts.append(content[last_pos : matches[i].start()])
500 new_parts.append(AUDIO_PLACEHOLDER)
501 last_pos = matches[j].end()
502
503 if j == i:
504 new_audio_codes_list.append(audio_codes_list[i])
505 else:
506 new_audio_codes_list.append(
507 torch.cat(audio_codes_list[i : j + 1], dim=0)
508 )
509
510 i = j + 1
511
512 new_parts.append(content[last_pos:])
513 return "".join(new_parts), new_audio_codes_list
514
515 @staticmethod
516 def apply_delay_pattern(codes: torch.Tensor, pad_code: int) -> torch.Tensor:
517 delayed_tokens = torch.full(
518 (codes.shape[0] + codes.shape[1] - 1, codes.shape[1]),
519 pad_code,
520 device=codes.device,
521 dtype=codes.dtype,
522 )
523 for i in range(codes.shape[1]):
524 delayed_tokens[i : i + codes.shape[0], i] = codes[:, i]
525 return delayed_tokens
526
527 @staticmethod
528 def apply_de_delay_pattern(delay_codes: torch.Tensor) -> torch.Tensor:
529 tokens = torch.full(
530 (delay_codes.shape[0] - delay_codes.shape[1] + 1, delay_codes.shape[1]),
531 0,
532 device=delay_codes.device,
533 dtype=delay_codes.dtype,
534 )
535 for i in range(delay_codes.shape[1]):
536 tokens[:, i] = delay_codes[i : i + tokens.shape[0], i]
537 return tokens
538
539 def _get_unified_codes(
540 self,
541 role: str,
542 content: str,
543 audio_codes_list: List[torch.Tensor],
544 truncation: bool,
545 ) -> torch.Tensor:
546 """
547 此时的 content 已经是带上了对话格式
548 """
549 if role == "user":
550 audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token
551 truncation = False
552 else:
553 audio_gen_slot_token = self.audio_assistant_gen_slot_token
554 audio_delay_slot_token = self.audio_assistant_delay_slot_token
555
556 if len(audio_codes_list):
557 n_vq = audio_codes_list[0].shape[1]
558 else:
559 n_vq = self.model_config.n_vq
560
561 if len(audio_codes_list) > 1 and AUDIO_PLACEHOLDER in content:
562 content, audio_codes_list = self._merge_consecutive_audio_placeholders(
563 content, audio_codes_list
564 )
565 content = self._replace_audio_placeholders(
566 content=content,
567 lengths=[len(audio_codes) for audio_codes in audio_codes_list],
568 n_vq=n_vq,
569 gen_slot_token=audio_gen_slot_token,
570 delay_slot_token=audio_delay_slot_token,
571 audio_start_token=self.audio_start_token,
572 audio_end_token=self.audio_end_token,
573 )
574 text_codes = torch.tensor(
575 self.tokenizer.encode(content),
576 device=audio_codes_list[0].device if audio_codes_list else None,
577 )
578
579 audio_start_indices = torch.where(
580 text_codes == self.model_config.audio_start_token_id
581 )[0]
582 audio_end_indices = torch.where(
583 text_codes == self.model_config.audio_end_token_id
584 )[0]
585 if len(audio_start_indices) != len(audio_codes_list) or len(
586 audio_end_indices
587 ) != len(audio_codes_list):
588 raise ValueError(
589 "Audio placeholders do not match the provided audio codes list."
590 )
591
592 delay_audio_codes_list = []
593 if len(audio_codes_list) == 0:
594 delay_audio_codes_list = torch.full(
595 (len(text_codes), n_vq),
596 self.model_config.audio_pad_code,
597 device=text_codes.device,
598 dtype=text_codes.dtype,
599 )
600 else:
601 prefix_idx = 0
602 for audio_start_idx_t, audio_end_idx_t, audio_codes in zip(
603 audio_start_indices, audio_end_indices, audio_codes_list
604 ):
605 audio_start_idx = int(audio_start_idx_t.item())
606 audio_end_idx = int(audio_end_idx_t.item())
607 delay_audio_codes = self.apply_delay_pattern(
608 audio_codes, self.model_config.audio_pad_code
609 )
610 pad_codes = torch.full(
611 (audio_start_idx - prefix_idx + 1, n_vq),
612 self.model_config.audio_pad_code,
613 device=audio_codes.device,
614 dtype=audio_codes.dtype,
615 )
616 delay_audio_codes_list.extend([pad_codes, delay_audio_codes])
617 prefix_idx = audio_end_idx
618
619 if truncation:
620 delay_audio_codes_list[-1] = delay_audio_codes_list[-1][
621 : -(n_vq - 1), :
622 ]
623 else:
624 last_audio_end_idx = int(audio_end_indices[-1].item())
625 pad_codes = torch.full(
626 (len(text_codes) - last_audio_end_idx, n_vq),
627 self.model_config.audio_pad_code,
628 device=audio_codes_list[0].device,
629 dtype=audio_codes_list[0].dtype,
630 )
631 delay_audio_codes_list.append(pad_codes)
632
633 delay_audio_codes_list = torch.cat(delay_audio_codes_list)
634
635 if text_codes.shape[0] != delay_audio_codes_list.shape[0]:
636 text_codes = text_codes[: delay_audio_codes_list.shape[0]]
637
638 unified_codes = torch.cat(
639 [text_codes.unsqueeze(1), delay_audio_codes_list], dim=1
640 )
641 return unified_codes
642
643 def _parse_text_codes(self, start_length, text_codes):
644 text = cast(str, self.tokenizer.decode(text_codes))
645 prefix = cast(str, self.tokenizer.decode(text_codes[:start_length]))
646 text = text[len(prefix) :]
647
648 AUDIO_PATTERN = re.compile(
649 rf"(?:{self.audio_start_token})?"
650 rf"(?:{self.audio_assistant_gen_slot_token})*"
651 rf"(?:{self.audio_assistant_delay_slot_token})*"
652 rf"{self.audio_end_token}"
653 )
654
655 def normalize_audio_segments(text: str) -> str:
656 def repl(match: re.Match) -> str:
657 seg = match.group(0)
658 # Replace with <|audio|> if gen_slot is present in the segment;
659 if self.audio_assistant_gen_slot_token in seg:
660 return AUDIO_PLACEHOLDER
661 # Otherwise, remove it.
662 return ""
663
664 return AUDIO_PATTERN.sub(repl, text)
665
666 return normalize_audio_segments(text)
667
668 def _parse_audio_codes(self, start_length, audio_codes):
669 # De-delay back to [T', n_vq]
670 audio_codes = self.apply_de_delay_pattern(audio_codes)
671
672 # Rows that are all pad are separators between real audio segments.
673 is_pad = (audio_codes == self.model_config.audio_pad_code).all(dim=1)
674 non_pad = ~is_pad
675 if not non_pad.any():
676 return []
677
678 idx = torch.nonzero(non_pad).squeeze(1)
679 breaks = torch.where(idx[1:] != idx[:-1] + 1)[0] + 1
680 if breaks.numel() == 0:
681 segments_idx = [idx]
682 else:
683 segments_idx = torch.split(idx, breaks.tolist())
684
685 audio_codes_list = [audio_codes[s] for s in segments_idx]
686
687 # Batch-decode all audio segments together.
688 decoded_audio_list = self.decode_audio_codes(audio_codes_list)
689
690 # Keep codec causal context by decoding the whole first segment first,
691 # then trim at waveform level according to start_length ratio.
692 if (
693 start_length > 0
694 and len(audio_codes_list) > 0
695 and len(decoded_audio_list) > 0
696 ):
697 first_codes_length = audio_codes_list[0].shape[0]
698 if first_codes_length > 0:
699 trim_ratio = max(
700 0.0, min(float(start_length) / float(first_codes_length), 1.0)
701 )
702 first_audio = decoded_audio_list[0]
703 if trim_ratio >= 1.0:
704 decoded_audio_list = decoded_audio_list[1:]
705 elif trim_ratio > 0.0:
706 trim_samples = int(first_audio.shape[-1] * trim_ratio)
707 decoded_audio_list[0] = first_audio[..., trim_samples:]
708
709 return decoded_audio_list
710
711 def decode(self, output: List[Tuple[int, torch.Tensor]]):
712 """
713 1. 这里不管怎样,都需要一个完整的 assistant generation ids;
714 2. 支持从任意位置进行截断;
715 """
716
717 genearted_messages = []
718 for start_length, generation_ids in output:
719 content = self._parse_text_codes(start_length, generation_ids[:, 0])
720 audio_codes_list = self._parse_audio_codes(
721 start_length, generation_ids[:, 1:]
722 )
723 if content == "":
724 message = None
725 else:
726 message = AssistantMessage(
727 content=content,
728 audio_codes_list=cast(
729 List[Union[str, torch.Tensor]], audio_codes_list
730 ),
731 )
732 genearted_messages.append(message)
733 return genearted_messages
734
735 @staticmethod
736 def loudness_normalize(
737 wav: torch.Tensor,
738 target_dbfs: float = -20,
739 gain_range: tuple[float, float] = (-3.0, 3.0),
740 ) -> torch.Tensor:
741 wav = wav.to(torch.float32)
742 if wav.numel() == 0:
743 return wav
744 current_dbfs = 10.0 * torch.log10(torch.mean(wav**2) + 1e-9)
745 gain = float(target_dbfs - current_dbfs)
746 gain = max(gain_range[0], min(gain, gain_range[1]))
747 factor = 10.0 ** (gain / 20.0)
748 return wav * factor
749
750 def _get_audio_tokenizer_device(self) -> torch.device:
751 """Best-effort device inference for `self.audio_tokenizer`.
752
753 Notes:
754 - Old TAC wrapper exposed `.device`, but standard `torch.nn.Module` does not.
755 - New MossAudioTokenizerModel is a `PreTrainedModel`; parameters define its device.
756 """
757
758 audio_tokenizer = getattr(self, "audio_tokenizer", None)
759 if audio_tokenizer is None:
760 logger.warning(
761 "audio_tokenizer is not set on processor. Using CPU as default."
762 )
763 return torch.device("cpu")
764
765 device_attr = getattr(audio_tokenizer, "device", None)
766 if isinstance(device_attr, torch.device):
767 return device_attr
768
769 try:
770 return next(audio_tokenizer.parameters()).device
771 except StopIteration:
772 # No parameters (shouldn't happen for real models); default to CPU.
773 logger.warning(
774 "No parameters found on audio_tokenizer. Using CPU as default."
775 )
776 return torch.device("cpu")
777
778 def encode_audios_from_wav(
779 self,
780 wav_list: List[torch.Tensor],
781 sampling_rate: int,
782 n_vq: Optional[int] = None,
783 ):
784 if self.audio_tokenizer is None:
785 raise RuntimeError("audio_tokenizer is not set on processor.")
786 audio_tokenizer = self.audio_tokenizer
787
788 if isinstance(wav_list, torch.Tensor):
789 wav_list = [wav_list]
790 wav_list_ = []
791 resample = False
792 if sampling_rate != self.model_config.sampling_rate:
793 resample = True
794 device = self._get_audio_tokenizer_device()
795 for wav in wav_list:
796 if wav.shape[0] > 1:
797 wav = torch.mean(wav, dim=0, keepdim=True)
798 if resample:
799 wav = torchaudio.functional.resample(
800 waveform=wav,
801 orig_freq=sampling_rate,
802 new_freq=self.model_config.sampling_rate,
803 )
804 wav = wav.to(device)
805 wav_list_.append(self.loudness_normalize(wav.squeeze(0)))
806
807 # New MossAudioTokenizerModel API: prefer batch_encode(list[wav])
808 if hasattr(audio_tokenizer, "batch_encode"):
809 enc = audio_tokenizer.batch_encode(wav_list_, num_quantizers=n_vq)
810 audio_codes = enc.audio_codes # (NQ, B, T)
811 audio_codes_lengths = enc.audio_codes_lengths # (B,)
812 else:
813 # Fallback: use encode() with explicit padding.
814 max_len = max(int(wav.shape[-1]) for wav in wav_list_)
815 input_values = torch.zeros(
816 len(wav_list_), 1, max_len, device=device, dtype=torch.float32
817 )
818 padding_mask = torch.zeros(
819 len(wav_list_), max_len, device=device, dtype=torch.bool
820 )
821 for i, wav in enumerate(wav_list_):
822 this_len = int(wav.shape[-1])
823 input_values[i, 0, :this_len] = wav
824 padding_mask[i, :this_len] = True
825 enc = audio_tokenizer.encode(
826 input_values,
827 padding_mask=padding_mask,
828 num_quantizers=n_vq,
829 return_dict=True,
830 )
831 audio_codes = enc.audio_codes
832 audio_codes_lengths = enc.audio_codes_lengths
833
834 if audio_codes is None or audio_codes_lengths is None:
835 raise RuntimeError(
836 "audio_tokenizer.encode() returned empty outputs (audio_codes/audio_codes_lengths)."
837 )
838
839 # Keep processor's historical contract: list[Tensor] with shape (T, NQ)
840 # and on CPU (so downstream text/audio packing remains device-agnostic).
841 codes_list: List[torch.Tensor] = []
842 for i in range(int(audio_codes.shape[1])):
843 length_i = int(audio_codes_lengths[i].item())
844 codes_i = (
845 audio_codes[:, i, :length_i]
846 .transpose(0, 1)
847 .contiguous()
848 .to(torch.long)
849 .cpu()
850 )
851 codes_list.append(codes_i)
852 return codes_list
853
854 def encode_audios_from_path(
855 self, wav_path_list: Union[str, List[str]], n_vq: Optional[int] = None
856 ):
857 if isinstance(wav_path_list, str):
858 wav_path_list = [wav_path_list]
859
860 if len(wav_path_list) == 0:
861 raise ValueError("Empty wav_path_list")
862
863 # Load + (if needed) resample each wav independently, so callers can
864 # pass a heterogeneous batch of files while still benefiting from
865 # audio_tokenizer.batch_encode.
866 target_sr = int(self.model_config.sampling_rate)
867 wav_list: List[torch.Tensor] = []
868 for wav_path in wav_path_list:
869 wav, sr = torchaudio.load(wav_path)
870 if int(sr) != target_sr:
871 wav = torchaudio.functional.resample(
872 waveform=wav,
873 orig_freq=int(sr),
874 new_freq=target_sr,
875 )
876 wav_list.append(wav)
877
878 return self.encode_audios_from_wav(wav_list, target_sr, n_vq)
879
880 def decode_audio_codes(
881 self, audio_tokens_list: Union[torch.Tensor, List[torch.Tensor]]
882 ):
883 if self.audio_tokenizer is None:
884 raise RuntimeError("audio_tokenizer is not set on processor.")
885 audio_tokenizer = self.audio_tokenizer
886
887 if isinstance(audio_tokens_list, torch.Tensor):
888 audio_tokens_list = [audio_tokens_list]
889 if len(audio_tokens_list) == 0:
890 return []
891
892 device = self._get_audio_tokenizer_device()
893
894 # Processor uses (T, NQ); MossAudioTokenizer expects (NQ, T) (or (NQ, B, T)).
895 codes_list = [
896 codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
897 for codes in audio_tokens_list
898 ]
899
900 # Fallback: pad to (NQ, B, T) + mask, then decode.
901 nq = int(codes_list[0].shape[0])
902 max_t = max(int(c.shape[1]) for c in codes_list)
903 audio_codes = torch.zeros(
904 nq, len(codes_list), max_t, device=device, dtype=torch.long
905 )
906 padding_mask = torch.zeros(
907 len(codes_list), max_t, device=device, dtype=torch.bool
908 )
909 for i, c in enumerate(codes_list):
910 t = int(c.shape[1])
911 audio_codes[:, i, :t] = c
912 padding_mask[i, :t] = True
913 dec = audio_tokenizer.decode(
914 audio_codes, padding_mask=padding_mask, return_dict=True, chunk_duration=8
915 )
916 audio = dec.audio
917 audio_lengths = dec.audio_lengths
918
919 if audio is None or audio_lengths is None:
920 raise RuntimeError(
921 "audio_tokenizer.decode() returned empty outputs (audio/audio_lengths)."
922 )
923
924 # Return historical contract: list of 1D waveforms (T,)
925 wav_list: List[torch.Tensor] = []
926 for i in range(int(audio.shape[0])):
927 length_i = int(audio_lengths[i].item())
928 wav = audio[i, 0, :length_i].contiguous().to(torch.float32).cpu()
929 wav_list.append(wav)
930 return wav_list
931