tokenization_voxcpm2.py
| 1 | """Custom tokenizer for VoxCPM2 that splits multi-character Chinese tokens. |
| 2 | |
| 3 | VoxCPM2 was trained with ``mask_multichar_chinese_tokens`` which splits |
| 4 | multi-character Chinese tokens (e.g. "你好" -> ["你", "好"]) into individual |
| 5 | character IDs before embedding. The base LlamaTokenizerFast produces |
| 6 | multi-character Chinese tokens that the model has never seen during training, |
| 7 | yielding garbled Chinese audio output in downstream inference frameworks. |
| 8 | |
| 9 | This module provides ``VoxCPM2Tokenizer`` which transparently applies the |
| 10 | character splitting inside ``encode()`` and ``__call__()``, so any downstream |
| 11 | consumer (vLLM, vLLM-Omni, Nano-vLLM, etc.) gets correct single-character |
| 12 | IDs without code changes. |
| 13 | """ |
| 14 | |
| 15 | from transformers import LlamaTokenizerFast |
| 16 | |
| 17 | |
| 18 | class VoxCPM2Tokenizer(LlamaTokenizerFast): |
| 19 | |
| 20 | def __init__(self, *args, **kwargs): |
| 21 | super().__init__(*args, **kwargs) |
| 22 | self._split_map = self._build_split_map() |
| 23 | |
| 24 | def _build_split_map(self) -> dict[int, list[int]]: |
| 25 | vocab = self.get_vocab() |
| 26 | split_map: dict[int, list[int]] = {} |
| 27 | for token, tid in vocab.items(): |
| 28 | clean = token.replace("\u2581", "") |
| 29 | if len(clean) >= 2 and all(self._is_cjk(c) for c in clean): |
| 30 | char_ids = self.convert_tokens_to_ids(list(clean)) |
| 31 | if all(c != self.unk_token_id for c in char_ids): |
| 32 | split_map[tid] = char_ids |
| 33 | return split_map |
| 34 | |
| 35 | @staticmethod |
| 36 | def _is_cjk(c: str) -> bool: |
| 37 | return ( |
| 38 | "\u4e00" <= c <= "\u9fff" |
| 39 | or "\u3400" <= c <= "\u4dbf" |
| 40 | or "\uf900" <= c <= "\ufaff" |
| 41 | or "\U00020000" <= c <= "\U0002a6df" |
| 42 | ) |
| 43 | |
| 44 | def _expand_ids(self, ids: list[int]) -> list[int]: |
| 45 | result: list[int] = [] |
| 46 | for tid in ids: |
| 47 | expansion = self._split_map.get(tid) |
| 48 | if expansion is not None: |
| 49 | result.extend(expansion) |
| 50 | else: |
| 51 | result.append(tid) |
| 52 | return result |
| 53 | |
| 54 | def encode(self, text, *args, **kwargs): |
| 55 | ids = super().encode(text, *args, **kwargs) |
| 56 | return self._expand_ids(ids) |
| 57 | |
| 58 | def __call__(self, text, *args, **kwargs): |
| 59 | result = super().__call__(text, *args, **kwargs) |
| 60 | if hasattr(result, "input_ids"): |
| 61 | ids = result["input_ids"] |
| 62 | if isinstance(ids, list) and ids and isinstance(ids[0], list): |
| 63 | result["input_ids"] = [self._expand_ids(x) for x in ids] |
| 64 | if "attention_mask" in result: |
| 65 | result["attention_mask"] = [ |
| 66 | [1] * len(x) for x in result["input_ids"] |
| 67 | ] |
| 68 | elif isinstance(ids, list): |
| 69 | result["input_ids"] = self._expand_ids(ids) |
| 70 | if "attention_mask" in result: |
| 71 | result["attention_mask"] = [1] * len(result["input_ids"]) |
| 72 | return result |
| 73 | |