processing_prismatic.py
12.4 KB · 258 lines · python Raw
1 """
2 processing_prismatic.py
3
4 HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration
5 specifies `siglip-224px+7b`.
6 """
7
8 from typing import Any, ClassVar, List, Optional, Tuple, Union
9
10 import timm.data
11 import torch
12 import torchvision.transforms.functional as TVF
13 from PIL import Image
14 from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
15 from transformers import PreTrainedTokenizerBase
16 from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
17 from transformers.processing_utils import ProcessorMixin
18 from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
19 from transformers.utils import TensorType
20
21
22 # === Image Processing ===
23 def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:
24 """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
25 (w, h), max_wh = image.size, max(image.size)
26 horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
27 padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
28
29 return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant")
30
31
32 class PrismaticImageProcessor(ImageProcessingMixin):
33 model_input_names: ClassVar[List[str]] = ["pixel_values"]
34
35 def __init__(
36 self,
37 use_fused_vision_backbone: bool = False,
38 image_resize_strategy: str = "letterbox",
39 input_sizes: Optional[List[Tuple[int, int, int]]] = None,
40 interpolations: Optional[List[str]] = None,
41 means: Optional[List[Tuple[float, float, float]]] = None,
42 stds: Optional[List[Tuple[float, float, float]]] = None,
43 **kwargs: str,
44 ) -> None:
45 """
46 Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be
47 created by TIMM, and edited to follow our custom `image_resize_strategy` logic.
48
49 @param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone
50 @param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >
51 @param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)
52 @param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic")
53 @param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)
54 @param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)
55 """
56 self.use_fused_vision_backbone = use_fused_vision_backbone
57 self.image_resize_strategy = image_resize_strategy
58
59 # Handle `None` default values
60 input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes
61 means = [(0.5, 0.5, 0.5)] if means is None else means
62 stds = [(0.5, 0.5, 0.5)] if stds is None else stds
63
64 # TIMM `data_cfg` Parameters
65 self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds
66
67 # Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
68 self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []
69 self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
70
71 for idx in range(len(input_sizes)):
72 transform = timm.data.create_transform(
73 input_size=self.input_sizes[idx],
74 interpolation=self.interpolations[idx],
75 mean=self.means[idx],
76 std=self.stds[idx],
77 crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)
78 crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0`
79 is_training=False, # No image augmentations when loading the transform!
80 )
81
82 # [Validation] Ensure appropriate transform structure, expected sizes
83 if not (
84 isinstance(transform, Compose)
85 and (len(transform.transforms) == 4)
86 and isinstance(transform.transforms[0], Resize)
87 and isinstance(transform.transforms[1], CenterCrop)
88 and isinstance(transform.transforms[2], ToTensor)
89 and isinstance(transform.transforms[3], Normalize)
90 and (transform.transforms[0].size == self.input_sizes[idx][-1])
91 and (transform.transforms[1].size == self.input_sizes[idx][-2:])
92 ):
93 raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`")
94
95 # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
96 # => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
97 resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]
98 self.tvf_resize_params.append(
99 {
100 "size": resize_t.size,
101 "interpolation": TVF.pil_modes_mapping[resize_t.interpolation],
102 "max_size": None,
103 "antialias": True,
104 }
105 )
106 self.tvf_crop_params.append({"output_size": crop_t.size})
107 self.tvf_normalize_params.append(
108 {
109 "mean": norm_t.mean.float().numpy().tolist(),
110 "std": norm_t.std.float().numpy().tolist(),
111 "inplace": False,
112 }
113 )
114 self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
115
116 # Handle Prismatic `image_resize_strategy`
117 if self.image_resize_strategy == "resize-naive":
118 self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
119 elif self.image_resize_strategy == "letterbox":
120 self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])
121 elif self.image_resize_strategy == "resize-crop":
122 pass
123 else:
124 raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!")
125
126 # Dispatch **kwargs to super()
127 super().__init__(**kwargs)
128
129 def apply_transform(self, img: Image.Image) -> torch.Tensor:
130 """Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])"""
131 if self.tvf_do_letterbox:
132 img = letterbox_pad_transform(img, self.tvf_letterbox_fill)
133
134 # [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side!
135 imgs_t = []
136 for idx in range(len(self.input_sizes)):
137 img_idx = TVF.resize(img, **self.tvf_resize_params[idx])
138 img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])
139 img_idx_t = TVF.to_tensor(img_idx)
140 img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])
141 imgs_t.append(img_idx_t)
142
143 # [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0
144 img_t = torch.vstack(imgs_t)
145
146 return img_t
147
148 def preprocess(
149 self,
150 images: Union[Image.Image, List[Image.Image]],
151 return_tensors: Optional[Union[str, TensorType]] = None,
152 **_: str,
153 ) -> BatchFeature:
154 """
155 Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we
156 explicitly only handle PIL.Image.Image instances for simplicity.
157
158 @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
159 @param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray
160
161 @return: Instance of `transformers :: BatchFeature` with a single key "pixel_values"
162 """
163 if not isinstance(images, list):
164 images = [images]
165
166 # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
167 pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
168
169 # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
170 return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors)
171
172 def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
173 return self.preprocess(images, **kwargs)
174
175
176 # === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===
177 # =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py
178 class PrismaticProcessor(ProcessorMixin):
179 attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"]
180 image_processor_class: str = "AutoImageProcessor"
181 tokenizer_class: str = "AutoTokenizer"
182
183 def __init__(
184 self,
185 image_processor: Optional[ImageProcessingMixin] = None,
186 tokenizer: Optional[PreTrainedTokenizerBase] = None,
187 ) -> None:
188 super().__init__(image_processor, tokenizer)
189
190 def __call__(
191 self,
192 text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
193 images: Union[Image.Image, List[Image.Image]],
194 padding: Union[bool, str, PaddingStrategy] = False,
195 truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
196 max_length: Optional[int] = None,
197 return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
198 ) -> BatchFeature:
199 """
200 Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,
201 forwards images to PrismaticImageProcessor.
202
203 @param text: The (batch) of text to encode; must be a string or list of strings.
204 @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
205 @param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False >
206 @param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified
207 @param max_length: Maximum length (in tokens) to truncate
208 @param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH)
209
210 @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
211 """
212 pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
213 text_inputs = self.tokenizer(
214 text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
215 )
216
217 # [Validate] Need same number of images and text inputs!
218 if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
219 raise ValueError("Batch is malformed; expected same number of images and text inputs!")
220
221 return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
222
223 # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
224 def batch_decode(
225 self,
226 sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
227 skip_special_tokens: bool = False,
228 clean_up_tokenization_spaces: Optional[bool] = None,
229 **kwargs: str,
230 ) -> List[str]:
231 return self.tokenizer.batch_decode(
232 sequences=sequences,
233 skip_special_tokens=skip_special_tokens,
234 clean_up_tokenization_spaces=clean_up_tokenization_spaces,
235 **kwargs,
236 )
237
238 def decode(
239 self,
240 token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
241 skip_special_tokens: bool = False,
242 clean_up_tokenization_spaces: Optional[bool] = None,
243 **kwargs: str,
244 ) -> str:
245 return self.tokenizer.decode(
246 token_ids=token_ids,
247 skip_special_tokens=skip_special_tokens,
248 clean_up_tokenization_spaces=clean_up_tokenization_spaces,
249 **kwargs,
250 )
251
252 @property
253 def model_input_names(self) -> List[str]:
254 tokenizer_input_names = self.tokenizer.model_input_names
255 image_processor_input_names = self.image_processor.model_input_names
256
257 return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
258