tokenizer_wrapper.py
| 1 | # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License"); |
| 2 | # you may not use this file except in compliance with the License. |
| 3 | # You may obtain a copy of the License at |
| 4 | # |
| 5 | # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE |
| 6 | # |
| 7 | # Unless required by applicable law or agreed to in writing, software |
| 8 | # distributed under the License is distributed on an "AS IS" BASIS, |
| 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 10 | # See the License for the specific language governing permissions and |
| 11 | # limitations under the License. |
| 12 | # ============================================================================== |
| 13 | |
| 14 | import warnings |
| 15 | import random |
| 16 | from typing import List, Optional, Union, Dict, Any |
| 17 | from collections import defaultdict |
| 18 | from copy import deepcopy |
| 19 | |
| 20 | import numpy as np |
| 21 | import torch |
| 22 | import torch.nn.functional as F |
| 23 | from transformers import AutoTokenizer |
| 24 | from diffusers.utils import BaseOutput |
| 25 | |
| 26 | |
| 27 | def default(value, default_value): |
| 28 | return value if value is not None else default_value |
| 29 | |
| 30 | |
| 31 | def ensure_list(value): |
| 32 | if value is None: |
| 33 | return [] |
| 34 | if isinstance(value, (list, tuple)): |
| 35 | return list(value) |
| 36 | return [value] |
| 37 | |
| 38 | |
| 39 | class Resolution(object): |
| 40 | def __init__(self, size, *args): |
| 41 | if isinstance(size, str): |
| 42 | if 'x' in size: |
| 43 | size = size.split('x') |
| 44 | size = (int(size[0]), int(size[1])) |
| 45 | else: |
| 46 | size = int(size) |
| 47 | if len(args) > 0: |
| 48 | size = (size, args[0]) |
| 49 | if isinstance(size, int): |
| 50 | size = (size, size) |
| 51 | |
| 52 | self.h = self.height = size[0] |
| 53 | self.w = self.width = size[1] |
| 54 | self.r = self.ratio = self.height / self.width |
| 55 | |
| 56 | def __getitem__(self, idx): |
| 57 | if idx == 0: |
| 58 | return self.h |
| 59 | elif idx == 1: |
| 60 | return self.w |
| 61 | else: |
| 62 | raise IndexError(f'Index {idx} out of range') |
| 63 | |
| 64 | def __str__(self): |
| 65 | return f'{self.h}x{self.w}' |
| 66 | |
| 67 | |
| 68 | class ResolutionGroup(object): |
| 69 | def __init__(self, base_size=None, step=None, align=1): |
| 70 | self.align = align |
| 71 | self.base_size = base_size |
| 72 | assert base_size % align == 0, f'base_size {base_size} is not divisible by align {align}' |
| 73 | if base_size is not None and not isinstance(base_size, int): |
| 74 | raise ValueError(f'base_size must be None or int, but got {type(base_size)}') |
| 75 | if step is None: |
| 76 | step = base_size // 16 |
| 77 | if step is not None and step > base_size // 2: |
| 78 | raise ValueError(f'step must be smaller than base_size // 2, but got {step} > {base_size // 2}') |
| 79 | |
| 80 | self.step = step |
| 81 | self.data = self._calc_by_step() |
| 82 | |
| 83 | self.ratio = np.array([x.ratio for x in self.data]) |
| 84 | self.attr = ['' for _ in range(len(self.data))] |
| 85 | self.prefix_space = 0 |
| 86 | |
| 87 | def __len__(self): |
| 88 | return len(self.data) |
| 89 | |
| 90 | def __getitem__(self, idx): |
| 91 | return self.data[idx] |
| 92 | |
| 93 | def __repr__(self): |
| 94 | prefix = self.prefix_space * ' ' |
| 95 | prefix_close = (self.prefix_space - 4) * ' ' |
| 96 | res_str = f'ResolutionGroup(base_size={self.base_size}, step={self.step}, data=' |
| 97 | attr_maxlen = max([len(x) for x in self.attr] + [5]) |
| 98 | res_str += \ |
| 99 | f'\n{prefix}ID: height width ratio {" " * max(0, attr_maxlen - 4)}count h/16 w/16 tokens\n{prefix}' |
| 100 | res_str += \ |
| 101 | ('\n' + prefix).join([f'{i:2d}: ({x.h:4d}, {x.w:4d}) {self.ratio[i]:.4f} {self.attr[i]:>{attr_maxlen}s} ' |
| 102 | f'({x.h // 16:3d}, {x.w // 16:3d}) {x.h // 16 * x.w // 16:6d}' |
| 103 | for i, x in enumerate(self.data)]) |
| 104 | res_str += f'\n{prefix_close})' |
| 105 | return res_str |
| 106 | |
| 107 | def _calc_by_step(self): |
| 108 | assert self.align <= self.step, f'align {self.align} must be smaller than step {self.step}' |
| 109 | |
| 110 | min_height = self.base_size // 2 |
| 111 | min_width = self.base_size // 2 |
| 112 | max_height = self.base_size * 2 |
| 113 | max_width = self.base_size * 2 |
| 114 | |
| 115 | resolutions = [Resolution(self.base_size, self.base_size)] |
| 116 | |
| 117 | cur_height, cur_width = self.base_size, self.base_size |
| 118 | while True: |
| 119 | if cur_height >= max_height and cur_width <= min_width: |
| 120 | break |
| 121 | |
| 122 | cur_height = min(cur_height + self.step, max_height) |
| 123 | cur_width = max(cur_width - self.step, min_width) |
| 124 | resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align)) |
| 125 | |
| 126 | cur_height, cur_width = self.base_size, self.base_size |
| 127 | while True: |
| 128 | if cur_height <= min_height and cur_width >= max_width: |
| 129 | break |
| 130 | |
| 131 | cur_height = max(cur_height - self.step, min_height) |
| 132 | cur_width = min(cur_width + self.step, max_width) |
| 133 | resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align)) |
| 134 | |
| 135 | resolutions = sorted(resolutions, key=lambda x: x.ratio) |
| 136 | |
| 137 | return resolutions |
| 138 | |
| 139 | def get_target_size(self, width, height): |
| 140 | ratio = height / width |
| 141 | idx = np.argmin(np.abs(self.ratio - ratio)) |
| 142 | reso = self.data[idx] |
| 143 | return reso.w, reso.h |
| 144 | |
| 145 | def get_base_size_and_ratio_index(self, width, height): |
| 146 | ratio = height / width |
| 147 | idx = np.argmin(np.abs(self.ratio - ratio)) |
| 148 | return self.base_size, idx |
| 149 | |
| 150 | |
| 151 | class ImageInfo: |
| 152 | """ Class to store image information for processing and generation. """ |
| 153 | |
| 154 | def __init__( |
| 155 | self, |
| 156 | image_type: str = None, |
| 157 | image_tensor: torch.Tensor = None, |
| 158 | image_width: int = None, |
| 159 | image_height: int = None, |
| 160 | token_width: int = None, |
| 161 | token_height: int = None, |
| 162 | image_token_length: int = None, |
| 163 | base_size: int = None, |
| 164 | ratio_index: int = None, |
| 165 | **kwargs, |
| 166 | ): |
| 167 | self.image_type = image_type |
| 168 | self.image_tensor = image_tensor |
| 169 | self.image_width = image_width |
| 170 | self.w = image_width |
| 171 | self.image_height = image_height |
| 172 | self.h = image_height |
| 173 | self.token_width = token_width |
| 174 | self.tk_w = token_width |
| 175 | self.token_height = token_height |
| 176 | self.tk_h = token_height |
| 177 | self.image_token_length = default( |
| 178 | image_token_length, |
| 179 | token_width * token_height if token_width is not None and token_height is not None else None |
| 180 | ) |
| 181 | self.base_size = base_size |
| 182 | self.ratio_index = ratio_index |
| 183 | |
| 184 | self.add_timestep_token = kwargs.get("add_timestep_token", True) |
| 185 | self.add_guidance_token = kwargs.get("add_guidance_token", False) |
| 186 | self.use_front_boi_token = kwargs.get("use_front_boi_token", True) |
| 187 | self.add_image_shape_token = kwargs.get("add_image_shape_token", True) |
| 188 | |
| 189 | def __getitem__(self, key: str) -> Any: |
| 190 | """Allow dictionary-like access to attributes.""" |
| 191 | if hasattr(self, key): |
| 192 | return getattr(self, key) |
| 193 | raise KeyError(f"Key '{key}' not found in ImageInfo") |
| 194 | |
| 195 | def __setitem__(self, key: str, value: Any) -> None: |
| 196 | """Allow dictionary-like assignment to attributes.""" |
| 197 | if hasattr(self, key): |
| 198 | setattr(self, key, value) |
| 199 | else: |
| 200 | raise KeyError(f"Key '{key}' not found in ImageInfo") |
| 201 | |
| 202 | def __contains__(self, key: str) -> bool: |
| 203 | """Check if the key exists in the ImageInfo object.""" |
| 204 | return hasattr(self, key) |
| 205 | |
| 206 | def __repr__(self): |
| 207 | return (f"ImageInfo(image_type={self.image_type}, image_tensor={self.image_tensor}, " |
| 208 | f"image_width={self.image_width}, image_height={self.image_height}, " |
| 209 | f"token_width={self.token_width}, token_height={self.token_height}, " |
| 210 | f"image_token_length={self.image_token_length}, " |
| 211 | f"base_size={self.base_size}, ratio_index={self.ratio_index}") |
| 212 | |
| 213 | @property |
| 214 | def meta_info(self): |
| 215 | # Used for image sections of tkwrapper.encode_general() |
| 216 | if self.image_type in ["vae", "gen_image"]: |
| 217 | return dict( |
| 218 | token_length=self.image_token_length, |
| 219 | add_timestep_token=self.add_timestep_token, |
| 220 | add_guidance_token=self.add_guidance_token, |
| 221 | use_front_boi_token=self.use_front_boi_token, |
| 222 | add_image_shape_token=self.add_image_shape_token, |
| 223 | base_size=self.base_size, |
| 224 | ratio_idx=self.ratio_index, |
| 225 | # for rope 2d |
| 226 | token_height=self.token_height, |
| 227 | token_width=self.token_width, |
| 228 | # for bc |
| 229 | image_height=self.image_height, |
| 230 | image_width=self.image_width, |
| 231 | ) |
| 232 | elif self.image_type in ["vit"]: |
| 233 | return dict( |
| 234 | token_length=self.image_token_length, |
| 235 | use_front_boi_token=self.use_front_boi_token, |
| 236 | add_image_shape_token=self.add_image_shape_token, |
| 237 | # for rope 2d |
| 238 | token_height=self.token_height, |
| 239 | token_width=self.token_width, |
| 240 | # for bc |
| 241 | image_height=self.image_height, |
| 242 | image_width=self.image_width, |
| 243 | ) |
| 244 | else: |
| 245 | raise ValueError(f"Unknown image type '{self.image_type}'") |
| 246 | |
| 247 | @property |
| 248 | def num_special_tokens(self): |
| 249 | if self.args is None: |
| 250 | raise ValueError("meta_info requires `args` attribute to be set.") |
| 251 | if self.image_type in ["vae", "src_image", "gen_image"]: |
| 252 | count = ( |
| 253 | 2 + # <boi> + <eoi> or <src_boi> + <src_eoi> |
| 254 | (1 if self.add_timestep_token else 0) + |
| 255 | (1 if self.add_guidance_token else 0) + |
| 256 | (2 if self.add_image_shape_token else 0) |
| 257 | ) |
| 258 | else: |
| 259 | raise ValueError(f"Unknown image_type: {self.image_type}") |
| 260 | return count |
| 261 | |
| 262 | def copy(self, copy_image_tensor=True): |
| 263 | if copy_image_tensor and self.image_tensor is None: |
| 264 | raise ValueError("image_tensor is None, cannot copy") |
| 265 | return ImageInfo( |
| 266 | image_type=self.image_type, |
| 267 | image_tensor=self.image_tensor.clone() if copy_image_tensor else None, |
| 268 | image_width=self.image_width, |
| 269 | image_height=self.image_height, |
| 270 | token_width=self.token_width, |
| 271 | token_height=self.token_height, |
| 272 | image_token_length=self.image_token_length, |
| 273 | base_size=self.base_size, |
| 274 | ratio_index=self.ratio_index, |
| 275 | ) |
| 276 | |
| 277 | def zeros_(self): |
| 278 | self.image_tensor = torch.zeros_like(self.image_tensor) |
| 279 | |
| 280 | |
| 281 | class ImageTensor(torch.Tensor): |
| 282 | # This class is just for type hinting purposes. Attribute `i` should be defined |
| 283 | # as an instance attribute of the torch.Tensor instance, like: tensor.i = ImageInfo(...) |
| 284 | i: ImageInfo |
| 285 | vision_encoder_kwargs: dict |
| 286 | |
| 287 | |
| 288 | class JointImageInfo(object): |
| 289 | def __init__(self, vae_image_info: ImageInfo, vision_image_info: ImageInfo, vision_encoder_kwargs: dict = None): |
| 290 | self.vae_image_info = vae_image_info |
| 291 | self.vision_image_info = vision_image_info |
| 292 | self.vision_encoder_kwargs = vision_encoder_kwargs |
| 293 | |
| 294 | # Define key attributes to align with ImageInfo for uniformity |
| 295 | self.image_type = "joint_image" |
| 296 | self.image_token_length = vae_image_info.image_token_length + vision_image_info.image_token_length |
| 297 | |
| 298 | self.add_timestep_token = vae_image_info.add_timestep_token |
| 299 | self.use_front_boi_token = vae_image_info.use_front_boi_token |
| 300 | self.add_image_shape_token = vae_image_info.add_image_shape_token |
| 301 | |
| 302 | def __repr__(self): |
| 303 | return f"JointImageInfo(vae_image={self.vae_image_info}, vision_image={self.vision_image_info})" |
| 304 | |
| 305 | @property |
| 306 | def meta_info(self): |
| 307 | # Used for image sections of tkwrapper.encode_general() |
| 308 | return dict( |
| 309 | token_length=[self.vae_image_info.image_token_length, self.vision_image_info.image_token_length], |
| 310 | add_timestep_token=self.add_timestep_token, |
| 311 | use_front_boi_token=self.use_front_boi_token, |
| 312 | add_image_shape_token=self.add_image_shape_token, |
| 313 | base_size=self.vae_image_info.base_size, |
| 314 | ratio_idx=self.vae_image_info.ratio_index, |
| 315 | # for rope 2d |
| 316 | token_height=[self.vae_image_info.token_height, self.vision_image_info.token_height], |
| 317 | token_width=[self.vae_image_info.token_width, self.vision_image_info.token_width], |
| 318 | # for bc |
| 319 | image_height=[self.vae_image_info.image_height, self.vision_image_info.image_height], |
| 320 | image_width=[self.vae_image_info.image_width, self.vision_image_info.image_width], |
| 321 | ) |
| 322 | |
| 323 | @property |
| 324 | def num_special_tokens(self): |
| 325 | return ( |
| 326 | 2 + # <boi> + <eoi> |
| 327 | (1 if self.add_timestep_token else 0) + |
| 328 | (2 if self.add_image_shape_token else 0) + |
| 329 | 1 # <joint_image_sep> |
| 330 | ) |
| 331 | |
| 332 | def copy(self, copy_image_tensor=True): |
| 333 | if copy_image_tensor and ( |
| 334 | self.vae_image_info.image_tensor is None or self.vision_image_info.image_tensor is None): |
| 335 | raise ValueError("image_tensor is None, cannot copy") |
| 336 | return JointImageInfo( |
| 337 | self.vae_image_info.copy(copy_image_tensor), |
| 338 | self.vision_image_info.copy(copy_image_tensor), |
| 339 | self.vision_encoder_kwargs, |
| 340 | ) |
| 341 | |
| 342 | def zeros_(self): |
| 343 | self.vae_image_info.zeros_() |
| 344 | self.vision_image_info.zeros_() |
| 345 | |
| 346 | |
| 347 | class JointImage(object): |
| 348 | def __init__(self, vae_image: ImageTensor, vision_image: ImageTensor): |
| 349 | self.vae_image = vae_image |
| 350 | self.vision_image = vision_image |
| 351 | self.i = JointImageInfo(vae_image.i, vision_image.i) |
| 352 | |
| 353 | |
| 354 | class TokenizerEncodeOutput(BaseOutput): |
| 355 | tokens: torch.Tensor = None |
| 356 | timestep_scatter_index: Optional[torch.Tensor] = None |
| 357 | guidance_scatter_index: Optional[torch.Tensor] = None |
| 358 | text_slices: Optional[List[slice]] = None |
| 359 | gen_image_slices: Optional[List[slice]] = None |
| 360 | joint_image_slices: Optional[List[slice]] = None |
| 361 | cond_vae_image_slices: Optional[List[slice]] = None |
| 362 | cond_vit_image_slices: Optional[List[slice]] = None |
| 363 | text_mask: Optional[torch.Tensor] = None |
| 364 | gen_image_mask: Optional[torch.Tensor] = None |
| 365 | cond_vae_image_mask: Optional[torch.Tensor] = None |
| 366 | cond_vit_image_mask: Optional[torch.Tensor] = None |
| 367 | real_pos: Optional[torch.Tensor] = None |
| 368 | all_image_slices: Optional[List[slice]] = None |
| 369 | cond_timestep_scatter_index: Optional[torch.Tensor] = None |
| 370 | gen_timestep_scatter_index: Optional[torch.Tensor] = None |
| 371 | |
| 372 | |
| 373 | class Conversation: |
| 374 | roles: List[str] = ["User", "Assistant"] |
| 375 | sep: str = "\n\n" |
| 376 | |
| 377 | |
| 378 | class TokenizerWrapper(object): |
| 379 | def __init__(self, tokenizer): |
| 380 | if isinstance(tokenizer, str): |
| 381 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) |
| 382 | else: |
| 383 | self.tokenizer = tokenizer |
| 384 | |
| 385 | # Define short names |
| 386 | self.bos_token_id = self.tokenizer.bos_token_id |
| 387 | self.eos_token_id = self.tokenizer.eos_token_id |
| 388 | self.pad_token_id = self.tokenizer.pad_token_id |
| 389 | self.boi_token_id = self.tokenizer.convert_tokens_to_ids("<boi>") |
| 390 | self.eoi_token_id = self.tokenizer.convert_tokens_to_ids("<eoi>") |
| 391 | self.img_token_id = self.tokenizer.convert_tokens_to_ids("<img>") |
| 392 | self.cfg_token_id = self.tokenizer.convert_tokens_to_ids("<cfg>") |
| 393 | self.end_answer_token_id = self.tokenizer.convert_tokens_to_ids("</answer>") |
| 394 | self.end_recaption_token_id = self.tokenizer.convert_tokens_to_ids("</recaption>") |
| 395 | self.ratio_token_offset = self.tokenizer.convert_tokens_to_ids("<img_ratio_0>") |
| 396 | self.special_token_map = self.tokenizer.added_tokens_encoder |
| 397 | |
| 398 | def pad(self, tensor_list, dim=0, pad_val=None): |
| 399 | if pad_val is None: |
| 400 | pad_val = self.pad_token_id |
| 401 | max_len = max([t.shape[dim] for t in tensor_list]) |
| 402 | padded_tensor_list = [] |
| 403 | for t in tensor_list: |
| 404 | if t.shape[dim] < max_len: |
| 405 | assert pad_val is not False, "Not allowed pad." |
| 406 | t = F.pad(t, (0, max_len - t.shape[dim]), value=pad_val) |
| 407 | padded_tensor_list.append(t) |
| 408 | return padded_tensor_list |
| 409 | |
| 410 | def encode(self, *args, **kwargs): |
| 411 | return self.tokenizer.encode(*args, **kwargs) |
| 412 | |
| 413 | def decode(self, *args, **kwargs): |
| 414 | return self.tokenizer.decode(*args, **kwargs) |
| 415 | |
| 416 | def encode_text( |
| 417 | self, |
| 418 | *texts, |
| 419 | uncond_enabled: Optional[Union[bool, List[bool]]] = None, |
| 420 | uncond_p: Optional[float] = None, |
| 421 | max_length: Optional[int] = None, |
| 422 | pad: Optional[str] = None, |
| 423 | return_lengths: bool = False, |
| 424 | ): |
| 425 | """ |
| 426 | Encode text and image for AR-like model training of the text-to-image/instruction tuning tasks. |
| 427 | Support encode multiple texts at once. Each text can be separately conditioned or unconditioned |
| 428 | based on the uncond_flags and a uniform uncond_p. |
| 429 | **<bos> token is always prepended to the text tokens.** |
| 430 | |
| 431 | Parameters |
| 432 | ---------- |
| 433 | texts: str or List[str] |
| 434 | List of texts to be encoded. |
| 435 | uncond_enabled: bool or List[bool] |
| 436 | List of flags to indicate whether the text should be unconditioned. |
| 437 | If False, the text will never be unconditioned. |
| 438 | If True, the text will be unconditioned with uncond_p. |
| 439 | uncond_p: float |
| 440 | Probability to the unconditional text. Only works when uncond_enabled is True. |
| 441 | max_length: int |
| 442 | Maximum length of the encoded text. |
| 443 | pad: Optional[str] |
| 444 | Padding method. Can be 'left' or 'right'. |
| 445 | return_lengths: bool |
| 446 | Whether to return the length of each encoded text. |
| 447 | """ |
| 448 | if pad is not None: |
| 449 | assert max_length is not None, "max_length should be provided when pad is not None." |
| 450 | |
| 451 | if uncond_enabled is None: |
| 452 | uncond_enabled = [True] * len(texts) |
| 453 | elif isinstance(uncond_enabled, bool): |
| 454 | uncond_enabled = [uncond_enabled] * len(texts) |
| 455 | if len(uncond_enabled) != len(texts): |
| 456 | print(uncond_enabled, texts) |
| 457 | assert len(uncond_enabled) == len(texts), ( |
| 458 | f"Length of uncond_flags should be equal to the number of texts, " |
| 459 | f"but got {len(uncond_enabled)} and {len(texts)}." |
| 460 | ) |
| 461 | |
| 462 | # Prepare text/uncond tokens |
| 463 | # TODO: If len(texts) > 1, such as instruction + prompt in inpainting, we need to determine how to do uncond. |
| 464 | # Now all texts will be cond or uncond at the same time. |
| 465 | do_uncond_drop = (uncond_p is not None) and (random.random() < uncond_p) |
| 466 | text_tokens, lengths = [], [] |
| 467 | cum_length = 0 |
| 468 | for text, uncond_flag in zip(texts, uncond_enabled): |
| 469 | # If reach the max_length and there still have unencoded texts, give a warning message and break the loop. |
| 470 | if max_length is not None and cum_length >= max_length: |
| 471 | warnings.warn( |
| 472 | f"Text length exceeds the max_length({max_length}). The remaining texts will be ignored: " |
| 473 | f"{text[:80]}..." |
| 474 | ) |
| 475 | break |
| 476 | # Set add_special_tokens=False to avoid adding <bos> token in some LLMs. |
| 477 | if isinstance(text, str): |
| 478 | text_token = self.tokenizer.encode(text, add_special_tokens=False) |
| 479 | else: |
| 480 | text_token = text |
| 481 | if uncond_flag and do_uncond_drop: |
| 482 | text_token = [self.cfg_token_id] * len(text_token) |
| 483 | # Cutoff the text by max_length if necessary |
| 484 | if max_length is not None and (cum_length + len(text_token)) > max_length: |
| 485 | text_token = text_token[:max_length - cum_length] |
| 486 | text_tokens.extend(text_token) |
| 487 | lengths.append(len(text_token)) |
| 488 | cum_length += len(text_token) |
| 489 | |
| 490 | # Prepend/Append <pad> tokens if applicable |
| 491 | if pad is not None and (pad_length := max_length - len(text_tokens)) > 0: |
| 492 | if pad == 'left': |
| 493 | text_tokens = [self.pad_token_id] * pad_length + text_tokens |
| 494 | elif pad == 'right': |
| 495 | text_tokens = text_tokens + [self.pad_token_id] * pad_length |
| 496 | else: |
| 497 | raise ValueError(f"Unsupported padding method: {pad}.") |
| 498 | |
| 499 | if return_lengths: |
| 500 | return text_tokens, lengths |
| 501 | return text_tokens |
| 502 | |
| 503 | @staticmethod |
| 504 | def _check_key_number_matched(keys, data): |
| 505 | # Assert keys and token_source are matched |
| 506 | assert set(keys) == set(data.keys()), ( |
| 507 | f"Keys in the template and token source should be matched, but got {set(keys)} and {list(data.keys())}." |
| 508 | ) |
| 509 | key_counts = {k: 0 for k in keys} |
| 510 | for key in keys: |
| 511 | key_counts[key] += 1 |
| 512 | for key, count in key_counts.items(): |
| 513 | assert len(data[key]) == count, ( |
| 514 | f"Number of `{key}` in the token source should be matched with the template, but got " |
| 515 | f"{data[key]}({len(data[key])}) and {count}." |
| 516 | ) |
| 517 | |
| 518 | def _add_image_meta_info_token(self, token_seq, token_count, extra_token_pos, add_timestep_token=False, |
| 519 | add_image_shape_token=False, base_size=None, ratio_idx=None, image_type=None, |
| 520 | add_guidance_token=False): |
| 521 | if add_image_shape_token: |
| 522 | token_seq.extend([ |
| 523 | self.special_token_map[f"<img_size_{base_size}>"], |
| 524 | self.special_token_map[f"<img_ratio_{ratio_idx}>"] |
| 525 | ]) |
| 526 | token_count += 2 |
| 527 | if add_timestep_token: |
| 528 | token_seq.extend([self.special_token_map["<timestep>"]]) |
| 529 | extra_token_pos['timestep'].append(token_count) |
| 530 | if image_type is not None: |
| 531 | if image_type == "gen_image": |
| 532 | extra_token_pos['gen_timestep'].append(token_count) |
| 533 | elif image_type in ["joint_image"]: |
| 534 | extra_token_pos['cond_timestep'].append(token_count) |
| 535 | else: |
| 536 | raise ValueError(f"Unsupported image type: {image_type}.") |
| 537 | token_count += 1 |
| 538 | if add_guidance_token: |
| 539 | token_seq.extend([self.special_token_map["<guidance>"]]) |
| 540 | extra_token_pos['guidance'].append(token_count) |
| 541 | token_count += 1 |
| 542 | return token_count |
| 543 | |
| 544 | @staticmethod |
| 545 | def _shorten_text(text): |
| 546 | import re |
| 547 | text = re.sub(r"(<img>)+", lambda m: f"[<img>]{{{len(m.group(0)) // 5}}}", text) |
| 548 | text = re.sub(r"(<pad>)+", lambda m: f"[<pad>]{{{len(m.group(0)) // 5}}}", text) |
| 549 | return text |
| 550 | |
| 551 | def encode_sequence( |
| 552 | self, |
| 553 | template: str, |
| 554 | token_source: Dict[str, List], |
| 555 | total_length=None, |
| 556 | add_timestep_token=False, |
| 557 | add_guidance_token=False, |
| 558 | last_key_only_prefix=False, |
| 559 | add_eos=True, |
| 560 | use_front_boi_token=True, |
| 561 | add_pad=True, |
| 562 | add_bos=True, |
| 563 | drop_last: Union[str, bool] = 'auto', |
| 564 | add_image_shape_token=False, |
| 565 | ): |
| 566 | """ |
| 567 | Encode a sequence based on the template (e.g., `text-image` for t2i, `text-image-image` for instruction tuning) |
| 568 | and token source. |
| 569 | |
| 570 | Parameters |
| 571 | ---------- |
| 572 | template: str |
| 573 | Template of the sequence. E.g., "text-gen_image" means the sequence is composed of text and an image. |
| 574 | "text-text-gen_image" means the sequence is composed of two sections of text and an image. |
| 575 | token_source: Dict[str, List] |
| 576 | Token source for each key in the template, in order. |
| 577 | - text: List[Dict]. |
| 578 | - gen_image: List[Dict]. |
| 579 | - joint_image: List[Dict]. |
| 580 | total_length: int |
| 581 | Total length of the encoded sequence, include padding tokens. |
| 582 | add_timestep_token: bool |
| 583 | Whether to add timestep token before the image tokens. |
| 584 | (Right after the <img_ratio_*><img_size_*> tokens) |
| 585 | add_guidance_token: bool |
| 586 | Whether to add guidance token before the image tokens. |
| 587 | last_key_only_prefix: bool |
| 588 | Whether to only use the modal prefix in the last key. |
| 589 | add_eos: bool or 'auto' |
| 590 | Whether to add eos token at the end of the sequence. If True, always add eos token. If 'auto', |
| 591 | add eos token only when the total_length is not reached and the last token is not <eos>. |
| 592 | use_front_boi_token: bool: |
| 593 | Whether to put the <boi> token at the front of iw, ih and timestep tokens. |
| 594 | add_pad: bool or 'auto' |
| 595 | Whether to add padding tokens to the sequence. If True and total_length is not reached, add padding tokens. |
| 596 | add_bos: bool |
| 597 | Whether to add bos token at the beginning of the sequence. |
| 598 | drop_last: bool or 'auto' |
| 599 | - If auto, drop last tokens exceeding the total_length if the total_length is provided. If cut point is |
| 600 | in the middle of the image tokens, an error will raised. |
| 601 | - If True, drop last tokens exceeding the total_length. If cut point is in the middle of the image tokens, |
| 602 | all the successive image tokens will be dropped. |
| 603 | - If False, keep the last tokens exceeding the total_length, even if the total_length is reached. |
| 604 | add_image_shape_token: bool |
| 605 | Whether to add image shape token before the image tokens. (Right before the <timestep> token) |
| 606 | |
| 607 | Returns |
| 608 | ------- |
| 609 | token_seq: list |
| 610 | Encoded token sequence. |
| 611 | extra_token_pos: dict |
| 612 | Positions of extra tokens. |
| 613 | """ |
| 614 | if last_key_only_prefix: |
| 615 | assert add_eos is not True, "add_eos should not be True when last_key_only_prefix is True." |
| 616 | if drop_last is True and total_length is None: |
| 617 | raise ValueError("total_length should be provided when drop_last is True.") |
| 618 | |
| 619 | keys = template.split('-') |
| 620 | modal_length = len(keys) |
| 621 | index_indicator = {k: 0 for k in token_source} |
| 622 | for k, v in token_source.items(): |
| 623 | assert isinstance(v, (list, tuple)), ( |
| 624 | f"Value of `{k}` in the token source should be a list or tuple, but got {type(v)}." |
| 625 | ) |
| 626 | self._check_key_number_matched(keys, token_source) |
| 627 | |
| 628 | token_seq = [] |
| 629 | token_count = 0 |
| 630 | extra_token_pos = defaultdict(list) |
| 631 | if add_bos: |
| 632 | token_seq.append(self.bos_token_id) |
| 633 | token_count += 1 |
| 634 | # If drop_last is True, we check the token_count on the fly and exit the loop if the total_length is reached. |
| 635 | # This check is only applied to the block tokens. Block tokens mean the tokens that are unsplittable, like |
| 636 | # image tokens. Text tokens are splittable, so we don't need to check the token_count for text. |
| 637 | # If the loop is broken by drop_last, we don't add the eos token at the end because the sequence is not |
| 638 | # complete. |
| 639 | drop_last_break = False |
| 640 | for i, key in enumerate(keys): |
| 641 | source = token_source[key][index_indicator[key]] |
| 642 | if key == "text": |
| 643 | token_seq.extend(source) # text token sequence |
| 644 | extra_token_pos["<text>_start"].append(token_count) |
| 645 | token_count += len(source) |
| 646 | extra_token_pos["<text>_end"].append(token_count - 1) |
| 647 | |
| 648 | elif key == "gen_image": |
| 649 | if isinstance(source, int): |
| 650 | source = {'length': source} |
| 651 | extra_count = 2 + ( |
| 652 | 1 if source.get('timestep', add_timestep_token) else 0) + ( |
| 653 | 1 if source.get('guidance', add_guidance_token) else 0) + ( |
| 654 | 2 if source.get('image_shape', add_image_shape_token) else 0 |
| 655 | ) |
| 656 | if drop_last is True and token_count + extra_count + source['length'] > total_length: |
| 657 | drop_last_break = True |
| 658 | break |
| 659 | if source.get('front_boi', use_front_boi_token): |
| 660 | token_seq.append(self.boi_token_id) |
| 661 | extra_token_pos["boi"].append(token_count) |
| 662 | token_count += 1 |
| 663 | token_count = self._add_image_meta_info_token( |
| 664 | token_seq=token_seq, |
| 665 | token_count=token_count, |
| 666 | extra_token_pos=extra_token_pos, |
| 667 | add_timestep_token=source.get('timestep', add_timestep_token), |
| 668 | add_guidance_token=source.get('guidance', add_guidance_token), |
| 669 | add_image_shape_token=source.get('image_shape', add_image_shape_token), |
| 670 | base_size=source.get('base_size'), |
| 671 | ratio_idx=source.get('ratio_idx'), |
| 672 | image_type=key, |
| 673 | ) |
| 674 | if not source.get('front_boi', use_front_boi_token): |
| 675 | token_seq.append(self.boi_token_id) |
| 676 | extra_token_pos["boi"].append(token_count) |
| 677 | token_count += 1 |
| 678 | if last_key_only_prefix and i == modal_length - 1: |
| 679 | pass # for AR inference |
| 680 | else: |
| 681 | token_seq.extend( |
| 682 | [self.img_token_id] * source['length'] + # token number |
| 683 | [self.eoi_token_id] |
| 684 | ) |
| 685 | extra_token_pos["<img>_start"].append(token_count) |
| 686 | extra_token_pos["<all_img>_start"].append(token_count) |
| 687 | token_count += source['length'] |
| 688 | extra_token_pos["<img>_end"].append(token_count - 1) |
| 689 | extra_token_pos["<all_img>_end"].append(token_count - 1) |
| 690 | extra_token_pos["eoi"].append(token_count) |
| 691 | token_count += 1 # <eoi> |
| 692 | |
| 693 | elif key == "joint_image": |
| 694 | assert isinstance(source['length'], list) and len( |
| 695 | source['length']) == 2, "joint_image length should be a list of two integers" |
| 696 | extra_count = 2 + 1 + ( # boi, eoi, joint_img_sep |
| 697 | 1 if source.get('timestep', add_timestep_token) else 0) + ( |
| 698 | 2 if source.get('image_shape', add_image_shape_token) else 0 |
| 699 | ) |
| 700 | if drop_last is True and token_count + extra_count + sum(source['length']) > total_length: |
| 701 | drop_last_break = True |
| 702 | break |
| 703 | if source.get('front_boi', use_front_boi_token): |
| 704 | token_seq.append(self.boi_token_id) # Use patched boi for Janus, otherwise useing default <boi> |
| 705 | extra_token_pos["boi"].append(token_count) |
| 706 | token_count += 1 |
| 707 | token_count = self._add_image_meta_info_token( |
| 708 | token_seq=token_seq, |
| 709 | token_count=token_count, |
| 710 | extra_token_pos=extra_token_pos, |
| 711 | add_timestep_token=source.get('timestep', add_timestep_token), |
| 712 | add_image_shape_token=source.get('image_shape', add_image_shape_token), |
| 713 | base_size=source.get('base_size'), |
| 714 | ratio_idx=source.get('ratio_idx'), |
| 715 | image_type=key, |
| 716 | ) |
| 717 | if not source.get('front_boi', use_front_boi_token): |
| 718 | token_seq.append(self.boi_token_id) |
| 719 | extra_token_pos["boi"].append(token_count) |
| 720 | token_count += 1 |
| 721 | if last_key_only_prefix and i == modal_length - 1: |
| 722 | pass # for AR inference |
| 723 | else: |
| 724 | token_seq.extend( |
| 725 | [self.img_token_id] * source['length'][0] |
| 726 | ) |
| 727 | extra_token_pos["<vae_img>_start"].append(token_count) |
| 728 | extra_token_pos["<joint_img>_start"].append(token_count) |
| 729 | extra_token_pos["<all_img>_start"].append(token_count) |
| 730 | token_count += source['length'][0] |
| 731 | extra_token_pos["<vae_img>_end"].append(token_count - 1) |
| 732 | extra_token_pos["<all_img>_end"].append(token_count - 1) |
| 733 | |
| 734 | token_seq.extend( |
| 735 | [self.special_token_map["<joint_img_sep>"]] |
| 736 | ) |
| 737 | extra_token_pos["joint_img_sep"].append(token_count) |
| 738 | token_count += 1 |
| 739 | |
| 740 | token_seq.extend( |
| 741 | [self.img_token_id] * source['length'][1] |
| 742 | ) |
| 743 | extra_token_pos["<vit_img>_start"].append(token_count) |
| 744 | extra_token_pos["<all_img>_start"].append(token_count) |
| 745 | token_count += source['length'][1] |
| 746 | extra_token_pos["<vit_img>_end"].append(token_count - 1) |
| 747 | extra_token_pos["<joint_img>_end"].append(token_count - 1) |
| 748 | extra_token_pos["<all_img>_end"].append(token_count - 1) |
| 749 | |
| 750 | token_seq.extend( |
| 751 | [self.eoi_token_id] |
| 752 | ) |
| 753 | extra_token_pos["eoi"].append(token_count) |
| 754 | token_count += 1 # <eoi> |
| 755 | |
| 756 | else: |
| 757 | raise ValueError(f"Not supported key: {key}") |
| 758 | index_indicator[key] += 1 |
| 759 | |
| 760 | if add_eos is True and not drop_last_break: |
| 761 | # Typically used for t2i task. |
| 762 | token_seq.append(self.eos_token_id) |
| 763 | extra_token_pos["eos"].append(token_count) |
| 764 | token_count += 1 |
| 765 | elif add_eos == 'auto' and not drop_last_break: |
| 766 | # Typically used for lm and mmu task. |
| 767 | if token_seq[-1] != self.eos_token_id and (total_length is None or token_count < total_length): |
| 768 | token_seq.append(self.eos_token_id) |
| 769 | extra_token_pos["eos"].append(token_count) |
| 770 | token_count += 1 |
| 771 | |
| 772 | if total_length: |
| 773 | # Check token count and clip sequence if necessary |
| 774 | if token_count > total_length and drop_last: |
| 775 | # Assert clip position is not in the middle of the block-wise tokens (gen_image, joint_image) |
| 776 | for start_key, end_key in [ |
| 777 | ("<img>_start", "<img>_end"), ("<joint_img>_start", "<joint_img>_end"), |
| 778 | ("<vae_img>_start", "<vae_img>_end"), ("<vit_img>_start", "<vit_img>_end"), |
| 779 | ]: |
| 780 | if start_key in extra_token_pos and end_key in extra_token_pos: |
| 781 | assert all( |
| 782 | (start > total_length or end + 1 < total_length) |
| 783 | for start, end in zip(extra_token_pos[start_key], extra_token_pos[end_key]) |
| 784 | ), ("Clip position should not be in the middle of the image tokens.\n" |
| 785 | f"Below is the text:\n{self._shorten_text(self.tokenizer.decode(token_seq))}") |
| 786 | token_seq = token_seq[:total_length] |
| 787 | |
| 788 | # Pad the sequence if necessary |
| 789 | pad_num = max(0, total_length - len(token_seq)) |
| 790 | if add_pad and pad_num: |
| 791 | token_seq.extend([self.pad_token_id] * pad_num) |
| 792 | extra_token_pos["first_pad"].append(token_count) |
| 793 | |
| 794 | return token_seq, extra_token_pos |
| 795 | |
| 796 | def batch_gen_infer( |
| 797 | self, |
| 798 | infer_fn, |
| 799 | prompt_list: list, |
| 800 | negative_prompt_list: list = None, |
| 801 | infer_fn_kwargs_list: List[Dict[str, int]] = None, |
| 802 | do_classifier_free_guidance=False, |
| 803 | condition_repeat_times: int = 1, |
| 804 | uncondition_repeat_times: int = 1, |
| 805 | ): |
| 806 | """ |
| 807 | Batch inference for the AR-like model training of the text-to-image/instruction tuning tasks. |
| 808 | |
| 809 | Parameters |
| 810 | ---------- |
| 811 | infer_fn: callable |
| 812 | Inference function to encode the prompt. |
| 813 | prompt_list: list |
| 814 | List of prompts. Each element can be a single prompt or a list of prompts passed to the infer_fn. |
| 815 | negative_prompt_list: list |
| 816 | List of negative prompts. Only used when do_classifier_free_guidance is True. If None, will use <cfg> |
| 817 | token sequence as negative prompt. |
| 818 | infer_fn_kwargs_list: List[Dict[str, int]] |
| 819 | List of keyword arguments for the infer_fn. |
| 820 | do_classifier_free_guidance: bool |
| 821 | Whether to do classifier-free guidance. |
| 822 | condition_repeat_times: int |
| 823 | Support multi-condition. |
| 824 | uncondition_repeat_times: int |
| 825 | Support multi-uncondition. |
| 826 | """ |
| 827 | if infer_fn_kwargs_list is None: |
| 828 | infer_fn_kwargs_list = [{} for _ in prompt_list] |
| 829 | |
| 830 | # [n_output, bsz] |
| 831 | cond_results_list = None |
| 832 | uncond_results_list = None |
| 833 | output_type_list = [] |
| 834 | |
| 835 | for prompt_idx, (prompt, infer_fn_kwargs) in enumerate(zip(prompt_list, infer_fn_kwargs_list)): |
| 836 | if not isinstance(prompt, (list, tuple)): |
| 837 | prompt = [prompt] |
| 838 | cond_kwargs = {"uncond_p": 0.0} if do_classifier_free_guidance else {} |
| 839 | results = infer_fn( |
| 840 | *prompt, |
| 841 | **infer_fn_kwargs, |
| 842 | **cond_kwargs, |
| 843 | ) |
| 844 | output_type_list.append((type(results), len(results) if isinstance(results, (list, tuple)) else 1)) |
| 845 | if isinstance(results, dict): |
| 846 | raise ValueError("Make batch on dict is not supported. Please return list or tuple for infer_fn.") |
| 847 | if not isinstance(results, (list, tuple)): |
| 848 | results = (results,) |
| 849 | if cond_results_list is None: |
| 850 | cond_results_list = [[] for _ in results] |
| 851 | uncond_results_list = [[] for _ in results] |
| 852 | for i, result in enumerate(results): |
| 853 | cond_results_list[i].append(result) |
| 854 | |
| 855 | if do_classifier_free_guidance: |
| 856 | if negative_prompt_list is None: |
| 857 | uncond_kwargs = {"uncond_p": 1.0} |
| 858 | uncond_results = infer_fn( |
| 859 | *prompt, |
| 860 | **infer_fn_kwargs, |
| 861 | **uncond_kwargs, |
| 862 | ) |
| 863 | else: |
| 864 | negative_prompt = negative_prompt_list[prompt_idx] |
| 865 | if not isinstance(negative_prompt, (list, tuple)): |
| 866 | negative_prompt = [negative_prompt] |
| 867 | uncond_results = infer_fn( |
| 868 | *negative_prompt, |
| 869 | **infer_fn_kwargs, |
| 870 | ) |
| 871 | if isinstance(uncond_results, TokenizerEncodeOutput): |
| 872 | uncond_results_list.append(uncond_results) |
| 873 | else: |
| 874 | for i, result in enumerate(uncond_results): |
| 875 | uncond_results_list[i].append(result) |
| 876 | |
| 877 | assert all(output_type_list[0] == n for n in output_type_list), \ |
| 878 | f"Number of outputs should be equal for all samples, but got {output_type_list}." |
| 879 | output_type, output_num = output_type_list[0] |
| 880 | |
| 881 | def make_batch(batch_cond_item, batch_uncond_item): |
| 882 | # Process each output item to make batch |
| 883 | first = batch_cond_item[0] # The first element in the batch |
| 884 | if isinstance(first, torch.Tensor): |
| 885 | stacked_item = torch.stack(self.pad( |
| 886 | batch_cond_item * condition_repeat_times + batch_uncond_item * uncondition_repeat_times, |
| 887 | )) |
| 888 | |
| 889 | elif first is None: |
| 890 | assert all(item is None for item in batch_cond_item + batch_uncond_item), \ |
| 891 | (f"The first cond item is None, but some items are not None:\n\n" |
| 892 | f"condition: {batch_cond_item}\n\n" |
| 893 | f"uncondition: {batch_uncond_item}") |
| 894 | stacked_item = None |
| 895 | |
| 896 | elif isinstance(first, (list, tuple)): |
| 897 | # If the output item is a list or tuple, we treat it as a whole, and won't make nested batch any more. |
| 898 | stacked_item = batch_cond_item * condition_repeat_times + batch_uncond_item * uncondition_repeat_times |
| 899 | |
| 900 | elif isinstance(first, TokenizerEncodeOutput): |
| 901 | stacked_item = {} |
| 902 | # Traverse not-None attributes |
| 903 | for key in list(first.keys()): |
| 904 | merged_list = [cond_item[key] for cond_item in batch_cond_item] * condition_repeat_times + \ |
| 905 | [uncond_item[key] for uncond_item in batch_uncond_item] * uncondition_repeat_times |
| 906 | if isinstance(first[key], torch.Tensor): |
| 907 | if 'mask' in key: |
| 908 | pad_val = 0.0 |
| 909 | elif key == 'tokens': |
| 910 | pad_val = self.special_token_map["<pad>"] |
| 911 | else: |
| 912 | pad_val = False # Should not pad for other tensors |
| 913 | stacked_item[key] = torch.stack(self.pad(merged_list, pad_val=pad_val), dim=0) |
| 914 | elif isinstance(first[key], list): |
| 915 | stacked_item[key] = merged_list |
| 916 | elif first[key] is None: |
| 917 | pass |
| 918 | else: |
| 919 | raise ValueError(f"Unsupported type of {key}: {type(first[key])}.") |
| 920 | stacked_item = TokenizerEncodeOutput(stacked_item) |
| 921 | |
| 922 | else: |
| 923 | raise TypeError(f"Making batch on type {type(first)} is not supported.") |
| 924 | |
| 925 | return stacked_item |
| 926 | |
| 927 | stacked_outputs = [] |
| 928 | for cond_results, uncond_results in zip(cond_results_list, uncond_results_list): |
| 929 | stacked_outputs.append(make_batch(cond_results, uncond_results)) |
| 930 | |
| 931 | if output_type == list: |
| 932 | return stacked_outputs |
| 933 | elif output_type == tuple: |
| 934 | return tuple(stacked_outputs) |
| 935 | elif output_num == 1: |
| 936 | return stacked_outputs[0] |
| 937 | else: |
| 938 | raise ValueError(f"Unsupported output type: {output_type}.") |
| 939 | |
| 940 | @staticmethod |
| 941 | def parse_extra_token_pos(extra_token_pos, prefix, tokens, rng=None): |
| 942 | if rng is None: |
| 943 | rng = slice(None) |
| 944 | image_slices = [ |
| 945 | slice(start, end + 1) |
| 946 | for start, end in zip(extra_token_pos[f'<{prefix}>_start'][rng], extra_token_pos[f'<{prefix}>_end'][rng]) |
| 947 | ] if f'<{prefix}>_start' in extra_token_pos and f'<{prefix}>_end' in extra_token_pos else [] |
| 948 | if image_slices: |
| 949 | image_mask = torch.zeros_like(tokens, dtype=torch.bool) |
| 950 | for image_slice in image_slices: |
| 951 | image_mask[image_slice] = True |
| 952 | else: |
| 953 | image_mask = None |
| 954 | return image_slices, image_mask |
| 955 | |
| 956 | def encode_general( |
| 957 | self, |
| 958 | sections: Optional[List[Dict[str, Any]]] = None, |
| 959 | max_token_length: Optional[int] = None, |
| 960 | add_eos='auto', |
| 961 | use_text_mask=True, |
| 962 | add_pad='auto', |
| 963 | add_bos=True, |
| 964 | drop_last='auto', |
| 965 | ): |
| 966 | """ |
| 967 | General encode function to encode a sequence with multiple sections of text and images. |
| 968 | Each section is a dict with a `type` key and other keys depending on the type. |
| 969 | Supported section types: |
| 970 | - text: dict with keys: |
| 971 | - text: str or List[int], text to be encoded. Either `text` or `tokens` should be provided. |
| 972 | - tokens: List[int], pre-encoded text tokens. Either `text` or `tokens` should be provided. |
| 973 | - uncond_enabled: bool, whether to enable uncondition for this text section. |
| 974 | - uncond_p: float, probability to drop the text section for uncondition. |
| 975 | - max_length: int, maximum length of the text section. |
| 976 | - ignore: bool, whether to ignore this text section in the text mask. |
| 977 | - start_offset: int, start offset of the text mask. |
| 978 | - end_offset: int, end offset of the text mask. |
| 979 | - gen_image: dict with keys: |
| 980 | - token_length: int, number of image tokens. |
| 981 | - add_timestep_token: bool, whether to add timestep token before the image tokens. |
| 982 | - add_guidance_token: bool, whether to add guidance token before the image tokens. |
| 983 | - use_front_boi_token: bool, whether to put the <boi> token at the front of size, ratio and timestep tokens. |
| 984 | - add_image_shape_token: bool, whether to add image shape token before the image tokens. |
| 985 | - base_size: int, base size of the image. |
| 986 | - ratio_idx: int, ratio index of the image. |
| 987 | - joint_image: dict with keys: |
| 988 | - token_length: List[int], number of image tokens for the two images. |
| 989 | - add_timestep_token: bool, whether to add timestep token before the image tokens. |
| 990 | - use_front_boi_token: bool, whether to put the <boi> token at the front of size, ratio and timestep tokens. |
| 991 | - add_image_shape_token: bool, whether to add image shape token before the image tokens. |
| 992 | - base_size: int, base size of the image. |
| 993 | - ratio_idx: int, ratio index of the image. |
| 994 | |
| 995 | Parameters |
| 996 | ---------- |
| 997 | sections: List[Dict[str, Any]] |
| 998 | List of sections to be encoded. |
| 999 | max_token_length: int |
| 1000 | Maximum length of the encoded token sequence. |
| 1001 | add_eos: bool or 'auto' |
| 1002 | Whether to add eos token at the end of the sequence. If True, always add eos |
| 1003 | token. If 'auto', add eos token only when the total_length is not reached and the last token is not <eos>. |
| 1004 | use_text_mask: bool |
| 1005 | Whether to generate text mask. |
| 1006 | add_pad: bool or 'auto' |
| 1007 | Whether to add padding tokens to the sequence. If True and total_length is not reached, |
| 1008 | add padding tokens. |
| 1009 | add_bos: bool |
| 1010 | Whether to add bos token at the beginning of the sequence. |
| 1011 | drop_last: bool or 'auto' |
| 1012 | - If auto, drop last tokens exceeding the total_length if the total_length is provided. |
| 1013 | If cut point is in the middle of the image tokens, an error will raised. |
| 1014 | - If True, drop last tokens exceeding the total_length. If cut point is in the |
| 1015 | middle of the image tokens, all the successive image tokens will be dropped. |
| 1016 | - If False, keep the last tokens exceeding the total_length, even if the total_length |
| 1017 | is reached. |
| 1018 | |
| 1019 | Returns |
| 1020 | ------- |
| 1021 | TokenizerEncodeOutput |
| 1022 | Encoded token sequence and extra information. |
| 1023 | """ |
| 1024 | if sections is None: |
| 1025 | raise ValueError("sections must be provided.") |
| 1026 | template = '-'.join([section['type'] for section in sections]) |
| 1027 | |
| 1028 | sections = deepcopy(sections) |
| 1029 | token_source = defaultdict(list) |
| 1030 | text_mask_specs = [] |
| 1031 | for section in sections: |
| 1032 | if section['type'] == 'text': |
| 1033 | text = self.encode_text( |
| 1034 | section['text'] if 'text' in section else section['tokens'], |
| 1035 | uncond_enabled=section.get('uncond_enabled'), |
| 1036 | uncond_p=section.get('uncond_p'), |
| 1037 | max_length=section.get('max_length'), |
| 1038 | ) |
| 1039 | token_source['text'].append(text) |
| 1040 | text_mask_specs.append(dict( |
| 1041 | ignore=section.get('ignore', False), |
| 1042 | start_offset=section.get('start_offset', 0), |
| 1043 | end_offset=section.get('end_offset', 0), |
| 1044 | )) |
| 1045 | elif section['type'] == 'gen_image': |
| 1046 | token_source['gen_image'].append(dict( |
| 1047 | length=section['token_length'], |
| 1048 | timestep=section.get('add_timestep_token', False), |
| 1049 | guidance=section.get('add_guidance_token', False), |
| 1050 | front_boi=section.get('use_front_boi_token', False), |
| 1051 | image_shape=section.get('add_image_shape_token', False), |
| 1052 | base_size=section.get('base_size'), |
| 1053 | ratio_idx=section.get('ratio_idx'), |
| 1054 | )) |
| 1055 | elif section['type'] == 'joint_image': |
| 1056 | token_source['joint_image'].append(dict( |
| 1057 | length=section['token_length'], |
| 1058 | timestep=section.get('add_timestep_token', False), |
| 1059 | front_boi=section.get('use_front_boi_token', False), |
| 1060 | image_shape=section.get('add_image_shape_token', False), |
| 1061 | base_size=section.get('base_size'), |
| 1062 | ratio_idx=section.get('ratio_idx'), |
| 1063 | )) |
| 1064 | else: |
| 1065 | raise ValueError(f"Invalid section type: {section['type']}") |
| 1066 | |
| 1067 | # Combine text and image tokens |
| 1068 | full_token_seq, extra_token_pos = self.encode_sequence( |
| 1069 | template=template, |
| 1070 | token_source=dict(token_source), |
| 1071 | total_length=max_token_length, |
| 1072 | add_eos=add_eos, |
| 1073 | add_pad=add_pad, |
| 1074 | add_bos=add_bos, |
| 1075 | drop_last=drop_last, |
| 1076 | ) |
| 1077 | full_seq_token_tensor = torch.tensor(full_token_seq, dtype=torch.long) |
| 1078 | |
| 1079 | timestep_scatter_index = torch.tensor(extra_token_pos['timestep'], dtype=torch.long) \ |
| 1080 | if 'timestep' in extra_token_pos else None |
| 1081 | guidance_scatter_index = torch.tensor(extra_token_pos['guidance'], dtype=torch.long) \ |
| 1082 | if 'guidance' in extra_token_pos else None |
| 1083 | cond_timestep_scatter_index = torch.tensor(extra_token_pos['cond_timestep'], dtype=torch.long) \ |
| 1084 | if 'cond_timestep' in extra_token_pos else None |
| 1085 | gen_timestep_scatter_index = torch.tensor(extra_token_pos['gen_timestep'], dtype=torch.long) \ |
| 1086 | if 'gen_timestep' in extra_token_pos else None |
| 1087 | |
| 1088 | # Gen image mask |
| 1089 | gen_image_slices, gen_image_mask = self.parse_extra_token_pos(extra_token_pos, 'img', full_seq_token_tensor) |
| 1090 | # Joint image |
| 1091 | joint_image_slices, _ = self.parse_extra_token_pos(extra_token_pos, 'joint_img', full_seq_token_tensor) |
| 1092 | # Conditional vae image |
| 1093 | cond_vae_image_slices, cond_vae_image_mask = self.parse_extra_token_pos( |
| 1094 | extra_token_pos, 'vae_img', full_seq_token_tensor) |
| 1095 | # Conditional vit image |
| 1096 | cond_vit_image_slices, cond_vit_image_mask = self.parse_extra_token_pos( |
| 1097 | extra_token_pos, 'vit_img', full_seq_token_tensor) |
| 1098 | # All image slices (gen_image, joint_image) |
| 1099 | all_image_slices = [ |
| 1100 | slice(start, end + 1) |
| 1101 | for start, end in zip(extra_token_pos['<all_img>_start'], extra_token_pos['<all_img>_end']) |
| 1102 | ] if '<all_img>_start' in extra_token_pos and '<all_img>_end' in extra_token_pos else [] |
| 1103 | |
| 1104 | # Text mask |
| 1105 | text_slices = [ |
| 1106 | slice(start, end + 1) |
| 1107 | for start, end in zip(extra_token_pos['<text>_start'], extra_token_pos['<text>_end']) |
| 1108 | ] if '<text>_start' in extra_token_pos and '<text>_end' in extra_token_pos else [] |
| 1109 | assert len(text_slices) <= len(text_mask_specs), \ |
| 1110 | (f"Number of text slices ({len(text_slices)}) should be less than or equal to " |
| 1111 | f"number of text mask specs ({len(text_mask_specs)})") |
| 1112 | if use_text_mask: |
| 1113 | text_mask = torch.zeros_like(full_seq_token_tensor, dtype=torch.float32) |
| 1114 | for text_slice, mask_spec in zip(text_slices, text_mask_specs): |
| 1115 | if not mask_spec['ignore']: |
| 1116 | real_slice = slice( |
| 1117 | text_slice.start + mask_spec['start_offset'], |
| 1118 | text_slice.stop + mask_spec['end_offset'] |
| 1119 | ) |
| 1120 | text_mask[real_slice] = 1.0 |
| 1121 | else: |
| 1122 | text_mask = None |
| 1123 | |
| 1124 | # real_pos is the first position of the <pad> token |
| 1125 | real_pos = torch.tensor(extra_token_pos.get('first_pad', [full_seq_token_tensor.shape[0]]), dtype=torch.long) |
| 1126 | |
| 1127 | return TokenizerEncodeOutput( |
| 1128 | tokens=full_seq_token_tensor, |
| 1129 | timestep_scatter_index=timestep_scatter_index, |
| 1130 | guidance_scatter_index=guidance_scatter_index, |
| 1131 | text_slices=text_slices, |
| 1132 | gen_image_slices=gen_image_slices, |
| 1133 | joint_image_slices=joint_image_slices, |
| 1134 | cond_vae_image_slices=cond_vae_image_slices, |
| 1135 | cond_vit_image_slices=cond_vit_image_slices, |
| 1136 | text_mask=text_mask, |
| 1137 | gen_image_mask=gen_image_mask, |
| 1138 | cond_vae_image_mask=cond_vae_image_mask, |
| 1139 | cond_vit_image_mask=cond_vit_image_mask, |
| 1140 | real_pos=real_pos, |
| 1141 | all_image_slices=all_image_slices, |
| 1142 | cond_timestep_scatter_index=cond_timestep_scatter_index, |
| 1143 | gen_timestep_scatter_index=gen_timestep_scatter_index, |
| 1144 | ) |
| 1145 | |
| 1146 | def get_cot_sections(self, cot_text, uncond_kwargs, cot_max_length=None, drop_think=False): |
| 1147 | if not cot_text: # None or empty |
| 1148 | return [] |
| 1149 | if '<think>' in cot_text and '</think>' in cot_text: |
| 1150 | before_think_sec = cot_text.split('<think>')[0] |
| 1151 | after_think_sec = cot_text.split('</think>')[1] |
| 1152 | think_sec = cot_text.split('<think>')[1].split('</think>')[0] |
| 1153 | return self.get_cot_sections(before_think_sec, uncond_kwargs, drop_think=drop_think) + \ |
| 1154 | ([ |
| 1155 | dict(type="text", text="<think>"), |
| 1156 | dict(type="text", text=think_sec, max_length=cot_max_length, **uncond_kwargs), |
| 1157 | dict(type="text", text="</think>") |
| 1158 | ] if not drop_think else []) + \ |
| 1159 | self.get_cot_sections(after_think_sec, uncond_kwargs, drop_think=drop_think) |
| 1160 | |
| 1161 | if '<recaption>' in cot_text and '</recaption>' in cot_text: |
| 1162 | before_recaption_sec = cot_text.split('<recaption>')[0] |
| 1163 | after_recaption_sec = cot_text.split('</recaption>')[1] |
| 1164 | recaption_sec = cot_text.split('<recaption>')[1].split('</recaption>')[0] |
| 1165 | return self.get_cot_sections(before_recaption_sec, uncond_kwargs, drop_think=drop_think) + \ |
| 1166 | [ |
| 1167 | dict(type="text", text="<recaption>"), |
| 1168 | dict(type="text", text=recaption_sec, max_length=cot_max_length, **uncond_kwargs), |
| 1169 | dict(type="text", text="</recaption>") |
| 1170 | ] + \ |
| 1171 | self.get_cot_sections(after_recaption_sec, uncond_kwargs, drop_think=drop_think) |
| 1172 | |
| 1173 | return [ |
| 1174 | dict(type="text", text=cot_text, **uncond_kwargs), |
| 1175 | ] |
| 1176 | |
| 1177 | def apply_general_template( |
| 1178 | self, |
| 1179 | message_list, |
| 1180 | max_length=None, |
| 1181 | add_assistant_prefix=False, |
| 1182 | answer="auto", |
| 1183 | bot_task="auto", |
| 1184 | sequence_template="instruct", |
| 1185 | uncond_p=0.0, |
| 1186 | cfg_factor=1, |
| 1187 | batchify=False, |
| 1188 | image_base_size=1024, |
| 1189 | drop_think=False, |
| 1190 | ): |
| 1191 | # If cfg_factor > 1, we need to repeat the unconditioned part |
| 1192 | if batchify: |
| 1193 | assert isinstance(message_list[0], list), \ |
| 1194 | f"When batchify is True, message_list should be a list of list, but got [{type(message_list[0])}, ...]." |
| 1195 | return self.batch_gen_infer( |
| 1196 | infer_fn=self.apply_general_template, |
| 1197 | prompt_list=[[]], |
| 1198 | infer_fn_kwargs_list=[dict( |
| 1199 | message_list=message_list_i, |
| 1200 | max_length=max_length, |
| 1201 | add_assistant_prefix=add_assistant_prefix, |
| 1202 | answer=answer, |
| 1203 | bot_task=bot_task, |
| 1204 | sequence_template=sequence_template, |
| 1205 | image_base_size=image_base_size, |
| 1206 | drop_think=drop_think, |
| 1207 | ) for message_list_i in message_list], |
| 1208 | do_classifier_free_guidance=cfg_factor > 1, |
| 1209 | condition_repeat_times=1, |
| 1210 | uncondition_repeat_times=cfg_factor - 1, |
| 1211 | ) |
| 1212 | |
| 1213 | conv = Conversation() |
| 1214 | uncond_kwargs = dict(uncond_enabled=uncond_p == 1.0, uncond_p=uncond_p) |
| 1215 | |
| 1216 | def process_successive_message(_message_list, _cur_message_idx, role, prefix, suffix, |
| 1217 | answer_prefix="", answer_suffix=""): |
| 1218 | _sub_sections = [] |
| 1219 | while _cur_message_idx < len(message_list) and _message_list[_cur_message_idx]['role'] == role: |
| 1220 | message = _message_list[_cur_message_idx] |
| 1221 | if message['type'] == 'text': |
| 1222 | text = message['content'] |
| 1223 | if role == "system": |
| 1224 | _sub_sections.append(dict(type="text", text=text)) |
| 1225 | elif role == "assistant": |
| 1226 | if ("<recaption>" in text and "</recaption>" in text) or ( |
| 1227 | "<think>" in text and "</think>" in text): |
| 1228 | _sub_sections.extend(self.get_cot_sections(text, uncond_kwargs, drop_think=drop_think)) |
| 1229 | else: |
| 1230 | _sub_sections.append(dict(type="text", text=text, **uncond_kwargs)) |
| 1231 | else: |
| 1232 | _sub_sections.append(dict( |
| 1233 | type="text", text=f"{answer_prefix}{text}{answer_suffix}", **uncond_kwargs)) |
| 1234 | elif message['type'] == 'gen_image': |
| 1235 | info = message['content'] |
| 1236 | assert isinstance(info, ImageInfo), f"Expected ImageInfo, but got {type(info)}" |
| 1237 | if role == "assistant": |
| 1238 | _sub_sections.append(dict(type="text", text=answer_prefix)) |
| 1239 | _sub_sections.append(dict(type=message['type'], **info.meta_info)) |
| 1240 | if role == "assistant": |
| 1241 | _sub_sections.append(dict(type="text", text=answer_suffix)) |
| 1242 | elif message['type'] == 'joint_image': |
| 1243 | info = message['content'] |
| 1244 | assert isinstance(info, JointImageInfo), f"Expected JointImageInfo, but got {type(info)}" |
| 1245 | _sub_sections.append(dict(type=message['type'], **info.meta_info)) |
| 1246 | else: |
| 1247 | raise ValueError(f"Unknown message type: {message['type']}") |
| 1248 | _cur_message_idx += 1 |
| 1249 | if len(_sub_sections) > 0: |
| 1250 | # Add role prefix and suffix |
| 1251 | _sub_sections.insert(0, dict(type='text', text=prefix)) |
| 1252 | _sub_sections.append(dict(type='text', text=suffix)) |
| 1253 | return _sub_sections, _cur_message_idx |
| 1254 | |
| 1255 | # Define assistant prefix and suffix |
| 1256 | if (answer == "auto" and sequence_template == "instruct") or answer is True: |
| 1257 | answer_prefix, answer_suffix = "<answer>", "</answer>" |
| 1258 | else: |
| 1259 | answer_prefix, answer_suffix = "", "" |
| 1260 | if sequence_template == "pretrain": |
| 1261 | system_suffix = "" |
| 1262 | user_prefix = "" |
| 1263 | user_suffix = "" |
| 1264 | bot_prefix = "" |
| 1265 | bot_suffix = "" |
| 1266 | else: |
| 1267 | system_suffix = f"{conv.sep}" |
| 1268 | user_prefix = f"{conv.roles[0]}: " |
| 1269 | user_suffix = f"{conv.sep}" |
| 1270 | bot_prefix = f"{conv.roles[1]}: " |
| 1271 | bot_suffix = f"{conv.sep}" |
| 1272 | |
| 1273 | # Process successive user and assistant messages |
| 1274 | sections = [] |
| 1275 | cur_message_idx = 0 |
| 1276 | final_role = None |
| 1277 | while cur_message_idx < len(message_list): |
| 1278 | # Process successive system messages |
| 1279 | sub_sections, cur_message_idx = process_successive_message( |
| 1280 | message_list, cur_message_idx, role="system", prefix="", suffix=system_suffix) |
| 1281 | # Add to the template and sections |
| 1282 | sections.extend(sub_sections) |
| 1283 | if len(sub_sections) > 0: |
| 1284 | final_role = "system" |
| 1285 | |
| 1286 | # Process successive user messages |
| 1287 | sub_sections, cur_message_idx = process_successive_message( |
| 1288 | message_list, cur_message_idx, role="user", prefix=user_prefix, suffix=user_suffix) |
| 1289 | # Add to the template and sections |
| 1290 | sections.extend(sub_sections) |
| 1291 | if len(sub_sections) > 0: |
| 1292 | final_role = "user" |
| 1293 | |
| 1294 | # Process successive assistant messages |
| 1295 | sub_sections, cur_message_idx = process_successive_message( |
| 1296 | message_list, cur_message_idx, role="assistant", prefix=bot_prefix, suffix=bot_suffix, |
| 1297 | answer_prefix=answer_prefix, answer_suffix=answer_suffix, |
| 1298 | ) |
| 1299 | # Add to the template and sections |
| 1300 | sections.extend(sub_sections) |
| 1301 | if len(sub_sections) > 0: |
| 1302 | final_role = "assistant" |
| 1303 | |
| 1304 | if add_assistant_prefix: |
| 1305 | if final_role == "assistant": |
| 1306 | # Avoid adding prefix twice |
| 1307 | _bot_prefix = "" |
| 1308 | # Remove the final bot_suffix |
| 1309 | if len(sections) > 0 and sections[-1]['type'] == 'text' and sections[-1]['text'] == bot_suffix: |
| 1310 | sections = sections[:-1] |
| 1311 | else: |
| 1312 | _bot_prefix = bot_prefix |
| 1313 | # We can add special tokens for the bot lastest message according to different tasks |
| 1314 | bot_response_prefix = dict( |
| 1315 | auto=_bot_prefix, |
| 1316 | image="", |
| 1317 | think=f"{_bot_prefix}<think>", |
| 1318 | recaption=f"{_bot_prefix}<recaption>", |
| 1319 | img_ratio=f"{_bot_prefix}{answer_prefix}<boi><img_size_{image_base_size}>", |
| 1320 | )[bot_task] |
| 1321 | sections.append(dict(type='text', text=bot_response_prefix)) |
| 1322 | |
| 1323 | output = self.encode_general( |
| 1324 | sections=sections, |
| 1325 | use_text_mask=False, |
| 1326 | add_eos=False, |
| 1327 | add_pad=False, |
| 1328 | ) |
| 1329 | |
| 1330 | if max_length is not None: |
| 1331 | if output.tokens.shape[-1] > max_length: |
| 1332 | raise ValueError( |
| 1333 | f"Encoded token length {output.tokens.shape[-1]} exceeds max_length {max_length}.\n" |
| 1334 | f"Please set a larger max_length or check the input messages:\n{message_list}" |
| 1335 | ) |
| 1336 | |
| 1337 | return output, sections |
| 1338 | |
| 1339 | def apply_chat_template( |
| 1340 | self, |
| 1341 | batch_prompt: Optional[List[str]] = None, |
| 1342 | batch_message_list: Optional[List[List[Dict[str, Any]]]] = None, |
| 1343 | mode: str = "gen_text", |
| 1344 | batch_gen_image_info: Optional[List[ImageInfo]] = None, |
| 1345 | batch_cond_image_info: Optional[Union[List[JointImageInfo], List[List[JointImageInfo]]]] = None, |
| 1346 | batch_system_prompt: Optional[List[str]] = None, |
| 1347 | batch_cot_text: Optional[List[str]] = None, |
| 1348 | max_length: Optional[int] = None, |
| 1349 | bot_task: str = "auto", # auto/image/think/recaption/img_ratio |
| 1350 | image_base_size: int = 1024, |
| 1351 | sequence_template: str = "pretrain", |
| 1352 | cfg_factor: int = 1, |
| 1353 | add_assistant_prefix: Optional[bool] = None, |
| 1354 | drop_think: bool = False, |
| 1355 | ) -> Dict[str, Any]: |
| 1356 | assert bot_task in ["image", "auto", "think", "recaption", "img_ratio"], \ |
| 1357 | f"bot_task should be one of ['image', 'auto', 'think', 'recaption', 'img_ratio'], but got {bot_task}." |
| 1358 | |
| 1359 | if batch_message_list is None: |
| 1360 | # Simple text-to-image or text-cot-to-image task |
| 1361 | batch_size = len(batch_prompt) |
| 1362 | |
| 1363 | # Batchify inputs |
| 1364 | if not isinstance(batch_system_prompt, list): |
| 1365 | batch_system_prompt = [batch_system_prompt] * batch_size |
| 1366 | if not isinstance(batch_gen_image_info, list): |
| 1367 | batch_gen_image_info = [batch_gen_image_info] * batch_size |
| 1368 | if batch_cot_text is not None: |
| 1369 | assert len(batch_cot_text) == batch_size, \ |
| 1370 | (f"batch_cot_text should have the same length as batch_size ({batch_size}), " |
| 1371 | f"but got {len(batch_cot_text)}.") |
| 1372 | else: |
| 1373 | batch_cot_text = [None] * batch_size |
| 1374 | if batch_cond_image_info is not None: |
| 1375 | assert len(batch_cond_image_info) == batch_size, \ |
| 1376 | (f"batch_cond_image_info should have the same length as batch_size ({batch_size}), " |
| 1377 | f"but got {len(batch_cond_image_info)}.") |
| 1378 | batch_cond_image_info = [ |
| 1379 | cond_image_info if isinstance(cond_image_info, list) else [cond_image_info] |
| 1380 | for cond_image_info in batch_cond_image_info |
| 1381 | ] |
| 1382 | else: |
| 1383 | batch_cond_image_info = [[] for _ in range(batch_size)] |
| 1384 | |
| 1385 | # Convert single round materials into standard message list |
| 1386 | batch_message_list = [] |
| 1387 | for prompt, system_prompt, cot_text, gen_image_info, cond_image_info_list in zip( |
| 1388 | batch_prompt, batch_system_prompt, batch_cot_text, batch_gen_image_info, |
| 1389 | batch_cond_image_info, |
| 1390 | ): |
| 1391 | message_list = [] |
| 1392 | # 1. system prompt section |
| 1393 | if system_prompt: |
| 1394 | message_list.append(dict( |
| 1395 | role="system", type="text", content=system_prompt, context_type="str")) |
| 1396 | # 2. user inputs sections |
| 1397 | # 2.1 image inputs |
| 1398 | if len(cond_image_info_list) > 0: |
| 1399 | message_list.extend([ |
| 1400 | dict(role="user", type="joint_image", content=cond_image_info, context_type="image_info") |
| 1401 | for cond_image_info in cond_image_info_list |
| 1402 | ]) |
| 1403 | # 2.2 text inputs |
| 1404 | message_list.append(dict( |
| 1405 | role="user", type="text", content=prompt, context_type="str")) |
| 1406 | # 3. assistant answer sections |
| 1407 | if cot_text is not None: |
| 1408 | message_list.append(dict(role="assistant", type="text", content=cot_text, context_type="str")) |
| 1409 | if mode == "gen_image": |
| 1410 | message_list.append(dict( |
| 1411 | role="assistant", type="gen_image", content=gen_image_info, context_type="image_info")) |
| 1412 | # --- |
| 1413 | batch_message_list.append(message_list) |
| 1414 | |
| 1415 | output, sections = self.apply_general_template( |
| 1416 | message_list=batch_message_list, |
| 1417 | max_length=max_length, |
| 1418 | add_assistant_prefix=default(add_assistant_prefix, mode != "gen_image"), |
| 1419 | bot_task=bot_task, |
| 1420 | sequence_template=sequence_template, |
| 1421 | cfg_factor=cfg_factor, |
| 1422 | batchify=True, |
| 1423 | image_base_size=image_base_size, |
| 1424 | drop_think=drop_think, |
| 1425 | ) |
| 1426 | return dict(output=output, sections=sections) |
| 1427 | |