image_processor.py
5.9 KB · 126 lines · python Raw
1 # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
2 # you may not use this file except in compliance with the License.
3 # You may obtain a copy of the License at
4 #
5 # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
6 #
7 # Unless required by applicable law or agreed to in writing, software
8 # distributed under the License is distributed on an "AS IS" BASIS,
9 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 # See the License for the specific language governing permissions and
11 # limitations under the License.
12 # ==============================================================================
13
14 from typing import Tuple
15
16 from PIL import Image
17 from torchvision import transforms
18 from transformers import Siglip2ImageProcessorFast
19
20 from .tokenizer_wrapper import ImageInfo, JointImageInfo, ResolutionGroup
21
22
23 def resize_and_crop(image: Image.Image, target_size: Tuple[int, int]) -> Image.Image:
24 tw, th = target_size
25 w, h = image.size
26
27 tr = th / tw
28 r = h / w
29
30 # resize
31 if r < tr:
32 resize_height = th
33 resize_width = int(round(th / h * w))
34 else:
35 resize_width = tw
36 resize_height = int(round(tw / w * h))
37
38 image = image.resize((resize_width, resize_height), resample=Image.Resampling.LANCZOS)
39
40 # center crop
41 crop_top = int(round((resize_height - th) / 2.0))
42 crop_left = int(round((resize_width - tw) / 2.0))
43
44 image = image.crop((crop_left, crop_top, crop_left + tw, crop_top + th))
45 return image
46
47
48 class HunyuanImage3ImageProcessor(object):
49 def __init__(self, config):
50 self.config = config
51
52 self.reso_group = ResolutionGroup(base_size=config.image_base_size)
53 self.vae_processor = transforms.Compose([
54 transforms.ToTensor(),
55 transforms.Normalize([0.5], [0.5]), # transform to [-1, 1]
56 ])
57 self.vision_encoder_processor = Siglip2ImageProcessorFast.from_dict(config.vit_processor)
58
59 def build_image_info(self, image_size):
60 # parse image size (HxW, H:W, or <img_ratio_i>)
61 if isinstance(image_size, str):
62 if image_size.startswith("<img_ratio_"):
63 ratio_index = int(image_size.split("_")[-1].rstrip(">"))
64 reso = self.reso_group[ratio_index]
65 image_size = reso.height, reso.width
66 elif 'x' in image_size:
67 image_size = [int(s) for s in image_size.split('x')]
68 elif ':' in image_size:
69 image_size = [int(s) for s in image_size.split(':')]
70 else:
71 raise ValueError(
72 f"`image_size` should be in the format of 'HxW', 'H:W' or <img_ratio_i>, got {image_size}.")
73 assert len(image_size) == 2, f"`image_size` should be in the format of 'HxW', got {image_size}."
74 elif isinstance(image_size, (list, tuple)):
75 assert len(image_size) == 2 and all(isinstance(s, int) for s in image_size), \
76 f"`image_size` should be a tuple of two integers or a string in the format of 'HxW', got {image_size}."
77 else:
78 raise ValueError(f"`image_size` should be a tuple of two integers or a string in the format of 'WxH', "
79 f"got {image_size}.")
80 image_width, image_height = self.reso_group.get_target_size(image_size[1], image_size[0])
81 token_height = image_height // (self.config.vae_downsample_factor[0] * self.config.patch_size)
82 token_width = image_width // (self.config.vae_downsample_factor[1] * self.config.patch_size)
83 base_size, ratio_idx = self.reso_group.get_base_size_and_ratio_index(image_size[1], image_size[0])
84 image_info = ImageInfo(
85 image_type="gen_image", image_width=image_width, image_height=image_height,
86 token_width=token_width, token_height=token_height, base_size=base_size, ratio_index=ratio_idx,
87 )
88 return image_info
89
90 def preprocess(self, image: Image.Image):
91 # ==== VAE processor ====
92 image_width, image_height = self.reso_group.get_target_size(image.width, image.height)
93 resized_image = resize_and_crop(image, (image_width, image_height))
94 image_tensor = self.vae_processor(resized_image)
95 token_height = image_height // (self.config.vae_downsample_factor[0] * self.config.patch_size)
96 token_width = image_width // (self.config.vae_downsample_factor[1] * self.config.patch_size)
97 base_size, ratio_index = self.reso_group.get_base_size_and_ratio_index(width=image_width, height=image_height)
98 vae_image_info = ImageInfo(
99 image_type="vae",
100 image_tensor=image_tensor.unsqueeze(0), # include batch dim
101 image_width=image_width, image_height=image_height,
102 token_width=token_width, token_height=token_height,
103 base_size=base_size, ratio_index=ratio_index,
104 )
105
106 # ==== ViT processor ====
107 inputs = self.vision_encoder_processor(image)
108 image = inputs["pixel_values"].squeeze(0) # seq_len x dim
109 pixel_attention_mask = inputs["pixel_attention_mask"].squeeze(0) # seq_len
110 spatial_shapes = inputs["spatial_shapes"].squeeze(0) # 2 (h, w)
111 vision_encoder_kwargs = dict(
112 pixel_attention_mask=pixel_attention_mask,
113 spatial_shapes=spatial_shapes,
114 )
115 vision_image_info = ImageInfo(
116 image_type="vit",
117 image_tensor=image.unsqueeze(0), # 1 x seq_len x dim
118 image_width=spatial_shapes[1].item() * self.config.vit_processor["patch_size"],
119 image_height=spatial_shapes[0].item() * self.config.vit_processor["patch_size"],
120 token_width=spatial_shapes[1].item(),
121 token_height=spatial_shapes[0].item(),
122 image_token_length=self.config.vit_processor["max_num_patches"],
123 # may not equal to token_width * token_height
124 )
125 return JointImageInfo(vae_image_info, vision_image_info, vision_encoder_kwargs)
126