configuration_moss_tts.py
5.5 KB · 115 lines · python Raw
1 # coding=utf-8
2 # Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 # http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15 """ MossTTSDelay model configuration """
16
17 from typing import Optional, Union
18 from transformers.configuration_utils import PretrainedConfig
19 from transformers.utils import logging
20 from transformers.models.qwen3 import Qwen3Config
21
22 logger = logging.get_logger(__name__)
23
24
25 class MossTTSDelayConfig(PretrainedConfig):
26 r"""
27 This is the configuration class to store the configuration of a [`MossTTSDelayModel`]. It is used to instantiate an
28 MossTTSDelay model according to the specified arguments, defining the model architecture. Instantiating a configuration
29 with the defaults will yield a similar configuration to that of the MossTTSDelay [MossTTSDelay-8B](https://huggingface.co/OpenMOSS/mosstts-8b) architecture.
30
31 Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32 documentation from [`PretrainedConfig`] for more information.
33
34 Args:
35 language_config (`Union[Qwen3Config, dict]`, *optional*):
36 Configuration for the backbone language model (Qwen3).
37 initializer_range (`float`, *optional*, defaults to 0.02):
38 The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
39 n_vq (`int`, *optional*, defaults to 32):
40 Number of additional VQ (Vector Quantization) heads/channels for audio.
41 Determines the number of codebooks used in the audio representation.
42 audio_vocab_size (`int`, *optional*, defaults to 1024):
43 Vocabulary size for the audio tokens (codebooks 1 to N).
44 audio_user_slot_token_id (`int`, *optional*, defaults to 151654):
45 The specific token ID used as a placeholder/slot for user-side audio inputs in the prompt.
46 audio_assistant_gen_slot_token_id (`int`, *optional*, defaults to 151656):
47 The specific token ID representing the generation slot for the assistant's audio output.
48 Acting as the trigger for the TTS generation process.
49 audio_assistant_delay_slot_token_id (`int`, *optional*, defaults to 151662):
50 The token ID used in the 'Delay Pattern' paradigm to represent the delayed/offset positions
51 between different VQ channels.
52 audio_start_token_id (`int`, *optional*, defaults to 151652):
53 Special token ID used to denote the start of an audio sequence in the stream.
54 audio_end_token_id (`int`, *optional*, defaults to 151653):
55 Special token ID used to denote the end of an audio sequence (EOS for audio).
56 audio_pad_code (`int`, *optional*, defaults to 1024):
57 The padding value used within the audio VQ codebooks. Typically equals `audio_vocab_size`.
58 """
59 model_type = "moss_tts_delay"
60 keys_to_ignore_at_inference = ["past_key_values"]
61
62 def __init__(
63 self,
64 language_config: Optional[Union[Qwen3Config, dict]] = None,
65 initializer_range: float = 0.02,
66 n_vq: int = 32,
67 pad_token_id: int = 151643,
68 im_start_token_id: int = 151644,
69 im_end_token_id: int = 151645,
70 audio_vocab_size: int = 1024,
71 audio_user_slot_token_id: int = 151654,
72 audio_assistant_gen_slot_token_id: int = 151656,
73 audio_assistant_delay_slot_token_id: int = 151662,
74 audio_start_token_id: int = 151652,
75 audio_end_token_id: int = 151653,
76 audio_pad_code: int = 1024,
77 sampling_rate: int = 24000,
78 **kwargs,
79 ):
80 if isinstance(language_config, dict):
81 self.language_config = Qwen3Config(**language_config)
82 elif language_config is None:
83 self.language_config = Qwen3Config()
84 else:
85 self.language_config = language_config
86
87 self.initializer_range = initializer_range
88 self.n_vq = n_vq
89 self.audio_vocab_size = audio_vocab_size
90 self.audio_user_slot_token_id = audio_user_slot_token_id
91 self.audio_assistant_gen_slot_token_id = audio_assistant_gen_slot_token_id
92 self.audio_assistant_delay_slot_token_id = audio_assistant_delay_slot_token_id
93 self.audio_start_token_id = audio_start_token_id
94 self.audio_end_token_id = audio_end_token_id
95 self.audio_pad_code = audio_pad_code
96 self.sampling_rate = sampling_rate
97
98 self.hidden_size = self.language_config.hidden_size
99 self.vocab_size = self.language_config.vocab_size
100 self.im_start_token_id = self.language_config
101 self.pad_token_id = pad_token_id
102 self.im_start_token_id = im_start_token_id
103 self.im_end_token_id = im_end_token_id
104
105
106 super().__init__(**kwargs)
107
108 def to_dict(self):
109 output = super().to_dict()
110 if hasattr(self.language_config, "to_dict"):
111 output["language_config"] = self.language_config.to_dict()
112 else:
113 output["language_config"] = self.language_config
114 return output
115