hunyuan_image_3_pipeline.py
36.6 KB · 880 lines · python Raw
1 # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
2 # you may not use this file except in compliance with the License.
3 # You may obtain a copy of the License at
4 #
5 # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
6 #
7 # Unless required by applicable law or agreed to in writing, software
8 # distributed under the License is distributed on an "AS IS" BASIS,
9 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 # See the License for the specific language governing permissions and
11 # limitations under the License.
12 # ==============================================================================
13 #
14 # Copyright 2024 The HuggingFace Team. All rights reserved.
15 #
16 # Licensed under the Apache License, Version 2.0 (the "License");
17 # you may not use this file except in compliance with the License.
18 # You may obtain a copy of the License at
19 #
20 # http://www.apache.org/licenses/LICENSE-2.0
21 #
22 # Unless required by applicable law or agreed to in writing, software
23 # distributed under the License is distributed on an "AS IS" BASIS,
24 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 # See the License for the specific language governing permissions and
26 # limitations under the License.
27 # ==============================================================================================
28
29 import inspect
30 import math
31 from dataclasses import dataclass
32 from typing import Any, Callable, Dict, List
33 from typing import Optional, Tuple, Union
34
35 import numpy as np
36 import torch
37 from PIL import Image
38 from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
39 from diffusers.configuration_utils import ConfigMixin, register_to_config
40 from diffusers.image_processor import VaeImageProcessor
41 from diffusers.pipelines.pipeline_utils import DiffusionPipeline
42 from diffusers.schedulers.scheduling_utils import SchedulerMixin
43 from diffusers.utils import BaseOutput, logging
44 from diffusers.utils.torch_utils import randn_tensor
45
46 logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
48
49 def retrieve_timesteps(
50 scheduler,
51 num_inference_steps: Optional[int] = None,
52 device: Optional[Union[str, torch.device]] = None,
53 timesteps: Optional[List[int]] = None,
54 sigmas: Optional[List[float]] = None,
55 **kwargs,
56 ):
57 """
58 Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
59 custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
60
61 Args:
62 scheduler (`SchedulerMixin`):
63 The scheduler to get timesteps from.
64 num_inference_steps (`int`):
65 The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
66 must be `None`.
67 device (`str` or `torch.device`, *optional*):
68 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
69 timesteps (`List[int]`, *optional*):
70 Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
71 `num_inference_steps` and `sigmas` must be `None`.
72 sigmas (`List[float]`, *optional*):
73 Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
74 `num_inference_steps` and `timesteps` must be `None`.
75
76 Returns:
77 `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
78 second element is the number of inference steps.
79 """
80 if timesteps is not None and sigmas is not None:
81 raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
82 if timesteps is not None:
83 accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
84 if not accepts_timesteps:
85 raise ValueError(
86 f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
87 f" timestep schedules. Please check whether you are using the correct scheduler."
88 )
89 scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
90 timesteps = scheduler.timesteps
91 num_inference_steps = len(timesteps)
92 elif sigmas is not None:
93 accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
94 if not accept_sigmas:
95 raise ValueError(
96 f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
97 f" sigmas schedules. Please check whether you are using the correct scheduler."
98 )
99 scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
100 timesteps = scheduler.timesteps
101 num_inference_steps = len(timesteps)
102 else:
103 scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
104 timesteps = scheduler.timesteps
105 return timesteps, num_inference_steps
106
107
108 def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
109 r"""
110 Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
111 Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
112 Flawed](https://arxiv.org/pdf/2305.08891.pdf).
113
114 Args:
115 noise_cfg (`torch.Tensor`):
116 The predicted noise tensor for the guided diffusion process.
117 noise_pred_text (`torch.Tensor`):
118 The predicted noise tensor for the text-guided diffusion process.
119 guidance_rescale (`float`, *optional*, defaults to 0.0):
120 A rescale factor applied to the noise predictions.
121 Returns:
122 noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
123 """
124 std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
125 std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
126 # rescale the results from guidance (fixes overexposure)
127 noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
128 # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
129 noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
130 return noise_cfg
131
132
133 @dataclass
134 class HunyuanImage3Text2ImagePipelineOutput(BaseOutput):
135 samples: Union[List[Any], np.ndarray]
136
137
138 @dataclass
139 class FlowMatchDiscreteSchedulerOutput(BaseOutput):
140 """
141 Output class for the scheduler's `step` function output.
142
143 Args:
144 prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
145 Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
146 denoising loop.
147 """
148
149 prev_sample: torch.FloatTensor
150
151
152 class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
153 """
154 Euler scheduler.
155
156 This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
157 methods the library implements for all schedulers such as loading and saving.
158
159 Args:
160 num_train_timesteps (`int`, defaults to 1000):
161 The number of diffusion steps to train the model.
162 timestep_spacing (`str`, defaults to `"linspace"`):
163 The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
164 Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
165 shift (`float`, defaults to 1.0):
166 The shift value for the timestep schedule.
167 reverse (`bool`, defaults to `True`):
168 Whether to reverse the timestep schedule.
169 """
170
171 _compatibles = []
172 order = 1
173
174 @register_to_config
175 def __init__(
176 self,
177 num_train_timesteps: int = 1000,
178 shift: float = 1.0,
179 reverse: bool = True,
180 solver: str = "euler",
181 use_flux_shift: bool = False,
182 flux_base_shift: float = 0.5,
183 flux_max_shift: float = 1.15,
184 n_tokens: Optional[int] = None,
185 ):
186 sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
187
188 if not reverse:
189 sigmas = sigmas.flip(0)
190
191 self.sigmas = sigmas
192 # the value fed to model
193 self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
194 self.timesteps_full = (sigmas * num_train_timesteps).to(dtype=torch.float32)
195
196 self._step_index = None
197 self._begin_index = None
198
199 self.supported_solver = [
200 "euler",
201 "heun-2", "midpoint-2",
202 "kutta-4",
203 ]
204 if solver not in self.supported_solver:
205 raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}")
206
207 # empty dt and derivative (for heun)
208 self.derivative_1 = None
209 self.derivative_2 = None
210 self.derivative_3 = None
211 self.dt = None
212
213 @property
214 def step_index(self):
215 """
216 The index counter for current timestep. It will increase 1 after each scheduler step.
217 """
218 return self._step_index
219
220 @property
221 def begin_index(self):
222 """
223 The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
224 """
225 return self._begin_index
226
227 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
228 def set_begin_index(self, begin_index: int = 0):
229 """
230 Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
231
232 Args:
233 begin_index (`int`):
234 The begin index for the scheduler.
235 """
236 self._begin_index = begin_index
237
238 def _sigma_to_t(self, sigma):
239 return sigma * self.config.num_train_timesteps
240
241 @property
242 def state_in_first_order(self):
243 return self.derivative_1 is None
244
245 @property
246 def state_in_second_order(self):
247 return self.derivative_2 is None
248
249 @property
250 def state_in_third_order(self):
251 return self.derivative_3 is None
252
253 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None,
254 n_tokens: int = None):
255 """
256 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
257
258 Args:
259 num_inference_steps (`int`):
260 The number of diffusion steps used when generating samples with a pre-trained model.
261 device (`str` or `torch.device`, *optional*):
262 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
263 n_tokens (`int`, *optional*):
264 Number of tokens in the input sequence.
265 """
266 self.num_inference_steps = num_inference_steps
267
268 sigmas = torch.linspace(1, 0, num_inference_steps + 1)
269
270 # Apply timestep shift
271 if self.config.use_flux_shift:
272 assert isinstance(n_tokens, int), "n_tokens should be provided for flux shift"
273 mu = self.get_lin_function(y1=self.config.flux_base_shift, y2=self.config.flux_max_shift)(n_tokens)
274 sigmas = self.flux_time_shift(mu, 1.0, sigmas)
275 elif self.config.shift != 1.:
276 sigmas = self.sd3_time_shift(sigmas)
277
278 if not self.config.reverse:
279 sigmas = 1 - sigmas
280
281 self.sigmas = sigmas
282 self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
283 self.timesteps_full = (sigmas * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
284
285 # empty dt and derivative (for kutta)
286 self.derivative_1 = None
287 self.derivative_2 = None
288 self.derivative_3 = None
289 self.dt = None
290
291 # Reset step index
292 self._step_index = None
293
294 def index_for_timestep(self, timestep, schedule_timesteps=None):
295 if schedule_timesteps is None:
296 schedule_timesteps = self.timesteps
297
298 indices = (schedule_timesteps == timestep).nonzero()
299
300 # The sigma index that is taken for the **very** first `step`
301 # is always the second index (or the last index if there is only 1)
302 # This way we can ensure we don't accidentally skip a sigma in
303 # case we start in the middle of the denoising schedule (e.g. for image-to-image)
304 pos = 1 if len(indices) > 1 else 0
305
306 return indices[pos].item()
307
308 def _init_step_index(self, timestep):
309 if self.begin_index is None:
310 if isinstance(timestep, torch.Tensor):
311 timestep = timestep.to(self.timesteps.device)
312 self._step_index = self.index_for_timestep(timestep)
313 else:
314 self._step_index = self._begin_index
315
316 def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
317 return sample
318
319 @staticmethod
320 def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15):
321 m = (y2 - y1) / (x2 - x1)
322 b = y1 - m * x1
323 return lambda x: m * x + b
324
325 @staticmethod
326 def flux_time_shift(mu: float, sigma: float, t: torch.Tensor):
327 return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
328
329 def sd3_time_shift(self, t: torch.Tensor):
330 return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
331
332 def step(
333 self,
334 model_output: torch.FloatTensor,
335 timestep: Union[float, torch.FloatTensor],
336 sample: torch.FloatTensor,
337 pred_uncond: torch.FloatTensor = None,
338 generator: Optional[torch.Generator] = None,
339 n_tokens: Optional[int] = None,
340 return_dict: bool = True,
341 ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
342 """
343 Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
344 process from the learned model outputs (most often the predicted noise).
345
346 Args:
347 model_output (`torch.FloatTensor`):
348 The direct output from learned diffusion model.
349 timestep (`float`):
350 The current discrete timestep in the diffusion chain.
351 sample (`torch.FloatTensor`):
352 A current instance of a sample created by the diffusion process.
353 generator (`torch.Generator`, *optional*):
354 A random number generator.
355 n_tokens (`int`, *optional*):
356 Number of tokens in the input sequence.
357 return_dict (`bool`):
358 Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
359 tuple.
360
361 Returns:
362 [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
363 If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
364 returned, otherwise a tuple is returned where the first element is the sample tensor.
365 """
366
367 if (
368 isinstance(timestep, int)
369 or isinstance(timestep, torch.IntTensor)
370 or isinstance(timestep, torch.LongTensor)
371 ):
372 raise ValueError(
373 (
374 "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
375 " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
376 " one of the `scheduler.timesteps` as a timestep."
377 ),
378 )
379
380 if self.step_index is None:
381 self._init_step_index(timestep)
382
383 # Upcast to avoid precision issues when computing prev_sample
384 sample = sample.to(torch.float32)
385 model_output = model_output.to(torch.float32)
386 pred_uncond = pred_uncond.to(torch.float32) if pred_uncond is not None else None
387
388 # dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
389 sigma = self.sigmas[self.step_index]
390 sigma_next = self.sigmas[self.step_index + 1]
391
392 last_inner_step = True
393 if self.config.solver == "euler":
394 derivative, dt, sample, last_inner_step = self.first_order_method(model_output, sigma, sigma_next, sample)
395 elif self.config.solver in ["heun-2", "midpoint-2"]:
396 derivative, dt, sample, last_inner_step = self.second_order_method(model_output, sigma, sigma_next, sample)
397 elif self.config.solver == "kutta-4":
398 derivative, dt, sample, last_inner_step = self.fourth_order_method(model_output, sigma, sigma_next, sample)
399 else:
400 raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}")
401
402 prev_sample = sample + derivative * dt
403
404 # Cast sample back to model compatible dtype
405 # prev_sample = prev_sample.to(model_output.dtype)
406
407 # upon completion increase step index by one
408 if last_inner_step:
409 self._step_index += 1
410
411 if not return_dict:
412 return (prev_sample,)
413
414 return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
415
416 def first_order_method(self, model_output, sigma, sigma_next, sample):
417 derivative = model_output
418 dt = sigma_next - sigma
419 return derivative, dt, sample, True
420
421 def second_order_method(self, model_output, sigma, sigma_next, sample):
422 if self.state_in_first_order:
423 # store for 2nd order step
424 self.derivative_1 = model_output
425 self.dt = sigma_next - sigma
426 self.sample = sample
427
428 derivative = model_output
429 if self.config.solver == 'heun-2':
430 dt = self.dt
431 elif self.config.solver == 'midpoint-2':
432 dt = self.dt / 2
433 else:
434 raise NotImplementedError(f"Solver {self.config.solver} not supported.")
435 last_inner_step = False
436
437 else:
438 if self.config.solver == 'heun-2':
439 derivative = 0.5 * (self.derivative_1 + model_output)
440 elif self.config.solver == 'midpoint-2':
441 derivative = model_output
442 else:
443 raise NotImplementedError(f"Solver {self.config.solver} not supported.")
444
445 # 3. take prev timestep & sample
446 dt = self.dt
447 sample = self.sample
448 last_inner_step = True
449
450 # free dt and derivative
451 # Note, this puts the scheduler in "first order mode"
452 self.derivative_1 = None
453 self.dt = None
454 self.sample = None
455
456 return derivative, dt, sample, last_inner_step
457
458 def fourth_order_method(self, model_output, sigma, sigma_next, sample):
459 if self.state_in_first_order:
460 self.derivative_1 = model_output
461 self.dt = sigma_next - sigma
462 self.sample = sample
463 derivative = model_output
464 dt = self.dt / 2
465 last_inner_step = False
466
467 elif self.state_in_second_order:
468 self.derivative_2 = model_output
469 derivative = model_output
470 dt = self.dt / 2
471 last_inner_step = False
472
473 elif self.state_in_third_order:
474 self.derivative_3 = model_output
475 derivative = model_output
476 dt = self.dt
477 last_inner_step = False
478
479 else:
480 derivative = (1/6 * self.derivative_1 + 1/3 * self.derivative_2 + 1/3 * self.derivative_3 +
481 1/6 * model_output)
482
483 # 3. take prev timestep & sample
484 dt = self.dt
485 sample = self.sample
486 last_inner_step = True
487
488 # free dt and derivative
489 # Note, this puts the scheduler in "first order mode"
490 self.derivative_1 = None
491 self.derivative_2 = None
492 self.derivative_3 = None
493 self.dt = None
494 self.sample = None
495
496 return derivative, dt, sample, last_inner_step
497
498 def __len__(self):
499 return self.config.num_train_timesteps
500
501
502 class ClassifierFreeGuidance:
503 def __init__(
504 self,
505 use_original_formulation: bool = False,
506 start: float = 0.0,
507 stop: float = 1.0,
508 ):
509 super().__init__()
510 self.use_original_formulation = use_original_formulation
511
512 def __call__(
513 self,
514 pred_cond: torch.Tensor,
515 pred_uncond: Optional[torch.Tensor],
516 guidance_scale: float,
517 step: int,
518 ) -> torch.Tensor:
519
520 shift = pred_cond - pred_uncond
521 pred = pred_cond if self.use_original_formulation else pred_uncond
522 pred = pred + guidance_scale * shift
523
524 return pred
525
526
527 class HunyuanImage3Text2ImagePipeline(DiffusionPipeline):
528 r"""
529 Pipeline for condition-to-sample generation using Stable Diffusion.
530
531 This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
532 implemented for all pipelines (downloading, saving, running on a particular device, etc.).
533
534 Args:
535 model ([`ModelMixin`]):
536 A model to denoise the diffused latents.
537 scheduler ([`SchedulerMixin`]):
538 A scheduler to be used in combination with `diffusion_model` to denoise the diffused latents. Can be one of
539 [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
540 """
541
542 model_cpu_offload_seq = ""
543 _optional_components = []
544 _exclude_from_cpu_offload = []
545 _callback_tensor_inputs = ["latents"]
546
547 def __init__(
548 self,
549 model,
550 scheduler: SchedulerMixin,
551 vae,
552 progress_bar_config: Dict[str, Any] = None,
553 ):
554 super().__init__()
555
556 # ==========================================================================================
557 if progress_bar_config is None:
558 progress_bar_config = {}
559 if not hasattr(self, '_progress_bar_config'):
560 self._progress_bar_config = {}
561 self._progress_bar_config.update(progress_bar_config)
562 # ==========================================================================================
563
564 self.register_modules(
565 model=model,
566 scheduler=scheduler,
567 vae=vae,
568 )
569
570 # should be a tuple or a list corresponding to the size of latents (batch_size, channel, *size)
571 # if None, will be treated as a tuple of 1
572 self.latent_scale_factor = self.model.config.vae_downsample_factor
573 self.image_processor = VaeImageProcessor(vae_scale_factor=self.latent_scale_factor)
574
575 # Must start with APG_mode_
576 self.cfg_operator = ClassifierFreeGuidance()
577
578 @staticmethod
579 def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
580 """
581 Denormalize an image array to [0,1].
582 """
583 return (images / 2 + 0.5).clamp(0, 1)
584
585 @staticmethod
586 def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
587 """
588 Convert a PyTorch tensor to a NumPy image.
589 """
590 images = images.cpu().permute(0, 2, 3, 1).float().numpy()
591 return images
592
593 @staticmethod
594 def numpy_to_pil(images: np.ndarray):
595 """
596 Convert a numpy image or a batch of images to a PIL image.
597 """
598 if images.ndim == 3:
599 images = images[None, ...]
600 images = (images * 255).round().astype("uint8")
601 if images.shape[-1] == 1:
602 # special case for grayscale (single channel) images
603 pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
604 else:
605 pil_images = [Image.fromarray(image) for image in images]
606
607 return pil_images
608
609 def prepare_extra_func_kwargs(self, func, kwargs):
610 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
611 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
612 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
613 # and should be between [0, 1]
614 extra_kwargs = {}
615
616 for k, v in kwargs.items():
617 accepts = k in set(inspect.signature(func).parameters.keys())
618 if accepts:
619 extra_kwargs[k] = v
620 return extra_kwargs
621
622 def prepare_latents(self, batch_size, latent_channel, image_size, dtype, device, generator, latents=None):
623 if self.latent_scale_factor is None:
624 latent_scale_factor = (1,) * len(image_size)
625 elif isinstance(self.latent_scale_factor, int):
626 latent_scale_factor = (self.latent_scale_factor,) * len(image_size)
627 elif isinstance(self.latent_scale_factor, tuple) or isinstance(self.latent_scale_factor, list):
628 assert len(self.latent_scale_factor) == len(image_size), \
629 "len(latent_scale_factor) shoudl be the same as len(image_size)"
630 latent_scale_factor = self.latent_scale_factor
631 else:
632 raise ValueError(
633 f"latent_scale_factor should be either None, int, tuple of int, or list of int, "
634 f"but got {self.latent_scale_factor}"
635 )
636
637 latents_shape = (
638 batch_size,
639 latent_channel,
640 *[int(s) // f for s, f in zip(image_size, latent_scale_factor)],
641 )
642 if isinstance(generator, list) and len(generator) != batch_size:
643 raise ValueError(
644 f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
645 f" size of {batch_size}. Make sure the batch size matches the length of the generators."
646 )
647
648 if latents is None:
649 latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
650 else:
651 latents = latents.to(device)
652
653 # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
654 if hasattr(self.scheduler, "init_noise_sigma"):
655 # scale the initial noise by the standard deviation required by the scheduler
656 latents = latents * self.scheduler.init_noise_sigma
657
658 return latents
659
660 @property
661 def guidance_scale(self):
662 return self._guidance_scale
663
664 @property
665 def guidance_rescale(self):
666 return self._guidance_rescale
667
668 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
669 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
670 # corresponds to doing no classifier free guidance.
671 @property
672 def do_classifier_free_guidance(self):
673 return self._guidance_scale > 1.0
674
675 @property
676 def num_timesteps(self):
677 return self._num_timesteps
678
679 def set_scheduler(self, new_scheduler):
680 self.register_modules(scheduler=new_scheduler)
681
682 @torch.no_grad()
683 def __call__(
684 self,
685 batch_size: int,
686 image_size: List[int],
687 num_inference_steps: int = 50,
688 timesteps: List[int] = None,
689 sigmas: List[float] = None,
690 guidance_scale: float = 7.5,
691 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
692 latents: Optional[torch.Tensor] = None,
693 output_type: Optional[str] = "pil",
694 return_dict: bool = True,
695 guidance_rescale: float = 0.0,
696 callback_on_step_end: Optional[
697 Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
698 ] = None,
699 callback_on_step_end_tensor_inputs: List[str] = ["latents"],
700 model_kwargs: Dict[str, Any] = None,
701 **kwargs,
702 ):
703 r"""
704 The call function to the pipeline for generation.
705
706 Args:
707 prompt (`str` or `List[str]`):
708 The text to guide image generation.
709 image_size (`Tuple[int]` or `List[int]`):
710 The size (height, width) of the generated image.
711 num_inference_steps (`int`, *optional*, defaults to 50):
712 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
713 expense of slower inference.
714 timesteps (`List[int]`, *optional*):
715 Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
716 in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
717 passed will be used. Must be in descending order.
718 sigmas (`List[float]`, *optional*):
719 Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
720 their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
721 will be used.
722 guidance_scale (`float`, *optional*, defaults to 7.5):
723 A higher guidance scale value encourages the model to generate samples closely linked to the
724 `condition` at the expense of lower sample quality. Guidance scale is enabled when `guidance_scale > 1`.
725 generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
726 A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
727 generation deterministic.
728 latents (`torch.Tensor`, *optional*):
729 Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for sample
730 generation. Can be used to tweak the same generation with different conditions. If not provided,
731 a latents tensor is generated by sampling using the supplied random `generator`.
732 output_type (`str`, *optional*, defaults to `"pil"`):
733 The output format of the generated sample.
734 return_dict (`bool`, *optional*, defaults to `True`):
735 Whether or not to return a [`~DiffusionPipelineOutput`] instead of a
736 plain tuple.
737 guidance_rescale (`float`, *optional*, defaults to 0.0):
738 Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
739 Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
740 using zero terminal SNR.
741 callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
742 A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
743 each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
744 DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
745 list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
746 callback_on_step_end_tensor_inputs (`List`, *optional*):
747 The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
748 will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
749 `._callback_tensor_inputs` attribute of your pipeline class.
750
751 Examples:
752
753 Returns:
754 [`~DiffusionPipelineOutput`] or `tuple`:
755 If `return_dict` is `True`, [`~DiffusionPipelineOutput`] is returned,
756 otherwise a `tuple` is returned where the first element is a list with the generated samples.
757 """
758
759 callback_steps = kwargs.pop("callback_steps", None)
760 pbar_steps = kwargs.pop("pbar_steps", None)
761
762 if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
763 callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
764
765 self._guidance_scale = guidance_scale
766 self._guidance_rescale = guidance_rescale
767
768 cfg_factor = 1 + self.do_classifier_free_guidance
769
770 # Define call parameters
771 device = self._execution_device
772
773 # Prepare timesteps
774 timesteps, num_inference_steps = retrieve_timesteps(
775 self.scheduler, num_inference_steps, device, timesteps, sigmas,
776 )
777
778 # Prepare latent variables
779 latents = self.prepare_latents(
780 batch_size=batch_size,
781 latent_channel=self.model.config.vae["latent_channels"],
782 image_size=image_size,
783 dtype=torch.bfloat16,
784 device=device,
785 generator=generator,
786 latents=latents,
787 )
788
789 # Prepare extra step kwargs.
790 _scheduler_step_extra_kwargs = self.prepare_extra_func_kwargs(
791 self.scheduler.step, {"generator": generator}
792 )
793
794 # Prepare model kwargs
795 input_ids = model_kwargs.pop("input_ids")
796 attention_mask = self.model._prepare_attention_mask_for_generation( # noqa
797 input_ids, self.model.generation_config, model_kwargs=model_kwargs,
798 )
799 model_kwargs["attention_mask"] = attention_mask.to(latents.device)
800
801 # Sampling loop
802 num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
803 self._num_timesteps = len(timesteps)
804
805 with self.progress_bar(total=num_inference_steps) as progress_bar:
806 for i, t in enumerate(timesteps):
807 # expand the latents if we are doing classifier free guidance
808 latent_model_input = torch.cat([latents] * cfg_factor)
809 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
810
811 t_expand = t.repeat(latent_model_input.shape[0])
812
813 model_inputs = self.model.prepare_inputs_for_generation(
814 input_ids,
815 images=latent_model_input,
816 timestep=t_expand,
817 **model_kwargs,
818 )
819
820 with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
821 model_output = self.model(**model_inputs, first_step=(i == 0))
822 pred = model_output["diffusion_prediction"]
823 pred = pred.to(dtype=torch.float32)
824
825 # perform guidance
826 if self.do_classifier_free_guidance:
827 pred_cond, pred_uncond = pred.chunk(2)
828 pred = self.cfg_operator(pred_cond, pred_uncond, self.guidance_scale, step=i)
829
830 if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
831 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
832 pred = rescale_noise_cfg(pred, pred_cond, guidance_rescale=self.guidance_rescale)
833
834 # compute the previous noisy sample x_t -> x_t-1
835 latents = self.scheduler.step(pred, t, latents, **_scheduler_step_extra_kwargs, return_dict=False)[0]
836
837 if i != len(timesteps) - 1:
838 model_kwargs = self.model._update_model_kwargs_for_generation( # noqa
839 model_output,
840 model_kwargs,
841 )
842 if input_ids.shape[1] != model_kwargs["position_ids"].shape[1]:
843 input_ids = torch.gather(input_ids, 1, index=model_kwargs["position_ids"])
844
845 if callback_on_step_end is not None:
846 callback_kwargs = {}
847 for k in callback_on_step_end_tensor_inputs:
848 callback_kwargs[k] = locals()[k]
849 callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
850
851 latents = callback_outputs.pop("latents", latents)
852
853 # call the callback, if provided
854 if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
855 progress_bar.update()
856
857 if hasattr(self.vae.config, 'scaling_factor') and self.vae.config.scaling_factor:
858 latents = latents / self.vae.config.scaling_factor
859 if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor:
860 latents = latents + self.vae.config.shift_factor
861
862 if hasattr(self.vae, "ffactor_temporal"):
863 latents = latents.unsqueeze(2)
864
865 with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
866 image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
867
868 # b c t h w
869 if hasattr(self.vae, "ffactor_temporal"):
870 assert image.shape[2] == 1, "image should have shape [B, C, T, H, W] and T should be 1"
871 image = image.squeeze(2)
872
873 do_denormalize = [True] * image.shape[0]
874 image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
875
876 if not return_dict:
877 return (image,)
878
879 return HunyuanImage3Text2ImagePipelineOutput(samples=image)
880