tokenization_voxcpm2.py
2.8 KB · 73 lines · python Raw
1 """Custom tokenizer for VoxCPM2 that splits multi-character Chinese tokens.
2
3 VoxCPM2 was trained with ``mask_multichar_chinese_tokens`` which splits
4 multi-character Chinese tokens (e.g. "你好" -> ["你", "好"]) into individual
5 character IDs before embedding. The base LlamaTokenizerFast produces
6 multi-character Chinese tokens that the model has never seen during training,
7 yielding garbled Chinese audio output in downstream inference frameworks.
8
9 This module provides ``VoxCPM2Tokenizer`` which transparently applies the
10 character splitting inside ``encode()`` and ``__call__()``, so any downstream
11 consumer (vLLM, vLLM-Omni, Nano-vLLM, etc.) gets correct single-character
12 IDs without code changes.
13 """
14
15 from transformers import LlamaTokenizerFast
16
17
18 class VoxCPM2Tokenizer(LlamaTokenizerFast):
19
20 def __init__(self, *args, **kwargs):
21 super().__init__(*args, **kwargs)
22 self._split_map = self._build_split_map()
23
24 def _build_split_map(self) -> dict[int, list[int]]:
25 vocab = self.get_vocab()
26 split_map: dict[int, list[int]] = {}
27 for token, tid in vocab.items():
28 clean = token.replace("\u2581", "")
29 if len(clean) >= 2 and all(self._is_cjk(c) for c in clean):
30 char_ids = self.convert_tokens_to_ids(list(clean))
31 if all(c != self.unk_token_id for c in char_ids):
32 split_map[tid] = char_ids
33 return split_map
34
35 @staticmethod
36 def _is_cjk(c: str) -> bool:
37 return (
38 "\u4e00" <= c <= "\u9fff"
39 or "\u3400" <= c <= "\u4dbf"
40 or "\uf900" <= c <= "\ufaff"
41 or "\U00020000" <= c <= "\U0002a6df"
42 )
43
44 def _expand_ids(self, ids: list[int]) -> list[int]:
45 result: list[int] = []
46 for tid in ids:
47 expansion = self._split_map.get(tid)
48 if expansion is not None:
49 result.extend(expansion)
50 else:
51 result.append(tid)
52 return result
53
54 def encode(self, text, *args, **kwargs):
55 ids = super().encode(text, *args, **kwargs)
56 return self._expand_ids(ids)
57
58 def __call__(self, text, *args, **kwargs):
59 result = super().__call__(text, *args, **kwargs)
60 if hasattr(result, "input_ids"):
61 ids = result["input_ids"]
62 if isinstance(ids, list) and ids and isinstance(ids[0], list):
63 result["input_ids"] = [self._expand_ids(x) for x in ids]
64 if "attention_mask" in result:
65 result["attention_mask"] = [
66 [1] * len(x) for x in result["input_ids"]
67 ]
68 elif isinstance(ids, list):
69 result["input_ids"] = self._expand_ids(ids)
70 if "attention_mask" in result:
71 result["attention_mask"] = [1] * len(result["input_ids"])
72 return result
73