processing_phi3_v.py
21.5 KB · 478 lines · python Raw
1 # coding=utf-8
2 # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 # http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 """
17 Processor class for Phi3-V.
18 """
19 import re
20 from typing import List, Optional, Union
21
22 import torch
23
24 import transformers
25 from transformers.feature_extraction_utils import BatchFeature
26 from transformers.image_utils import ImageInput
27 from transformers.processing_utils import ProcessorMixin
28 from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy
29 from transformers.utils import TensorType
30
31
32 """Image processor class for Phi3-V."""
33
34 from typing import List, Optional, Union
35
36 import numpy as np
37
38 from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
39 from transformers.image_transforms import (
40 convert_to_rgb,
41 )
42 from transformers.image_utils import (
43 OPENAI_CLIP_MEAN,
44 OPENAI_CLIP_STD,
45 ImageInput,
46 make_list_of_images,
47 valid_images,
48 )
49 from transformers.utils import TensorType, is_vision_available, logging
50
51 from transformers import AutoImageProcessor
52
53 logger = logging.get_logger(__name__)
54
55
56 if is_vision_available():
57 from PIL import Image
58
59 import torch
60 import torchvision
61
62 def padding_336(b):
63 width, height = b.size
64 tar = int(np.ceil(height / 336) * 336)
65 top_padding = int((tar - height)/2)
66 bottom_padding = tar - height - top_padding
67 left_padding = 0
68 right_padding = 0
69 b = torchvision.transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255,255,255])
70
71 return b
72
73 def calc_padded_size(width, height, padding_unit=336):
74 target_height = int(np.ceil(height / padding_unit) * padding_unit)
75 top_padding = int((target_height - height) / 2)
76 bottom_padding = target_height - height - top_padding
77 left_padding = 0
78 right_padding = 0
79 padded_width = width + left_padding + right_padding
80 padded_height = height + top_padding + bottom_padding
81 return padded_width, padded_height
82
83 def HD_transform(img, hd_num=16):
84 width, height = img.size
85 trans = False
86 if width < height:
87 img = img.transpose(Image.TRANSPOSE)
88 trans = True
89 width, height = img.size
90 ratio = (width/ height)
91 scale = 1
92 while scale*np.ceil(scale/ratio) <= hd_num:
93 scale += 1
94 scale -= 1
95 new_w = int(scale * 336)
96 new_h = int(new_w / ratio)
97
98 img = torchvision.transforms.functional.resize(img, [new_h, new_w],)
99 img = padding_336(img)
100 width, height = img.size
101 if trans:
102 img = img.transpose(Image.TRANSPOSE)
103
104 return img
105
106 def calc_hd_transform_size(width, height, hd_num=16):
107 transposed = False
108 if width < height:
109 width, height = height, width
110 transposed = True
111
112 ratio = width / height
113 scale = 1
114 while scale * np.ceil(scale / ratio) <= hd_num:
115 scale += 1
116 scale -= 1
117
118 new_width = int(scale * 336)
119 new_height = int(new_width / ratio)
120
121 padded_width, padded_height = calc_padded_size(new_width, new_height)
122
123 if transposed:
124 padded_width, padded_height = padded_height, padded_width
125
126 return padded_width, padded_height
127
128 def pad_to_max_num_crops_tensor(images, max_crops=5):
129 """
130 images: B x 3 x H x W, B<=max_crops
131 """
132 B, _, H, W = images.shape
133 if B < max_crops:
134 pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device)
135 images = torch.cat([images, pad], dim=0)
136 return images
137
138
139 class Phi3VImageProcessor(BaseImageProcessor):
140 r"""
141 Constructs a Phi3 image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques
142 for processing high resolution images as explained in the [InternLM-XComposer2-4KHD](https://arxiv.org/pdf/2404.06512)
143
144 Args:
145 image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
146 Mean to use if normalizing the image. This is a float or list of floats the length of the number of
147 channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
148 image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
149 Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
150 number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
151 Can be overridden by the `image_std` parameter in the `preprocess` method.
152 do_convert_rgb (`bool`, *optional*, defaults to `True`):
153 Whether to convert the image to RGB.
154 """
155
156 model_input_names = ["pixel_values"]
157
158 def __init__(
159 self,
160 num_crops: int = 1,
161 image_mean: Optional[Union[float, List[float]]] = None,
162 image_std: Optional[Union[float, List[float]]] = None,
163 do_convert_rgb: bool = True,
164 **kwargs,
165 ) -> None:
166 super().__init__(**kwargs)
167 self.num_crops = num_crops
168 self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
169 self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
170 self.do_convert_rgb = do_convert_rgb
171
172 def calc_num_image_tokens(
173 self,
174 images: ImageInput
175 ):
176 """ Calculate the number of image tokens for each image.
177 Args:
178 images (`ImageInput`):
179 Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
180 passing in images with pixel values between 0 and 1, set `do_rescale=False`.
181 """
182 images = make_list_of_images(images)
183
184 if not valid_images(images):
185 raise ValueError(
186 "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
187 "torch.Tensor, tf.Tensor or jax.ndarray."
188 )
189
190 images = [image.convert('RGB') for image in images]
191 # (H, W, C)
192 elems = [HD_transform(im, hd_num = self.num_crops) for im in images]
193 shapes = [[im.size[1], im.size[0]] for im in elems]
194 num_img_tokens = [int((h//336*w//336+1)*144 + 1 + (h//336+1)*12) for h, w in shapes]
195 return num_img_tokens
196
197 def calc_num_image_tokens_from_image_size(self, width, height):
198 """
199 Calculate the number of image tokens for a given image size.
200 Args:
201 width (`int`): Width of the image.
202 height (`int`): Height of the image.
203 """
204 new_width, new_height = calc_hd_transform_size(width, height, hd_num=self.num_crops)
205 num_img_tokens = int((new_height // 336 * new_width // 336 + 1) * 144 + 1 + (new_height // 336 + 1) * 12)
206 return num_img_tokens
207
208 def preprocess(
209 self,
210 images: ImageInput,
211 image_mean: Optional[Union[float, List[float]]] = None,
212 image_std: Optional[Union[float, List[float]]] = None,
213 do_convert_rgb: bool = None,
214 return_tensors: Optional[Union[str, TensorType]] = None,
215 ):
216 """
217 Args:
218 images (`ImageInput`):
219 Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
220 passing in images with pixel values between 0 and 1, set `do_rescale=False`.
221 image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
222 Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
223 image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
224 Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
225 `True`.
226 do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
227 Whether to convert the image to RGB.
228 return_tensors (`str` or `TensorType`, *optional*):
229 The type of tensors to return. Can be one of:
230 - Unset: Return a list of `np.ndarray`.
231 - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
232 - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
233 - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
234 - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
235 """
236 image_mean = image_mean if image_mean is not None else self.image_mean
237 image_std = image_std if image_std is not None else self.image_std
238 do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
239
240 images = make_list_of_images(images)
241
242 if not valid_images(images):
243 raise ValueError(
244 "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
245 "torch.Tensor, tf.Tensor or jax.ndarray."
246 )
247
248 if do_convert_rgb:
249 images = [convert_to_rgb(image) for image in images]
250
251 image_sizes = []
252 img_processor = torchvision.transforms.Compose([
253 torchvision.transforms.ToTensor(),
254 torchvision.transforms.Normalize(image_mean, image_std)
255 ])
256
257 # PIL images
258 # HD_transform pad images to size of multiiply of 336, 336
259 # convert to RGB first
260 images = [image.convert('RGB') for image in images]
261 elems = [HD_transform(im, hd_num = self.num_crops) for im in images]
262 # tensor transform and normalize
263 hd_images = [img_processor(im) for im in elems]
264 # create global image
265 global_image = [torch.nn.functional.interpolate(im.unsqueeze(0).float(), size=(336, 336), mode='bicubic',).to(im.dtype) for im in hd_images]
266
267 # [(3, h, w)], where h, w is multiple of 336
268 shapes = [[im.size(1), im.size(2)] for im in hd_images]
269 num_img_tokens = [int(((h//336)*(w//336)+1)*144 + 1 + (h//336+1)*12) for h, w in shapes]
270 # reshape to channel dimension -> (num_images, num_crops, 3, 336, 336)
271 # (1, 3, h//336, 336, w//336, 336) -> (1, h//336, w//336, 3, 336, 336) -> (h//336*w//336, 3, 336, 336)
272 hd_images_reshape = [im.reshape(1, 3, h//336, 336, w//336, 336).permute(0,2,4,1,3,5).reshape(-1, 3, 336, 336).contiguous() for im, (h, w) in zip(hd_images, shapes)]
273 # concat global image and local image
274 hd_images_reshape = [torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape)]
275
276 # pad to max_num_crops
277 image_transformed = [pad_to_max_num_crops_tensor(im, self.num_crops+1) for im in hd_images_reshape]
278 image_transformed = torch.stack(image_transformed, dim=0)
279 image_sizes = [torch.LongTensor(_shapes) for _shapes in shapes]
280 padded_images = image_transformed
281 image_sizes = shapes
282
283 data = {"pixel_values": padded_images,
284 "image_sizes": image_sizes,
285 "num_img_tokens": num_img_tokens
286 }
287
288 return BatchFeature(data=data, tensor_type=return_tensors)
289
290 AutoImageProcessor.register("Phi3VImageProcessor", Phi3VImageProcessor)
291
292 transformers.Phi3VImageProcessor = Phi3VImageProcessor
293
294 class Phi3VProcessor(ProcessorMixin):
295 r"""
296 Constructs a Phi3-V processor which wraps a Phi3-V image processor and a LLaMa tokenizer into a single processor.
297
298 [`Phi3VProcessor`] offers all the functionalities of [`Phi3VImageProcessor`] and [`LlamaTokenizerFast`]. See the
299 [`~Phi3VProcessor.__call__`] and [`~Phi3VProcessor.decode`] for more information.
300
301 Args:
302 image_processor ([`Phi3VImageProcessor`], *optional*):
303 The image processor is a required input.
304 tokenizer ([`LlamaTokenizerFast`], *optional*):
305 The tokenizer is a required input.
306 """
307
308 attributes = ["image_processor", "tokenizer"]
309 image_processor_class = "Phi3VImageProcessor"
310 tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
311 special_image_token = "<|image|>"
312
313 def __init__(self, image_processor, tokenizer):
314 self.image_processor = image_processor
315 self.tokenizer = tokenizer
316 self.num_img_tokens = image_processor.num_img_tokens
317 self.img_tokens = [f"<|image_{i+1}|>" for i in range(1000000)]
318
319 def __call__(
320 self,
321 text: Union[TextInput, List[TextInput]],
322 images: ImageInput = None,
323 padding: Union[bool, str, PaddingStrategy] = False,
324 truncation: Union[bool, str, TruncationStrategy] = None,
325 max_length=None,
326 return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
327 ) -> BatchFeature:
328 """
329 Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
330 and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
331 the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
332 Phi3ImageProcessor's [`~Phi3ImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
333 of the above two methods for more information.
334
335 Args:
336 text (`str`, `List[str]`, `List[List[str]]`):
337 The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
338 (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
339 `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
340 images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
341 The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
342 tensor. Both channels-first and channels-last formats are supported.
343 padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
344 Select a strategy to pad the returned sequences (according to the model's padding side and padding
345 index) among:
346 - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
347 sequence if provided).
348 - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
349 acceptable input length for the model if that argument is not provided.
350 - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
351 lengths).
352 max_length (`int`, *optional*):
353 Maximum length of the returned list and optionally padding length (see above).
354 truncation (`bool`, *optional*):
355 Activates truncation to cut input sequences longer than `max_length` to `max_length`.
356 return_tensors (`str` or [`~utils.TensorType`], *optional*):
357 If set, will return tensors of a particular framework. Acceptable values are:
358
359 - `'tf'`: Return TensorFlow `tf.constant` objects.
360 - `'pt'`: Return PyTorch `torch.Tensor` objects.
361 - `'np'`: Return NumPy `np.ndarray` objects.
362 - `'jax'`: Return JAX `jnp.ndarray` objects.
363
364 Returns:
365 [`BatchFeature`]: A [`BatchFeature`] with the following fields:
366
367 - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
368 - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
369 `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
370 `None`).
371 - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
372 """
373 if images is not None:
374 image_inputs = self.image_processor(images, return_tensors=return_tensors)
375 else:
376 image_inputs = {}
377 inputs = self._convert_images_texts_to_inputs(image_inputs, text, padding=padding, truncation=truncation, max_length=max_length, return_tensors=return_tensors)
378 return inputs
379
380 def calc_num_image_tokens(self, images: ImageInput):
381 """ Calculate the number of image tokens for each image.
382 Args:
383 images (`ImageInput`):
384 Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
385 passing in images with pixel values between 0 and 1, set `do_rescale=False`.
386 """
387 return self.image_processor.calc_num_image_tokens(images)
388
389 def calc_num_image_tokens_from_image_size(self, width, height):
390 """ Calculate the number of image token for an image with given width and height.
391 Args:
392 width (`int`):
393 Width of the image.
394 height (`int`):
395 Height of the image.
396 """
397 return self.image_processor.calc_num_image_tokens_from_image_size(width, height)
398
399
400 @property
401 def special_image_token_id(self):
402 return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
403
404 def get_special_image_token_id(self):
405 return self.tokenizer.convert_tokens_to_ids(self.special_image_token)
406
407 def _convert_images_texts_to_inputs(self, images, texts, padding=False, truncation=None, max_length=None, return_tensors=None):
408
409 if not len(images):
410 model_inputs = self.tokenizer(texts, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length)
411 return BatchFeature(data={**model_inputs})
412
413 pattern = r"<\|image_\d+\|>"
414 prompt_chunks = [self.tokenizer(chunk).input_ids for chunk in re.split(pattern, texts)]
415
416 if 'num_img_tokens' in images:
417 num_img_tokens = images['num_img_tokens']
418 else:
419 assert 'num_crops' in images, 'num_crops must be provided in images if num_img_tokens is not provided'
420 num_crops = images['num_crops']
421 num_img_tokens = [_num_crops * self.num_img_tokens for _num_crops in num_crops]
422
423 images, image_sizes = images['pixel_values'], images['image_sizes']
424
425 # image_tags needs to start from 1 to n
426 image_tags = re.findall(pattern, texts)
427 # image_ids = [int(s.split("|")[1].split("_")[-1]) * -1 for s in image_tags]
428 # image_ids_pad = [[iid]*num_img_tokens[i] for i, iid in enumerate(image_ids)]
429 image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
430 unique_image_ids = sorted(list(set(image_ids)))
431 # image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]
432 # check the condition
433 assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
434 # total images must be the same as the number of image tags
435 assert len(unique_image_ids) == len(images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(images)} images"
436
437 image_ids_pad = [[-iid]*num_img_tokens[iid-1] for iid in image_ids]
438
439 def insert_separator(X, sep_list):
440 if len(X) > len(sep_list):
441 sep_list.append([])
442 return [ele for sublist in zip(X, sep_list) for ele in sublist]
443 input_ids = []
444 offset = 0
445 for x in insert_separator(prompt_chunks, image_ids_pad):
446 input_ids.extend(x[offset:])
447
448 input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
449 attention_mask = (input_ids > -1000000).to(torch.long)
450
451 return BatchFeature(data={"input_ids": input_ids,
452 "attention_mask": attention_mask,
453 "pixel_values": images,
454 "image_sizes": image_sizes})
455
456
457 # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
458 def batch_decode(self, *args, **kwargs):
459 """
460 This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
461 refer to the docstring of this method for more information.
462 """
463 return self.tokenizer.batch_decode(*args, **kwargs)
464
465 # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
466 def decode(self, *args, **kwargs):
467 """
468 This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
469 the docstring of this method for more information.
470 """
471 return self.tokenizer.decode(*args, **kwargs)
472
473 @property
474 # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
475 def model_input_names(self):
476 tokenizer_input_names = self.tokenizer.model_input_names
477 image_processor_input_names = self.image_processor.model_input_names
478 return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))