tokenization_kimi.py
| 1 | import os |
| 2 | from collections import OrderedDict |
| 3 | from logging import getLogger |
| 4 | from pathlib import Path |
| 5 | from shutil import copyfile |
| 6 | from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast |
| 7 | |
| 8 | import tiktoken |
| 9 | from tiktoken.load import load_tiktoken_bpe |
| 10 | from tokenizers import AddedToken |
| 11 | |
| 12 | from transformers.convert_slow_tokenizer import bytes_to_unicode |
| 13 | from transformers.tokenization_utils import PreTrainedTokenizer |
| 14 | |
| 15 | from .tool_declaration_ts import encode_tools_to_typescript_style |
| 16 | |
| 17 | logger = getLogger(__name__) |
| 18 | VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"} |
| 19 | |
| 20 | |
| 21 | class TikTokenTokenizer(PreTrainedTokenizer): |
| 22 | """ |
| 23 | Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py. |
| 24 | |
| 25 | This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to |
| 26 | this superclass for more information regarding those methods. |
| 27 | |
| 28 | Args: |
| 29 | vocab_file (`str`): |
| 30 | The path to the Tiktoken model file. |
| 31 | bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`): |
| 32 | The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. |
| 33 | eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`): |
| 34 | The end of sequence token. |
| 35 | unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`): |
| 36 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this |
| 37 | token instead. The second to last item in special_tokens. |
| 38 | pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`): |
| 39 | The token used for padding, for example when batching sequences of different lengths. |
| 40 | additional_special_tokens (list of `str`, *optional*): |
| 41 | A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be |
| 42 | skipped when decoding if `skip_special_tokens` is set to `True`. |
| 43 | """ |
| 44 | |
| 45 | vocab_files_names = VOCAB_FILES_NAMES |
| 46 | |
| 47 | model_input_names = ["input_ids", "attention_mask"] |
| 48 | |
| 49 | special_tokens: Dict[str, int] |
| 50 | |
| 51 | num_reserved_special_tokens = 256 |
| 52 | |
| 53 | pat_str = "|".join([ |
| 54 | r"""[\p{Han}]+""", |
| 55 | r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""", |
| 56 | r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""", |
| 57 | r"""\p{N}{1,3}""", |
| 58 | r""" ?[^\s\p{L}\p{N}]+[\r\n]*""", |
| 59 | r"""\s*[\r\n]+""", |
| 60 | r"""\s+(?!\S)""", |
| 61 | r"""\s+""", |
| 62 | ]) |
| 63 | |
| 64 | def __init__( |
| 65 | self, |
| 66 | vocab_file, |
| 67 | bos_token: Union[str, AddedToken] = "[BOS]", |
| 68 | eos_token: Union[str, AddedToken] = "[EOS]", |
| 69 | unk_token: Union[str, AddedToken, None] = None, |
| 70 | pad_token: Union[str, AddedToken, None] = None, |
| 71 | additional_special_tokens: List[str] = None, |
| 72 | added_tokens_decoder: Optional[dict] = None, |
| 73 | **kwargs, |
| 74 | ): |
| 75 | assert os.path.isfile(vocab_file), vocab_file |
| 76 | |
| 77 | if additional_special_tokens is None: |
| 78 | additional_special_tokens = [ |
| 79 | "<|im_end|>", |
| 80 | "<|im_user|>", |
| 81 | "<|im_assistant|>", |
| 82 | "<|start_header_id|>", |
| 83 | "<|end_header_id|>", |
| 84 | "[EOT]", |
| 85 | "<|im_system|>", |
| 86 | "<|im_middle|>", |
| 87 | ] |
| 88 | |
| 89 | if added_tokens_decoder: |
| 90 | special_tokens_mapping = { |
| 91 | i: added_tokens_decoder[i].content |
| 92 | for i in added_tokens_decoder |
| 93 | } |
| 94 | else: |
| 95 | special_tokens_mapping = {} |
| 96 | |
| 97 | self.vocab_file = vocab_file |
| 98 | mergeable_ranks = load_tiktoken_bpe(vocab_file) |
| 99 | num_base_tokens = len(mergeable_ranks) |
| 100 | self.special_tokens = { |
| 101 | special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i |
| 102 | for i in range(num_base_tokens, num_base_tokens + |
| 103 | self.num_reserved_special_tokens) |
| 104 | } |
| 105 | |
| 106 | self.model = tiktoken.Encoding( |
| 107 | name=Path(vocab_file).name, |
| 108 | pat_str=self.pat_str, |
| 109 | mergeable_ranks=mergeable_ranks, |
| 110 | special_tokens=self.special_tokens, |
| 111 | ) |
| 112 | logger.info(f"Reloaded tiktoken model from {vocab_file}") |
| 113 | |
| 114 | self.n_words: int = self.model.n_vocab |
| 115 | # BOS / EOS token IDs |
| 116 | self.bos_id: int = self.special_tokens[str(bos_token)] |
| 117 | self.eos_id: int = self.special_tokens[str(eos_token)] |
| 118 | logger.info( |
| 119 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" |
| 120 | ) |
| 121 | |
| 122 | self.pad_id: int = self.special_tokens[str(pad_token)] |
| 123 | self.unk_id: int = self.special_tokens[str(unk_token)] |
| 124 | |
| 125 | self.byte_encoder = bytes_to_unicode() |
| 126 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} |
| 127 | |
| 128 | self.decoder = {} |
| 129 | for i in range(self.n_words): |
| 130 | # Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee |
| 131 | decoding = ''.join([ |
| 132 | self.byte_encoder[ord(char)] for char in |
| 133 | self.model.decode_single_token_bytes(i).decode('latin-1') |
| 134 | ]) |
| 135 | self.decoder[i] = decoding |
| 136 | |
| 137 | self.encoder = {} |
| 138 | for i in range(self.n_words): |
| 139 | if i in self.decoder: |
| 140 | self.encoder[self.decoder[i]] = i |
| 141 | |
| 142 | self._token_config_cache = OrderedDict() |
| 143 | self._cache_max_size = 128 |
| 144 | |
| 145 | super().__init__( |
| 146 | bos_token=bos_token, |
| 147 | eos_token=eos_token, |
| 148 | unk_token=unk_token, |
| 149 | pad_token=pad_token, |
| 150 | additional_special_tokens=additional_special_tokens, |
| 151 | added_tokens_decoder=added_tokens_decoder, |
| 152 | **kwargs, |
| 153 | ) |
| 154 | self.all_special_ids_set = set(self.all_special_ids) |
| 155 | |
| 156 | def encode(self, |
| 157 | text: str, |
| 158 | allow_special_tokens: bool = True, |
| 159 | **kwargs) -> List[int]: |
| 160 | """ |
| 161 | Encodes a string into a list of token IDs. |
| 162 | |
| 163 | Args: |
| 164 | text (str): The input string to be encoded. |
| 165 | |
| 166 | Returns: |
| 167 | list[int]: A list of token IDs. |
| 168 | """ |
| 169 | # If there are other args, we should call super().encode because there are a lot of code |
| 170 | # to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id. |
| 171 | # NOTE: our encode method is not compatible with the super().encode method, |
| 172 | # e.g. split_special_tokens' default is True in our encode method. |
| 173 | if len(kwargs) > 0: |
| 174 | logger.warning(f"Calling super().encode with {kwargs}") |
| 175 | return super().encode(text, **kwargs) |
| 176 | |
| 177 | assert type(text) is str |
| 178 | |
| 179 | # The tiktoken tokenizer can handle <=400k chars without |
| 180 | # pyo3_runtime.PanicException. |
| 181 | TIKTOKEN_MAX_ENCODE_CHARS = 400_000 |
| 182 | |
| 183 | # https://github.com/openai/tiktoken/issues/195 |
| 184 | # Here we iterate over subsequences and split if we exceed the limit |
| 185 | # of max consecutive non-whitespace or whitespace characters. |
| 186 | MAX_NO_WHITESPACES_CHARS = 25_000 |
| 187 | |
| 188 | texts = self.pre_tokenizer_process(text) |
| 189 | |
| 190 | all_substrs = [] |
| 191 | for text in texts: |
| 192 | substrs = ( |
| 193 | substr for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS) |
| 194 | for substr in self._split_whitespaces_or_nonwhitespaces( |
| 195 | text[i:i + |
| 196 | TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS)) |
| 197 | all_substrs.extend(substrs) |
| 198 | |
| 199 | t: List[int] = [] |
| 200 | for substr in all_substrs: |
| 201 | if allow_special_tokens: |
| 202 | t.extend( |
| 203 | # we should consider special token as a common token |
| 204 | self.model.encode( |
| 205 | substr, |
| 206 | allowed_special="all", |
| 207 | )) |
| 208 | else: |
| 209 | t.extend( |
| 210 | # we should consider special token as a common token |
| 211 | self.model.encode( |
| 212 | substr, |
| 213 | disallowed_special=(), |
| 214 | )) |
| 215 | |
| 216 | return t |
| 217 | |
| 218 | def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str: |
| 219 | """ |
| 220 | Decodes a list of token IDs into a string. |
| 221 | |
| 222 | Args: |
| 223 | token_ids (List[int]): The list of token IDs to be decoded. |
| 224 | |
| 225 | Returns: |
| 226 | str: The decoded string. |
| 227 | """ |
| 228 | # If there are other args, we should call super().decode because there are a lot of code |
| 229 | # to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token. |
| 230 | if len(kwargs) > 0: |
| 231 | return super().decode(token_ids, **kwargs) |
| 232 | |
| 233 | if type(token_ids) is int: |
| 234 | token_ids = [token_ids] |
| 235 | |
| 236 | return self.model.decode(cast(List[int], token_ids)) |
| 237 | |
| 238 | @staticmethod |
| 239 | def _split_whitespaces_or_nonwhitespaces( |
| 240 | s: str, max_consecutive_slice_len: int) -> Iterator[str]: |
| 241 | """ |
| 242 | Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` |
| 243 | consecutive whitespaces or consecutive non-whitespaces. |
| 244 | """ |
| 245 | current_slice_len = 0 |
| 246 | current_slice_is_space = s[0].isspace() if len(s) > 0 else False |
| 247 | slice_start = 0 |
| 248 | |
| 249 | for i in range(len(s)): |
| 250 | is_now_space = s[i].isspace() |
| 251 | |
| 252 | if current_slice_is_space ^ is_now_space: |
| 253 | current_slice_len = 1 |
| 254 | current_slice_is_space = is_now_space |
| 255 | else: |
| 256 | current_slice_len += 1 |
| 257 | if current_slice_len > max_consecutive_slice_len: |
| 258 | yield s[slice_start:i] |
| 259 | slice_start = i |
| 260 | current_slice_len = 1 |
| 261 | yield s[slice_start:] |
| 262 | |
| 263 | def pre_tokenizer_process(self, text: str) -> List[str]: |
| 264 | """ |
| 265 | pre-tokenizes the input text into a list of tokens. |
| 266 | This method is used to split the input text into smaller chunks for internal processing. |
| 267 | """ |
| 268 | return [text] |
| 269 | |
| 270 | """ ----- Below are the abstract methods required by PreTrainedTokenizer ----- """ |
| 271 | |
| 272 | @property |
| 273 | def vocab_size(self) -> int: |
| 274 | return self.n_words |
| 275 | |
| 276 | def get_vocab(self) -> Dict[str, int]: |
| 277 | return self.encoder |
| 278 | |
| 279 | def _tokenize(self, text: str, **kwargs) -> List[str]: |
| 280 | return [self.decoder[t] for t in self.encode(text)] |
| 281 | |
| 282 | def _convert_token_to_id(self, token: str) -> int: |
| 283 | return self.encoder.get(token, self.unk_id) |
| 284 | |
| 285 | def _convert_id_to_token(self, index: int) -> str: |
| 286 | return self.decoder.get(index) |
| 287 | |
| 288 | @staticmethod |
| 289 | def clean_up_tokenization(out_string: str) -> str: |
| 290 | return out_string |
| 291 | |
| 292 | def convert_tokens_to_string(self, tokens: List[str]) -> str: |
| 293 | text = ''.join(tokens) |
| 294 | text = bytearray([self.byte_decoder[c] |
| 295 | for c in text]).decode('utf-8', 'replace') |
| 296 | return text |
| 297 | |
| 298 | def save_vocabulary(self, |
| 299 | save_directory: str, |
| 300 | filename_prefix: Optional[str] = None) -> Tuple[str]: |
| 301 | if not os.path.isdir(save_directory): |
| 302 | raise ValueError( |
| 303 | f"vocabulary path ({save_directory}) should be a directory") |
| 304 | out_vocab_file = os.path.join( |
| 305 | save_directory, |
| 306 | (filename_prefix + "-" if filename_prefix else "") + |
| 307 | VOCAB_FILES_NAMES["vocab_file"]) |
| 308 | |
| 309 | if os.path.abspath(self.vocab_file) != os.path.abspath( |
| 310 | out_vocab_file) and os.path.isfile(self.vocab_file): |
| 311 | copyfile(self.vocab_file, out_vocab_file) |
| 312 | |
| 313 | return (out_vocab_file, ) |
| 314 | |
| 315 | def apply_chat_template(self, |
| 316 | conversation, |
| 317 | tools: Optional[list[dict]] = None, |
| 318 | tokenize: bool = False, |
| 319 | add_generation_prompt: bool = True, |
| 320 | thinking: bool = True, |
| 321 | **kwargs): |
| 322 | |
| 323 | tools = deep_sort_dict(tools) |
| 324 | |
| 325 | # Convert tools to TypeScript style string if tools are provided |
| 326 | tools_ts_str = None |
| 327 | if tools: |
| 328 | try: |
| 329 | tools_ts_str = encode_tools_to_typescript_style(tools) |
| 330 | |
| 331 | except Exception as e: |
| 332 | print(f"Failed to convert tools to TypeScript style: {e}") |
| 333 | tools_ts_str = None |
| 334 | |
| 335 | # Store the TypeScript string in kwargs so it can be accessed by the template |
| 336 | if tools_ts_str is not None: |
| 337 | kwargs['tools_ts_str'] = tools_ts_str |
| 338 | return super().apply_chat_template( |
| 339 | conversation, |
| 340 | tools=tools, |
| 341 | tokenize=tokenize, |
| 342 | add_generation_prompt=add_generation_prompt, |
| 343 | thinking=thinking, |
| 344 | **kwargs) |
| 345 | |
| 346 | |
| 347 | def deep_sort_dict(obj: Any) -> Any: |
| 348 | if isinstance(obj, dict): |
| 349 | return {k: deep_sort_dict(v) for k, v in sorted(obj.items())} |
| 350 | if isinstance(obj, list): |
| 351 | return [deep_sort_dict(item) for item in obj] |
| 352 | return obj |
| 353 | |