tokenizer_wrapper.py
65.3 KB · 1427 lines · python Raw
1 # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
2 # you may not use this file except in compliance with the License.
3 # You may obtain a copy of the License at
4 #
5 # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
6 #
7 # Unless required by applicable law or agreed to in writing, software
8 # distributed under the License is distributed on an "AS IS" BASIS,
9 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 # See the License for the specific language governing permissions and
11 # limitations under the License.
12 # ==============================================================================
13
14 import warnings
15 import random
16 from typing import List, Optional, Union, Dict, Any
17 from collections import defaultdict
18 from copy import deepcopy
19
20 import numpy as np
21 import torch
22 import torch.nn.functional as F
23 from transformers import AutoTokenizer
24 from diffusers.utils import BaseOutput
25
26
27 def default(value, default_value):
28 return value if value is not None else default_value
29
30
31 def ensure_list(value):
32 if value is None:
33 return []
34 if isinstance(value, (list, tuple)):
35 return list(value)
36 return [value]
37
38
39 class Resolution(object):
40 def __init__(self, size, *args):
41 if isinstance(size, str):
42 if 'x' in size:
43 size = size.split('x')
44 size = (int(size[0]), int(size[1]))
45 else:
46 size = int(size)
47 if len(args) > 0:
48 size = (size, args[0])
49 if isinstance(size, int):
50 size = (size, size)
51
52 self.h = self.height = size[0]
53 self.w = self.width = size[1]
54 self.r = self.ratio = self.height / self.width
55
56 def __getitem__(self, idx):
57 if idx == 0:
58 return self.h
59 elif idx == 1:
60 return self.w
61 else:
62 raise IndexError(f'Index {idx} out of range')
63
64 def __str__(self):
65 return f'{self.h}x{self.w}'
66
67
68 class ResolutionGroup(object):
69 def __init__(self, base_size=None, step=None, align=1):
70 self.align = align
71 self.base_size = base_size
72 assert base_size % align == 0, f'base_size {base_size} is not divisible by align {align}'
73 if base_size is not None and not isinstance(base_size, int):
74 raise ValueError(f'base_size must be None or int, but got {type(base_size)}')
75 if step is None:
76 step = base_size // 16
77 if step is not None and step > base_size // 2:
78 raise ValueError(f'step must be smaller than base_size // 2, but got {step} > {base_size // 2}')
79
80 self.step = step
81 self.data = self._calc_by_step()
82
83 self.ratio = np.array([x.ratio for x in self.data])
84 self.attr = ['' for _ in range(len(self.data))]
85 self.prefix_space = 0
86
87 def __len__(self):
88 return len(self.data)
89
90 def __getitem__(self, idx):
91 return self.data[idx]
92
93 def __repr__(self):
94 prefix = self.prefix_space * ' '
95 prefix_close = (self.prefix_space - 4) * ' '
96 res_str = f'ResolutionGroup(base_size={self.base_size}, step={self.step}, data='
97 attr_maxlen = max([len(x) for x in self.attr] + [5])
98 res_str += \
99 f'\n{prefix}ID: height width ratio {" " * max(0, attr_maxlen - 4)}count h/16 w/16 tokens\n{prefix}'
100 res_str += \
101 ('\n' + prefix).join([f'{i:2d}: ({x.h:4d}, {x.w:4d}) {self.ratio[i]:.4f} {self.attr[i]:>{attr_maxlen}s} '
102 f'({x.h // 16:3d}, {x.w // 16:3d}) {x.h // 16 * x.w // 16:6d}'
103 for i, x in enumerate(self.data)])
104 res_str += f'\n{prefix_close})'
105 return res_str
106
107 def _calc_by_step(self):
108 assert self.align <= self.step, f'align {self.align} must be smaller than step {self.step}'
109
110 min_height = self.base_size // 2
111 min_width = self.base_size // 2
112 max_height = self.base_size * 2
113 max_width = self.base_size * 2
114
115 resolutions = [Resolution(self.base_size, self.base_size)]
116
117 cur_height, cur_width = self.base_size, self.base_size
118 while True:
119 if cur_height >= max_height and cur_width <= min_width:
120 break
121
122 cur_height = min(cur_height + self.step, max_height)
123 cur_width = max(cur_width - self.step, min_width)
124 resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align))
125
126 cur_height, cur_width = self.base_size, self.base_size
127 while True:
128 if cur_height <= min_height and cur_width >= max_width:
129 break
130
131 cur_height = max(cur_height - self.step, min_height)
132 cur_width = min(cur_width + self.step, max_width)
133 resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align))
134
135 resolutions = sorted(resolutions, key=lambda x: x.ratio)
136
137 return resolutions
138
139 def get_target_size(self, width, height):
140 ratio = height / width
141 idx = np.argmin(np.abs(self.ratio - ratio))
142 reso = self.data[idx]
143 return reso.w, reso.h
144
145 def get_base_size_and_ratio_index(self, width, height):
146 ratio = height / width
147 idx = np.argmin(np.abs(self.ratio - ratio))
148 return self.base_size, idx
149
150
151 class ImageInfo:
152 """ Class to store image information for processing and generation. """
153
154 def __init__(
155 self,
156 image_type: str = None,
157 image_tensor: torch.Tensor = None,
158 image_width: int = None,
159 image_height: int = None,
160 token_width: int = None,
161 token_height: int = None,
162 image_token_length: int = None,
163 base_size: int = None,
164 ratio_index: int = None,
165 **kwargs,
166 ):
167 self.image_type = image_type
168 self.image_tensor = image_tensor
169 self.image_width = image_width
170 self.w = image_width
171 self.image_height = image_height
172 self.h = image_height
173 self.token_width = token_width
174 self.tk_w = token_width
175 self.token_height = token_height
176 self.tk_h = token_height
177 self.image_token_length = default(
178 image_token_length,
179 token_width * token_height if token_width is not None and token_height is not None else None
180 )
181 self.base_size = base_size
182 self.ratio_index = ratio_index
183
184 self.add_timestep_token = kwargs.get("add_timestep_token", True)
185 self.add_guidance_token = kwargs.get("add_guidance_token", False)
186 self.use_front_boi_token = kwargs.get("use_front_boi_token", True)
187 self.add_image_shape_token = kwargs.get("add_image_shape_token", True)
188
189 def __getitem__(self, key: str) -> Any:
190 """Allow dictionary-like access to attributes."""
191 if hasattr(self, key):
192 return getattr(self, key)
193 raise KeyError(f"Key '{key}' not found in ImageInfo")
194
195 def __setitem__(self, key: str, value: Any) -> None:
196 """Allow dictionary-like assignment to attributes."""
197 if hasattr(self, key):
198 setattr(self, key, value)
199 else:
200 raise KeyError(f"Key '{key}' not found in ImageInfo")
201
202 def __contains__(self, key: str) -> bool:
203 """Check if the key exists in the ImageInfo object."""
204 return hasattr(self, key)
205
206 def __repr__(self):
207 return (f"ImageInfo(image_type={self.image_type}, image_tensor={self.image_tensor}, "
208 f"image_width={self.image_width}, image_height={self.image_height}, "
209 f"token_width={self.token_width}, token_height={self.token_height}, "
210 f"image_token_length={self.image_token_length}, "
211 f"base_size={self.base_size}, ratio_index={self.ratio_index}")
212
213 @property
214 def meta_info(self):
215 # Used for image sections of tkwrapper.encode_general()
216 if self.image_type in ["vae", "gen_image"]:
217 return dict(
218 token_length=self.image_token_length,
219 add_timestep_token=self.add_timestep_token,
220 add_guidance_token=self.add_guidance_token,
221 use_front_boi_token=self.use_front_boi_token,
222 add_image_shape_token=self.add_image_shape_token,
223 base_size=self.base_size,
224 ratio_idx=self.ratio_index,
225 # for rope 2d
226 token_height=self.token_height,
227 token_width=self.token_width,
228 # for bc
229 image_height=self.image_height,
230 image_width=self.image_width,
231 )
232 elif self.image_type in ["vit"]:
233 return dict(
234 token_length=self.image_token_length,
235 use_front_boi_token=self.use_front_boi_token,
236 add_image_shape_token=self.add_image_shape_token,
237 # for rope 2d
238 token_height=self.token_height,
239 token_width=self.token_width,
240 # for bc
241 image_height=self.image_height,
242 image_width=self.image_width,
243 )
244 else:
245 raise ValueError(f"Unknown image type '{self.image_type}'")
246
247 @property
248 def num_special_tokens(self):
249 if self.args is None:
250 raise ValueError("meta_info requires `args` attribute to be set.")
251 if self.image_type in ["vae", "src_image", "gen_image"]:
252 count = (
253 2 + # <boi> + <eoi> or <src_boi> + <src_eoi>
254 (1 if self.add_timestep_token else 0) +
255 (1 if self.add_guidance_token else 0) +
256 (2 if self.add_image_shape_token else 0)
257 )
258 else:
259 raise ValueError(f"Unknown image_type: {self.image_type}")
260 return count
261
262 def copy(self, copy_image_tensor=True):
263 if copy_image_tensor and self.image_tensor is None:
264 raise ValueError("image_tensor is None, cannot copy")
265 return ImageInfo(
266 image_type=self.image_type,
267 image_tensor=self.image_tensor.clone() if copy_image_tensor else None,
268 image_width=self.image_width,
269 image_height=self.image_height,
270 token_width=self.token_width,
271 token_height=self.token_height,
272 image_token_length=self.image_token_length,
273 base_size=self.base_size,
274 ratio_index=self.ratio_index,
275 )
276
277 def zeros_(self):
278 self.image_tensor = torch.zeros_like(self.image_tensor)
279
280
281 class ImageTensor(torch.Tensor):
282 # This class is just for type hinting purposes. Attribute `i` should be defined
283 # as an instance attribute of the torch.Tensor instance, like: tensor.i = ImageInfo(...)
284 i: ImageInfo
285 vision_encoder_kwargs: dict
286
287
288 class JointImageInfo(object):
289 def __init__(self, vae_image_info: ImageInfo, vision_image_info: ImageInfo, vision_encoder_kwargs: dict = None):
290 self.vae_image_info = vae_image_info
291 self.vision_image_info = vision_image_info
292 self.vision_encoder_kwargs = vision_encoder_kwargs
293
294 # Define key attributes to align with ImageInfo for uniformity
295 self.image_type = "joint_image"
296 self.image_token_length = vae_image_info.image_token_length + vision_image_info.image_token_length
297
298 self.add_timestep_token = vae_image_info.add_timestep_token
299 self.use_front_boi_token = vae_image_info.use_front_boi_token
300 self.add_image_shape_token = vae_image_info.add_image_shape_token
301
302 def __repr__(self):
303 return f"JointImageInfo(vae_image={self.vae_image_info}, vision_image={self.vision_image_info})"
304
305 @property
306 def meta_info(self):
307 # Used for image sections of tkwrapper.encode_general()
308 return dict(
309 token_length=[self.vae_image_info.image_token_length, self.vision_image_info.image_token_length],
310 add_timestep_token=self.add_timestep_token,
311 use_front_boi_token=self.use_front_boi_token,
312 add_image_shape_token=self.add_image_shape_token,
313 base_size=self.vae_image_info.base_size,
314 ratio_idx=self.vae_image_info.ratio_index,
315 # for rope 2d
316 token_height=[self.vae_image_info.token_height, self.vision_image_info.token_height],
317 token_width=[self.vae_image_info.token_width, self.vision_image_info.token_width],
318 # for bc
319 image_height=[self.vae_image_info.image_height, self.vision_image_info.image_height],
320 image_width=[self.vae_image_info.image_width, self.vision_image_info.image_width],
321 )
322
323 @property
324 def num_special_tokens(self):
325 return (
326 2 + # <boi> + <eoi>
327 (1 if self.add_timestep_token else 0) +
328 (2 if self.add_image_shape_token else 0) +
329 1 # <joint_image_sep>
330 )
331
332 def copy(self, copy_image_tensor=True):
333 if copy_image_tensor and (
334 self.vae_image_info.image_tensor is None or self.vision_image_info.image_tensor is None):
335 raise ValueError("image_tensor is None, cannot copy")
336 return JointImageInfo(
337 self.vae_image_info.copy(copy_image_tensor),
338 self.vision_image_info.copy(copy_image_tensor),
339 self.vision_encoder_kwargs,
340 )
341
342 def zeros_(self):
343 self.vae_image_info.zeros_()
344 self.vision_image_info.zeros_()
345
346
347 class JointImage(object):
348 def __init__(self, vae_image: ImageTensor, vision_image: ImageTensor):
349 self.vae_image = vae_image
350 self.vision_image = vision_image
351 self.i = JointImageInfo(vae_image.i, vision_image.i)
352
353
354 class TokenizerEncodeOutput(BaseOutput):
355 tokens: torch.Tensor = None
356 timestep_scatter_index: Optional[torch.Tensor] = None
357 guidance_scatter_index: Optional[torch.Tensor] = None
358 text_slices: Optional[List[slice]] = None
359 gen_image_slices: Optional[List[slice]] = None
360 joint_image_slices: Optional[List[slice]] = None
361 cond_vae_image_slices: Optional[List[slice]] = None
362 cond_vit_image_slices: Optional[List[slice]] = None
363 text_mask: Optional[torch.Tensor] = None
364 gen_image_mask: Optional[torch.Tensor] = None
365 cond_vae_image_mask: Optional[torch.Tensor] = None
366 cond_vit_image_mask: Optional[torch.Tensor] = None
367 real_pos: Optional[torch.Tensor] = None
368 all_image_slices: Optional[List[slice]] = None
369 cond_timestep_scatter_index: Optional[torch.Tensor] = None
370 gen_timestep_scatter_index: Optional[torch.Tensor] = None
371
372
373 class Conversation:
374 roles: List[str] = ["User", "Assistant"]
375 sep: str = "\n\n"
376
377
378 class TokenizerWrapper(object):
379 def __init__(self, tokenizer):
380 if isinstance(tokenizer, str):
381 self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
382 else:
383 self.tokenizer = tokenizer
384
385 # Define short names
386 self.bos_token_id = self.tokenizer.bos_token_id
387 self.eos_token_id = self.tokenizer.eos_token_id
388 self.pad_token_id = self.tokenizer.pad_token_id
389 self.boi_token_id = self.tokenizer.convert_tokens_to_ids("<boi>")
390 self.eoi_token_id = self.tokenizer.convert_tokens_to_ids("<eoi>")
391 self.img_token_id = self.tokenizer.convert_tokens_to_ids("<img>")
392 self.cfg_token_id = self.tokenizer.convert_tokens_to_ids("<cfg>")
393 self.end_answer_token_id = self.tokenizer.convert_tokens_to_ids("</answer>")
394 self.end_recaption_token_id = self.tokenizer.convert_tokens_to_ids("</recaption>")
395 self.ratio_token_offset = self.tokenizer.convert_tokens_to_ids("<img_ratio_0>")
396 self.special_token_map = self.tokenizer.added_tokens_encoder
397
398 def pad(self, tensor_list, dim=0, pad_val=None):
399 if pad_val is None:
400 pad_val = self.pad_token_id
401 max_len = max([t.shape[dim] for t in tensor_list])
402 padded_tensor_list = []
403 for t in tensor_list:
404 if t.shape[dim] < max_len:
405 assert pad_val is not False, "Not allowed pad."
406 t = F.pad(t, (0, max_len - t.shape[dim]), value=pad_val)
407 padded_tensor_list.append(t)
408 return padded_tensor_list
409
410 def encode(self, *args, **kwargs):
411 return self.tokenizer.encode(*args, **kwargs)
412
413 def decode(self, *args, **kwargs):
414 return self.tokenizer.decode(*args, **kwargs)
415
416 def encode_text(
417 self,
418 *texts,
419 uncond_enabled: Optional[Union[bool, List[bool]]] = None,
420 uncond_p: Optional[float] = None,
421 max_length: Optional[int] = None,
422 pad: Optional[str] = None,
423 return_lengths: bool = False,
424 ):
425 """
426 Encode text and image for AR-like model training of the text-to-image/instruction tuning tasks.
427 Support encode multiple texts at once. Each text can be separately conditioned or unconditioned
428 based on the uncond_flags and a uniform uncond_p.
429 **<bos> token is always prepended to the text tokens.**
430
431 Parameters
432 ----------
433 texts: str or List[str]
434 List of texts to be encoded.
435 uncond_enabled: bool or List[bool]
436 List of flags to indicate whether the text should be unconditioned.
437 If False, the text will never be unconditioned.
438 If True, the text will be unconditioned with uncond_p.
439 uncond_p: float
440 Probability to the unconditional text. Only works when uncond_enabled is True.
441 max_length: int
442 Maximum length of the encoded text.
443 pad: Optional[str]
444 Padding method. Can be 'left' or 'right'.
445 return_lengths: bool
446 Whether to return the length of each encoded text.
447 """
448 if pad is not None:
449 assert max_length is not None, "max_length should be provided when pad is not None."
450
451 if uncond_enabled is None:
452 uncond_enabled = [True] * len(texts)
453 elif isinstance(uncond_enabled, bool):
454 uncond_enabled = [uncond_enabled] * len(texts)
455 if len(uncond_enabled) != len(texts):
456 print(uncond_enabled, texts)
457 assert len(uncond_enabled) == len(texts), (
458 f"Length of uncond_flags should be equal to the number of texts, "
459 f"but got {len(uncond_enabled)} and {len(texts)}."
460 )
461
462 # Prepare text/uncond tokens
463 # TODO: If len(texts) > 1, such as instruction + prompt in inpainting, we need to determine how to do uncond.
464 # Now all texts will be cond or uncond at the same time.
465 do_uncond_drop = (uncond_p is not None) and (random.random() < uncond_p)
466 text_tokens, lengths = [], []
467 cum_length = 0
468 for text, uncond_flag in zip(texts, uncond_enabled):
469 # If reach the max_length and there still have unencoded texts, give a warning message and break the loop.
470 if max_length is not None and cum_length >= max_length:
471 warnings.warn(
472 f"Text length exceeds the max_length({max_length}). The remaining texts will be ignored: "
473 f"{text[:80]}..."
474 )
475 break
476 # Set add_special_tokens=False to avoid adding <bos> token in some LLMs.
477 if isinstance(text, str):
478 text_token = self.tokenizer.encode(text, add_special_tokens=False)
479 else:
480 text_token = text
481 if uncond_flag and do_uncond_drop:
482 text_token = [self.cfg_token_id] * len(text_token)
483 # Cutoff the text by max_length if necessary
484 if max_length is not None and (cum_length + len(text_token)) > max_length:
485 text_token = text_token[:max_length - cum_length]
486 text_tokens.extend(text_token)
487 lengths.append(len(text_token))
488 cum_length += len(text_token)
489
490 # Prepend/Append <pad> tokens if applicable
491 if pad is not None and (pad_length := max_length - len(text_tokens)) > 0:
492 if pad == 'left':
493 text_tokens = [self.pad_token_id] * pad_length + text_tokens
494 elif pad == 'right':
495 text_tokens = text_tokens + [self.pad_token_id] * pad_length
496 else:
497 raise ValueError(f"Unsupported padding method: {pad}.")
498
499 if return_lengths:
500 return text_tokens, lengths
501 return text_tokens
502
503 @staticmethod
504 def _check_key_number_matched(keys, data):
505 # Assert keys and token_source are matched
506 assert set(keys) == set(data.keys()), (
507 f"Keys in the template and token source should be matched, but got {set(keys)} and {list(data.keys())}."
508 )
509 key_counts = {k: 0 for k in keys}
510 for key in keys:
511 key_counts[key] += 1
512 for key, count in key_counts.items():
513 assert len(data[key]) == count, (
514 f"Number of `{key}` in the token source should be matched with the template, but got "
515 f"{data[key]}({len(data[key])}) and {count}."
516 )
517
518 def _add_image_meta_info_token(self, token_seq, token_count, extra_token_pos, add_timestep_token=False,
519 add_image_shape_token=False, base_size=None, ratio_idx=None, image_type=None,
520 add_guidance_token=False):
521 if add_image_shape_token:
522 token_seq.extend([
523 self.special_token_map[f"<img_size_{base_size}>"],
524 self.special_token_map[f"<img_ratio_{ratio_idx}>"]
525 ])
526 token_count += 2
527 if add_timestep_token:
528 token_seq.extend([self.special_token_map["<timestep>"]])
529 extra_token_pos['timestep'].append(token_count)
530 if image_type is not None:
531 if image_type == "gen_image":
532 extra_token_pos['gen_timestep'].append(token_count)
533 elif image_type in ["joint_image"]:
534 extra_token_pos['cond_timestep'].append(token_count)
535 else:
536 raise ValueError(f"Unsupported image type: {image_type}.")
537 token_count += 1
538 if add_guidance_token:
539 token_seq.extend([self.special_token_map["<guidance>"]])
540 extra_token_pos['guidance'].append(token_count)
541 token_count += 1
542 return token_count
543
544 @staticmethod
545 def _shorten_text(text):
546 import re
547 text = re.sub(r"(<img>)+", lambda m: f"[<img>]{{{len(m.group(0)) // 5}}}", text)
548 text = re.sub(r"(<pad>)+", lambda m: f"[<pad>]{{{len(m.group(0)) // 5}}}", text)
549 return text
550
551 def encode_sequence(
552 self,
553 template: str,
554 token_source: Dict[str, List],
555 total_length=None,
556 add_timestep_token=False,
557 add_guidance_token=False,
558 last_key_only_prefix=False,
559 add_eos=True,
560 use_front_boi_token=True,
561 add_pad=True,
562 add_bos=True,
563 drop_last: Union[str, bool] = 'auto',
564 add_image_shape_token=False,
565 ):
566 """
567 Encode a sequence based on the template (e.g., `text-image` for t2i, `text-image-image` for instruction tuning)
568 and token source.
569
570 Parameters
571 ----------
572 template: str
573 Template of the sequence. E.g., "text-gen_image" means the sequence is composed of text and an image.
574 "text-text-gen_image" means the sequence is composed of two sections of text and an image.
575 token_source: Dict[str, List]
576 Token source for each key in the template, in order.
577 - text: List[Dict].
578 - gen_image: List[Dict].
579 - joint_image: List[Dict].
580 total_length: int
581 Total length of the encoded sequence, include padding tokens.
582 add_timestep_token: bool
583 Whether to add timestep token before the image tokens.
584 (Right after the <img_ratio_*><img_size_*> tokens)
585 add_guidance_token: bool
586 Whether to add guidance token before the image tokens.
587 last_key_only_prefix: bool
588 Whether to only use the modal prefix in the last key.
589 add_eos: bool or 'auto'
590 Whether to add eos token at the end of the sequence. If True, always add eos token. If 'auto',
591 add eos token only when the total_length is not reached and the last token is not <eos>.
592 use_front_boi_token: bool:
593 Whether to put the <boi> token at the front of iw, ih and timestep tokens.
594 add_pad: bool or 'auto'
595 Whether to add padding tokens to the sequence. If True and total_length is not reached, add padding tokens.
596 add_bos: bool
597 Whether to add bos token at the beginning of the sequence.
598 drop_last: bool or 'auto'
599 - If auto, drop last tokens exceeding the total_length if the total_length is provided. If cut point is
600 in the middle of the image tokens, an error will raised.
601 - If True, drop last tokens exceeding the total_length. If cut point is in the middle of the image tokens,
602 all the successive image tokens will be dropped.
603 - If False, keep the last tokens exceeding the total_length, even if the total_length is reached.
604 add_image_shape_token: bool
605 Whether to add image shape token before the image tokens. (Right before the <timestep> token)
606
607 Returns
608 -------
609 token_seq: list
610 Encoded token sequence.
611 extra_token_pos: dict
612 Positions of extra tokens.
613 """
614 if last_key_only_prefix:
615 assert add_eos is not True, "add_eos should not be True when last_key_only_prefix is True."
616 if drop_last is True and total_length is None:
617 raise ValueError("total_length should be provided when drop_last is True.")
618
619 keys = template.split('-')
620 modal_length = len(keys)
621 index_indicator = {k: 0 for k in token_source}
622 for k, v in token_source.items():
623 assert isinstance(v, (list, tuple)), (
624 f"Value of `{k}` in the token source should be a list or tuple, but got {type(v)}."
625 )
626 self._check_key_number_matched(keys, token_source)
627
628 token_seq = []
629 token_count = 0
630 extra_token_pos = defaultdict(list)
631 if add_bos:
632 token_seq.append(self.bos_token_id)
633 token_count += 1
634 # If drop_last is True, we check the token_count on the fly and exit the loop if the total_length is reached.
635 # This check is only applied to the block tokens. Block tokens mean the tokens that are unsplittable, like
636 # image tokens. Text tokens are splittable, so we don't need to check the token_count for text.
637 # If the loop is broken by drop_last, we don't add the eos token at the end because the sequence is not
638 # complete.
639 drop_last_break = False
640 for i, key in enumerate(keys):
641 source = token_source[key][index_indicator[key]]
642 if key == "text":
643 token_seq.extend(source) # text token sequence
644 extra_token_pos["<text>_start"].append(token_count)
645 token_count += len(source)
646 extra_token_pos["<text>_end"].append(token_count - 1)
647
648 elif key == "gen_image":
649 if isinstance(source, int):
650 source = {'length': source}
651 extra_count = 2 + (
652 1 if source.get('timestep', add_timestep_token) else 0) + (
653 1 if source.get('guidance', add_guidance_token) else 0) + (
654 2 if source.get('image_shape', add_image_shape_token) else 0
655 )
656 if drop_last is True and token_count + extra_count + source['length'] > total_length:
657 drop_last_break = True
658 break
659 if source.get('front_boi', use_front_boi_token):
660 token_seq.append(self.boi_token_id)
661 extra_token_pos["boi"].append(token_count)
662 token_count += 1
663 token_count = self._add_image_meta_info_token(
664 token_seq=token_seq,
665 token_count=token_count,
666 extra_token_pos=extra_token_pos,
667 add_timestep_token=source.get('timestep', add_timestep_token),
668 add_guidance_token=source.get('guidance', add_guidance_token),
669 add_image_shape_token=source.get('image_shape', add_image_shape_token),
670 base_size=source.get('base_size'),
671 ratio_idx=source.get('ratio_idx'),
672 image_type=key,
673 )
674 if not source.get('front_boi', use_front_boi_token):
675 token_seq.append(self.boi_token_id)
676 extra_token_pos["boi"].append(token_count)
677 token_count += 1
678 if last_key_only_prefix and i == modal_length - 1:
679 pass # for AR inference
680 else:
681 token_seq.extend(
682 [self.img_token_id] * source['length'] + # token number
683 [self.eoi_token_id]
684 )
685 extra_token_pos["<img>_start"].append(token_count)
686 extra_token_pos["<all_img>_start"].append(token_count)
687 token_count += source['length']
688 extra_token_pos["<img>_end"].append(token_count - 1)
689 extra_token_pos["<all_img>_end"].append(token_count - 1)
690 extra_token_pos["eoi"].append(token_count)
691 token_count += 1 # <eoi>
692
693 elif key == "joint_image":
694 assert isinstance(source['length'], list) and len(
695 source['length']) == 2, "joint_image length should be a list of two integers"
696 extra_count = 2 + 1 + ( # boi, eoi, joint_img_sep
697 1 if source.get('timestep', add_timestep_token) else 0) + (
698 2 if source.get('image_shape', add_image_shape_token) else 0
699 )
700 if drop_last is True and token_count + extra_count + sum(source['length']) > total_length:
701 drop_last_break = True
702 break
703 if source.get('front_boi', use_front_boi_token):
704 token_seq.append(self.boi_token_id) # Use patched boi for Janus, otherwise useing default <boi>
705 extra_token_pos["boi"].append(token_count)
706 token_count += 1
707 token_count = self._add_image_meta_info_token(
708 token_seq=token_seq,
709 token_count=token_count,
710 extra_token_pos=extra_token_pos,
711 add_timestep_token=source.get('timestep', add_timestep_token),
712 add_image_shape_token=source.get('image_shape', add_image_shape_token),
713 base_size=source.get('base_size'),
714 ratio_idx=source.get('ratio_idx'),
715 image_type=key,
716 )
717 if not source.get('front_boi', use_front_boi_token):
718 token_seq.append(self.boi_token_id)
719 extra_token_pos["boi"].append(token_count)
720 token_count += 1
721 if last_key_only_prefix and i == modal_length - 1:
722 pass # for AR inference
723 else:
724 token_seq.extend(
725 [self.img_token_id] * source['length'][0]
726 )
727 extra_token_pos["<vae_img>_start"].append(token_count)
728 extra_token_pos["<joint_img>_start"].append(token_count)
729 extra_token_pos["<all_img>_start"].append(token_count)
730 token_count += source['length'][0]
731 extra_token_pos["<vae_img>_end"].append(token_count - 1)
732 extra_token_pos["<all_img>_end"].append(token_count - 1)
733
734 token_seq.extend(
735 [self.special_token_map["<joint_img_sep>"]]
736 )
737 extra_token_pos["joint_img_sep"].append(token_count)
738 token_count += 1
739
740 token_seq.extend(
741 [self.img_token_id] * source['length'][1]
742 )
743 extra_token_pos["<vit_img>_start"].append(token_count)
744 extra_token_pos["<all_img>_start"].append(token_count)
745 token_count += source['length'][1]
746 extra_token_pos["<vit_img>_end"].append(token_count - 1)
747 extra_token_pos["<joint_img>_end"].append(token_count - 1)
748 extra_token_pos["<all_img>_end"].append(token_count - 1)
749
750 token_seq.extend(
751 [self.eoi_token_id]
752 )
753 extra_token_pos["eoi"].append(token_count)
754 token_count += 1 # <eoi>
755
756 else:
757 raise ValueError(f"Not supported key: {key}")
758 index_indicator[key] += 1
759
760 if add_eos is True and not drop_last_break:
761 # Typically used for t2i task.
762 token_seq.append(self.eos_token_id)
763 extra_token_pos["eos"].append(token_count)
764 token_count += 1
765 elif add_eos == 'auto' and not drop_last_break:
766 # Typically used for lm and mmu task.
767 if token_seq[-1] != self.eos_token_id and (total_length is None or token_count < total_length):
768 token_seq.append(self.eos_token_id)
769 extra_token_pos["eos"].append(token_count)
770 token_count += 1
771
772 if total_length:
773 # Check token count and clip sequence if necessary
774 if token_count > total_length and drop_last:
775 # Assert clip position is not in the middle of the block-wise tokens (gen_image, joint_image)
776 for start_key, end_key in [
777 ("<img>_start", "<img>_end"), ("<joint_img>_start", "<joint_img>_end"),
778 ("<vae_img>_start", "<vae_img>_end"), ("<vit_img>_start", "<vit_img>_end"),
779 ]:
780 if start_key in extra_token_pos and end_key in extra_token_pos:
781 assert all(
782 (start > total_length or end + 1 < total_length)
783 for start, end in zip(extra_token_pos[start_key], extra_token_pos[end_key])
784 ), ("Clip position should not be in the middle of the image tokens.\n"
785 f"Below is the text:\n{self._shorten_text(self.tokenizer.decode(token_seq))}")
786 token_seq = token_seq[:total_length]
787
788 # Pad the sequence if necessary
789 pad_num = max(0, total_length - len(token_seq))
790 if add_pad and pad_num:
791 token_seq.extend([self.pad_token_id] * pad_num)
792 extra_token_pos["first_pad"].append(token_count)
793
794 return token_seq, extra_token_pos
795
796 def batch_gen_infer(
797 self,
798 infer_fn,
799 prompt_list: list,
800 negative_prompt_list: list = None,
801 infer_fn_kwargs_list: List[Dict[str, int]] = None,
802 do_classifier_free_guidance=False,
803 condition_repeat_times: int = 1,
804 uncondition_repeat_times: int = 1,
805 ):
806 """
807 Batch inference for the AR-like model training of the text-to-image/instruction tuning tasks.
808
809 Parameters
810 ----------
811 infer_fn: callable
812 Inference function to encode the prompt.
813 prompt_list: list
814 List of prompts. Each element can be a single prompt or a list of prompts passed to the infer_fn.
815 negative_prompt_list: list
816 List of negative prompts. Only used when do_classifier_free_guidance is True. If None, will use <cfg>
817 token sequence as negative prompt.
818 infer_fn_kwargs_list: List[Dict[str, int]]
819 List of keyword arguments for the infer_fn.
820 do_classifier_free_guidance: bool
821 Whether to do classifier-free guidance.
822 condition_repeat_times: int
823 Support multi-condition.
824 uncondition_repeat_times: int
825 Support multi-uncondition.
826 """
827 if infer_fn_kwargs_list is None:
828 infer_fn_kwargs_list = [{} for _ in prompt_list]
829
830 # [n_output, bsz]
831 cond_results_list = None
832 uncond_results_list = None
833 output_type_list = []
834
835 for prompt_idx, (prompt, infer_fn_kwargs) in enumerate(zip(prompt_list, infer_fn_kwargs_list)):
836 if not isinstance(prompt, (list, tuple)):
837 prompt = [prompt]
838 cond_kwargs = {"uncond_p": 0.0} if do_classifier_free_guidance else {}
839 results = infer_fn(
840 *prompt,
841 **infer_fn_kwargs,
842 **cond_kwargs,
843 )
844 output_type_list.append((type(results), len(results) if isinstance(results, (list, tuple)) else 1))
845 if isinstance(results, dict):
846 raise ValueError("Make batch on dict is not supported. Please return list or tuple for infer_fn.")
847 if not isinstance(results, (list, tuple)):
848 results = (results,)
849 if cond_results_list is None:
850 cond_results_list = [[] for _ in results]
851 uncond_results_list = [[] for _ in results]
852 for i, result in enumerate(results):
853 cond_results_list[i].append(result)
854
855 if do_classifier_free_guidance:
856 if negative_prompt_list is None:
857 uncond_kwargs = {"uncond_p": 1.0}
858 uncond_results = infer_fn(
859 *prompt,
860 **infer_fn_kwargs,
861 **uncond_kwargs,
862 )
863 else:
864 negative_prompt = negative_prompt_list[prompt_idx]
865 if not isinstance(negative_prompt, (list, tuple)):
866 negative_prompt = [negative_prompt]
867 uncond_results = infer_fn(
868 *negative_prompt,
869 **infer_fn_kwargs,
870 )
871 if isinstance(uncond_results, TokenizerEncodeOutput):
872 uncond_results_list.append(uncond_results)
873 else:
874 for i, result in enumerate(uncond_results):
875 uncond_results_list[i].append(result)
876
877 assert all(output_type_list[0] == n for n in output_type_list), \
878 f"Number of outputs should be equal for all samples, but got {output_type_list}."
879 output_type, output_num = output_type_list[0]
880
881 def make_batch(batch_cond_item, batch_uncond_item):
882 # Process each output item to make batch
883 first = batch_cond_item[0] # The first element in the batch
884 if isinstance(first, torch.Tensor):
885 stacked_item = torch.stack(self.pad(
886 batch_cond_item * condition_repeat_times + batch_uncond_item * uncondition_repeat_times,
887 ))
888
889 elif first is None:
890 assert all(item is None for item in batch_cond_item + batch_uncond_item), \
891 (f"The first cond item is None, but some items are not None:\n\n"
892 f"condition: {batch_cond_item}\n\n"
893 f"uncondition: {batch_uncond_item}")
894 stacked_item = None
895
896 elif isinstance(first, (list, tuple)):
897 # If the output item is a list or tuple, we treat it as a whole, and won't make nested batch any more.
898 stacked_item = batch_cond_item * condition_repeat_times + batch_uncond_item * uncondition_repeat_times
899
900 elif isinstance(first, TokenizerEncodeOutput):
901 stacked_item = {}
902 # Traverse not-None attributes
903 for key in list(first.keys()):
904 merged_list = [cond_item[key] for cond_item in batch_cond_item] * condition_repeat_times + \
905 [uncond_item[key] for uncond_item in batch_uncond_item] * uncondition_repeat_times
906 if isinstance(first[key], torch.Tensor):
907 if 'mask' in key:
908 pad_val = 0.0
909 elif key == 'tokens':
910 pad_val = self.special_token_map["<pad>"]
911 else:
912 pad_val = False # Should not pad for other tensors
913 stacked_item[key] = torch.stack(self.pad(merged_list, pad_val=pad_val), dim=0)
914 elif isinstance(first[key], list):
915 stacked_item[key] = merged_list
916 elif first[key] is None:
917 pass
918 else:
919 raise ValueError(f"Unsupported type of {key}: {type(first[key])}.")
920 stacked_item = TokenizerEncodeOutput(stacked_item)
921
922 else:
923 raise TypeError(f"Making batch on type {type(first)} is not supported.")
924
925 return stacked_item
926
927 stacked_outputs = []
928 for cond_results, uncond_results in zip(cond_results_list, uncond_results_list):
929 stacked_outputs.append(make_batch(cond_results, uncond_results))
930
931 if output_type == list:
932 return stacked_outputs
933 elif output_type == tuple:
934 return tuple(stacked_outputs)
935 elif output_num == 1:
936 return stacked_outputs[0]
937 else:
938 raise ValueError(f"Unsupported output type: {output_type}.")
939
940 @staticmethod
941 def parse_extra_token_pos(extra_token_pos, prefix, tokens, rng=None):
942 if rng is None:
943 rng = slice(None)
944 image_slices = [
945 slice(start, end + 1)
946 for start, end in zip(extra_token_pos[f'<{prefix}>_start'][rng], extra_token_pos[f'<{prefix}>_end'][rng])
947 ] if f'<{prefix}>_start' in extra_token_pos and f'<{prefix}>_end' in extra_token_pos else []
948 if image_slices:
949 image_mask = torch.zeros_like(tokens, dtype=torch.bool)
950 for image_slice in image_slices:
951 image_mask[image_slice] = True
952 else:
953 image_mask = None
954 return image_slices, image_mask
955
956 def encode_general(
957 self,
958 sections: Optional[List[Dict[str, Any]]] = None,
959 max_token_length: Optional[int] = None,
960 add_eos='auto',
961 use_text_mask=True,
962 add_pad='auto',
963 add_bos=True,
964 drop_last='auto',
965 ):
966 """
967 General encode function to encode a sequence with multiple sections of text and images.
968 Each section is a dict with a `type` key and other keys depending on the type.
969 Supported section types:
970 - text: dict with keys:
971 - text: str or List[int], text to be encoded. Either `text` or `tokens` should be provided.
972 - tokens: List[int], pre-encoded text tokens. Either `text` or `tokens` should be provided.
973 - uncond_enabled: bool, whether to enable uncondition for this text section.
974 - uncond_p: float, probability to drop the text section for uncondition.
975 - max_length: int, maximum length of the text section.
976 - ignore: bool, whether to ignore this text section in the text mask.
977 - start_offset: int, start offset of the text mask.
978 - end_offset: int, end offset of the text mask.
979 - gen_image: dict with keys:
980 - token_length: int, number of image tokens.
981 - add_timestep_token: bool, whether to add timestep token before the image tokens.
982 - add_guidance_token: bool, whether to add guidance token before the image tokens.
983 - use_front_boi_token: bool, whether to put the <boi> token at the front of size, ratio and timestep tokens.
984 - add_image_shape_token: bool, whether to add image shape token before the image tokens.
985 - base_size: int, base size of the image.
986 - ratio_idx: int, ratio index of the image.
987 - joint_image: dict with keys:
988 - token_length: List[int], number of image tokens for the two images.
989 - add_timestep_token: bool, whether to add timestep token before the image tokens.
990 - use_front_boi_token: bool, whether to put the <boi> token at the front of size, ratio and timestep tokens.
991 - add_image_shape_token: bool, whether to add image shape token before the image tokens.
992 - base_size: int, base size of the image.
993 - ratio_idx: int, ratio index of the image.
994
995 Parameters
996 ----------
997 sections: List[Dict[str, Any]]
998 List of sections to be encoded.
999 max_token_length: int
1000 Maximum length of the encoded token sequence.
1001 add_eos: bool or 'auto'
1002 Whether to add eos token at the end of the sequence. If True, always add eos
1003 token. If 'auto', add eos token only when the total_length is not reached and the last token is not <eos>.
1004 use_text_mask: bool
1005 Whether to generate text mask.
1006 add_pad: bool or 'auto'
1007 Whether to add padding tokens to the sequence. If True and total_length is not reached,
1008 add padding tokens.
1009 add_bos: bool
1010 Whether to add bos token at the beginning of the sequence.
1011 drop_last: bool or 'auto'
1012 - If auto, drop last tokens exceeding the total_length if the total_length is provided.
1013 If cut point is in the middle of the image tokens, an error will raised.
1014 - If True, drop last tokens exceeding the total_length. If cut point is in the
1015 middle of the image tokens, all the successive image tokens will be dropped.
1016 - If False, keep the last tokens exceeding the total_length, even if the total_length
1017 is reached.
1018
1019 Returns
1020 -------
1021 TokenizerEncodeOutput
1022 Encoded token sequence and extra information.
1023 """
1024 if sections is None:
1025 raise ValueError("sections must be provided.")
1026 template = '-'.join([section['type'] for section in sections])
1027
1028 sections = deepcopy(sections)
1029 token_source = defaultdict(list)
1030 text_mask_specs = []
1031 for section in sections:
1032 if section['type'] == 'text':
1033 text = self.encode_text(
1034 section['text'] if 'text' in section else section['tokens'],
1035 uncond_enabled=section.get('uncond_enabled'),
1036 uncond_p=section.get('uncond_p'),
1037 max_length=section.get('max_length'),
1038 )
1039 token_source['text'].append(text)
1040 text_mask_specs.append(dict(
1041 ignore=section.get('ignore', False),
1042 start_offset=section.get('start_offset', 0),
1043 end_offset=section.get('end_offset', 0),
1044 ))
1045 elif section['type'] == 'gen_image':
1046 token_source['gen_image'].append(dict(
1047 length=section['token_length'],
1048 timestep=section.get('add_timestep_token', False),
1049 guidance=section.get('add_guidance_token', False),
1050 front_boi=section.get('use_front_boi_token', False),
1051 image_shape=section.get('add_image_shape_token', False),
1052 base_size=section.get('base_size'),
1053 ratio_idx=section.get('ratio_idx'),
1054 ))
1055 elif section['type'] == 'joint_image':
1056 token_source['joint_image'].append(dict(
1057 length=section['token_length'],
1058 timestep=section.get('add_timestep_token', False),
1059 front_boi=section.get('use_front_boi_token', False),
1060 image_shape=section.get('add_image_shape_token', False),
1061 base_size=section.get('base_size'),
1062 ratio_idx=section.get('ratio_idx'),
1063 ))
1064 else:
1065 raise ValueError(f"Invalid section type: {section['type']}")
1066
1067 # Combine text and image tokens
1068 full_token_seq, extra_token_pos = self.encode_sequence(
1069 template=template,
1070 token_source=dict(token_source),
1071 total_length=max_token_length,
1072 add_eos=add_eos,
1073 add_pad=add_pad,
1074 add_bos=add_bos,
1075 drop_last=drop_last,
1076 )
1077 full_seq_token_tensor = torch.tensor(full_token_seq, dtype=torch.long)
1078
1079 timestep_scatter_index = torch.tensor(extra_token_pos['timestep'], dtype=torch.long) \
1080 if 'timestep' in extra_token_pos else None
1081 guidance_scatter_index = torch.tensor(extra_token_pos['guidance'], dtype=torch.long) \
1082 if 'guidance' in extra_token_pos else None
1083 cond_timestep_scatter_index = torch.tensor(extra_token_pos['cond_timestep'], dtype=torch.long) \
1084 if 'cond_timestep' in extra_token_pos else None
1085 gen_timestep_scatter_index = torch.tensor(extra_token_pos['gen_timestep'], dtype=torch.long) \
1086 if 'gen_timestep' in extra_token_pos else None
1087
1088 # Gen image mask
1089 gen_image_slices, gen_image_mask = self.parse_extra_token_pos(extra_token_pos, 'img', full_seq_token_tensor)
1090 # Joint image
1091 joint_image_slices, _ = self.parse_extra_token_pos(extra_token_pos, 'joint_img', full_seq_token_tensor)
1092 # Conditional vae image
1093 cond_vae_image_slices, cond_vae_image_mask = self.parse_extra_token_pos(
1094 extra_token_pos, 'vae_img', full_seq_token_tensor)
1095 # Conditional vit image
1096 cond_vit_image_slices, cond_vit_image_mask = self.parse_extra_token_pos(
1097 extra_token_pos, 'vit_img', full_seq_token_tensor)
1098 # All image slices (gen_image, joint_image)
1099 all_image_slices = [
1100 slice(start, end + 1)
1101 for start, end in zip(extra_token_pos['<all_img>_start'], extra_token_pos['<all_img>_end'])
1102 ] if '<all_img>_start' in extra_token_pos and '<all_img>_end' in extra_token_pos else []
1103
1104 # Text mask
1105 text_slices = [
1106 slice(start, end + 1)
1107 for start, end in zip(extra_token_pos['<text>_start'], extra_token_pos['<text>_end'])
1108 ] if '<text>_start' in extra_token_pos and '<text>_end' in extra_token_pos else []
1109 assert len(text_slices) <= len(text_mask_specs), \
1110 (f"Number of text slices ({len(text_slices)}) should be less than or equal to "
1111 f"number of text mask specs ({len(text_mask_specs)})")
1112 if use_text_mask:
1113 text_mask = torch.zeros_like(full_seq_token_tensor, dtype=torch.float32)
1114 for text_slice, mask_spec in zip(text_slices, text_mask_specs):
1115 if not mask_spec['ignore']:
1116 real_slice = slice(
1117 text_slice.start + mask_spec['start_offset'],
1118 text_slice.stop + mask_spec['end_offset']
1119 )
1120 text_mask[real_slice] = 1.0
1121 else:
1122 text_mask = None
1123
1124 # real_pos is the first position of the <pad> token
1125 real_pos = torch.tensor(extra_token_pos.get('first_pad', [full_seq_token_tensor.shape[0]]), dtype=torch.long)
1126
1127 return TokenizerEncodeOutput(
1128 tokens=full_seq_token_tensor,
1129 timestep_scatter_index=timestep_scatter_index,
1130 guidance_scatter_index=guidance_scatter_index,
1131 text_slices=text_slices,
1132 gen_image_slices=gen_image_slices,
1133 joint_image_slices=joint_image_slices,
1134 cond_vae_image_slices=cond_vae_image_slices,
1135 cond_vit_image_slices=cond_vit_image_slices,
1136 text_mask=text_mask,
1137 gen_image_mask=gen_image_mask,
1138 cond_vae_image_mask=cond_vae_image_mask,
1139 cond_vit_image_mask=cond_vit_image_mask,
1140 real_pos=real_pos,
1141 all_image_slices=all_image_slices,
1142 cond_timestep_scatter_index=cond_timestep_scatter_index,
1143 gen_timestep_scatter_index=gen_timestep_scatter_index,
1144 )
1145
1146 def get_cot_sections(self, cot_text, uncond_kwargs, cot_max_length=None, drop_think=False):
1147 if not cot_text: # None or empty
1148 return []
1149 if '<think>' in cot_text and '</think>' in cot_text:
1150 before_think_sec = cot_text.split('<think>')[0]
1151 after_think_sec = cot_text.split('</think>')[1]
1152 think_sec = cot_text.split('<think>')[1].split('</think>')[0]
1153 return self.get_cot_sections(before_think_sec, uncond_kwargs, drop_think=drop_think) + \
1154 ([
1155 dict(type="text", text="<think>"),
1156 dict(type="text", text=think_sec, max_length=cot_max_length, **uncond_kwargs),
1157 dict(type="text", text="</think>")
1158 ] if not drop_think else []) + \
1159 self.get_cot_sections(after_think_sec, uncond_kwargs, drop_think=drop_think)
1160
1161 if '<recaption>' in cot_text and '</recaption>' in cot_text:
1162 before_recaption_sec = cot_text.split('<recaption>')[0]
1163 after_recaption_sec = cot_text.split('</recaption>')[1]
1164 recaption_sec = cot_text.split('<recaption>')[1].split('</recaption>')[0]
1165 return self.get_cot_sections(before_recaption_sec, uncond_kwargs, drop_think=drop_think) + \
1166 [
1167 dict(type="text", text="<recaption>"),
1168 dict(type="text", text=recaption_sec, max_length=cot_max_length, **uncond_kwargs),
1169 dict(type="text", text="</recaption>")
1170 ] + \
1171 self.get_cot_sections(after_recaption_sec, uncond_kwargs, drop_think=drop_think)
1172
1173 return [
1174 dict(type="text", text=cot_text, **uncond_kwargs),
1175 ]
1176
1177 def apply_general_template(
1178 self,
1179 message_list,
1180 max_length=None,
1181 add_assistant_prefix=False,
1182 answer="auto",
1183 bot_task="auto",
1184 sequence_template="instruct",
1185 uncond_p=0.0,
1186 cfg_factor=1,
1187 batchify=False,
1188 image_base_size=1024,
1189 drop_think=False,
1190 ):
1191 # If cfg_factor > 1, we need to repeat the unconditioned part
1192 if batchify:
1193 assert isinstance(message_list[0], list), \
1194 f"When batchify is True, message_list should be a list of list, but got [{type(message_list[0])}, ...]."
1195 return self.batch_gen_infer(
1196 infer_fn=self.apply_general_template,
1197 prompt_list=[[]],
1198 infer_fn_kwargs_list=[dict(
1199 message_list=message_list_i,
1200 max_length=max_length,
1201 add_assistant_prefix=add_assistant_prefix,
1202 answer=answer,
1203 bot_task=bot_task,
1204 sequence_template=sequence_template,
1205 image_base_size=image_base_size,
1206 drop_think=drop_think,
1207 ) for message_list_i in message_list],
1208 do_classifier_free_guidance=cfg_factor > 1,
1209 condition_repeat_times=1,
1210 uncondition_repeat_times=cfg_factor - 1,
1211 )
1212
1213 conv = Conversation()
1214 uncond_kwargs = dict(uncond_enabled=uncond_p == 1.0, uncond_p=uncond_p)
1215
1216 def process_successive_message(_message_list, _cur_message_idx, role, prefix, suffix,
1217 answer_prefix="", answer_suffix=""):
1218 _sub_sections = []
1219 while _cur_message_idx < len(message_list) and _message_list[_cur_message_idx]['role'] == role:
1220 message = _message_list[_cur_message_idx]
1221 if message['type'] == 'text':
1222 text = message['content']
1223 if role == "system":
1224 _sub_sections.append(dict(type="text", text=text))
1225 elif role == "assistant":
1226 if ("<recaption>" in text and "</recaption>" in text) or (
1227 "<think>" in text and "</think>" in text):
1228 _sub_sections.extend(self.get_cot_sections(text, uncond_kwargs, drop_think=drop_think))
1229 else:
1230 _sub_sections.append(dict(type="text", text=text, **uncond_kwargs))
1231 else:
1232 _sub_sections.append(dict(
1233 type="text", text=f"{answer_prefix}{text}{answer_suffix}", **uncond_kwargs))
1234 elif message['type'] == 'gen_image':
1235 info = message['content']
1236 assert isinstance(info, ImageInfo), f"Expected ImageInfo, but got {type(info)}"
1237 if role == "assistant":
1238 _sub_sections.append(dict(type="text", text=answer_prefix))
1239 _sub_sections.append(dict(type=message['type'], **info.meta_info))
1240 if role == "assistant":
1241 _sub_sections.append(dict(type="text", text=answer_suffix))
1242 elif message['type'] == 'joint_image':
1243 info = message['content']
1244 assert isinstance(info, JointImageInfo), f"Expected JointImageInfo, but got {type(info)}"
1245 _sub_sections.append(dict(type=message['type'], **info.meta_info))
1246 else:
1247 raise ValueError(f"Unknown message type: {message['type']}")
1248 _cur_message_idx += 1
1249 if len(_sub_sections) > 0:
1250 # Add role prefix and suffix
1251 _sub_sections.insert(0, dict(type='text', text=prefix))
1252 _sub_sections.append(dict(type='text', text=suffix))
1253 return _sub_sections, _cur_message_idx
1254
1255 # Define assistant prefix and suffix
1256 if (answer == "auto" and sequence_template == "instruct") or answer is True:
1257 answer_prefix, answer_suffix = "<answer>", "</answer>"
1258 else:
1259 answer_prefix, answer_suffix = "", ""
1260 if sequence_template == "pretrain":
1261 system_suffix = ""
1262 user_prefix = ""
1263 user_suffix = ""
1264 bot_prefix = ""
1265 bot_suffix = ""
1266 else:
1267 system_suffix = f"{conv.sep}"
1268 user_prefix = f"{conv.roles[0]}: "
1269 user_suffix = f"{conv.sep}"
1270 bot_prefix = f"{conv.roles[1]}: "
1271 bot_suffix = f"{conv.sep}"
1272
1273 # Process successive user and assistant messages
1274 sections = []
1275 cur_message_idx = 0
1276 final_role = None
1277 while cur_message_idx < len(message_list):
1278 # Process successive system messages
1279 sub_sections, cur_message_idx = process_successive_message(
1280 message_list, cur_message_idx, role="system", prefix="", suffix=system_suffix)
1281 # Add to the template and sections
1282 sections.extend(sub_sections)
1283 if len(sub_sections) > 0:
1284 final_role = "system"
1285
1286 # Process successive user messages
1287 sub_sections, cur_message_idx = process_successive_message(
1288 message_list, cur_message_idx, role="user", prefix=user_prefix, suffix=user_suffix)
1289 # Add to the template and sections
1290 sections.extend(sub_sections)
1291 if len(sub_sections) > 0:
1292 final_role = "user"
1293
1294 # Process successive assistant messages
1295 sub_sections, cur_message_idx = process_successive_message(
1296 message_list, cur_message_idx, role="assistant", prefix=bot_prefix, suffix=bot_suffix,
1297 answer_prefix=answer_prefix, answer_suffix=answer_suffix,
1298 )
1299 # Add to the template and sections
1300 sections.extend(sub_sections)
1301 if len(sub_sections) > 0:
1302 final_role = "assistant"
1303
1304 if add_assistant_prefix:
1305 if final_role == "assistant":
1306 # Avoid adding prefix twice
1307 _bot_prefix = ""
1308 # Remove the final bot_suffix
1309 if len(sections) > 0 and sections[-1]['type'] == 'text' and sections[-1]['text'] == bot_suffix:
1310 sections = sections[:-1]
1311 else:
1312 _bot_prefix = bot_prefix
1313 # We can add special tokens for the bot lastest message according to different tasks
1314 bot_response_prefix = dict(
1315 auto=_bot_prefix,
1316 image="",
1317 think=f"{_bot_prefix}<think>",
1318 recaption=f"{_bot_prefix}<recaption>",
1319 img_ratio=f"{_bot_prefix}{answer_prefix}<boi><img_size_{image_base_size}>",
1320 )[bot_task]
1321 sections.append(dict(type='text', text=bot_response_prefix))
1322
1323 output = self.encode_general(
1324 sections=sections,
1325 use_text_mask=False,
1326 add_eos=False,
1327 add_pad=False,
1328 )
1329
1330 if max_length is not None:
1331 if output.tokens.shape[-1] > max_length:
1332 raise ValueError(
1333 f"Encoded token length {output.tokens.shape[-1]} exceeds max_length {max_length}.\n"
1334 f"Please set a larger max_length or check the input messages:\n{message_list}"
1335 )
1336
1337 return output, sections
1338
1339 def apply_chat_template(
1340 self,
1341 batch_prompt: Optional[List[str]] = None,
1342 batch_message_list: Optional[List[List[Dict[str, Any]]]] = None,
1343 mode: str = "gen_text",
1344 batch_gen_image_info: Optional[List[ImageInfo]] = None,
1345 batch_cond_image_info: Optional[Union[List[JointImageInfo], List[List[JointImageInfo]]]] = None,
1346 batch_system_prompt: Optional[List[str]] = None,
1347 batch_cot_text: Optional[List[str]] = None,
1348 max_length: Optional[int] = None,
1349 bot_task: str = "auto", # auto/image/think/recaption/img_ratio
1350 image_base_size: int = 1024,
1351 sequence_template: str = "pretrain",
1352 cfg_factor: int = 1,
1353 add_assistant_prefix: Optional[bool] = None,
1354 drop_think: bool = False,
1355 ) -> Dict[str, Any]:
1356 assert bot_task in ["image", "auto", "think", "recaption", "img_ratio"], \
1357 f"bot_task should be one of ['image', 'auto', 'think', 'recaption', 'img_ratio'], but got {bot_task}."
1358
1359 if batch_message_list is None:
1360 # Simple text-to-image or text-cot-to-image task
1361 batch_size = len(batch_prompt)
1362
1363 # Batchify inputs
1364 if not isinstance(batch_system_prompt, list):
1365 batch_system_prompt = [batch_system_prompt] * batch_size
1366 if not isinstance(batch_gen_image_info, list):
1367 batch_gen_image_info = [batch_gen_image_info] * batch_size
1368 if batch_cot_text is not None:
1369 assert len(batch_cot_text) == batch_size, \
1370 (f"batch_cot_text should have the same length as batch_size ({batch_size}), "
1371 f"but got {len(batch_cot_text)}.")
1372 else:
1373 batch_cot_text = [None] * batch_size
1374 if batch_cond_image_info is not None:
1375 assert len(batch_cond_image_info) == batch_size, \
1376 (f"batch_cond_image_info should have the same length as batch_size ({batch_size}), "
1377 f"but got {len(batch_cond_image_info)}.")
1378 batch_cond_image_info = [
1379 cond_image_info if isinstance(cond_image_info, list) else [cond_image_info]
1380 for cond_image_info in batch_cond_image_info
1381 ]
1382 else:
1383 batch_cond_image_info = [[] for _ in range(batch_size)]
1384
1385 # Convert single round materials into standard message list
1386 batch_message_list = []
1387 for prompt, system_prompt, cot_text, gen_image_info, cond_image_info_list in zip(
1388 batch_prompt, batch_system_prompt, batch_cot_text, batch_gen_image_info,
1389 batch_cond_image_info,
1390 ):
1391 message_list = []
1392 # 1. system prompt section
1393 if system_prompt:
1394 message_list.append(dict(
1395 role="system", type="text", content=system_prompt, context_type="str"))
1396 # 2. user inputs sections
1397 # 2.1 image inputs
1398 if len(cond_image_info_list) > 0:
1399 message_list.extend([
1400 dict(role="user", type="joint_image", content=cond_image_info, context_type="image_info")
1401 for cond_image_info in cond_image_info_list
1402 ])
1403 # 2.2 text inputs
1404 message_list.append(dict(
1405 role="user", type="text", content=prompt, context_type="str"))
1406 # 3. assistant answer sections
1407 if cot_text is not None:
1408 message_list.append(dict(role="assistant", type="text", content=cot_text, context_type="str"))
1409 if mode == "gen_image":
1410 message_list.append(dict(
1411 role="assistant", type="gen_image", content=gen_image_info, context_type="image_info"))
1412 # ---
1413 batch_message_list.append(message_list)
1414
1415 output, sections = self.apply_general_template(
1416 message_list=batch_message_list,
1417 max_length=max_length,
1418 add_assistant_prefix=default(add_assistant_prefix, mode != "gen_image"),
1419 bot_task=bot_task,
1420 sequence_template=sequence_template,
1421 cfg_factor=cfg_factor,
1422 batchify=True,
1423 image_base_size=image_base_size,
1424 drop_think=drop_think,
1425 )
1426 return dict(output=output, sections=sections)
1427