hunyuan_image_3_pipeline.py
| 1 | # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License"); |
| 2 | # you may not use this file except in compliance with the License. |
| 3 | # You may obtain a copy of the License at |
| 4 | # |
| 5 | # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE |
| 6 | # |
| 7 | # Unless required by applicable law or agreed to in writing, software |
| 8 | # distributed under the License is distributed on an "AS IS" BASIS, |
| 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 10 | # See the License for the specific language governing permissions and |
| 11 | # limitations under the License. |
| 12 | # ============================================================================== |
| 13 | # |
| 14 | # Copyright 2024 The HuggingFace Team. All rights reserved. |
| 15 | # |
| 16 | # Licensed under the Apache License, Version 2.0 (the "License"); |
| 17 | # you may not use this file except in compliance with the License. |
| 18 | # You may obtain a copy of the License at |
| 19 | # |
| 20 | # http://www.apache.org/licenses/LICENSE-2.0 |
| 21 | # |
| 22 | # Unless required by applicable law or agreed to in writing, software |
| 23 | # distributed under the License is distributed on an "AS IS" BASIS, |
| 24 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 25 | # See the License for the specific language governing permissions and |
| 26 | # limitations under the License. |
| 27 | # ============================================================================================== |
| 28 | |
| 29 | import inspect |
| 30 | import math |
| 31 | from dataclasses import dataclass |
| 32 | from typing import Any, Callable, Dict, List |
| 33 | from typing import Optional, Tuple, Union |
| 34 | |
| 35 | import numpy as np |
| 36 | import torch |
| 37 | from PIL import Image |
| 38 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback |
| 39 | from diffusers.configuration_utils import ConfigMixin, register_to_config |
| 40 | from diffusers.image_processor import VaeImageProcessor |
| 41 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
| 42 | from diffusers.schedulers.scheduling_utils import SchedulerMixin |
| 43 | from diffusers.utils import BaseOutput, logging |
| 44 | from diffusers.utils.torch_utils import randn_tensor |
| 45 | |
| 46 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| 47 | |
| 48 | |
| 49 | def retrieve_timesteps( |
| 50 | scheduler, |
| 51 | num_inference_steps: Optional[int] = None, |
| 52 | device: Optional[Union[str, torch.device]] = None, |
| 53 | timesteps: Optional[List[int]] = None, |
| 54 | sigmas: Optional[List[float]] = None, |
| 55 | **kwargs, |
| 56 | ): |
| 57 | """ |
| 58 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
| 59 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
| 60 | |
| 61 | Args: |
| 62 | scheduler (`SchedulerMixin`): |
| 63 | The scheduler to get timesteps from. |
| 64 | num_inference_steps (`int`): |
| 65 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
| 66 | must be `None`. |
| 67 | device (`str` or `torch.device`, *optional*): |
| 68 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
| 69 | timesteps (`List[int]`, *optional*): |
| 70 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
| 71 | `num_inference_steps` and `sigmas` must be `None`. |
| 72 | sigmas (`List[float]`, *optional*): |
| 73 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
| 74 | `num_inference_steps` and `timesteps` must be `None`. |
| 75 | |
| 76 | Returns: |
| 77 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
| 78 | second element is the number of inference steps. |
| 79 | """ |
| 80 | if timesteps is not None and sigmas is not None: |
| 81 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
| 82 | if timesteps is not None: |
| 83 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| 84 | if not accepts_timesteps: |
| 85 | raise ValueError( |
| 86 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| 87 | f" timestep schedules. Please check whether you are using the correct scheduler." |
| 88 | ) |
| 89 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
| 90 | timesteps = scheduler.timesteps |
| 91 | num_inference_steps = len(timesteps) |
| 92 | elif sigmas is not None: |
| 93 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| 94 | if not accept_sigmas: |
| 95 | raise ValueError( |
| 96 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| 97 | f" sigmas schedules. Please check whether you are using the correct scheduler." |
| 98 | ) |
| 99 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
| 100 | timesteps = scheduler.timesteps |
| 101 | num_inference_steps = len(timesteps) |
| 102 | else: |
| 103 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
| 104 | timesteps = scheduler.timesteps |
| 105 | return timesteps, num_inference_steps |
| 106 | |
| 107 | |
| 108 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
| 109 | r""" |
| 110 | Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on |
| 111 | Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are |
| 112 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). |
| 113 | |
| 114 | Args: |
| 115 | noise_cfg (`torch.Tensor`): |
| 116 | The predicted noise tensor for the guided diffusion process. |
| 117 | noise_pred_text (`torch.Tensor`): |
| 118 | The predicted noise tensor for the text-guided diffusion process. |
| 119 | guidance_rescale (`float`, *optional*, defaults to 0.0): |
| 120 | A rescale factor applied to the noise predictions. |
| 121 | Returns: |
| 122 | noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. |
| 123 | """ |
| 124 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
| 125 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
| 126 | # rescale the results from guidance (fixes overexposure) |
| 127 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
| 128 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images |
| 129 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
| 130 | return noise_cfg |
| 131 | |
| 132 | |
| 133 | @dataclass |
| 134 | class HunyuanImage3Text2ImagePipelineOutput(BaseOutput): |
| 135 | samples: Union[List[Any], np.ndarray] |
| 136 | |
| 137 | |
| 138 | @dataclass |
| 139 | class FlowMatchDiscreteSchedulerOutput(BaseOutput): |
| 140 | """ |
| 141 | Output class for the scheduler's `step` function output. |
| 142 | |
| 143 | Args: |
| 144 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): |
| 145 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the |
| 146 | denoising loop. |
| 147 | """ |
| 148 | |
| 149 | prev_sample: torch.FloatTensor |
| 150 | |
| 151 | |
| 152 | class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin): |
| 153 | """ |
| 154 | Euler scheduler. |
| 155 | |
| 156 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic |
| 157 | methods the library implements for all schedulers such as loading and saving. |
| 158 | |
| 159 | Args: |
| 160 | num_train_timesteps (`int`, defaults to 1000): |
| 161 | The number of diffusion steps to train the model. |
| 162 | timestep_spacing (`str`, defaults to `"linspace"`): |
| 163 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and |
| 164 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. |
| 165 | shift (`float`, defaults to 1.0): |
| 166 | The shift value for the timestep schedule. |
| 167 | reverse (`bool`, defaults to `True`): |
| 168 | Whether to reverse the timestep schedule. |
| 169 | """ |
| 170 | |
| 171 | _compatibles = [] |
| 172 | order = 1 |
| 173 | |
| 174 | @register_to_config |
| 175 | def __init__( |
| 176 | self, |
| 177 | num_train_timesteps: int = 1000, |
| 178 | shift: float = 1.0, |
| 179 | reverse: bool = True, |
| 180 | solver: str = "euler", |
| 181 | use_flux_shift: bool = False, |
| 182 | flux_base_shift: float = 0.5, |
| 183 | flux_max_shift: float = 1.15, |
| 184 | n_tokens: Optional[int] = None, |
| 185 | ): |
| 186 | sigmas = torch.linspace(1, 0, num_train_timesteps + 1) |
| 187 | |
| 188 | if not reverse: |
| 189 | sigmas = sigmas.flip(0) |
| 190 | |
| 191 | self.sigmas = sigmas |
| 192 | # the value fed to model |
| 193 | self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32) |
| 194 | self.timesteps_full = (sigmas * num_train_timesteps).to(dtype=torch.float32) |
| 195 | |
| 196 | self._step_index = None |
| 197 | self._begin_index = None |
| 198 | |
| 199 | self.supported_solver = [ |
| 200 | "euler", |
| 201 | "heun-2", "midpoint-2", |
| 202 | "kutta-4", |
| 203 | ] |
| 204 | if solver not in self.supported_solver: |
| 205 | raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}") |
| 206 | |
| 207 | # empty dt and derivative (for heun) |
| 208 | self.derivative_1 = None |
| 209 | self.derivative_2 = None |
| 210 | self.derivative_3 = None |
| 211 | self.dt = None |
| 212 | |
| 213 | @property |
| 214 | def step_index(self): |
| 215 | """ |
| 216 | The index counter for current timestep. It will increase 1 after each scheduler step. |
| 217 | """ |
| 218 | return self._step_index |
| 219 | |
| 220 | @property |
| 221 | def begin_index(self): |
| 222 | """ |
| 223 | The index for the first timestep. It should be set from pipeline with `set_begin_index` method. |
| 224 | """ |
| 225 | return self._begin_index |
| 226 | |
| 227 | # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index |
| 228 | def set_begin_index(self, begin_index: int = 0): |
| 229 | """ |
| 230 | Sets the begin index for the scheduler. This function should be run from pipeline before the inference. |
| 231 | |
| 232 | Args: |
| 233 | begin_index (`int`): |
| 234 | The begin index for the scheduler. |
| 235 | """ |
| 236 | self._begin_index = begin_index |
| 237 | |
| 238 | def _sigma_to_t(self, sigma): |
| 239 | return sigma * self.config.num_train_timesteps |
| 240 | |
| 241 | @property |
| 242 | def state_in_first_order(self): |
| 243 | return self.derivative_1 is None |
| 244 | |
| 245 | @property |
| 246 | def state_in_second_order(self): |
| 247 | return self.derivative_2 is None |
| 248 | |
| 249 | @property |
| 250 | def state_in_third_order(self): |
| 251 | return self.derivative_3 is None |
| 252 | |
| 253 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, |
| 254 | n_tokens: int = None): |
| 255 | """ |
| 256 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). |
| 257 | |
| 258 | Args: |
| 259 | num_inference_steps (`int`): |
| 260 | The number of diffusion steps used when generating samples with a pre-trained model. |
| 261 | device (`str` or `torch.device`, *optional*): |
| 262 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
| 263 | n_tokens (`int`, *optional*): |
| 264 | Number of tokens in the input sequence. |
| 265 | """ |
| 266 | self.num_inference_steps = num_inference_steps |
| 267 | |
| 268 | sigmas = torch.linspace(1, 0, num_inference_steps + 1) |
| 269 | |
| 270 | # Apply timestep shift |
| 271 | if self.config.use_flux_shift: |
| 272 | assert isinstance(n_tokens, int), "n_tokens should be provided for flux shift" |
| 273 | mu = self.get_lin_function(y1=self.config.flux_base_shift, y2=self.config.flux_max_shift)(n_tokens) |
| 274 | sigmas = self.flux_time_shift(mu, 1.0, sigmas) |
| 275 | elif self.config.shift != 1.: |
| 276 | sigmas = self.sd3_time_shift(sigmas) |
| 277 | |
| 278 | if not self.config.reverse: |
| 279 | sigmas = 1 - sigmas |
| 280 | |
| 281 | self.sigmas = sigmas |
| 282 | self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device) |
| 283 | self.timesteps_full = (sigmas * self.config.num_train_timesteps).to(dtype=torch.float32, device=device) |
| 284 | |
| 285 | # empty dt and derivative (for kutta) |
| 286 | self.derivative_1 = None |
| 287 | self.derivative_2 = None |
| 288 | self.derivative_3 = None |
| 289 | self.dt = None |
| 290 | |
| 291 | # Reset step index |
| 292 | self._step_index = None |
| 293 | |
| 294 | def index_for_timestep(self, timestep, schedule_timesteps=None): |
| 295 | if schedule_timesteps is None: |
| 296 | schedule_timesteps = self.timesteps |
| 297 | |
| 298 | indices = (schedule_timesteps == timestep).nonzero() |
| 299 | |
| 300 | # The sigma index that is taken for the **very** first `step` |
| 301 | # is always the second index (or the last index if there is only 1) |
| 302 | # This way we can ensure we don't accidentally skip a sigma in |
| 303 | # case we start in the middle of the denoising schedule (e.g. for image-to-image) |
| 304 | pos = 1 if len(indices) > 1 else 0 |
| 305 | |
| 306 | return indices[pos].item() |
| 307 | |
| 308 | def _init_step_index(self, timestep): |
| 309 | if self.begin_index is None: |
| 310 | if isinstance(timestep, torch.Tensor): |
| 311 | timestep = timestep.to(self.timesteps.device) |
| 312 | self._step_index = self.index_for_timestep(timestep) |
| 313 | else: |
| 314 | self._step_index = self._begin_index |
| 315 | |
| 316 | def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: |
| 317 | return sample |
| 318 | |
| 319 | @staticmethod |
| 320 | def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15): |
| 321 | m = (y2 - y1) / (x2 - x1) |
| 322 | b = y1 - m * x1 |
| 323 | return lambda x: m * x + b |
| 324 | |
| 325 | @staticmethod |
| 326 | def flux_time_shift(mu: float, sigma: float, t: torch.Tensor): |
| 327 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) |
| 328 | |
| 329 | def sd3_time_shift(self, t: torch.Tensor): |
| 330 | return (self.config.shift * t) / (1 + (self.config.shift - 1) * t) |
| 331 | |
| 332 | def step( |
| 333 | self, |
| 334 | model_output: torch.FloatTensor, |
| 335 | timestep: Union[float, torch.FloatTensor], |
| 336 | sample: torch.FloatTensor, |
| 337 | pred_uncond: torch.FloatTensor = None, |
| 338 | generator: Optional[torch.Generator] = None, |
| 339 | n_tokens: Optional[int] = None, |
| 340 | return_dict: bool = True, |
| 341 | ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]: |
| 342 | """ |
| 343 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion |
| 344 | process from the learned model outputs (most often the predicted noise). |
| 345 | |
| 346 | Args: |
| 347 | model_output (`torch.FloatTensor`): |
| 348 | The direct output from learned diffusion model. |
| 349 | timestep (`float`): |
| 350 | The current discrete timestep in the diffusion chain. |
| 351 | sample (`torch.FloatTensor`): |
| 352 | A current instance of a sample created by the diffusion process. |
| 353 | generator (`torch.Generator`, *optional*): |
| 354 | A random number generator. |
| 355 | n_tokens (`int`, *optional*): |
| 356 | Number of tokens in the input sequence. |
| 357 | return_dict (`bool`): |
| 358 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or |
| 359 | tuple. |
| 360 | |
| 361 | Returns: |
| 362 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: |
| 363 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is |
| 364 | returned, otherwise a tuple is returned where the first element is the sample tensor. |
| 365 | """ |
| 366 | |
| 367 | if ( |
| 368 | isinstance(timestep, int) |
| 369 | or isinstance(timestep, torch.IntTensor) |
| 370 | or isinstance(timestep, torch.LongTensor) |
| 371 | ): |
| 372 | raise ValueError( |
| 373 | ( |
| 374 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" |
| 375 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" |
| 376 | " one of the `scheduler.timesteps` as a timestep." |
| 377 | ), |
| 378 | ) |
| 379 | |
| 380 | if self.step_index is None: |
| 381 | self._init_step_index(timestep) |
| 382 | |
| 383 | # Upcast to avoid precision issues when computing prev_sample |
| 384 | sample = sample.to(torch.float32) |
| 385 | model_output = model_output.to(torch.float32) |
| 386 | pred_uncond = pred_uncond.to(torch.float32) if pred_uncond is not None else None |
| 387 | |
| 388 | # dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] |
| 389 | sigma = self.sigmas[self.step_index] |
| 390 | sigma_next = self.sigmas[self.step_index + 1] |
| 391 | |
| 392 | last_inner_step = True |
| 393 | if self.config.solver == "euler": |
| 394 | derivative, dt, sample, last_inner_step = self.first_order_method(model_output, sigma, sigma_next, sample) |
| 395 | elif self.config.solver in ["heun-2", "midpoint-2"]: |
| 396 | derivative, dt, sample, last_inner_step = self.second_order_method(model_output, sigma, sigma_next, sample) |
| 397 | elif self.config.solver == "kutta-4": |
| 398 | derivative, dt, sample, last_inner_step = self.fourth_order_method(model_output, sigma, sigma_next, sample) |
| 399 | else: |
| 400 | raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}") |
| 401 | |
| 402 | prev_sample = sample + derivative * dt |
| 403 | |
| 404 | # Cast sample back to model compatible dtype |
| 405 | # prev_sample = prev_sample.to(model_output.dtype) |
| 406 | |
| 407 | # upon completion increase step index by one |
| 408 | if last_inner_step: |
| 409 | self._step_index += 1 |
| 410 | |
| 411 | if not return_dict: |
| 412 | return (prev_sample,) |
| 413 | |
| 414 | return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample) |
| 415 | |
| 416 | def first_order_method(self, model_output, sigma, sigma_next, sample): |
| 417 | derivative = model_output |
| 418 | dt = sigma_next - sigma |
| 419 | return derivative, dt, sample, True |
| 420 | |
| 421 | def second_order_method(self, model_output, sigma, sigma_next, sample): |
| 422 | if self.state_in_first_order: |
| 423 | # store for 2nd order step |
| 424 | self.derivative_1 = model_output |
| 425 | self.dt = sigma_next - sigma |
| 426 | self.sample = sample |
| 427 | |
| 428 | derivative = model_output |
| 429 | if self.config.solver == 'heun-2': |
| 430 | dt = self.dt |
| 431 | elif self.config.solver == 'midpoint-2': |
| 432 | dt = self.dt / 2 |
| 433 | else: |
| 434 | raise NotImplementedError(f"Solver {self.config.solver} not supported.") |
| 435 | last_inner_step = False |
| 436 | |
| 437 | else: |
| 438 | if self.config.solver == 'heun-2': |
| 439 | derivative = 0.5 * (self.derivative_1 + model_output) |
| 440 | elif self.config.solver == 'midpoint-2': |
| 441 | derivative = model_output |
| 442 | else: |
| 443 | raise NotImplementedError(f"Solver {self.config.solver} not supported.") |
| 444 | |
| 445 | # 3. take prev timestep & sample |
| 446 | dt = self.dt |
| 447 | sample = self.sample |
| 448 | last_inner_step = True |
| 449 | |
| 450 | # free dt and derivative |
| 451 | # Note, this puts the scheduler in "first order mode" |
| 452 | self.derivative_1 = None |
| 453 | self.dt = None |
| 454 | self.sample = None |
| 455 | |
| 456 | return derivative, dt, sample, last_inner_step |
| 457 | |
| 458 | def fourth_order_method(self, model_output, sigma, sigma_next, sample): |
| 459 | if self.state_in_first_order: |
| 460 | self.derivative_1 = model_output |
| 461 | self.dt = sigma_next - sigma |
| 462 | self.sample = sample |
| 463 | derivative = model_output |
| 464 | dt = self.dt / 2 |
| 465 | last_inner_step = False |
| 466 | |
| 467 | elif self.state_in_second_order: |
| 468 | self.derivative_2 = model_output |
| 469 | derivative = model_output |
| 470 | dt = self.dt / 2 |
| 471 | last_inner_step = False |
| 472 | |
| 473 | elif self.state_in_third_order: |
| 474 | self.derivative_3 = model_output |
| 475 | derivative = model_output |
| 476 | dt = self.dt |
| 477 | last_inner_step = False |
| 478 | |
| 479 | else: |
| 480 | derivative = (1/6 * self.derivative_1 + 1/3 * self.derivative_2 + 1/3 * self.derivative_3 + |
| 481 | 1/6 * model_output) |
| 482 | |
| 483 | # 3. take prev timestep & sample |
| 484 | dt = self.dt |
| 485 | sample = self.sample |
| 486 | last_inner_step = True |
| 487 | |
| 488 | # free dt and derivative |
| 489 | # Note, this puts the scheduler in "first order mode" |
| 490 | self.derivative_1 = None |
| 491 | self.derivative_2 = None |
| 492 | self.derivative_3 = None |
| 493 | self.dt = None |
| 494 | self.sample = None |
| 495 | |
| 496 | return derivative, dt, sample, last_inner_step |
| 497 | |
| 498 | def __len__(self): |
| 499 | return self.config.num_train_timesteps |
| 500 | |
| 501 | |
| 502 | class ClassifierFreeGuidance: |
| 503 | def __init__( |
| 504 | self, |
| 505 | use_original_formulation: bool = False, |
| 506 | start: float = 0.0, |
| 507 | stop: float = 1.0, |
| 508 | ): |
| 509 | super().__init__() |
| 510 | self.use_original_formulation = use_original_formulation |
| 511 | |
| 512 | def __call__( |
| 513 | self, |
| 514 | pred_cond: torch.Tensor, |
| 515 | pred_uncond: Optional[torch.Tensor], |
| 516 | guidance_scale: float, |
| 517 | step: int, |
| 518 | ) -> torch.Tensor: |
| 519 | |
| 520 | shift = pred_cond - pred_uncond |
| 521 | pred = pred_cond if self.use_original_formulation else pred_uncond |
| 522 | pred = pred + guidance_scale * shift |
| 523 | |
| 524 | return pred |
| 525 | |
| 526 | |
| 527 | class HunyuanImage3Text2ImagePipeline(DiffusionPipeline): |
| 528 | r""" |
| 529 | Pipeline for condition-to-sample generation using Stable Diffusion. |
| 530 | |
| 531 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods |
| 532 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). |
| 533 | |
| 534 | Args: |
| 535 | model ([`ModelMixin`]): |
| 536 | A model to denoise the diffused latents. |
| 537 | scheduler ([`SchedulerMixin`]): |
| 538 | A scheduler to be used in combination with `diffusion_model` to denoise the diffused latents. Can be one of |
| 539 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. |
| 540 | """ |
| 541 | |
| 542 | model_cpu_offload_seq = "" |
| 543 | _optional_components = [] |
| 544 | _exclude_from_cpu_offload = [] |
| 545 | _callback_tensor_inputs = ["latents"] |
| 546 | |
| 547 | def __init__( |
| 548 | self, |
| 549 | model, |
| 550 | scheduler: SchedulerMixin, |
| 551 | vae, |
| 552 | progress_bar_config: Dict[str, Any] = None, |
| 553 | ): |
| 554 | super().__init__() |
| 555 | |
| 556 | # ========================================================================================== |
| 557 | if progress_bar_config is None: |
| 558 | progress_bar_config = {} |
| 559 | if not hasattr(self, '_progress_bar_config'): |
| 560 | self._progress_bar_config = {} |
| 561 | self._progress_bar_config.update(progress_bar_config) |
| 562 | # ========================================================================================== |
| 563 | |
| 564 | self.register_modules( |
| 565 | model=model, |
| 566 | scheduler=scheduler, |
| 567 | vae=vae, |
| 568 | ) |
| 569 | |
| 570 | # should be a tuple or a list corresponding to the size of latents (batch_size, channel, *size) |
| 571 | # if None, will be treated as a tuple of 1 |
| 572 | self.latent_scale_factor = self.model.config.vae_downsample_factor |
| 573 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.latent_scale_factor) |
| 574 | |
| 575 | # Must start with APG_mode_ |
| 576 | self.cfg_operator = ClassifierFreeGuidance() |
| 577 | |
| 578 | @staticmethod |
| 579 | def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: |
| 580 | """ |
| 581 | Denormalize an image array to [0,1]. |
| 582 | """ |
| 583 | return (images / 2 + 0.5).clamp(0, 1) |
| 584 | |
| 585 | @staticmethod |
| 586 | def pt_to_numpy(images: torch.Tensor) -> np.ndarray: |
| 587 | """ |
| 588 | Convert a PyTorch tensor to a NumPy image. |
| 589 | """ |
| 590 | images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
| 591 | return images |
| 592 | |
| 593 | @staticmethod |
| 594 | def numpy_to_pil(images: np.ndarray): |
| 595 | """ |
| 596 | Convert a numpy image or a batch of images to a PIL image. |
| 597 | """ |
| 598 | if images.ndim == 3: |
| 599 | images = images[None, ...] |
| 600 | images = (images * 255).round().astype("uint8") |
| 601 | if images.shape[-1] == 1: |
| 602 | # special case for grayscale (single channel) images |
| 603 | pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] |
| 604 | else: |
| 605 | pil_images = [Image.fromarray(image) for image in images] |
| 606 | |
| 607 | return pil_images |
| 608 | |
| 609 | def prepare_extra_func_kwargs(self, func, kwargs): |
| 610 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature |
| 611 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. |
| 612 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 |
| 613 | # and should be between [0, 1] |
| 614 | extra_kwargs = {} |
| 615 | |
| 616 | for k, v in kwargs.items(): |
| 617 | accepts = k in set(inspect.signature(func).parameters.keys()) |
| 618 | if accepts: |
| 619 | extra_kwargs[k] = v |
| 620 | return extra_kwargs |
| 621 | |
| 622 | def prepare_latents(self, batch_size, latent_channel, image_size, dtype, device, generator, latents=None): |
| 623 | if self.latent_scale_factor is None: |
| 624 | latent_scale_factor = (1,) * len(image_size) |
| 625 | elif isinstance(self.latent_scale_factor, int): |
| 626 | latent_scale_factor = (self.latent_scale_factor,) * len(image_size) |
| 627 | elif isinstance(self.latent_scale_factor, tuple) or isinstance(self.latent_scale_factor, list): |
| 628 | assert len(self.latent_scale_factor) == len(image_size), \ |
| 629 | "len(latent_scale_factor) shoudl be the same as len(image_size)" |
| 630 | latent_scale_factor = self.latent_scale_factor |
| 631 | else: |
| 632 | raise ValueError( |
| 633 | f"latent_scale_factor should be either None, int, tuple of int, or list of int, " |
| 634 | f"but got {self.latent_scale_factor}" |
| 635 | ) |
| 636 | |
| 637 | latents_shape = ( |
| 638 | batch_size, |
| 639 | latent_channel, |
| 640 | *[int(s) // f for s, f in zip(image_size, latent_scale_factor)], |
| 641 | ) |
| 642 | if isinstance(generator, list) and len(generator) != batch_size: |
| 643 | raise ValueError( |
| 644 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
| 645 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
| 646 | ) |
| 647 | |
| 648 | if latents is None: |
| 649 | latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype) |
| 650 | else: |
| 651 | latents = latents.to(device) |
| 652 | |
| 653 | # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler |
| 654 | if hasattr(self.scheduler, "init_noise_sigma"): |
| 655 | # scale the initial noise by the standard deviation required by the scheduler |
| 656 | latents = latents * self.scheduler.init_noise_sigma |
| 657 | |
| 658 | return latents |
| 659 | |
| 660 | @property |
| 661 | def guidance_scale(self): |
| 662 | return self._guidance_scale |
| 663 | |
| 664 | @property |
| 665 | def guidance_rescale(self): |
| 666 | return self._guidance_rescale |
| 667 | |
| 668 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
| 669 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
| 670 | # corresponds to doing no classifier free guidance. |
| 671 | @property |
| 672 | def do_classifier_free_guidance(self): |
| 673 | return self._guidance_scale > 1.0 |
| 674 | |
| 675 | @property |
| 676 | def num_timesteps(self): |
| 677 | return self._num_timesteps |
| 678 | |
| 679 | def set_scheduler(self, new_scheduler): |
| 680 | self.register_modules(scheduler=new_scheduler) |
| 681 | |
| 682 | @torch.no_grad() |
| 683 | def __call__( |
| 684 | self, |
| 685 | batch_size: int, |
| 686 | image_size: List[int], |
| 687 | num_inference_steps: int = 50, |
| 688 | timesteps: List[int] = None, |
| 689 | sigmas: List[float] = None, |
| 690 | guidance_scale: float = 7.5, |
| 691 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| 692 | latents: Optional[torch.Tensor] = None, |
| 693 | output_type: Optional[str] = "pil", |
| 694 | return_dict: bool = True, |
| 695 | guidance_rescale: float = 0.0, |
| 696 | callback_on_step_end: Optional[ |
| 697 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] |
| 698 | ] = None, |
| 699 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
| 700 | model_kwargs: Dict[str, Any] = None, |
| 701 | **kwargs, |
| 702 | ): |
| 703 | r""" |
| 704 | The call function to the pipeline for generation. |
| 705 | |
| 706 | Args: |
| 707 | prompt (`str` or `List[str]`): |
| 708 | The text to guide image generation. |
| 709 | image_size (`Tuple[int]` or `List[int]`): |
| 710 | The size (height, width) of the generated image. |
| 711 | num_inference_steps (`int`, *optional*, defaults to 50): |
| 712 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| 713 | expense of slower inference. |
| 714 | timesteps (`List[int]`, *optional*): |
| 715 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument |
| 716 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is |
| 717 | passed will be used. Must be in descending order. |
| 718 | sigmas (`List[float]`, *optional*): |
| 719 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in |
| 720 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed |
| 721 | will be used. |
| 722 | guidance_scale (`float`, *optional*, defaults to 7.5): |
| 723 | A higher guidance scale value encourages the model to generate samples closely linked to the |
| 724 | `condition` at the expense of lower sample quality. Guidance scale is enabled when `guidance_scale > 1`. |
| 725 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
| 726 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make |
| 727 | generation deterministic. |
| 728 | latents (`torch.Tensor`, *optional*): |
| 729 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for sample |
| 730 | generation. Can be used to tweak the same generation with different conditions. If not provided, |
| 731 | a latents tensor is generated by sampling using the supplied random `generator`. |
| 732 | output_type (`str`, *optional*, defaults to `"pil"`): |
| 733 | The output format of the generated sample. |
| 734 | return_dict (`bool`, *optional*, defaults to `True`): |
| 735 | Whether or not to return a [`~DiffusionPipelineOutput`] instead of a |
| 736 | plain tuple. |
| 737 | guidance_rescale (`float`, *optional*, defaults to 0.0): |
| 738 | Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are |
| 739 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when |
| 740 | using zero terminal SNR. |
| 741 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): |
| 742 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of |
| 743 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self: |
| 744 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a |
| 745 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`. |
| 746 | callback_on_step_end_tensor_inputs (`List`, *optional*): |
| 747 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list |
| 748 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the |
| 749 | `._callback_tensor_inputs` attribute of your pipeline class. |
| 750 | |
| 751 | Examples: |
| 752 | |
| 753 | Returns: |
| 754 | [`~DiffusionPipelineOutput`] or `tuple`: |
| 755 | If `return_dict` is `True`, [`~DiffusionPipelineOutput`] is returned, |
| 756 | otherwise a `tuple` is returned where the first element is a list with the generated samples. |
| 757 | """ |
| 758 | |
| 759 | callback_steps = kwargs.pop("callback_steps", None) |
| 760 | pbar_steps = kwargs.pop("pbar_steps", None) |
| 761 | |
| 762 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): |
| 763 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs |
| 764 | |
| 765 | self._guidance_scale = guidance_scale |
| 766 | self._guidance_rescale = guidance_rescale |
| 767 | |
| 768 | cfg_factor = 1 + self.do_classifier_free_guidance |
| 769 | |
| 770 | # Define call parameters |
| 771 | device = self._execution_device |
| 772 | |
| 773 | # Prepare timesteps |
| 774 | timesteps, num_inference_steps = retrieve_timesteps( |
| 775 | self.scheduler, num_inference_steps, device, timesteps, sigmas, |
| 776 | ) |
| 777 | |
| 778 | # Prepare latent variables |
| 779 | latents = self.prepare_latents( |
| 780 | batch_size=batch_size, |
| 781 | latent_channel=self.model.config.vae["latent_channels"], |
| 782 | image_size=image_size, |
| 783 | dtype=torch.bfloat16, |
| 784 | device=device, |
| 785 | generator=generator, |
| 786 | latents=latents, |
| 787 | ) |
| 788 | |
| 789 | # Prepare extra step kwargs. |
| 790 | _scheduler_step_extra_kwargs = self.prepare_extra_func_kwargs( |
| 791 | self.scheduler.step, {"generator": generator} |
| 792 | ) |
| 793 | |
| 794 | # Prepare model kwargs |
| 795 | input_ids = model_kwargs.pop("input_ids") |
| 796 | attention_mask = self.model._prepare_attention_mask_for_generation( # noqa |
| 797 | input_ids, self.model.generation_config, model_kwargs=model_kwargs, |
| 798 | ) |
| 799 | model_kwargs["attention_mask"] = attention_mask.to(latents.device) |
| 800 | |
| 801 | # Sampling loop |
| 802 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| 803 | self._num_timesteps = len(timesteps) |
| 804 | |
| 805 | with self.progress_bar(total=num_inference_steps) as progress_bar: |
| 806 | for i, t in enumerate(timesteps): |
| 807 | # expand the latents if we are doing classifier free guidance |
| 808 | latent_model_input = torch.cat([latents] * cfg_factor) |
| 809 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| 810 | |
| 811 | t_expand = t.repeat(latent_model_input.shape[0]) |
| 812 | |
| 813 | model_inputs = self.model.prepare_inputs_for_generation( |
| 814 | input_ids, |
| 815 | images=latent_model_input, |
| 816 | timestep=t_expand, |
| 817 | **model_kwargs, |
| 818 | ) |
| 819 | |
| 820 | with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
| 821 | model_output = self.model(**model_inputs, first_step=(i == 0)) |
| 822 | pred = model_output["diffusion_prediction"] |
| 823 | pred = pred.to(dtype=torch.float32) |
| 824 | |
| 825 | # perform guidance |
| 826 | if self.do_classifier_free_guidance: |
| 827 | pred_cond, pred_uncond = pred.chunk(2) |
| 828 | pred = self.cfg_operator(pred_cond, pred_uncond, self.guidance_scale, step=i) |
| 829 | |
| 830 | if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: |
| 831 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf |
| 832 | pred = rescale_noise_cfg(pred, pred_cond, guidance_rescale=self.guidance_rescale) |
| 833 | |
| 834 | # compute the previous noisy sample x_t -> x_t-1 |
| 835 | latents = self.scheduler.step(pred, t, latents, **_scheduler_step_extra_kwargs, return_dict=False)[0] |
| 836 | |
| 837 | if i != len(timesteps) - 1: |
| 838 | model_kwargs = self.model._update_model_kwargs_for_generation( # noqa |
| 839 | model_output, |
| 840 | model_kwargs, |
| 841 | ) |
| 842 | if input_ids.shape[1] != model_kwargs["position_ids"].shape[1]: |
| 843 | input_ids = torch.gather(input_ids, 1, index=model_kwargs["position_ids"]) |
| 844 | |
| 845 | if callback_on_step_end is not None: |
| 846 | callback_kwargs = {} |
| 847 | for k in callback_on_step_end_tensor_inputs: |
| 848 | callback_kwargs[k] = locals()[k] |
| 849 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
| 850 | |
| 851 | latents = callback_outputs.pop("latents", latents) |
| 852 | |
| 853 | # call the callback, if provided |
| 854 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| 855 | progress_bar.update() |
| 856 | |
| 857 | if hasattr(self.vae.config, 'scaling_factor') and self.vae.config.scaling_factor: |
| 858 | latents = latents / self.vae.config.scaling_factor |
| 859 | if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor: |
| 860 | latents = latents + self.vae.config.shift_factor |
| 861 | |
| 862 | if hasattr(self.vae, "ffactor_temporal"): |
| 863 | latents = latents.unsqueeze(2) |
| 864 | |
| 865 | with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True): |
| 866 | image = self.vae.decode(latents, return_dict=False, generator=generator)[0] |
| 867 | |
| 868 | # b c t h w |
| 869 | if hasattr(self.vae, "ffactor_temporal"): |
| 870 | assert image.shape[2] == 1, "image should have shape [B, C, T, H, W] and T should be 1" |
| 871 | image = image.squeeze(2) |
| 872 | |
| 873 | do_denormalize = [True] * image.shape[0] |
| 874 | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) |
| 875 | |
| 876 | if not return_dict: |
| 877 | return (image,) |
| 878 | |
| 879 | return HunyuanImage3Text2ImagePipelineOutput(samples=image) |
| 880 | |