tokenization_kimi.py
13.0 KB · 353 lines · python Raw
1 import os
2 from collections import OrderedDict
3 from logging import getLogger
4 from pathlib import Path
5 from shutil import copyfile
6 from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
7
8 import tiktoken
9 from tiktoken.load import load_tiktoken_bpe
10 from tokenizers import AddedToken
11
12 from transformers.convert_slow_tokenizer import bytes_to_unicode
13 from transformers.tokenization_utils import PreTrainedTokenizer
14
15 from .tool_declaration_ts import encode_tools_to_typescript_style
16
17 logger = getLogger(__name__)
18 VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
19
20
21 class TikTokenTokenizer(PreTrainedTokenizer):
22 """
23 Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
24
25 This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
26 this superclass for more information regarding those methods.
27
28 Args:
29 vocab_file (`str`):
30 The path to the Tiktoken model file.
31 bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
32 The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
33 eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
34 The end of sequence token.
35 unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
36 The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
37 token instead. The second to last item in special_tokens.
38 pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
39 The token used for padding, for example when batching sequences of different lengths.
40 additional_special_tokens (list of `str`, *optional*):
41 A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
42 skipped when decoding if `skip_special_tokens` is set to `True`.
43 """
44
45 vocab_files_names = VOCAB_FILES_NAMES
46
47 model_input_names = ["input_ids", "attention_mask"]
48
49 special_tokens: Dict[str, int]
50
51 num_reserved_special_tokens = 256
52
53 pat_str = "|".join([
54 r"""[\p{Han}]+""",
55 r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
56 r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
57 r"""\p{N}{1,3}""",
58 r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
59 r"""\s*[\r\n]+""",
60 r"""\s+(?!\S)""",
61 r"""\s+""",
62 ])
63
64 def __init__(
65 self,
66 vocab_file,
67 bos_token: Union[str, AddedToken] = "[BOS]",
68 eos_token: Union[str, AddedToken] = "[EOS]",
69 unk_token: Union[str, AddedToken, None] = None,
70 pad_token: Union[str, AddedToken, None] = None,
71 additional_special_tokens: List[str] = None,
72 added_tokens_decoder: Optional[dict] = None,
73 **kwargs,
74 ):
75 assert os.path.isfile(vocab_file), vocab_file
76
77 if additional_special_tokens is None:
78 additional_special_tokens = [
79 "<|im_end|>",
80 "<|im_user|>",
81 "<|im_assistant|>",
82 "<|start_header_id|>",
83 "<|end_header_id|>",
84 "[EOT]",
85 "<|im_system|>",
86 "<|im_middle|>",
87 ]
88
89 if added_tokens_decoder:
90 special_tokens_mapping = {
91 i: added_tokens_decoder[i].content
92 for i in added_tokens_decoder
93 }
94 else:
95 special_tokens_mapping = {}
96
97 self.vocab_file = vocab_file
98 mergeable_ranks = load_tiktoken_bpe(vocab_file)
99 num_base_tokens = len(mergeable_ranks)
100 self.special_tokens = {
101 special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
102 for i in range(num_base_tokens, num_base_tokens +
103 self.num_reserved_special_tokens)
104 }
105
106 self.model = tiktoken.Encoding(
107 name=Path(vocab_file).name,
108 pat_str=self.pat_str,
109 mergeable_ranks=mergeable_ranks,
110 special_tokens=self.special_tokens,
111 )
112 logger.info(f"Reloaded tiktoken model from {vocab_file}")
113
114 self.n_words: int = self.model.n_vocab
115 # BOS / EOS token IDs
116 self.bos_id: int = self.special_tokens[str(bos_token)]
117 self.eos_id: int = self.special_tokens[str(eos_token)]
118 logger.info(
119 f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
120 )
121
122 self.pad_id: int = self.special_tokens[str(pad_token)]
123 self.unk_id: int = self.special_tokens[str(unk_token)]
124
125 self.byte_encoder = bytes_to_unicode()
126 self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
127
128 self.decoder = {}
129 for i in range(self.n_words):
130 # Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
131 decoding = ''.join([
132 self.byte_encoder[ord(char)] for char in
133 self.model.decode_single_token_bytes(i).decode('latin-1')
134 ])
135 self.decoder[i] = decoding
136
137 self.encoder = {}
138 for i in range(self.n_words):
139 if i in self.decoder:
140 self.encoder[self.decoder[i]] = i
141
142 self._token_config_cache = OrderedDict()
143 self._cache_max_size = 128
144
145 super().__init__(
146 bos_token=bos_token,
147 eos_token=eos_token,
148 unk_token=unk_token,
149 pad_token=pad_token,
150 additional_special_tokens=additional_special_tokens,
151 added_tokens_decoder=added_tokens_decoder,
152 **kwargs,
153 )
154 self.all_special_ids_set = set(self.all_special_ids)
155
156 def encode(self,
157 text: str,
158 allow_special_tokens: bool = True,
159 **kwargs) -> List[int]:
160 """
161 Encodes a string into a list of token IDs.
162
163 Args:
164 text (str): The input string to be encoded.
165
166 Returns:
167 list[int]: A list of token IDs.
168 """
169 # If there are other args, we should call super().encode because there are a lot of code
170 # to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.
171 # NOTE: our encode method is not compatible with the super().encode method,
172 # e.g. split_special_tokens' default is True in our encode method.
173 if len(kwargs) > 0:
174 logger.warning(f"Calling super().encode with {kwargs}")
175 return super().encode(text, **kwargs)
176
177 assert type(text) is str
178
179 # The tiktoken tokenizer can handle <=400k chars without
180 # pyo3_runtime.PanicException.
181 TIKTOKEN_MAX_ENCODE_CHARS = 400_000
182
183 # https://github.com/openai/tiktoken/issues/195
184 # Here we iterate over subsequences and split if we exceed the limit
185 # of max consecutive non-whitespace or whitespace characters.
186 MAX_NO_WHITESPACES_CHARS = 25_000
187
188 texts = self.pre_tokenizer_process(text)
189
190 all_substrs = []
191 for text in texts:
192 substrs = (
193 substr for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
194 for substr in self._split_whitespaces_or_nonwhitespaces(
195 text[i:i +
196 TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS))
197 all_substrs.extend(substrs)
198
199 t: List[int] = []
200 for substr in all_substrs:
201 if allow_special_tokens:
202 t.extend(
203 # we should consider special token as a common token
204 self.model.encode(
205 substr,
206 allowed_special="all",
207 ))
208 else:
209 t.extend(
210 # we should consider special token as a common token
211 self.model.encode(
212 substr,
213 disallowed_special=(),
214 ))
215
216 return t
217
218 def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
219 """
220 Decodes a list of token IDs into a string.
221
222 Args:
223 token_ids (List[int]): The list of token IDs to be decoded.
224
225 Returns:
226 str: The decoded string.
227 """
228 # If there are other args, we should call super().decode because there are a lot of code
229 # to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
230 if len(kwargs) > 0:
231 return super().decode(token_ids, **kwargs)
232
233 if type(token_ids) is int:
234 token_ids = [token_ids]
235
236 return self.model.decode(cast(List[int], token_ids))
237
238 @staticmethod
239 def _split_whitespaces_or_nonwhitespaces(
240 s: str, max_consecutive_slice_len: int) -> Iterator[str]:
241 """
242 Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
243 consecutive whitespaces or consecutive non-whitespaces.
244 """
245 current_slice_len = 0
246 current_slice_is_space = s[0].isspace() if len(s) > 0 else False
247 slice_start = 0
248
249 for i in range(len(s)):
250 is_now_space = s[i].isspace()
251
252 if current_slice_is_space ^ is_now_space:
253 current_slice_len = 1
254 current_slice_is_space = is_now_space
255 else:
256 current_slice_len += 1
257 if current_slice_len > max_consecutive_slice_len:
258 yield s[slice_start:i]
259 slice_start = i
260 current_slice_len = 1
261 yield s[slice_start:]
262
263 def pre_tokenizer_process(self, text: str) -> List[str]:
264 """
265 pre-tokenizes the input text into a list of tokens.
266 This method is used to split the input text into smaller chunks for internal processing.
267 """
268 return [text]
269
270 """ ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
271
272 @property
273 def vocab_size(self) -> int:
274 return self.n_words
275
276 def get_vocab(self) -> Dict[str, int]:
277 return self.encoder
278
279 def _tokenize(self, text: str, **kwargs) -> List[str]:
280 return [self.decoder[t] for t in self.encode(text)]
281
282 def _convert_token_to_id(self, token: str) -> int:
283 return self.encoder.get(token, self.unk_id)
284
285 def _convert_id_to_token(self, index: int) -> str:
286 return self.decoder.get(index)
287
288 @staticmethod
289 def clean_up_tokenization(out_string: str) -> str:
290 return out_string
291
292 def convert_tokens_to_string(self, tokens: List[str]) -> str:
293 text = ''.join(tokens)
294 text = bytearray([self.byte_decoder[c]
295 for c in text]).decode('utf-8', 'replace')
296 return text
297
298 def save_vocabulary(self,
299 save_directory: str,
300 filename_prefix: Optional[str] = None) -> Tuple[str]:
301 if not os.path.isdir(save_directory):
302 raise ValueError(
303 f"vocabulary path ({save_directory}) should be a directory")
304 out_vocab_file = os.path.join(
305 save_directory,
306 (filename_prefix + "-" if filename_prefix else "") +
307 VOCAB_FILES_NAMES["vocab_file"])
308
309 if os.path.abspath(self.vocab_file) != os.path.abspath(
310 out_vocab_file) and os.path.isfile(self.vocab_file):
311 copyfile(self.vocab_file, out_vocab_file)
312
313 return (out_vocab_file, )
314
315 def apply_chat_template(self,
316 conversation,
317 tools: Optional[list[dict]] = None,
318 tokenize: bool = False,
319 add_generation_prompt: bool = True,
320 thinking: bool = True,
321 **kwargs):
322
323 tools = deep_sort_dict(tools)
324
325 # Convert tools to TypeScript style string if tools are provided
326 tools_ts_str = None
327 if tools:
328 try:
329 tools_ts_str = encode_tools_to_typescript_style(tools)
330
331 except Exception as e:
332 print(f"Failed to convert tools to TypeScript style: {e}")
333 tools_ts_str = None
334
335 # Store the TypeScript string in kwargs so it can be accessed by the template
336 if tools_ts_str is not None:
337 kwargs['tools_ts_str'] = tools_ts_str
338 return super().apply_chat_template(
339 conversation,
340 tools=tools,
341 tokenize=tokenize,
342 add_generation_prompt=add_generation_prompt,
343 thinking=thinking,
344 **kwargs)
345
346
347 def deep_sort_dict(obj: Any) -> Any:
348 if isinstance(obj, dict):
349 return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
350 if isinstance(obj, list):
351 return [deep_sort_dict(item) for item in obj]
352 return obj
353