hunyuan.py
114.3 KB · 2660 lines · python Raw
1 # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
2 # you may not use this file except in compliance with the License.
3 # You may obtain a copy of the License at
4 #
5 # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
6 #
7 # Unless required by applicable law or agreed to in writing, software
8 # distributed under the License is distributed on an "AS IS" BASIS,
9 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 # See the License for the specific language governing permissions and
11 # limitations under the License.
12 # ==============================================================================
13
14 import math
15 import random
16 import re
17 import warnings
18 from dataclasses import dataclass
19 from typing import TYPE_CHECKING, List, Union, Optional, Dict, Any, Tuple, Callable
20
21 import torch
22 import torch.nn.functional as F
23 import torch.utils.checkpoint
24 from einops import rearrange
25 from torch import Tensor
26 from torch import nn
27 from torch.cuda import nvtx
28 from transformers.activations import ACT2FN
29 from transformers.cache_utils import Cache, StaticCache
30 from transformers.generation.logits_process import LogitsProcessorList
31 from transformers.generation.stopping_criteria import StoppingCriteriaList
32 from transformers.generation.utils import GenerationMixin, GenerationConfig, ALL_CACHE_NAMES
33 from transformers.modeling_outputs import (
34 BaseModelOutputWithPast,
35 CausalLMOutputWithPast,
36 )
37 from transformers.modeling_utils import PreTrainedModel
38 from transformers.utils import (
39 ModelOutput,
40 add_start_docstrings,
41 add_start_docstrings_to_model_forward,
42 is_flash_attn_2_available,
43 logging,
44 )
45
46 if TYPE_CHECKING:
47 from transformers.generation.streamers import BaseStreamer
48
49 try:
50 import flashinfer
51 except Exception as e:
52 flashinfer = None
53
54 from .autoencoder_kl_3d import AutoencoderKLConv3D
55 from .configuration_hunyuan import HunyuanImage3Config
56 from .hunyuan_image_3_pipeline import HunyuanImage3Text2ImagePipeline, FlowMatchDiscreteScheduler
57 from .image_processor import HunyuanImage3ImageProcessor
58 from .siglip2 import Siglip2VisionTransformer, LightProjector
59 from .tokenizer_wrapper import TokenizerWrapper, ImageInfo, JointImageInfo
60 from .system_prompt import get_system_prompt, t2i_system_prompts
61
62
63 logger = logging.get_logger(__name__)
64
65
66 if is_flash_attn_2_available():
67 from flash_attn import flash_attn_func
68
69 # Type aliases
70 BatchRaggedImages = Union[torch.Tensor, List[Union[torch.Tensor, List[torch.Tensor]]]]
71 BatchRaggedTensor = Union[torch.Tensor, List[torch.Tensor]]
72
73
74 _CONFIG_FOR_DOC = "HunyuanImage3Config"
75
76 Hunyuan_START_DOCSTRING = r"""
77 This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
78 library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
79 etc.)
80
81 This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
82 Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
83 and behavior.
84
85 Parameters:
86 config ([`HunyuanImage3Config`]):
87 Model configuration class with all the parameters of the model. Initializing with a config file does not
88 load the weights associated with the model, only the configuration. Check out the
89 [`~PreTrainedModel.from_pretrained`] method to load the model weights.
90 """
91
92 # =======================================================
93 # Helper Functions
94 # =======================================================
95
96 def default(val, d):
97 return val if val is not None else d
98
99
100 def to_device(data, device):
101 if device is None:
102 return data
103 if isinstance(data, torch.Tensor):
104 return data.to(device)
105 elif isinstance(data, list):
106 return [to_device(x, device) for x in data]
107 else:
108 return data
109
110
111 def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
112 """
113 This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
114 num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
115 """
116 batch, num_key_value_heads, slen, head_dim = hidden_states.shape
117 if n_rep == 1:
118 return hidden_states
119 hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
120 return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
121
122
123 def real_batched_index_select(t, dim, idx):
124 """ index_select for batched index and batched t """
125 assert t.ndim >= 2 and idx.ndim >= 2, f"{t.ndim=} {idx.ndim=}"
126 assert len(t) == len(idx), f"{len(t)=} != {len(idx)=}"
127 return torch.stack([torch.index_select(t[i], dim - 1, idx[i]) for i in range(len(t))])
128
129
130 # =======================================================
131 # Module Functions
132 # =======================================================
133
134 def timestep_embedding(t, dim, max_period=10000):
135 """
136 Create sinusoidal timestep embeddings.
137
138 Args:
139 t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
140 dim (int): the dimension of the output.
141 max_period (int): controls the minimum frequency of the embeddings.
142
143 Returns:
144 embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
145
146 .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
147 """
148 half = dim // 2
149 freqs = torch.exp(
150 -math.log(max_period)
151 * torch.arange(start=0, end=half, dtype=torch.float32)
152 / half
153 ).to(device=t.device)
154 args = t[:, None].float() * freqs[None]
155 embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
156 if dim % 2:
157 embedding = torch.cat(
158 [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
159 )
160 return embedding
161
162
163 def conv_nd(dims, *args, **kwargs):
164 """
165 Create a 1D, 2D, or 3D convolution module.
166 """
167 if dims == 1:
168 return nn.Conv1d(*args, **kwargs)
169 elif dims == 2:
170 return nn.Conv2d(*args, **kwargs)
171 elif dims == 3:
172 return nn.Conv3d(*args, **kwargs)
173 raise ValueError(f"unsupported dimensions: {dims}")
174
175
176 def linear(*args, **kwargs):
177 """
178 Create a linear module.
179 """
180 return nn.Linear(*args, **kwargs)
181
182
183 def avg_pool_nd(dims, *args, **kwargs):
184 """
185 Create a 1D, 2D, or 3D average pooling module.
186 """
187 if dims == 1:
188 return nn.AvgPool1d(*args, **kwargs)
189 elif dims == 2:
190 return nn.AvgPool2d(*args, **kwargs)
191 elif dims == 3:
192 return nn.AvgPool3d(*args, **kwargs)
193 raise ValueError(f"unsupported dimensions: {dims}")
194
195
196 def zero_module(module):
197 """
198 Zero out the parameters of a module and return it.
199 """
200 for p in module.parameters():
201 p.detach().zero_()
202 return module
203
204
205 def normalization(channels, **kwargs):
206 """
207 Make a standard normalization layer.
208
209 :param channels: number of input channels.
210 :return: a nn.Module for normalization.
211 """
212 return nn.GroupNorm(32, channels, **kwargs)
213
214
215 def topkgating(
216 logits: Tensor,
217 topk: int,
218 group_limited_greedy: bool = False,
219 n_group: int = None,
220 topk_group: int = None,
221 norm_topk_prob: bool = True,
222 routed_scaling_factor: float = 1.0,
223 capacity_factor: float = 1.0,
224 drop_tokens: bool = False,
225 ):
226 logits = logits.float()
227 gates = F.softmax(logits, dim=1)
228
229 if group_limited_greedy:
230 group_shape = list(gates.shape[:-1]) + [n_group, gates.shape[-1] // n_group]
231 group_scores = (
232 gates.reshape(group_shape).max(dim=-1).values
233 ) # [n, n_group]
234 group_idx = torch.topk(
235 group_scores, topk_group, dim=-1, sorted=False
236 )[
237 1
238 ] # [n, top_k_group]
239 group_mask = torch.zeros_like(group_scores) # [n, n_group]
240 group_mask.scatter_(1, group_idx, 1) # [n, n_group]
241 score_mask = (
242 group_mask.unsqueeze(-1)
243 .expand(
244 group_shape
245 )
246 .reshape(list(gates.shape))
247 ) # [n, e]
248 gates = gates.masked_fill(~score_mask.bool(), 0.0)
249
250 num_experts = int(gates.shape[1])
251 # Top-k router probability and corresponding expert indices for each token.
252 # Shape: [tokens_per_group, num_selected_experts].
253 expert_gate, expert_index = torch.topk(gates, topk)
254 expert_mask = F.one_hot(expert_index, num_experts)
255 # For a given token, determine if it was routed to a given expert.
256 # Shape: [tokens_per_group, num_experts]
257 expert_mask_aux = expert_mask.max(dim=-2)[0]
258 tokens_per_group_and_expert = torch.mean(expert_mask_aux.float(), dim=-2)
259 router_prob_per_group_and_expert = torch.mean(gates.float(), dim=-2)
260 l_aux = num_experts ** 2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)
261
262 if drop_tokens:
263 expert_capacity = int(max(topk, topk * gates.shape[0] // gates.shape[1]) * capacity_factor)
264 else:
265 expert_index_flat = expert_index.flatten()
266 tokens_per_expert = torch.bincount(expert_index_flat, minlength=num_experts)
267 expert_capacity = torch.max(tokens_per_expert).item()
268
269 if norm_topk_prob and topk > 1:
270 gates_s = torch.clamp(
271 torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps
272 )
273 router_probs = gates / gates_s
274 else:
275 router_probs = gates * routed_scaling_factor
276 # Make num_selected_experts the leading axis to ensure that top-1 choices
277 # have priority over top-2 choices, which have priority over top-3 choices,
278 # etc.
279 expert_index = torch.transpose(expert_index, 0, 1)
280 # Shape: [num_selected_experts * tokens_per_group]
281 expert_index = expert_index.reshape(-1)
282
283 # Create mask out of indices.
284 # Shape: [tokens_per_group * num_selected_experts, num_experts].
285 expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
286 exp_counts = torch.sum(expert_mask, dim=0).detach()
287
288 # Experts have a fixed capacity that we cannot exceed. A token's priority
289 # within the expert's buffer is given by the masked, cumulative capacity of
290 # its target expert.
291 # Shape: [tokens_per_group * num_selected_experts, num_experts].
292 token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1
293 # Shape: [num_selected_experts, tokens_per_group, num_experts].
294 token_priority = token_priority.reshape((topk, -1, num_experts))
295 # Shape: [tokens_per_group, num_selected_experts, num_experts].
296 token_priority = torch.transpose(token_priority, 0, 1)
297 # For each token, across all selected experts, select the only non-negative
298 # (unmasked) priority. Now, for group G routing to expert E, token T has
299 # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
300 # is its targeted expert.
301 # Shape: [tokens_per_group, num_experts].
302 token_priority = torch.max(token_priority, dim=1)[0]
303
304 # Token T can only be routed to expert E if its priority is positive and
305 # less than the expert capacity. One-hot matrix will ignore indices outside
306 # the range [0, expert_capacity).
307 # Shape: [tokens_per_group, num_experts, expert_capacity].
308 valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
309 token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
310 dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
311 valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity)
312 dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
313
314 # The combine array will be used for combining expert outputs, scaled by the
315 # router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
316 # expert_capacity].
317 combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
318 exp_counts_capacity = torch.sum(dispatch_mask)
319 exp_capacity_rate = exp_counts_capacity / (logits.shape[0] * topk)
320
321 return [l_aux, exp_capacity_rate], combine_weights, dispatch_mask, exp_counts
322
323
324 # =======================================================
325 # Multi-Dimensional RoPE
326 # =======================================================
327
328 def _to_tuple(x, dim=2):
329 if isinstance(x, int):
330 return (x,) * dim
331 elif len(x) == dim:
332 return x
333 else:
334 raise ValueError(f"Expected length {dim} or int, but got {x}")
335
336
337 def get_meshgrid_nd(start, *args, dim=2):
338 """
339 Get n-D meshgrid with start, stop and num.
340
341 Args:
342 start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
343 step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
344 should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
345 n-tuples.
346 *args: See above.
347 dim (int): Dimension of the meshgrid. Defaults to 2.
348
349 Returns:
350 grid (np.ndarray): [dim, ...]
351 """
352 if len(args) == 0:
353 # start is grid_size
354 num = _to_tuple(start, dim=dim)
355 start = (0,) * dim
356 stop = num
357 elif len(args) == 1:
358 # start is start, args[0] is stop, step is 1
359 start = _to_tuple(start, dim=dim)
360 stop = _to_tuple(args[0], dim=dim)
361 num = [stop[i] - start[i] for i in range(dim)]
362 # assert num are all integers
363 num_int = [int(x) for x in num]
364 assert (torch.tensor(num) == torch.tensor(num_int)).all(), f"num should be int, but got {num}"
365 num = num_int
366 elif len(args) == 2:
367 # start is start, args[0] is stop, args[1] is num
368 start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
369 stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
370 num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
371 else:
372 raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
373
374 # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
375 axis_grid = []
376 for i in range(dim):
377 a, b, n = start[i], stop[i], num[i]
378 g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
379 axis_grid.append(g)
380 grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [H, W]
381 grid = torch.stack(grid, dim=0) # [dim, H, W]
382
383 return grid
384
385
386 def build_2d_rope(
387 seq_len: int, n_elem: int, image_infos: Optional[List[Tuple[slice, Tuple[int, int]]]] = None,
388 device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0,
389 return_all_pos: bool = False,
390 ):
391 """
392 Reference: https://kexue.fm/archives/10352
393
394 Start from 1, we have
395 beta_y = L + (wh - h)/2
396 beta_x = L + (wh - w)/2
397
398 Returns
399 -------
400 cos: torch.Tensor with shape of [seq_len, n_elem]
401 sin: torch.Tensor with shape of [seq_len, n_elem]
402 """
403 assert n_elem % 4 == 0, f"n_elem must be divisible by 4, but got {n_elem}."
404
405 # theta
406 if base_rescale_factor != 1.0:
407 base *= base_rescale_factor ** (n_elem / (n_elem - 2))
408 theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
409 theta = theta.reshape(1, n_elem // 4, 2) # [1, half_d, 2]
410
411 # position indices
412 if image_infos is None:
413 image_infos = []
414
415 image_infos_list = [image_infos]
416 sample_seq_lens = [seq_len]
417
418 # Prepare position indices for each sample
419 x_sections = []
420 y_sections = []
421 for sample_id, sample_image_infos in enumerate(image_infos_list):
422 last_pos = 0
423 for sec_slice, (h, w) in sample_image_infos:
424 L = sec_slice.start # start from 0, so image_slice.start is just L
425 # previous text
426 if last_pos < L:
427 y_sections.append(torch.arange(last_pos, L))
428 x_sections.append(torch.arange(last_pos, L))
429 elif h is None:
430 # Interleave data has overlapped positions for <boi> <size> <ratio> <timestep> <eoi> tokens.
431 y_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
432 x_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
433 continue
434 else:
435 # Interleave data has overlapped positions for noised image and the successive clean image,
436 # leading to last_pos (= last text end L + noise w * h) > L (last text end L).
437 pass
438 # current image
439 beta_y = L + (w * h - h) / 2
440 beta_x = L + (w * h - w) / 2
441 grid = get_meshgrid_nd((beta_y, beta_x), (beta_y + h, beta_x + w)) # [2, h, w]
442 grid = grid.reshape(2, -1) # (y, x)
443 y_sections.append(grid[0])
444 x_sections.append(grid[1])
445 # step
446 last_pos = L + w * h
447 # final text
448 y_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id]))
449 x_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id]))
450
451 x_pos = torch.cat(x_sections).long()
452 y_pos = torch.cat(y_sections).long()
453 # If there are overlap positions, we need to remove them.
454 x_pos = x_pos[:seq_len]
455 y_pos = y_pos[:seq_len]
456 all_pos = torch.stack((y_pos, x_pos), dim=1).unsqueeze(1).to(device) # [seq_len, 1, 2]
457
458 # calc rope
459 idx_theta = (all_pos * theta).reshape(all_pos.shape[0], n_elem // 2).repeat(1, 2)
460
461 cos = torch.cos(idx_theta)
462 sin = torch.sin(idx_theta)
463
464 if return_all_pos:
465 return cos, sin, all_pos
466
467 return cos, sin
468
469
470 def build_batch_2d_rope(
471 seq_len: int, n_elem: int, image_infos: Optional[List[List[Tuple[slice, Tuple[int, int]]]]] = None,
472 device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0,
473 return_all_pos: bool = False,
474 ):
475 cos_list, sin_list, all_pos_list = [], [], []
476 if image_infos is None:
477 image_infos = [None]
478 for i, image_info in enumerate(image_infos):
479 res = build_2d_rope(
480 seq_len, n_elem, image_infos=image_info, device=device,
481 base=base, base_rescale_factor=base_rescale_factor,
482 return_all_pos=return_all_pos,
483 )
484 if return_all_pos:
485 cos, sin, all_pos = res
486 else:
487 cos, sin = res
488 all_pos = None
489 cos_list.append(cos)
490 sin_list.append(sin)
491 all_pos_list.append(all_pos)
492
493 stacked_cos = torch.stack(cos_list, dim=0)
494 stacked_sin = torch.stack(sin_list, dim=0)
495
496 if return_all_pos:
497 return stacked_cos, stacked_sin, all_pos_list
498
499 return stacked_cos, stacked_sin
500
501
502 def rotate_half(x):
503 """Rotates half the hidden dims of the input."""
504 x1 = x[..., : x.shape[-1] // 2]
505 x2 = x[..., x.shape[-1] // 2:]
506 return torch.cat((-x2, x1), dim=-1)
507
508
509 def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
510 """Applies Rotary Position Embedding to the query and key tensors.
511
512 Args:
513 q (`torch.Tensor`): The query tensor.
514 k (`torch.Tensor`): The key tensor.
515 cos (`torch.Tensor`): The cosine part of the rotary embedding.
516 sin (`torch.Tensor`): The sine part of the rotary embedding.
517 position_ids (`torch.Tensor`):
518 The position indices of the tokens corresponding to the query and key tensors. For example, this can be
519 used to pass shifted position ids when working with a KV-cache.
520 unsqueeze_dim (`int`, *optional*, defaults to 1):
521 The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
522 sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
523 that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
524 k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
525 cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
526 the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
527 Returns:
528 `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
529 """
530 if position_ids is not None:
531 cos = cos[position_ids]
532 sin = sin[position_ids]
533
534 cos = cos.unsqueeze(unsqueeze_dim)
535 sin = sin.unsqueeze(unsqueeze_dim)
536
537 q_embed = (q * cos) + (rotate_half(q) * sin)
538 k_embed = (k * cos) + (rotate_half(k) * sin)
539 return q_embed, k_embed
540
541
542 # =======================================================
543 # Modules for Image Generation
544 # =======================================================
545
546 class TimestepEmbedder(nn.Module):
547 """
548 Embeds scalar timesteps into vector representations.
549 """
550 def __init__(self,
551 hidden_size,
552 act_layer=nn.GELU,
553 frequency_embedding_size=256,
554 max_period=10000,
555 out_size=None,
556 dtype=None,
557 device=None
558 ):
559 factory_kwargs = {'dtype': dtype, 'device': device}
560 super().__init__()
561 self.frequency_embedding_size = frequency_embedding_size
562 self.max_period = max_period
563 if out_size is None:
564 out_size = hidden_size
565
566 self.mlp = nn.Sequential(
567 nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
568 act_layer(),
569 nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
570 )
571 nn.init.normal_(self.mlp[0].weight, std=0.02)
572 nn.init.normal_(self.mlp[2].weight, std=0.02)
573
574 def forward(self, t):
575 t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
576 t_emb = self.mlp(t_freq)
577 return t_emb
578
579
580 class Upsample(nn.Module):
581 """
582 An upsampling layer with an optional convolution.
583
584 :param channels: channels in the inputs and outputs.
585 :param use_conv: a bool determining if a convolution is applied.
586 :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
587 upsampling occurs in the inner-two dimensions.
588 """
589
590 def __init__(self, channels, use_conv, dims=2, out_channels=None, device=None, dtype=None):
591 factory_kwargs = {'device': device, 'dtype': dtype}
592 super().__init__()
593 self.channels = channels
594 self.out_channels = out_channels or channels
595 self.use_conv = use_conv
596 self.dims = dims
597 if use_conv:
598 self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1, **factory_kwargs)
599
600 def forward(self, x):
601 assert x.shape[1] == self.channels
602 if self.dims == 3:
603 x = F.interpolate(
604 x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
605 )
606 else:
607 x = F.interpolate(x, scale_factor=2, mode="nearest")
608 if self.use_conv:
609 x = self.conv(x)
610 return x
611
612
613 class Downsample(nn.Module):
614 """
615 A downsampling layer with an optional convolution.
616
617 :param channels: channels in the inputs and outputs.
618 :param use_conv: a bool determining if a convolution is applied.
619 :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
620 downsampling occurs in the inner-two dimensions.
621 """
622
623 def __init__(self, channels, use_conv, dims=2, out_channels=None, device=None, dtype=None):
624 factory_kwargs = {'device': device, 'dtype': dtype}
625 super().__init__()
626 self.channels = channels
627 self.out_channels = out_channels or channels
628 self.use_conv = use_conv
629 self.dims = dims
630 stride = 2 if dims != 3 else (1, 2, 2)
631 if use_conv:
632 self.op = conv_nd(
633 dims, self.channels, self.out_channels, 3, stride=stride, padding=1, **factory_kwargs
634 )
635 else:
636 assert self.channels == self.out_channels
637 self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
638
639 def forward(self, x):
640 assert x.shape[1] == self.channels
641 return self.op(x)
642
643
644 class ResBlock(nn.Module):
645 """
646 A residual block that can optionally change the number of channels.
647
648 :param in_channels: the number of input channels.
649 :param emb_channels: the number of timestep embedding channels.
650 :param dropout: the rate of dropout.
651 :param out_channels: if specified, the number of out channels.
652 :param use_conv: if True and out_channels is specified, use a spatial
653 convolution instead of a smaller 1x1 convolution to change the
654 channels in the skip connection.
655 :param dims: determines if the signal is 1D, 2D, or 3D.
656 :param up: if True, use this block for upsampling.
657 :param down: if True, use this block for downsampling.
658 """
659
660 def __init__(
661 self,
662 in_channels,
663 emb_channels,
664 out_channels=None,
665 dropout=0.0,
666 use_conv=False,
667 dims=2,
668 up=False,
669 down=False,
670 device=None,
671 dtype=None,
672 ):
673 factory_kwargs = {'dtype': dtype, 'device': device}
674 super().__init__()
675 self.in_channels = in_channels
676 self.dropout = dropout
677 self.out_channels = out_channels or self.in_channels
678 self.use_conv = use_conv
679
680 self.in_layers = nn.Sequential(
681 normalization(self.in_channels, **factory_kwargs),
682 nn.SiLU(),
683 conv_nd(dims, self.in_channels, self.out_channels, 3, padding=1, **factory_kwargs),
684 )
685
686 self.updown = up or down
687
688 if up:
689 self.h_upd = Upsample(self.in_channels, False, dims, **factory_kwargs)
690 self.x_upd = Upsample(self.in_channels, False, dims, **factory_kwargs)
691 elif down:
692 self.h_upd = Downsample(self.in_channels, False, dims, **factory_kwargs)
693 self.x_upd = Downsample(self.in_channels, False, dims, **factory_kwargs)
694 else:
695 self.h_upd = self.x_upd = nn.Identity()
696
697 self.emb_layers = nn.Sequential(
698 nn.SiLU(),
699 linear(emb_channels, 2 * self.out_channels, **factory_kwargs)
700 )
701
702 self.out_layers = nn.Sequential(
703 normalization(self.out_channels, **factory_kwargs),
704 nn.SiLU(),
705 nn.Dropout(p=dropout),
706 zero_module(
707 conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, **factory_kwargs)
708 ),
709 )
710
711 if self.out_channels == self.in_channels:
712 self.skip_connection = nn.Identity()
713 elif use_conv:
714 self.skip_connection = conv_nd(
715 dims, self.in_channels, self.out_channels, 3, padding=1, **factory_kwargs
716 )
717 else:
718 self.skip_connection = conv_nd(dims, self.in_channels, self.out_channels, 1, **factory_kwargs)
719
720 def forward(self, x, emb):
721 if self.updown:
722 in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
723 h = in_rest(x)
724 h = self.h_upd(h)
725 x = self.x_upd(x)
726 h = in_conv(h)
727 else:
728 h = self.in_layers(x)
729
730 emb_out = self.emb_layers(emb)
731 while len(emb_out.shape) < len(h.shape):
732 emb_out = emb_out[..., None]
733
734 # Adaptive Group Normalization
735 out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
736 scale, shift = torch.chunk(emb_out, 2, dim=1)
737 h = out_norm(h) * (1. + scale) + shift
738 h = out_rest(h)
739
740 return self.skip_connection(x) + h
741
742
743 class UNetDown(nn.Module):
744 """
745 patch_size: one of [1, 2 ,4 ,8]
746 in_channels: vae latent dim
747 hidden_channels: hidden dim for reducing parameters
748 out_channels: transformer model dim
749 """
750 def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels,
751 dropout=0.0, device=None, dtype=None):
752 factory_kwargs = {'dtype': dtype, 'device': device}
753 super().__init__()
754
755 self.patch_size = patch_size
756 assert self.patch_size in [1, 2, 4, 8]
757
758 self.model = nn.ModuleList(
759 [conv_nd(
760 2,
761 in_channels=in_channels,
762 out_channels=hidden_channels,
763 kernel_size=3,
764 padding=1,
765 **factory_kwargs
766 )]
767 )
768
769 if self.patch_size == 1:
770 self.model.append(ResBlock(
771 in_channels=hidden_channels,
772 emb_channels=emb_channels,
773 out_channels=out_channels,
774 dropout=dropout,
775 **factory_kwargs
776 ))
777 else:
778 for i in range(self.patch_size // 2):
779 self.model.append(ResBlock(
780 in_channels=hidden_channels,
781 emb_channels=emb_channels,
782 out_channels=hidden_channels if (i + 1) * 2 != self.patch_size else out_channels,
783 dropout=dropout,
784 down=True,
785 **factory_kwargs
786 ))
787
788 def forward(self, x, t):
789 assert x.shape[2] % self.patch_size == 0 and x.shape[3] % self.patch_size == 0
790 for module in self.model:
791 if isinstance(module, ResBlock):
792 x = module(x, t)
793 else:
794 x = module(x)
795 _, _, token_h, token_w = x.shape
796 x = rearrange(x, 'b c h w -> b (h w) c')
797 return x, token_h, token_w
798
799
800 class UNetUp(nn.Module):
801 """
802 patch_size: one of [1, 2 ,4 ,8]
803 in_channels: transformer model dim
804 hidden_channels: hidden dim for reducing parameters
805 out_channels: vae latent dim
806 """
807 def __init__(self, patch_size, in_channels, emb_channels, hidden_channels, out_channels,
808 dropout=0.0, device=None, dtype=None, out_norm=False):
809 factory_kwargs = {'dtype': dtype, 'device': device}
810 super().__init__()
811
812 self.patch_size = patch_size
813 assert self.patch_size in [1, 2, 4, 8]
814
815 self.model = nn.ModuleList()
816
817 if self.patch_size == 1:
818 self.model.append(ResBlock(
819 in_channels=in_channels,
820 emb_channels=emb_channels,
821 out_channels=hidden_channels,
822 dropout=dropout,
823 **factory_kwargs
824 ))
825 else:
826 for i in range(self.patch_size // 2):
827 self.model.append(ResBlock(
828 in_channels=in_channels if i == 0 else hidden_channels,
829 emb_channels=emb_channels,
830 out_channels=hidden_channels,
831 dropout=dropout,
832 up=True,
833 **factory_kwargs
834 ))
835
836 if out_norm:
837 self.model.append(nn.Sequential(
838 normalization(hidden_channels, **factory_kwargs),
839 nn.SiLU(),
840 conv_nd(
841 2,
842 in_channels=hidden_channels,
843 out_channels=out_channels,
844 kernel_size=3,
845 padding=1,
846 **factory_kwargs
847 ),
848 ))
849 else:
850 self.model.append(conv_nd(
851 2,
852 in_channels=hidden_channels,
853 out_channels=out_channels,
854 kernel_size=3,
855 padding=1,
856 **factory_kwargs
857 ))
858
859 # batch_size, seq_len, model_dim
860 def forward(self, x, t, token_h, token_w):
861 x = rearrange(x, 'b (h w) c -> b c h w', h=token_h, w=token_w)
862 for module in self.model:
863 if isinstance(module, ResBlock):
864 x = module(x, t)
865 else:
866 x = module(x)
867 return x
868
869
870 # =======================================================
871 # Modules for Transformer Backbone
872 # =======================================================
873
874 @dataclass
875 class CausalMMOutputWithPast(CausalLMOutputWithPast):
876 diffusion_prediction: Optional[torch.Tensor] = None
877
878
879 class HunyuanStaticCache(StaticCache):
880 """
881 A custom static cache for multi-modal models that supports dynamic extension of the cache
882 and inplace updates of the cache.
883
884 This cache supports batch cache_position updates.
885 """
886 def __init__(self, *args, **kwargs):
887 self.dynamic = kwargs.pop("dynamic", False)
888 super().__init__(*args, **kwargs)
889
890 def update(
891 self,
892 key_states: torch.Tensor,
893 value_states: torch.Tensor,
894 layer_idx: int,
895 cache_kwargs: Optional[Dict[str, Any]] = None,
896 ) -> Tuple[torch.Tensor, torch.Tensor]:
897 """
898 Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
899 It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
900
901 Parameters:
902 key_states (`torch.Tensor`):
903 The new key states to cache.
904 value_states (`torch.Tensor`):
905 The new value states to cache.
906 layer_idx (`int`):
907 The index of the layer to cache the states for.
908 cache_kwargs (`Dict[str, Any]`, `optional`):
909 Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
910 to know how where to write in the cache.
911
912 Return:
913 A tuple containing the updated key and value states.
914 """
915 cache_position = cache_kwargs.get("cache_position")
916 if hasattr(self, "key_cache") and hasattr(self, "value_cache"):
917 if self.key_cache[layer_idx].device != key_states.device:
918 self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
919 self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
920 k_out = self.key_cache[layer_idx]
921 v_out = self.value_cache[layer_idx]
922 key_states = key_states.to(k_out.dtype)
923 value_states = value_states.to(v_out.dtype)
924 else:
925 if self.layers[layer_idx].keys is None:
926 self.layers[layer_idx].lazy_initialization(key_states)
927 k_out = self.layers[layer_idx].keys
928 v_out = self.layers[layer_idx].values
929
930 if cache_position is None:
931 k_out.copy_(key_states)
932 v_out.copy_(value_states)
933 else:
934 # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
935 # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
936 # operation, that avoids copies and uses less memory.
937 if cache_position.dim() == 1:
938 k_out.index_copy_(2, cache_position, key_states)
939 v_out.index_copy_(2, cache_position, value_states)
940
941 if self.dynamic:
942 end = cache_position[-1].item() + 1
943 k_out = k_out[:, :, :end]
944 v_out = v_out[:, :, :end]
945 else:
946 assert cache_position.dim() == 2, f"multiple batch dims not yet {cache_position.shape=}"
947 batch_size, idx_size = cache_position.shape
948 assert batch_size == k_out.size(0)
949 assert batch_size == v_out.size(0)
950 assert batch_size == key_states.size(0)
951 assert batch_size == value_states.size(0)
952 for i in range(batch_size):
953 unbatched_dim = 1
954 k_out[i].index_copy_(unbatched_dim, cache_position[i], key_states[i])
955 v_out[i].index_copy_(unbatched_dim, cache_position[i], value_states[i])
956
957 if self.dynamic:
958 assert len(cache_position) == 1
959 end = cache_position[0, -1].item() + 1
960 k_out = k_out[:, :, :end]
961 v_out = v_out[:, :, :end]
962
963 return k_out, v_out
964
965
966 class HunyuanRMSNorm(nn.Module):
967 def __init__(self, hidden_size, eps=1e-6):
968 """
969 HunyuanRMSNorm is equivalent to T5LayerNorm
970 """
971 super().__init__()
972 self.weight = nn.Parameter(torch.ones(hidden_size))
973 self.variance_epsilon = eps
974
975 def forward(self, hidden_states):
976 input_dtype = hidden_states.dtype
977 hidden_states = hidden_states.to(torch.float32)
978 variance = hidden_states.pow(2).mean(-1, keepdim=True)
979 hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
980 return self.weight * hidden_states.to(input_dtype)
981
982
983 class HunyuanMLP(nn.Module):
984 def __init__(self, config: HunyuanImage3Config, layer_idx=None, is_shared_mlp=False, is_moe=False):
985 super().__init__()
986 self.config = config
987 self.layer_idx = layer_idx
988 self.hidden_size = config.hidden_size
989 self.hidden_act = config.hidden_act
990
991 self.intermediate_size = config.intermediate_size
992 if is_shared_mlp or is_moe:
993 # 如果是 moe 的话,优先用 moe_intermediate_size
994 if config.moe_intermediate_size is not None:
995 self.intermediate_size = config.moe_intermediate_size \
996 if isinstance(config.moe_intermediate_size, int) else config.moe_intermediate_size[layer_idx]
997
998 if is_shared_mlp:
999 num_shared_expert = config.num_shared_expert \
1000 if isinstance(config.num_shared_expert, int) else config.num_shared_expert[layer_idx]
1001 self.intermediate_size *= num_shared_expert
1002
1003 self.act_fn = ACT2FN[config.hidden_act]
1004 if self.hidden_act == "silu":
1005 self.intermediate_size *= 2 # SwiGLU
1006 self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
1007 self.down_proj = nn.Linear(self.intermediate_size // 2, self.hidden_size, bias=config.mlp_bias)
1008 elif self.hidden_act == "gelu":
1009 self.gate_and_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
1010 self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
1011 else:
1012 assert False, "other hidden_act are not supported"
1013
1014 def forward(self, x):
1015 if self.hidden_act == "silu":
1016 gate_and_up_proj = self.gate_and_up_proj(x)
1017 x1, x2 = gate_and_up_proj.chunk(2, dim=2)
1018 down_proj = self.down_proj(x1 * self.act_fn(x2))
1019 return down_proj
1020 elif self.hidden_act == "gelu":
1021 intermediate = self.gate_and_up_proj(x)
1022 intermediate = self.act_fn(intermediate)
1023 output = self.down_proj(intermediate)
1024 return output
1025 else:
1026 assert False, "other hidden_act are not supported"
1027
1028
1029 class HunyuanTopKGate(nn.Module):
1030 def __init__(self, config: HunyuanImage3Config, layer_idx: Optional[int] = None):
1031 super().__init__()
1032 self.config = config
1033 self.layer_idx = layer_idx
1034 self.moe_topk = config.moe_topk if isinstance(config.moe_topk, int) else config.moe_topk[layer_idx]
1035 self.drop_tokens = config.moe_drop_tokens
1036 self.min_capacity = 8
1037 self.random_routing_dropped_token = config.moe_random_routing_dropped_token
1038 num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx]
1039 self.wg = nn.Linear(config.hidden_size, num_experts, bias=False, dtype=torch.float32)
1040
1041 # DeepSeek gating args
1042 self.routed_scaling_factor = config.routed_scaling_factor
1043 self.n_group = config.n_group
1044 self.topk_group = config.topk_group
1045 self.norm_topk_prob = config.norm_topk_prob
1046 self.group_limited_greedy = config.group_limited_greedy
1047
1048 def forward(self, hidden_states, topk_impl='default'):
1049 bsz, seq_len, hidden_size = hidden_states.shape
1050 hidden_states = hidden_states.reshape(-1, hidden_size)
1051 if self.wg.weight.dtype == torch.float32:
1052 hidden_states = hidden_states.float()
1053 logits = self.wg(hidden_states)
1054 if topk_impl == 'default':
1055 gate_output = topkgating(logits, self.moe_topk, group_limited_greedy=self.group_limited_greedy,
1056 n_group=self.n_group, topk_group=self.topk_group,
1057 norm_topk_prob=self.norm_topk_prob,
1058 routed_scaling_factor=self.routed_scaling_factor,
1059 capacity_factor=self.config.capacity_factor,
1060 drop_tokens=self.drop_tokens)
1061 elif topk_impl == 'easy':
1062 gate_output = self.easy_topk(logits, self.moe_topk)
1063 else:
1064 raise ValueError(f"Unsupported topk_impl: {topk_impl}")
1065
1066 return gate_output
1067
1068 @staticmethod
1069 def easy_topk(logits, moe_topk):
1070 gates = F.softmax(logits, dim=1)
1071 topk_weight_1, expert_index = torch.topk(gates, moe_topk)
1072 weight_sums = topk_weight_1.sum(dim=1, keepdim=True)
1073 weight_sums = torch.clamp(weight_sums, min=1e-8)
1074 topk_weight = topk_weight_1 / weight_sums
1075
1076 return topk_weight, expert_index
1077
1078
1079 class HunyuanMoE(nn.Module):
1080 def __init__(self, config: HunyuanImage3Config, layer_idx: Optional[int] = None):
1081 super().__init__()
1082 self.config = config
1083 self.layer_idx = layer_idx
1084 self.moe_topk = config.moe_topk
1085 self.num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx]
1086 if config.use_mixed_mlp_moe:
1087 self.shared_mlp = HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
1088 self.gate = HunyuanTopKGate(config, layer_idx=layer_idx)
1089 self.experts = nn.ModuleList(
1090 [HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=True) for _ in range(self.num_experts)]
1091 )
1092
1093 self._moe_impl = config.moe_impl
1094 # For FlashInfer
1095 self.moe_weight = None
1096 self.moe_weight_2 = None
1097 self._weights_initialized = False
1098
1099 @property
1100 def moe_impl(self):
1101 return self._moe_impl
1102
1103 @moe_impl.setter
1104 def moe_impl(self, value):
1105 self._moe_impl = value
1106 if self._moe_impl == "flashinfer":
1107 assert flashinfer is not None, "When using fused_moe, flashinfer must be installed."
1108
1109 def forward(self, hidden_states):
1110 torch.cuda.set_device(hidden_states.device.index)
1111 bsz, seq_len, hidden_size = hidden_states.shape
1112
1113 if self.config.use_mixed_mlp_moe:
1114 hidden_states_mlp = self.shared_mlp(hidden_states)
1115
1116 reshaped_input = hidden_states.reshape(-1, hidden_size) # [bsz*seq_len, hidden_size]
1117
1118 with nvtx.range("MoE"):
1119 if self._moe_impl == "flashinfer":
1120 # Get expert weights
1121 if not self._weights_initialized:
1122 self._initialize_weights_on_device(hidden_states.device)
1123 topk_weight, topk_index = self.gate(hidden_states, topk_impl='easy')
1124
1125 combined_output = torch.zeros_like(reshaped_input)
1126 _ = flashinfer.fused_moe.cutlass_fused_moe( # noqa
1127 reshaped_input.contiguous(),
1128 topk_index.to(torch.int).contiguous(),
1129 topk_weight.to(torch.float).contiguous(),
1130 self.moe_weight,
1131 self.moe_weight_2,
1132 torch.bfloat16,
1133 output=combined_output,
1134 quant_scales=None,
1135 )
1136 else:
1137 # Original implementation - fallback for compatibility
1138 l_moe, combine_weights, dispatch_mask, exp_counts = self.gate(hidden_states, topk_impl='default')
1139 dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
1140 chunks = dispatched_input.chunk(self.num_experts, dim=0)
1141 expert_outputs = []
1142 for chunk, expert in zip(chunks, self.experts):
1143 expert_outputs.append(expert(chunk))
1144
1145 expert_output = torch.cat(expert_outputs, dim=0)
1146 combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output)
1147
1148 combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
1149
1150 if self.config.use_mixed_mlp_moe:
1151 output = hidden_states_mlp + combined_output # noqa
1152 else:
1153 output = combined_output
1154
1155 return output
1156
1157 def _initialize_weights_on_device(self, device):
1158 expert_weights_gate_up = []
1159 expert_weights_down = []
1160
1161 for expert in self.experts:
1162 expert.to(device)
1163 expert_weights_gate_up.append(expert.gate_and_up_proj.weight.to(device))
1164 expert_weights_down.append(expert.down_proj.weight.to(device))
1165
1166 self.moe_weight = torch.stack(expert_weights_gate_up).contiguous()
1167 self.moe_weight_2 = torch.stack(expert_weights_down).contiguous()
1168 # empty the expert weights
1169 for expert in self.experts:
1170 expert.gate_and_up_proj.weight.data = torch.empty(0, device=device)
1171 if expert.gate_and_up_proj.bias is not None:
1172 expert.gate_and_up_proj.bias.data = torch.empty(0, device=device)
1173 expert.down_proj.weight.data = torch.empty(0, device=device)
1174 if expert.down_proj.bias is not None:
1175 expert.down_proj.bias.data = torch.empty(0, device=device)
1176
1177 self._weights_initialized = True
1178
1179
1180 class HunyuanImage3SDPAAttention(nn.Module):
1181 """PyTorch SDPA attention implementation using torch.nn.functional.scaled_dot_product_attention"""
1182
1183 def __init__(self, config: HunyuanImage3Config, layer_idx: int):
1184 super().__init__()
1185 self.config = config
1186 self.layer_idx = layer_idx
1187 self.attention_type = 'self'
1188
1189 self.attention_dropout = config.attention_dropout
1190 self.hidden_size = config.hidden_size
1191 self.num_heads = config.num_attention_heads
1192 # self.head_dim = self.hidden_size // self.num_heads
1193 self.head_dim = config.attention_head_dim
1194 self.num_key_value_heads = config.num_key_value_heads if config.num_key_value_heads else self.num_heads
1195 self.num_key_value_groups = self.num_heads // self.num_key_value_heads
1196 self.max_position_embeddings = config.max_position_embeddings
1197 self.rope_theta = config.rope_theta
1198 self.is_causal = True
1199 self.use_qk_norm = config.use_qk_norm
1200 self.use_rotary_pos_emb = config.use_rotary_pos_emb
1201 self.hidden_size_q = self.head_dim * self.num_heads
1202 self.hidden_size_kv = self.head_dim * self.num_key_value_heads
1203
1204 # define layers
1205 self.qkv_proj = nn.Linear(
1206 self.hidden_size,
1207 self.hidden_size_q + 2 * self.hidden_size_kv,
1208 bias=config.attention_bias
1209 )
1210 self.o_proj = nn.Linear(self.hidden_size_q, self.hidden_size, bias=config.attention_bias)
1211
1212 if self.use_qk_norm:
1213 self.query_layernorm = HunyuanRMSNorm(self.head_dim, eps=config.rms_norm_eps)
1214 self.key_layernorm = HunyuanRMSNorm(self.head_dim, eps=config.rms_norm_eps)
1215
1216 if self.use_rotary_pos_emb:
1217 self._init_rope()
1218
1219 def _init_rope(self):
1220 scaling_type = self.config.rope_scaling["type"]
1221 if scaling_type == "custom":
1222 # Using custom rotary embedding
1223 self.rotary_emb = None
1224 else:
1225 raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
1226
1227 def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
1228 return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
1229
1230 def forward(
1231 self,
1232 hidden_states: torch.Tensor,
1233 attention_mask: Optional[torch.Tensor] = None,
1234 position_ids: Optional[torch.LongTensor] = None,
1235 past_key_value: Optional[Cache] = None,
1236 output_attentions: bool = False,
1237 use_cache: Optional[bool] = False,
1238 custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
1239 **kwargs,
1240 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
1241 if output_attentions:
1242 raise NotImplementedError(
1243 'HunyuanImage3Model is using HunyuanImage3SDPAAttention,'
1244 'but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`.'
1245 )
1246
1247 bsz, q_len, _ = hidden_states.size()
1248
1249 qkv_states = self.qkv_proj(hidden_states)
1250 qkv_states = qkv_states.reshape(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups + 2,
1251 self.head_dim)
1252 query_states, key_states, value_states = torch.split(qkv_states, [self.num_key_value_groups, 1, 1], dim=3)
1253
1254 query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1255 key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1256 value_states = value_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1257
1258 if self.use_rotary_pos_emb:
1259 cos, sin = custom_pos_emb
1260 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
1261
1262 if self.use_qk_norm:
1263 query_states = self.query_layernorm(query_states)
1264 key_states = self.key_layernorm(key_states)
1265
1266 query_states = query_states.to(value_states.dtype)
1267 key_states = key_states.to(value_states.dtype)
1268
1269 if past_key_value is not None:
1270 cache_kwargs = {"cache_position": position_ids}
1271 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1272 query_states = query_states.to(key_states.dtype)
1273
1274 key_states = repeat_kv(key_states, self.num_key_value_groups)
1275 value_states = repeat_kv(value_states, self.num_key_value_groups)
1276
1277 # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
1278 # custom attn_mask,
1279 # Reference: https://github.com/pytorch/pytorch/issues/112577.
1280 if query_states.device.type == "cuda" and attention_mask is not None:
1281 query_states = query_states.contiguous()
1282 key_states = key_states.contiguous()
1283 value_states = value_states.contiguous()
1284
1285 attn_output = torch.nn.functional.scaled_dot_product_attention(
1286 query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0
1287 )
1288 attn_output = attn_output.transpose(1, 2).contiguous()
1289
1290 attn_output = attn_output.reshape(bsz, q_len, -1)
1291
1292 attn_output = self.o_proj(attn_output)
1293
1294 return attn_output, None, past_key_value
1295
1296
1297 class HunyuanImage3FlashAttention2(HunyuanImage3SDPAAttention):
1298
1299 def forward(
1300 self,
1301 hidden_states: torch.Tensor,
1302 attention_mask: Optional[torch.Tensor] = None,
1303 position_ids: Optional[torch.LongTensor] = None,
1304 past_key_value: Optional[Cache] = None,
1305 output_attentions: bool = False,
1306 use_cache: Optional[bool] = False,
1307 custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
1308 **kwargs,
1309 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
1310 if output_attentions:
1311 return super().forward(
1312 hidden_states=hidden_states,
1313 attention_mask=attention_mask,
1314 position_ids=position_ids,
1315 past_key_value=past_key_value,
1316 output_attentions=output_attentions,
1317 )
1318
1319 bsz, q_len, _ = hidden_states.size()
1320
1321 qkv_states = self.qkv_proj(hidden_states)
1322 qkv_states = qkv_states.reshape(bsz, q_len, self.num_key_value_heads, self.num_key_value_groups + 2,
1323 self.head_dim)
1324 query_states, key_states, value_states = torch.split(qkv_states, [self.num_key_value_groups, 1, 1], dim=3)
1325
1326 query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1327 key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1328 value_states = value_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1329
1330 if self.use_rotary_pos_emb:
1331 cos, sin = custom_pos_emb
1332 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
1333
1334 if self.use_qk_norm:
1335 query_states = self.query_layernorm(query_states)
1336 key_states = self.key_layernorm(key_states)
1337
1338 query_states = query_states.to(value_states.dtype)
1339 key_states = key_states.to(value_states.dtype)
1340
1341 if past_key_value is not None:
1342 cache_kwargs = {"cache_position": position_ids}
1343 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1344
1345 key_states = repeat_kv(key_states, self.num_key_value_groups)
1346 value_states = repeat_kv(value_states, self.num_key_value_groups)
1347
1348 # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
1349 # custom attn_mask,
1350 # Reference: https://github.com/pytorch/pytorch/issues/112577.
1351 if query_states.device.type == "cuda" and attention_mask is not None:
1352 query_states = query_states.contiguous()
1353 key_states = key_states.contiguous()
1354 value_states = value_states.contiguous()
1355
1356 target_dtype = key_states.dtype if key_states.dtype in [torch.bfloat16, torch.float16] else torch.bfloat16
1357
1358 q_fa = query_states.to(target_dtype).transpose(1, 2).contiguous()
1359 k_fa = key_states.to(target_dtype).transpose(1, 2).contiguous()
1360 v_fa = value_states.to(target_dtype).transpose(1, 2).contiguous()
1361
1362 mode = kwargs.get("mode", "gen_text")
1363 # For gen_text and gen_image, we need to handle the attention differently
1364 with nvtx.range("attention"):
1365 if mode == "gen_text":
1366 if attention_mask is None:
1367 attn_output = flash_attn_func(q_fa, k_fa, v_fa, causal=False) # decode attention
1368 else:
1369 attn_output = flash_attn_func(q_fa, k_fa, v_fa, causal=True) # prefill attention
1370 else: # image attention
1371 gen_timestep_scatter_index: Optional[torch.Tensor] = kwargs.get("gen_timestep_scatter_index", None)
1372 assert gen_timestep_scatter_index is not None, \
1373 "When gen_image, `gen_timestep_scatter_index` must be provided."
1374 # TODO: batchify
1375 timestep_index = gen_timestep_scatter_index[0, 0].item()
1376 # When image generation, different attention implementations for the first step and the following steps
1377 # help to improve the inference speed.
1378 first_step = kwargs.get("first_step", None)
1379 if first_step is None:
1380 raise ValueError("When gen_image, `first_step` must be provided.")
1381 if first_step:
1382 casual_len = timestep_index + 1
1383 text_query_states = q_fa[:, :casual_len, :, :]
1384 text_key_states = k_fa[:, :casual_len, :, :]
1385 text_value_states = v_fa[:, :casual_len, :, :]
1386 text_attn_output = flash_attn_func(
1387 text_query_states, text_key_states, text_value_states, causal=True)
1388 image_query_states = q_fa[:, casual_len:, :, :]
1389 image_attn_output = flash_attn_func(image_query_states, k_fa, v_fa, causal=False)
1390 attn_output = torch.cat((text_attn_output, image_attn_output), dim=1)
1391 else:
1392 casual_len = timestep_index + 1
1393 timestep_query_states = q_fa[:, 0:1, :, :]
1394 timestep_key_states = k_fa[:, :casual_len, :, :]
1395 timestep_value_states = v_fa[:, :casual_len, :, :]
1396 timestep_attn_output = flash_attn_func(
1397 timestep_query_states, timestep_key_states, timestep_value_states, causal=True)
1398 image_query_states = q_fa[:, 1:, :, :]
1399 image_attn_output = flash_attn_func(image_query_states, k_fa, v_fa, causal=False)
1400 attn_output = torch.cat((timestep_attn_output, image_attn_output), dim=1)
1401
1402 attn_output = attn_output.reshape(bsz, q_len, -1)
1403
1404 attn_output = self.o_proj(attn_output)
1405
1406 return attn_output, None, past_key_value
1407
1408
1409 Hunyuan_ATTENTION_CLASSES = {
1410 "eager": HunyuanImage3SDPAAttention,
1411 "sdpa": HunyuanImage3SDPAAttention,
1412 "flash_attention_2": HunyuanImage3FlashAttention2,
1413 }
1414
1415
1416 class HunyuanImage3DecoderLayer(nn.Module):
1417 def __init__(self, config: HunyuanImage3Config, layer_idx: int):
1418 super().__init__()
1419 self.hidden_size = config.hidden_size
1420 self.layer_idx = layer_idx
1421
1422 attn_impl = config._attn_implementation # noqa
1423 if attn_impl in Hunyuan_ATTENTION_CLASSES:
1424 self.self_attn = Hunyuan_ATTENTION_CLASSES[attn_impl](config=config, layer_idx=layer_idx)
1425 else:
1426 raise ValueError(f"Unsupported attention implementation: {attn_impl}")
1427
1428 if ((isinstance(config.num_experts, int) and config.num_experts > 1) or (
1429 isinstance(config.num_experts, list) and max(
1430 config.num_experts) > 1)) and layer_idx >= config.moe_layer_num_skipped:
1431 self.mlp = HunyuanMoE(config, layer_idx=layer_idx)
1432 else:
1433 self.mlp = HunyuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False, is_moe=False)
1434 if config.norm_type == 'hf_rms' or config.norm_type == 'rms':
1435 self.input_layernorm = HunyuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1436 self.post_attention_layernorm = HunyuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1437 elif config.norm_type == 'fused' or config.norm_type == 'torch_nn':
1438 self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
1439 self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
1440 else:
1441 assert False, "other norm_type are not supported"
1442
1443 def forward(
1444 self,
1445 hidden_states: torch.Tensor,
1446 attention_mask: Optional[torch.Tensor] = None,
1447 position_ids: Optional[torch.LongTensor] = None,
1448 past_key_value: Optional[Tuple[torch.Tensor]] = None,
1449 output_attentions: Optional[bool] = False,
1450 use_cache: Optional[bool] = False,
1451 custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
1452 **kwargs,
1453 ) -> Tuple[torch.FloatTensor | Any]:
1454 """
1455 Args:
1456 hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1457 attention_mask (`torch.FloatTensor`, *optional*):
1458 attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1459 query_sequence_length, key_sequence_length)` if default attention is used.
1460 position_ids (`torch.LongTensor`, *optional*):
1461 Indices of positions of each input sequence tokens in the position embeddings.
1462 output_attentions (`bool`, *optional*):
1463 Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1464 returned tensors for more detail.
1465 use_cache (`bool`, *optional*):
1466 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1467 (see `past_key_values`).
1468 past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1469 custom_pos_emb (`Tuple[torch.FloatTensor]`, *optional*): custom position embedding for rotary
1470 position embedding
1471 """
1472 if "padding_mask" in kwargs:
1473 warnings.warn(
1474 "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use "
1475 "`attention_mask` instead.`"
1476 )
1477
1478 residual = hidden_states
1479
1480 hidden_states = self.input_layernorm(hidden_states)
1481
1482 # Self Attention
1483 hidden_states, self_attn_weights, present_key_value = self.self_attn(
1484 hidden_states=hidden_states,
1485 attention_mask=attention_mask,
1486 position_ids=position_ids,
1487 past_key_value=past_key_value,
1488 output_attentions=output_attentions,
1489 use_cache=use_cache,
1490 custom_pos_emb=custom_pos_emb,
1491 **kwargs,
1492 )
1493 hidden_states = residual + hidden_states
1494 # Fully Connected
1495 residual = hidden_states
1496 hidden_states = self.post_attention_layernorm(hidden_states)
1497 hidden_states = self.mlp(hidden_states)
1498
1499 hidden_states = residual + hidden_states
1500
1501 outputs = (hidden_states,)
1502
1503 if output_attentions:
1504 outputs += (self_attn_weights,)
1505
1506 if use_cache:
1507 outputs += (present_key_value,)
1508
1509 return outputs
1510
1511
1512 @add_start_docstrings(
1513 "The bare Hunyuan Image 3 Model outputting raw hidden-states without any specific head on top.",
1514 Hunyuan_START_DOCSTRING,
1515 )
1516 class HunyuanImage3PreTrainedModel(PreTrainedModel):
1517 config_class = HunyuanImage3Config
1518 base_model_prefix = ""
1519 supports_gradient_checkpointing = True
1520 _no_split_modules = ["HunyuanImage3DecoderLayer"]
1521 _skip_keys_device_placement = "past_key_values"
1522 _supports_flash_attn_2 = True
1523 _supports_sdpa = True
1524 _supports_cache_class = True
1525
1526 def _init_weights(self, module):
1527 std = self.config.initializer_range
1528 if isinstance(module, nn.Linear):
1529 module.weight.data.normal_(mean=0.0, std=std)
1530 if module.bias is not None:
1531 module.bias.data.zero_()
1532 elif isinstance(module, nn.Embedding):
1533 module.weight.data.normal_(mean=0.0, std=std)
1534 if module.padding_idx is not None:
1535 module.weight.data[module.padding_idx].zero_()
1536
1537
1538 Hunyuan_INPUTS_DOCSTRING = r"""
1539 Args:
1540 input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1541 Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1542 it.
1543
1544 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1545 [`PreTrainedTokenizer.__call__`] for details.
1546
1547 [What are input IDs?](../glossary#input-ids)
1548 attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1549 Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1550
1551 - 1 for tokens that are **not masked**,
1552 - 0 for tokens that are **masked**.
1553
1554 [What are attention masks?](../glossary#attention-mask)
1555
1556 Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1557 [`PreTrainedTokenizer.__call__`] for details.
1558
1559 If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1560 `past_key_values`).
1561
1562 If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1563 and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1564 information on the default strategy.
1565
1566 - 1 indicates the head is **not masked**,
1567 - 0 indicates the head is **masked**.
1568 position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1569 Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1570 config.n_positions - 1]`.
1571
1572 [What are position IDs?](../glossary#position-ids)
1573 past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1574 Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1575 blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1576 returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1577
1578 Two formats are allowed:
1579 - a [`~cache_utils.Cache`] instance;
1580 - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1581 shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1582 cache format.
1583
1584 The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1585 legacy cache format will be returned.
1586
1587 If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1588 have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1589 of shape `(batch_size, sequence_length)`.
1590 inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1591 Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1592 is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1593 model's internal embedding lookup matrix.
1594 use_cache (`bool`, *optional*):
1595 If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1596 `past_key_values`).
1597 output_attentions (`bool`, *optional*):
1598 Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1599 tensors for more detail.
1600 output_hidden_states (`bool`, *optional*):
1601 Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1602 more detail.
1603 return_dict (`bool`, *optional*):
1604 Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1605 """
1606
1607
1608 @add_start_docstrings(
1609 "The bare Hunyuan Model outputting raw hidden-states without any specific head on top.",
1610 Hunyuan_START_DOCSTRING,
1611 )
1612 class HunyuanImage3Model(HunyuanImage3PreTrainedModel):
1613 def __init__(self, config: HunyuanImage3Config):
1614 super().__init__(config)
1615 self.padding_idx = config.pad_token_id
1616 self.vocab_size = config.vocab_size
1617 self.add_classification_head = config.add_classification_head
1618 self.wte = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1619 self.layers = nn.ModuleList(
1620 [HunyuanImage3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1621 )
1622 if not config.add_classification_head:
1623 self.ln_f = HunyuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1624
1625 # Initialize weights and apply final processing
1626 self.post_init()
1627
1628 self.shared_tensor = None
1629
1630 @add_start_docstrings_to_model_forward(Hunyuan_INPUTS_DOCSTRING)
1631 def forward(
1632 self,
1633 input_ids: torch.LongTensor = None,
1634 attention_mask: Optional[torch.Tensor] = None,
1635 position_ids: Optional[torch.LongTensor] = None,
1636 past_key_values: Optional[List[torch.FloatTensor]] = None,
1637 inputs_embeds: Optional[torch.FloatTensor] = None,
1638 use_cache: Optional[bool] = None,
1639 output_attentions: Optional[bool] = None,
1640 output_hidden_states: Optional[bool] = None,
1641 return_dict: Optional[bool] = None,
1642 custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
1643 mode: str = "gen_text",
1644 first_step: Optional[bool] = None,
1645 gen_timestep_scatter_index: Optional[torch.Tensor] = None,
1646 ) -> Union[Tuple, BaseModelOutputWithPast]:
1647
1648 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1649 output_hidden_states = (
1650 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1651 )
1652 use_cache = use_cache if use_cache is not None else self.config.use_cache
1653
1654 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1655
1656 if inputs_embeds is None:
1657 inputs_embeds = self.wte(input_ids)
1658
1659 # embed positions
1660 hidden_states = inputs_embeds
1661
1662 # decoder layers
1663 all_hidden_states = () if output_hidden_states else None
1664 all_self_attns = () if output_attentions else None
1665 next_decoder_cache = None
1666
1667 for layer_idx, decoder_layer in enumerate(self.layers):
1668 if output_hidden_states:
1669 all_hidden_states += (hidden_states,)
1670
1671 layer_outputs = decoder_layer(
1672 hidden_states,
1673 attention_mask=attention_mask,
1674 position_ids=position_ids,
1675 past_key_value=past_key_values,
1676 output_attentions=output_attentions,
1677 use_cache=use_cache,
1678 custom_pos_emb=custom_pos_emb,
1679 mode=mode,
1680 first_step=first_step,
1681 gen_timestep_scatter_index=gen_timestep_scatter_index,
1682 )
1683
1684 hidden_states = layer_outputs[0]
1685
1686 if use_cache:
1687 next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1688
1689 if output_attentions:
1690 all_self_attns += (layer_outputs[1],)
1691
1692 if not self.add_classification_head:
1693 # Do ln_f outside of the model for compatibility with image generation.
1694 pass
1695 # hidden_states = self.ln_f(hidden_states)
1696
1697 # add hidden states from the last decoder layer
1698 if output_hidden_states:
1699 all_hidden_states += (hidden_states,)
1700
1701 next_cache = None
1702 if use_cache:
1703 next_cache = next_decoder_cache
1704 if not return_dict:
1705 return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1706 return BaseModelOutputWithPast(
1707 last_hidden_state=hidden_states,
1708 past_key_values=next_cache,
1709 hidden_states=all_hidden_states,
1710 attentions=all_self_attns,
1711 )
1712
1713
1714 class HunyuanImage3ForCausalMM(HunyuanImage3PreTrainedModel, GenerationMixin):
1715 def __init__(self, config: HunyuanImage3Config):
1716 super().__init__(config)
1717 self.config = config
1718 self._tkwrapper: Optional[TokenizerWrapper] = None
1719
1720 # Initialize image preprocessor (for conditional images)
1721 self.image_processor = HunyuanImage3ImageProcessor(config)
1722
1723 # vae and gen_image pipeline
1724 self.vae = AutoencoderKLConv3D.from_config(config.vae)
1725 self._pipeline = None
1726
1727 # vit
1728 self.vision_model = Siglip2VisionTransformer(config.vit)
1729 self.vision_aligner = LightProjector(config.vit_aligner)
1730
1731 # image generation related
1732 self.timestep_emb = TimestepEmbedder(hidden_size=config.hidden_size)
1733 if config.img_proj_type == "unet":
1734 self.patch_embed = UNetDown(
1735 patch_size=config.patch_size,
1736 emb_channels=config.hidden_size,
1737 in_channels=config.vae["latent_channels"],
1738 hidden_channels=config.patch_embed_hidden_dim,
1739 out_channels=config.hidden_size,
1740 )
1741 self.time_embed = TimestepEmbedder(hidden_size=config.hidden_size)
1742
1743 self.final_layer = UNetUp(
1744 patch_size=config.patch_size,
1745 emb_channels=config.hidden_size,
1746 in_channels=config.hidden_size,
1747 hidden_channels=config.patch_embed_hidden_dim,
1748 out_channels=config.vae["latent_channels"],
1749 out_norm=True,
1750 )
1751 self.time_embed_2 = TimestepEmbedder(hidden_size=config.hidden_size)
1752 else:
1753 raise ValueError(f"Unknown img_proj_type {config.img_proj_type}")
1754
1755 # transformer backbone
1756 self.model = HunyuanImage3Model(config)
1757
1758 self.pad_id = config.pad_id
1759 self.vocab_size = config.vocab_size
1760
1761 # linear head
1762 self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1763
1764 # Initialize weights and apply final processing
1765 self.post_init()
1766
1767 @property
1768 def tokenizer(self):
1769 if self._tkwrapper is None:
1770 raise ValueError("Attribute `tokenizer` has not been initialized yet. Please set it first.")
1771 return self._tkwrapper
1772
1773 def load_tokenizer(self, tokenizer):
1774 self._tkwrapper = TokenizerWrapper(tokenizer)
1775
1776 @property
1777 def pipeline(self):
1778 if self._pipeline is None:
1779 self.scheduler = FlowMatchDiscreteScheduler(
1780 shift=self.generation_config.flow_shift, reverse=True, solver="euler",
1781 )
1782 self._pipeline = HunyuanImage3Text2ImagePipeline(
1783 model=self, scheduler=self.scheduler, vae=self.vae,
1784 )
1785 return self._pipeline
1786
1787 @staticmethod
1788 def get_pos_emb(custom_pos_emb, position_ids):
1789 cos, sin = custom_pos_emb
1790 cos = real_batched_index_select(cos, dim=1, idx=position_ids)
1791 sin = real_batched_index_select(sin, dim=1, idx=position_ids)
1792 return cos, sin
1793
1794 def instantiate_vae_image_tokens(
1795 self,
1796 x: torch.Tensor,
1797 images: BatchRaggedImages,
1798 ts: BatchRaggedTensor,
1799 image_mask: torch.Tensor,
1800 ):
1801 """
1802 Instantiate the VAE image embeddings into the input embedding sequence.
1803
1804 Args:
1805 x: input sequence, (batch_size, seq_len, n_embd)
1806 images: BatchRaggedImages
1807 images can be a 4-D tensor, or a list of 4-D tensors, or a list of lists of 3-D tensors.
1808 ts: BatchRaggedTensor
1809 ts can be a 1-D tensor, or a list of 1-D tensors
1810 image_mask: (batch_size, seq_len)
1811 """
1812 batch_size, seq_len, n_embd = x.shape
1813
1814 if isinstance(images, list):
1815 index = torch.arange(seq_len, device=x.device).unsqueeze(0).repeat(batch_size, 1)
1816 t_emb = []
1817 for i, (image_i, t_i) in enumerate(zip(images, ts)):
1818 if isinstance(image_i, torch.Tensor):
1819 # time_embed needs a 1-D tensor as input
1820 t_i_emb = self.time_embed(t_i)
1821 # n_{i} x one_image_seq_len x n_embd
1822 image_i_seq, _, _ = self.patch_embed(image_i, t_i_emb)
1823 # 1 x (n_{i} * one_image_seq_len)
1824 image_i_scatter_index = index[i:i + 1].masked_select(image_mask[i:i + 1].bool()).reshape(1, -1)
1825 x[i:i + 1].scatter_(
1826 dim=1,
1827 index=image_i_scatter_index.unsqueeze(-1).repeat(1, 1, n_embd),
1828 # 1 x (n_{i} * one_image_seq_len) x n_embd
1829 src=image_i_seq.reshape(1, -1, n_embd), # 1 x (n_{i} * one_image_seq_len) x n_embd
1830 )
1831 t_emb.append(t_i_emb)
1832 elif isinstance(image_i, list):
1833 # time_embed needs a 1-D tensor as input
1834 t_i_emb = self.time_embed(t_i) # n_{i} x d
1835 image_i_seq_list = [], []
1836 for j in range(len(image_i)):
1837 image_ij = image_i[j]
1838 if image_ij.dim() == 4:
1839 assert image_i[j].shape[0] == 1, "image_i[j] should have a batch dimension of 1"
1840 elif image_ij.dim() == 3:
1841 image_ij = image_ij.unsqueeze(0)
1842 else:
1843 raise ValueError(f"image_i[j] should have 3 or 4 dimensions, got {image_ij.dim()}")
1844 # 1 x one_image_seq_len_{j} x n_embd
1845 image_i_seq_j, _, _ = self.patch_embed(image_ij, t_i_emb[j:j + 1])
1846 image_i_seq_list.append(image_i_seq_j)
1847 # 1 x sum_{j}(one_image_seq_len_{j}) x n_embd
1848 image_i_seq = torch.cat(image_i_seq_list, dim=1)
1849 # 1 x sum_{j}(one_image_seq_len_{j})
1850 image_i_scatter_index = index[i:i + 1].masked_select(image_mask[i:i + 1].bool()).reshape(1, -1)
1851 x[i:i + 1].scatter_(
1852 dim=1,
1853 index=image_i_scatter_index.unsqueeze(-1).repeat(1, 1, n_embd),
1854 # 1 x sum_{j}(one_image_seq_len_{j}) x n_embd
1855 src=image_i_seq.reshape(1, -1, n_embd), # 1 x sum_{j}(one_image_seq_len_{j}) x n_embd
1856 )
1857 t_emb.append(t_i_emb)
1858 else:
1859 raise TypeError(f"image_i should be a torch.Tensor or a list, got {type(image_i)}")
1860 token_h, token_w = None, None
1861 else:
1862 # images is a 4-D tensor
1863 batch_size, seq_len, n_embd = x.shape
1864 index = torch.arange(seq_len, device=x.device).unsqueeze(0).repeat(batch_size, 1)
1865 t_emb = self.time_embed(ts)
1866 image_seq, token_h, token_w = self.patch_embed(images, t_emb)
1867 image_scatter_index = index.masked_select(image_mask.bool()).reshape(batch_size, -1)
1868 x.scatter_(
1869 dim=1,
1870 index=image_scatter_index.unsqueeze(-1).repeat(1, 1, n_embd),
1871 src=image_seq,
1872 )
1873
1874 return x, token_h, token_w
1875
1876 def instantiate_timestep_tokens(
1877 self,
1878 x: torch.Tensor,
1879 t: BatchRaggedTensor,
1880 timestep_scatter_index: BatchRaggedTensor,
1881 ):
1882 batch_size, seq_len, n_embd = x.shape
1883 # batch_size x n x n_embd
1884 timestep_scatter_src = self.timestep_emb(t.reshape(-1)).reshape(batch_size, -1, n_embd)
1885 x.scatter_(
1886 dim=1,
1887 index=timestep_scatter_index.unsqueeze(-1).repeat(1, 1, n_embd),
1888 src=timestep_scatter_src,
1889 )
1890
1891 return x
1892
1893 def instantiate_vit_image_tokens(
1894 self,
1895 x: torch.Tensor,
1896 cond_vit_images: Union[torch.Tensor, List[torch.Tensor]],
1897 cond_vit_image_mask: torch.Tensor,
1898 vit_kwargs: Dict[str, Any],
1899 ):
1900 # 1. Forward the vit encoder and vit aligner to get the vit image embeddings and align them to the
1901 # transformer hidden size
1902 cond_vit_image_embeds = []
1903 for batch_idx, image in enumerate(cond_vit_images):
1904 cur_kwargs = {k: v[batch_idx] for k, v in vit_kwargs.items()}
1905 image_embed = self.vision_model(image, **cur_kwargs).last_hidden_state
1906 image_embed = self.vision_aligner(image_embed)
1907 n, seq_len, dim = image_embed.shape
1908 image_embed = image_embed.reshape(n * seq_len, dim)
1909 cond_vit_image_embeds.append(image_embed)
1910
1911 # 2. Instantiate the vit image embeddings into the input sequence
1912 batch_size, seq_len, n_embd = x.shape
1913 index = torch.arange(seq_len, device=x.device).unsqueeze(0).repeat(batch_size, 1)
1914
1915 for i, (image_embed, mask) in enumerate(zip(cond_vit_image_embeds, cond_vit_image_mask)):
1916 image_scatter_index = index[i:i+1].masked_select(mask.bool()).reshape(1, -1)
1917 x[i:i+1].scatter_(
1918 dim=1,
1919 index=image_scatter_index.unsqueeze(-1).repeat(1, 1, n_embd),
1920 src=image_embed.reshape(1, -1, n_embd),
1921 )
1922
1923 return x
1924
1925 def ragged_final_layer(self, x, image_mask, timestep, token_h, token_w, first_step):
1926 bsz, seq_len, n_embd = x.shape
1927 if first_step:
1928 image_output = x.masked_select(image_mask.unsqueeze(-1).bool()).reshape(bsz, -1, n_embd)
1929 else:
1930 image_output = x[:, 1:, :]
1931 timestep_emb = self.time_embed_2(timestep)
1932 pred = self.final_layer(image_output, timestep_emb, token_h, token_w)
1933 return pred
1934
1935 @staticmethod
1936 def _check_inputs(cond, target, check_list):
1937 if cond:
1938 for name, item in check_list:
1939 assert item is not None, f"`{name}` should be provided when `{target}`."
1940
1941 @add_start_docstrings_to_model_forward(Hunyuan_INPUTS_DOCSTRING)
1942 def forward(
1943 self,
1944 input_ids: torch.LongTensor = None,
1945 attention_mask: Optional[torch.Tensor] = None,
1946 position_ids: Optional[torch.LongTensor] = None,
1947 past_key_values: Optional[List[torch.FloatTensor]] = None,
1948 use_cache: Optional[bool] = None,
1949 output_attentions: Optional[bool] = None,
1950 output_hidden_states: Optional[bool] = None,
1951 return_dict: Optional[bool] = None,
1952 custom_pos_emb: Optional[Tuple[torch.FloatTensor]] = None,
1953 mode: str = "gen_text",
1954 first_step: Optional[bool] = None,
1955 # for gen image
1956 images: Optional[BatchRaggedImages] = None,
1957 image_mask: Optional[torch.Tensor] = None,
1958 timestep: Optional[BatchRaggedTensor] = None,
1959 gen_timestep_scatter_index: Optional[torch.Tensor] = None,
1960 # for cond image
1961 cond_vae_images: Optional[BatchRaggedImages] = None,
1962 cond_timestep: Optional[BatchRaggedTensor] = None,
1963 cond_vae_image_mask: Optional[torch.Tensor] = None,
1964 cond_vit_images: Optional[BatchRaggedImages] = None,
1965 cond_vit_image_mask: Optional[torch.Tensor] = None,
1966 vit_kwargs: Optional[Dict[str, Any]] = None,
1967 cond_timestep_scatter_index: Optional[torch.Tensor] = None,
1968 ) -> Union[Tuple, CausalMMOutputWithPast]:
1969 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1970 # Sanity Check of Inputs
1971 self._check_inputs(mode == "gen_image", "in `gen_image` mode", [
1972 ("images", images), ("timestep", timestep), ("gen_timestep_scatter_index", gen_timestep_scatter_index),
1973 ])
1974 self._check_inputs(mode == "gen_image" and first_step, "in `gen_image` mode at the first step", [
1975 ("image_mask", image_mask),
1976 ])
1977 self._check_inputs(cond_vae_images is not None, "`cond_vae_images` is provided", [
1978 ("cond_timestep", cond_timestep), ("cond_vae_image_mask", cond_vae_image_mask),
1979 ("cond_timestep_scatter_index", cond_timestep_scatter_index),
1980 ])
1981 self._check_inputs(cond_vit_images is not None, "`cond_vit_images` is provided", [
1982 ("cond_vit_image_mask", cond_vit_image_mask), ("vit_kwargs", vit_kwargs),
1983 ])
1984
1985 custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids)
1986
1987 inputs_embeds = self.model.wte(input_ids)
1988 bsz, seq_len, n_embd = inputs_embeds.shape
1989
1990 # Instantiate placeholder tokens: <timestep>, <img> for the gen image
1991 if mode == "gen_text":
1992 # For gen_text, make sure gen_timestep_scatter_index is None
1993 gen_timestep_scatter_index = None
1994 token_h, token_w = None, None
1995 else:
1996 if first_step:
1997 inputs_embeds, token_h, token_w = self.instantiate_vae_image_tokens(
1998 inputs_embeds, images, timestep, image_mask)
1999 inputs_embeds = self.instantiate_timestep_tokens(
2000 inputs_embeds, timestep, gen_timestep_scatter_index)
2001 else:
2002 t_emb = self.time_embed(timestep)
2003 image_emb, token_h, token_w = self.patch_embed(images, t_emb)
2004 timestep_emb = self.timestep_emb(timestep).reshape(bsz, -1, n_embd)
2005 inputs_embeds = torch.cat([timestep_emb, image_emb], dim=1)
2006
2007 # Instantiate placeholder tokens: <timestep>, <img> for cond images
2008 # Should only run once with kv-cache enabled.
2009 if cond_vae_images is not None:
2010 inputs_embeds, _, _ = self.instantiate_vae_image_tokens(
2011 inputs_embeds, cond_vae_images, cond_timestep, cond_vae_image_mask)
2012 inputs_embeds = self.instantiate_timestep_tokens(
2013 inputs_embeds, cond_timestep, cond_timestep_scatter_index)
2014 if cond_vit_images is not None:
2015 inputs_embeds = self.instantiate_vit_image_tokens(
2016 inputs_embeds, cond_vit_images, cond_vit_image_mask, vit_kwargs)
2017
2018 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
2019 outputs = self.model(
2020 input_ids=input_ids,
2021 attention_mask=attention_mask,
2022 position_ids=position_ids,
2023 past_key_values=past_key_values,
2024 inputs_embeds=inputs_embeds,
2025 use_cache=use_cache,
2026 output_attentions=output_attentions,
2027 output_hidden_states=output_hidden_states,
2028 return_dict=return_dict,
2029 custom_pos_emb=custom_pos_emb,
2030 mode=mode,
2031 first_step=first_step,
2032 gen_timestep_scatter_index=gen_timestep_scatter_index,
2033 )
2034 hidden_states = outputs[0]
2035
2036 if mode == "gen_text":
2037 hidden_states = self.model.ln_f(hidden_states)
2038 logits = self.lm_head(hidden_states)
2039 logits = logits.float()
2040 diffusion_prediction = None
2041 else:
2042 logits = None
2043 hidden_states = hidden_states.to(input_ids.device)
2044 diffusion_prediction = self.ragged_final_layer(
2045 hidden_states, image_mask, timestep, token_h, token_w, first_step)
2046
2047 if not return_dict:
2048 output = (logits,) + outputs[1:] + (diffusion_prediction,)
2049 return output
2050
2051 output = CausalMMOutputWithPast(
2052 logits=logits,
2053 past_key_values=outputs.past_key_values,
2054 hidden_states=outputs.hidden_states,
2055 attentions=outputs.attentions,
2056 diffusion_prediction=diffusion_prediction,
2057 )
2058
2059 return output
2060
2061 @staticmethod
2062 def check_inputs(prompt=None, message_list=None):
2063 if prompt is None and message_list is None:
2064 raise ValueError("Either `prompt` or `message_list` should be provided.")
2065 if prompt is not None and message_list is not None:
2066 raise ValueError("Only one of `prompt` or `message_list` should be provided.")
2067 if prompt is not None:
2068 assert isinstance(prompt, str) or isinstance(prompt, list), \
2069 f"`prompt` should be a string or a list of strings, but got {type(prompt)}."
2070 if isinstance(prompt, list):
2071 assert len(prompt) > 0 and all(isinstance(p, str) for p in prompt), \
2072 "`prompt` should be a non-empty list of strings."
2073 if message_list is not None:
2074 if not isinstance(message_list, list):
2075 raise ValueError(f"`message_list` should be a list of messages, but got {type(message_list)}.")
2076 assert len(message_list) > 0, "`message_list` should be a non-empty list."
2077 for message in message_list:
2078 assert isinstance(message, list) or isinstance(message, dict), \
2079 f"Each message should be a list of dicts or a dict, but got {type(message)}."
2080
2081 @staticmethod
2082 def prepare_seed(seed, batch_size):
2083 if isinstance(seed, torch.Tensor):
2084 seed = seed.tolist()
2085 if seed is None:
2086 seeds = [random.randint(0, 10_000_000) for _ in range(batch_size)]
2087 elif isinstance(seed, int):
2088 seeds = [seed for _ in range(batch_size)]
2089 elif isinstance(seed, (list, tuple)):
2090 if len(seed) == batch_size:
2091 seeds = [int(seed[i]) for i in range(batch_size)]
2092 else:
2093 raise ValueError(f"Length of seed must be equal to the batch_size({batch_size}), got {seed}.")
2094 else:
2095 raise ValueError(f"Seed must be an integer, a list of integers, or None, got {seed}.")
2096 return seeds
2097
2098 @staticmethod
2099 def build_batch_rope_image_info(output, sections):
2100 rope_image_info = []
2101 for image_slices, sections_i in zip(output.all_image_slices, sections):
2102 image_shapes = []
2103 for section in sections_i:
2104 if 'image' in section['type']:
2105 if isinstance(section['token_height'], list):
2106 assert len(section['token_height']) == len(section['token_height']), \
2107 (f"token_height and token_width should have the same length, "
2108 f"but got {len(section['token_height'])} and {len(section['token_width'])}")
2109 image_shapes.extend(list(zip(section['token_height'], section['token_width'])))
2110 else:
2111 image_shapes.append((section['token_height'], section['token_width']))
2112 assert len(image_slices) == len(image_shapes), (
2113 f"Size miss matching: Image slices({len(image_slices)}) != image shapes({len(image_shapes)})"
2114 )
2115 rope_image_info.append(list(zip(image_slices, image_shapes)))
2116 return rope_image_info
2117
2118 def vae_encode(self, image, cfg_factor=1):
2119 config = self.vae.config
2120
2121 with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
2122 vae_encode_result = self.vae.encode(image)
2123 if isinstance(vae_encode_result, torch.Tensor):
2124 latents = vae_encode_result
2125 else:
2126 latents = vae_encode_result.latent_dist.sample()
2127 if hasattr(config, 'shift_factor') and config.shift_factor:
2128 latents.sub_(config.shift_factor)
2129 if hasattr(config, 'scaling_factor') and config.scaling_factor:
2130 latents.mul_(config.scaling_factor)
2131
2132 if hasattr(self.vae, "ffactor_temporal"):
2133 assert latents.shape[2] == 1, "latents should have shape [B, C, T, H, W] and T should be 1"
2134 latents = latents.squeeze(2)
2135
2136 # Here we always use t=0 to declare it is a clean conditional image
2137 t = torch.zeros((latents.shape[0],))
2138
2139 if cfg_factor > 1:
2140 t = t.repeat(cfg_factor)
2141 latents = latents.repeat(cfg_factor, 1, 1, 1)
2142
2143 return t, latents
2144
2145 def _encode_cond_image(
2146 self,
2147 batch_cond_image_info_list: List[List[JointImageInfo]],
2148 cfg_factor: int = 1,
2149 ):
2150 # VAE encode one by one, as we assume cond images have different sizes
2151 batch_cond_vae_images, batch_cond_t, batch_cond_vit_images = [], [], []
2152 for cond_image_info_list in batch_cond_image_info_list:
2153 cond_vae_image_list, cond_t_list, cond_vit_image_list = [], [], []
2154 for image_info in cond_image_info_list:
2155 cond_t_, cond_vae_image_ = self.vae_encode(
2156 image_info.vae_image_info.image_tensor.to(self.device),
2157 )
2158 cond_vit_image_list.append(image_info.vision_image_info.image_tensor)
2159 cond_vae_image_list.append(cond_vae_image_.squeeze(0))
2160 cond_t_list.append(cond_t_)
2161 batch_cond_vae_images.append(cond_vae_image_list)
2162 batch_cond_t.append(cond_t_list)
2163 batch_cond_vit_images.append(torch.cat(cond_vit_image_list, dim=0))
2164
2165 # If only one cond image for each sample and all have the same size, we can batch them together
2166 # In this case, cond_vae_images is a 4-D tensor.
2167 if all([len(items) == 1 for items in batch_cond_vae_images]) and all(
2168 items[0].shape == batch_cond_vae_images[0][0].shape for items in batch_cond_vae_images):
2169 cond_vae_images = torch.stack([items[0] for items in batch_cond_vae_images], dim=0)
2170 cond_t = torch.cat([items[0] for items in batch_cond_t], dim=0)
2171 if cfg_factor > 1:
2172 cond_t = cond_t.repeat(cfg_factor)
2173 cond_vae_images = cond_vae_images.repeat(cfg_factor, 1, 1, 1)
2174 else:
2175 # In this case, cond_vae_images is a list of 4-D tensors or a list of lists of 3-D tensors.
2176 cond_t = [torch.cat(item, dim=0) for item in batch_cond_t]
2177 cond_vae_images = []
2178 for items in batch_cond_vae_images:
2179 if all(items[0].shape == item.shape for item in items):
2180 cond_vae_images.append(torch.stack(items, dim=0))
2181 else:
2182 cond_vae_images.append(items)
2183 if cfg_factor > 1:
2184 cond_t = cond_t * cfg_factor
2185 cond_vae_images = cond_vae_images * cfg_factor
2186
2187 if cfg_factor > 1:
2188 batch_cond_vit_images = batch_cond_vit_images * cfg_factor
2189
2190 return cond_vae_images, cond_t, batch_cond_vit_images
2191
2192 def prepare_model_inputs(
2193 self,
2194 prompt=None,
2195 mode="gen_text",
2196 system_prompt=None,
2197 cot_text=None,
2198 image_size="auto",
2199 message_list=None,
2200 device=None,
2201 max_new_tokens=None,
2202 **kwargs,
2203 ):
2204 # 1. Sanity check
2205 self.check_inputs(prompt, message_list)
2206 device = default(device, self.device)
2207
2208 # 2. Format inputs
2209 batch_message_list = message_list
2210 batch_prompt = prompt
2211 batch_cot_text = cot_text
2212 batch_system_prompt = system_prompt
2213 batch_gen_image_info = None
2214 # TODO: construct with user input images
2215 batch_cond_image_info = None
2216
2217 # -- 2.1 message_list
2218 if batch_message_list is not None:
2219 if isinstance(batch_message_list[0], dict):
2220 batch_message_list = [batch_message_list]
2221 batch_size = len(batch_message_list)
2222
2223 batch_gen_image_info = [
2224 [message['content'] for message in message_list_ if message['type'] == 'gen_image']
2225 for message_list_ in batch_message_list
2226 ]
2227 # At most one gen_image is allowed for each message_list
2228 batch_gen_image_info = [info[-1] if len(info) > 0 else None for info in batch_gen_image_info]
2229 # Multiple cond images are allowed.
2230 batch_cond_image_info = [
2231 [message['content'] for message in message_list_ if message['type'] == 'joint_image']
2232 for message_list_ in batch_message_list
2233 ]
2234
2235 # -- 2.2 Prompt, cot text, system prompt
2236 else:
2237 if isinstance(batch_prompt, str):
2238 batch_prompt = [batch_prompt]
2239 batch_size = len(batch_prompt)
2240
2241 if batch_cot_text is not None:
2242 if isinstance(batch_cot_text, str):
2243 batch_cot_text = [batch_cot_text]
2244 else:
2245 assert isinstance(batch_cot_text, list) and len(batch_cot_text) == batch_size, \
2246 "`cot_text` should be a string or a list of strings with the same length as `prompt`."
2247
2248 if batch_system_prompt is not None:
2249 if isinstance(batch_system_prompt, str):
2250 batch_system_prompt = [batch_system_prompt]
2251 else:
2252 assert isinstance(batch_system_prompt, list) and len(batch_system_prompt) == batch_size, \
2253 "`system_prompts` should be a string or a list of strings with the same length as `prompt`."
2254
2255 if mode == "gen_image":
2256 batch_gen_image_info = [self.image_processor.build_image_info(image_size) for _ in range(batch_size)]
2257
2258 # -- 2.3 seed
2259 seeds = self.prepare_seed(seed=kwargs.get('seed'), batch_size=batch_size)
2260 generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
2261
2262 # 3. apply chat template
2263 cfg_factor = {"gen_text": 1, "gen_image": 2}
2264 bot_task = kwargs.pop("bot_task", "auto")
2265 # If `drop_think` enabled, always drop <think> parts in the context.
2266 drop_think = kwargs.get('drop_think', self.generation_config.drop_think)
2267 # Apply batched prompt or batched message_list to build input sequence with associated info.
2268 out = self._tkwrapper.apply_chat_template(
2269 batch_prompt=batch_prompt,
2270 batch_message_list=batch_message_list,
2271 mode=mode,
2272 batch_gen_image_info=batch_gen_image_info,
2273 batch_cond_image_info=batch_cond_image_info,
2274 batch_system_prompt=batch_system_prompt,
2275 batch_cot_text=batch_cot_text,
2276 max_length=kwargs.get('max_length'),
2277 bot_task=bot_task,
2278 image_base_size=self.config.image_base_size,
2279 sequence_template=self.generation_config.sequence_template,
2280 cfg_factor=cfg_factor[mode],
2281 drop_think=drop_think,
2282 )
2283 output, sections = out['output'], out['sections']
2284
2285 # 4. Encode conditional images
2286 if batch_cond_image_info is not None and len(batch_cond_image_info[0]) > 0:
2287 cond_vae_images, cond_timestep, cond_vit_images = self._encode_cond_image(
2288 batch_cond_image_info, cfg_factor[mode]
2289 )
2290 # Pack vit kwargs. Siglip2-so requires spatial_shapes and attention_mask for inference.
2291 vit_kwargs = {"spatial_shapes": [], "attention_mask": []}
2292 for cond_image_info in batch_cond_image_info:
2293 vit_kwargs["spatial_shapes"].append(
2294 torch.stack([item.vision_encoder_kwargs["spatial_shapes"] for item in cond_image_info]))
2295 vit_kwargs["attention_mask"].append(
2296 torch.stack([item.vision_encoder_kwargs["pixel_attention_mask"] for item in cond_image_info]))
2297 if cfg_factor[mode] > 1:
2298 vit_kwargs["spatial_shapes"] = vit_kwargs["spatial_shapes"] * cfg_factor[mode]
2299 vit_kwargs["attention_mask"] = vit_kwargs["attention_mask"] * cfg_factor[mode]
2300 else:
2301 cond_vae_images, cond_timestep, cond_vit_images = None, None, None
2302 vit_kwargs = None
2303
2304 # 5. Build position embeddings
2305 rope_image_info = self.build_batch_rope_image_info(output, sections)
2306 if mode == "gen_text":
2307 seq_len = self.generation_config.max_length
2308 else:
2309 seq_len = output.tokens.shape[1]
2310 cos, sin = build_batch_2d_rope(
2311 image_infos=rope_image_info,
2312 seq_len=seq_len,
2313 n_elem=self.config.attention_head_dim,
2314 device=device,
2315 base=self.config.rope_theta,
2316 )
2317
2318 # 6. Build kv cache
2319 if bot_task == "img_ratio":
2320 max_new_tokens = 1
2321 if mode == "gen_image":
2322 # Image generation will not extend sequence length, using token length as max_cache_len is enough.
2323 max_cache_len = output.tokens.shape[1]
2324 else:
2325 max_cache_len = output.tokens.shape[1] + default(max_new_tokens, self.generation_config.max_length)
2326 cache = HunyuanStaticCache(
2327 config=self.config,
2328 batch_size=batch_size * cfg_factor[mode],
2329 max_cache_len=max_cache_len,
2330 dtype=torch.bfloat16,
2331 dynamic=mode == "gen_text",
2332 )
2333
2334 # 7. Build position ids
2335 batch_input_pos = torch.arange(
2336 0, output.tokens.shape[1], dtype=torch.long, device=device)[None].expand(
2337 batch_size * cfg_factor[mode], -1) # use expand to share indices to save memory
2338
2339 # 8. Build model input kwargs
2340 tkw = self._tkwrapper
2341 if image_size == "auto":
2342 extra_auto_stops = [tkw.special_token_map[f"<img_ratio_{i}>"] for i in range(33)]
2343 else:
2344 extra_auto_stops = [tkw.boi_token_id]
2345 stop_token_id = dict(
2346 auto=[tkw.eos_token_id] + extra_auto_stops,
2347 image=[tkw.eos_token_id],
2348 recaption=[tkw.end_recaption_token_id, tkw.end_answer_token_id, tkw.eos_token_id],
2349 think=[tkw.end_recaption_token_id, tkw.end_answer_token_id, tkw.eos_token_id],
2350 img_ratio=extra_auto_stops,
2351 )
2352 model_input_kwargs = dict(
2353 input_ids=output.tokens.to(device),
2354 position_ids=batch_input_pos,
2355 past_key_values=cache,
2356 custom_pos_emb=(cos, sin),
2357 mode=mode,
2358 image_mask=to_device(output.gen_image_mask, device),
2359 gen_timestep_scatter_index=to_device(output.gen_timestep_scatter_index, device),
2360 cond_vae_images=to_device(cond_vae_images, device),
2361 cond_timestep=to_device(cond_timestep, device),
2362 cond_vae_image_mask=to_device(output.cond_vae_image_mask, device),
2363 cond_vit_images=to_device(cond_vit_images, device),
2364 cond_vit_image_mask=to_device(output.cond_vit_image_mask, device),
2365 vit_kwargs={
2366 k: to_device(v, self.device) for k, v in vit_kwargs.items()
2367 } if vit_kwargs is not None else None,
2368 cond_timestep_scatter_index=to_device(output.cond_timestep_scatter_index, device),
2369 # for inner usage
2370 tokenizer_output=output,
2371 batch_gen_image_info=batch_gen_image_info,
2372 generator=generator,
2373 # generation config
2374 eos_token_id=stop_token_id[bot_task],
2375 max_new_tokens=max_new_tokens,
2376 )
2377
2378 return model_input_kwargs
2379
2380 def _prepare_attention_mask_for_generation(
2381 self,
2382 inputs_tensor: torch.Tensor,
2383 generation_config: GenerationConfig,
2384 model_kwargs: Dict[str, Any],
2385 ) -> torch.Tensor:
2386 # create `4d` bool attention mask (b, 1, seqlen, seqlen) using this implementation to bypass the 2d requirement
2387 # in the `transformers.generation_utils.GenerationMixin.generate`.
2388 # This implementation can handle sequences with text and image modalities, where text tokens use causal
2389 # attention and image tokens use full attention.
2390 bsz, seq_len = inputs_tensor.shape
2391 tokenizer_output = model_kwargs["tokenizer_output"]
2392 batch_image_slices = [
2393 tokenizer_output.joint_image_slices[i] + tokenizer_output.gen_image_slices[i]
2394 for i in range(bsz)
2395 ]
2396 attention_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0).repeat(bsz, 1, 1)
2397 for i in range(bsz):
2398 for j, image_slice in enumerate(batch_image_slices[i]):
2399 attention_mask[i, image_slice, image_slice] = True
2400 attention_mask = attention_mask.unsqueeze(1)
2401 return attention_mask
2402
2403 def prepare_inputs_for_generation(
2404 self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None,
2405 tokenizer_output=None, batch_gen_image_info=None, generator=None, **kwargs
2406 ):
2407 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
2408 if inputs_embeds is not None and past_key_values is None:
2409 model_inputs = {"inputs_embeds": inputs_embeds}
2410 else:
2411 if input_ids.shape[1] != kwargs["position_ids"].shape[1]: # in decode steps
2412 input_ids = torch.gather(input_ids, dim=1, index=kwargs["position_ids"])
2413 model_inputs = {"input_ids": input_ids}
2414
2415 model_inputs.update(
2416 {
2417 "attention_mask": attention_mask,
2418 "position_ids": kwargs["position_ids"],
2419 "past_key_values": past_key_values,
2420 "use_cache": kwargs.get("use_cache"),
2421 "custom_pos_emb": kwargs["custom_pos_emb"],
2422 "mode": kwargs["mode"],
2423 "images": kwargs.get("images"),
2424 "image_mask": kwargs.get("image_mask"),
2425 "timestep": kwargs.get("timestep"),
2426 "gen_timestep_scatter_index": kwargs.get("gen_timestep_scatter_index"),
2427 "cond_vae_images": kwargs.get("cond_vae_images"),
2428 "cond_timestep": kwargs.get("cond_timestep"),
2429 "cond_vae_image_mask": kwargs.get("cond_vae_image_mask"),
2430 "cond_vit_images": kwargs.get("cond_vit_images"),
2431 "cond_vit_image_mask": kwargs.get("cond_vit_image_mask"),
2432 "vit_kwargs": kwargs.get("vit_kwargs"),
2433 "cond_timestep_scatter_index": kwargs.get("cond_timestep_scatter_index"),
2434 }
2435 )
2436 return model_inputs
2437
2438 def _update_model_kwargs_for_generation(
2439 self,
2440 outputs: ModelOutput,
2441 model_kwargs: Dict[str, Any],
2442 is_encoder_decoder: bool = False,
2443 num_new_tokens: int = 1,
2444 ) -> Dict[str, Any]:
2445 mode = model_kwargs["mode"]
2446
2447 updated_model_kwargs = {
2448 "mode": mode,
2449 "custom_pos_emb": model_kwargs["custom_pos_emb"],
2450 }
2451
2452 # update past_key_values keeping its naming used in model code
2453 for possible_cache_name in ALL_CACHE_NAMES:
2454 if possible_cache_name in outputs:
2455 # TODO (joao): remove output/input mismatch when these old models (xlnet, reformer) are deprecated
2456 if possible_cache_name in ("past_buckets_states", "mems"):
2457 cache_name = "past_key_values"
2458 else:
2459 cache_name = possible_cache_name
2460 updated_model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
2461 break
2462
2463 if "tokenizer_output" in model_kwargs:
2464 if mode == "gen_text":
2465 # When enable batching, we use right padding, which requires a real_pos to index the valid
2466 # end position of the sequence. If tokenizer_output in model_kwargs, it means we are in the
2467 # prefill step of generation.
2468 real_pos = to_device(model_kwargs["tokenizer_output"].real_pos, self.device)
2469 updated_model_kwargs["position_ids"] = real_pos
2470 else:
2471 # position ids
2472 image_mask = model_kwargs["image_mask"]
2473 bsz, seq_len = image_mask.shape
2474 index = torch.arange(seq_len, device=image_mask.device).unsqueeze(0).repeat(bsz, 1)
2475 position_ids = index.masked_select(image_mask.bool()).reshape(bsz, -1)
2476 timestep_position_ids = \
2477 index[torch.arange(bsz), model_kwargs["gen_timestep_scatter_index"][:, -1]].unsqueeze(-1)
2478 updated_model_kwargs["position_ids"] = torch.cat([timestep_position_ids, position_ids], dim=1)
2479
2480 # attention mask
2481 mask_list = []
2482 for attention_mask_i, position_ids_i in zip(
2483 model_kwargs["attention_mask"], updated_model_kwargs["position_ids"]):
2484 mask_list.append(torch.index_select(attention_mask_i, dim=1, index=position_ids_i.reshape(-1)))
2485 attention_mask = torch.stack(mask_list, dim=0)
2486 updated_model_kwargs["attention_mask"] = attention_mask
2487 updated_model_kwargs["gen_timestep_scatter_index"] = model_kwargs["gen_timestep_scatter_index"]
2488
2489 else:
2490 if mode == "gen_text":
2491 # Now we are in the decode steps.
2492 updated_model_kwargs["position_ids"] = model_kwargs["position_ids"] + 1
2493 else:
2494 updated_model_kwargs["position_ids"] = model_kwargs["position_ids"]
2495 updated_model_kwargs["attention_mask"] = model_kwargs["attention_mask"]
2496 updated_model_kwargs["gen_timestep_scatter_index"] = model_kwargs["gen_timestep_scatter_index"]
2497
2498 return updated_model_kwargs
2499
2500 def _generate(
2501 self,
2502 inputs: Optional[torch.Tensor] = None,
2503 generation_config: Optional[GenerationConfig] = None,
2504 logits_processor: Optional[LogitsProcessorList] = None,
2505 stopping_criteria: Optional[StoppingCriteriaList] = None,
2506 prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
2507 synced_gpus: Optional[bool] = None,
2508 assistant_model: Optional["PreTrainedModel"] = None,
2509 streamer: Optional["BaseStreamer"] = None,
2510 negative_prompt_ids: Optional[torch.Tensor] = None,
2511 negative_prompt_attention_mask: Optional[torch.Tensor] = None,
2512 use_model_defaults: Optional[bool] = None,
2513 generator: Optional[List[torch.Generator]] = None,
2514 verbose: int = 0,
2515 **kwargs,
2516 ):
2517 mode = kwargs.get("mode", "gen_text")
2518
2519 # Log info
2520 if verbose >= 1:
2521 output = kwargs["tokenizer_output"]
2522 context = self._tkwrapper.tokenizer.decode(output.tokens[0], skip_special_tokens=False)
2523 # Replace <img><img>...<img> with [<img>]{number}
2524 context = re.sub(r"(<img>)+", lambda m: f"[<img>]{{{len(m.group(0)) // 5}}}", context)
2525 info_list = [
2526 ("token shape", output.tokens.shape),
2527 ("context[0]", context),
2528 ]
2529 gen_config = default(generation_config, self.generation_config)
2530 if mode == "gen_image":
2531 if generator is not None:
2532 info_list.extend([
2533 ("seed", [g.initial_seed() for g in generator]),
2534 ])
2535 info_list.extend([
2536 ("image_size", [f"{info.image_height}x{info.image_width}" for info in kwargs["batch_gen_image_info"]]),
2537 ("infer_steps", kwargs.get("diff_infer_steps", gen_config.diff_infer_steps)),
2538 ("guidance_scale", kwargs.get("diff_guidance_scale", gen_config.diff_guidance_scale)),
2539 ("flow_shift", kwargs.get("flow_shift", gen_config.flow_shift)),
2540 ])
2541 else:
2542 info_list.extend([
2543 ("do_sample", kwargs.get("do_sample", gen_config.do_sample)),
2544 ("max_new_tokens", kwargs.get("max_new_tokens", gen_config.max_new_tokens)),
2545 ("top_k", kwargs.get("top_k", gen_config.top_k)),
2546 ("top_p", kwargs.get("top_p", gen_config.top_p)),
2547 ("temperature", kwargs.get("temperature", gen_config.temperature)),
2548 ("repetition_penalty", kwargs.get("repetition_penalty", gen_config.repetition_penalty)),
2549 ])
2550 max_key_len = max(len(k) for k, _ in info_list)
2551 info_str = "=" * 50 + \
2552 "\nModel input info:\n" + \
2553 "\n".join([f" {k.rjust(max_key_len)}: {v}" for k, v in info_list]) + \
2554 "\n--------------------------------------------------"
2555 print(info_str)
2556
2557 if mode == "gen_text":
2558 with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
2559 return super().generate(
2560 inputs,
2561 generation_config,
2562 logits_processor,
2563 stopping_criteria,
2564 prefix_allowed_tokens_fn,
2565 synced_gpus,
2566 assistant_model,
2567 streamer,
2568 negative_prompt_ids,
2569 negative_prompt_attention_mask,
2570 use_model_defaults,
2571 **kwargs,
2572 )
2573
2574 elif mode == "gen_image":
2575 batch_gen_image_info: List[ImageInfo] = kwargs.get("batch_gen_image_info")
2576 if batch_gen_image_info is None:
2577 raise ValueError("`batch_gen_image_info` should be provided when `mode` is `gen_image`.")
2578
2579 results = self.pipeline(
2580 batch_size=len(batch_gen_image_info),
2581 image_size=[batch_gen_image_info[0].image_height, batch_gen_image_info[0].image_width],
2582 num_inference_steps=kwargs.get("diff_infer_steps", self.generation_config.diff_infer_steps),
2583 guidance_scale=kwargs.get("diff_guidance_scale", self.generation_config.diff_guidance_scale),
2584 generator=generator,
2585 model_kwargs=kwargs,
2586 )
2587 samples = results[0]
2588 return samples
2589
2590 else:
2591 raise ValueError(f"Unknown mode {mode}, only `gen_text` and `gen_image` are supported.")
2592
2593 def get_cot_text(self, output: torch.Tensor):
2594 if output.ndim == 2:
2595 return [self.get_cot_text(output_i) for output_i in output]
2596 elif output.ndim == 1:
2597 if output[-1] == self._tkwrapper.eos_token_id:
2598 output = output[:-1]
2599 cot_text = self._tkwrapper.decode(output).split("Assistant: ")[1]
2600 return cot_text
2601 else:
2602 raise ValueError(f"output should be 1D or 2D tensor, but got {output.ndim}D tensor.")
2603
2604 def generate_image(
2605 self,
2606 prompt,
2607 seed=None,
2608 image_size="auto",
2609 use_system_prompt=None,
2610 system_prompt=None,
2611 bot_task=None,
2612 stream=False,
2613 **kwargs,
2614 ):
2615 max_new_tokens = kwargs.pop("max_new_tokens", 8192)
2616 verbose = kwargs.pop("verbose", 0)
2617
2618 if stream:
2619 from transformers import TextStreamer
2620 streamer = TextStreamer(self._tkwrapper.tokenizer, skip_prompt=True, skip_special_tokens=False)
2621 kwargs["streamer"] = streamer
2622
2623 use_system_prompt = default(use_system_prompt, self.generation_config.use_system_prompt)
2624 bot_task = default(bot_task, self.generation_config.bot_task)
2625 system_prompt = get_system_prompt(use_system_prompt, bot_task, system_prompt)
2626
2627 if bot_task in ["think", "recaption"]:
2628 # Cot
2629 model_inputs = self.prepare_model_inputs(
2630 prompt=prompt, bot_task=bot_task, system_prompt=system_prompt, max_new_tokens=max_new_tokens)
2631 print(f"<{bot_task}>", end="", flush=True)
2632 outputs = self._generate(**model_inputs, **kwargs, verbose=verbose)
2633 cot_text = self.get_cot_text(outputs[0])
2634 # Switch system_prompt to `en_recaption` if drop_think is enabled.
2635 if self.generation_config.drop_think and system_prompt:
2636 system_prompt = t2i_system_prompts["en_recaption"][0]
2637 else:
2638 cot_text = None
2639
2640 # Image ratio
2641 if image_size == "auto":
2642 model_inputs = self.prepare_model_inputs(
2643 prompt=prompt, cot_text=cot_text, bot_task="img_ratio", system_prompt=system_prompt, seed=seed)
2644 outputs = self._generate(**model_inputs, **kwargs, verbose=verbose)
2645 ratio_index = outputs[0, -1].item() - self._tkwrapper.ratio_token_offset
2646 # In some cases, the generated ratio_index is out of range. A valid ratio_index should be in [0, 32].
2647 # If ratio_index is out of range, we set it to 16 (i.e., 1:1).
2648 if ratio_index < 0 or ratio_index >= len(self.image_processor.reso_group):
2649 ratio_index = 16
2650 reso = self.image_processor.reso_group[ratio_index]
2651 image_size = reso.height, reso.width
2652
2653 # Generate image
2654 model_inputs = self.prepare_model_inputs(
2655 prompt=prompt, cot_text=cot_text, system_prompt=system_prompt, mode="gen_image", seed=seed,
2656 image_size=image_size,
2657 )
2658 outputs = self._generate(**model_inputs, **kwargs, verbose=verbose)
2659 return outputs[0]
2660