processing_prismatic.py
| 1 | """ |
| 2 | processing_prismatic.py |
| 3 | |
| 4 | HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration |
| 5 | specifies `siglip-224px+7b`. |
| 6 | """ |
| 7 | |
| 8 | from typing import Any, ClassVar, List, Optional, Tuple, Union |
| 9 | |
| 10 | import timm.data |
| 11 | import torch |
| 12 | import torchvision.transforms.functional as TVF |
| 13 | from PIL import Image |
| 14 | from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor |
| 15 | from transformers import PreTrainedTokenizerBase |
| 16 | from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin |
| 17 | from transformers.processing_utils import ProcessorMixin |
| 18 | from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy |
| 19 | from transformers.utils import TensorType |
| 20 | |
| 21 | |
| 22 | # === Image Processing === |
| 23 | def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image: |
| 24 | """Given a PIL.Image, pad to square by adding a symmetric border around the height/width.""" |
| 25 | (w, h), max_wh = image.size, max(image.size) |
| 26 | horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2) |
| 27 | padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) |
| 28 | |
| 29 | return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant") |
| 30 | |
| 31 | |
| 32 | class PrismaticImageProcessor(ImageProcessingMixin): |
| 33 | model_input_names: ClassVar[List[str]] = ["pixel_values"] |
| 34 | |
| 35 | def __init__( |
| 36 | self, |
| 37 | use_fused_vision_backbone: bool = False, |
| 38 | image_resize_strategy: str = "letterbox", |
| 39 | input_sizes: Optional[List[Tuple[int, int, int]]] = None, |
| 40 | interpolations: Optional[List[str]] = None, |
| 41 | means: Optional[List[Tuple[float, float, float]]] = None, |
| 42 | stds: Optional[List[Tuple[float, float, float]]] = None, |
| 43 | **kwargs: str, |
| 44 | ) -> None: |
| 45 | """ |
| 46 | Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be |
| 47 | created by TIMM, and edited to follow our custom `image_resize_strategy` logic. |
| 48 | |
| 49 | @param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone |
| 50 | @param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox > |
| 51 | @param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height) |
| 52 | @param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic") |
| 53 | @param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`) |
| 54 | @param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`) |
| 55 | """ |
| 56 | self.use_fused_vision_backbone = use_fused_vision_backbone |
| 57 | self.image_resize_strategy = image_resize_strategy |
| 58 | |
| 59 | # Handle `None` default values |
| 60 | input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes |
| 61 | means = [(0.5, 0.5, 0.5)] if means is None else means |
| 62 | stds = [(0.5, 0.5, 0.5)] if stds is None else stds |
| 63 | |
| 64 | # TIMM `data_cfg` Parameters |
| 65 | self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds |
| 66 | |
| 67 | # Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values! |
| 68 | self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], [] |
| 69 | self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None |
| 70 | |
| 71 | for idx in range(len(input_sizes)): |
| 72 | transform = timm.data.create_transform( |
| 73 | input_size=self.input_sizes[idx], |
| 74 | interpolation=self.interpolations[idx], |
| 75 | mean=self.means[idx], |
| 76 | std=self.stds[idx], |
| 77 | crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`) |
| 78 | crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0` |
| 79 | is_training=False, # No image augmentations when loading the transform! |
| 80 | ) |
| 81 | |
| 82 | # [Validation] Ensure appropriate transform structure, expected sizes |
| 83 | if not ( |
| 84 | isinstance(transform, Compose) |
| 85 | and (len(transform.transforms) == 4) |
| 86 | and isinstance(transform.transforms[0], Resize) |
| 87 | and isinstance(transform.transforms[1], CenterCrop) |
| 88 | and isinstance(transform.transforms[2], ToTensor) |
| 89 | and isinstance(transform.transforms[3], Normalize) |
| 90 | and (transform.transforms[0].size == self.input_sizes[idx][-1]) |
| 91 | and (transform.transforms[1].size == self.input_sizes[idx][-2:]) |
| 92 | ): |
| 93 | raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`") |
| 94 | |
| 95 | # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute. |
| 96 | # => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`) |
| 97 | resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3] |
| 98 | self.tvf_resize_params.append( |
| 99 | { |
| 100 | "size": resize_t.size, |
| 101 | "interpolation": TVF.pil_modes_mapping[resize_t.interpolation], |
| 102 | "max_size": None, |
| 103 | "antialias": True, |
| 104 | } |
| 105 | ) |
| 106 | self.tvf_crop_params.append({"output_size": crop_t.size}) |
| 107 | self.tvf_normalize_params.append( |
| 108 | { |
| 109 | "mean": norm_t.mean.float().numpy().tolist(), |
| 110 | "std": norm_t.std.float().numpy().tolist(), |
| 111 | "inplace": False, |
| 112 | } |
| 113 | ) |
| 114 | self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None |
| 115 | |
| 116 | # Handle Prismatic `image_resize_strategy` |
| 117 | if self.image_resize_strategy == "resize-naive": |
| 118 | self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size) |
| 119 | elif self.image_resize_strategy == "letterbox": |
| 120 | self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]]) |
| 121 | elif self.image_resize_strategy == "resize-crop": |
| 122 | pass |
| 123 | else: |
| 124 | raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!") |
| 125 | |
| 126 | # Dispatch **kwargs to super() |
| 127 | super().__init__(**kwargs) |
| 128 | |
| 129 | def apply_transform(self, img: Image.Image) -> torch.Tensor: |
| 130 | """Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])""" |
| 131 | if self.tvf_do_letterbox: |
| 132 | img = letterbox_pad_transform(img, self.tvf_letterbox_fill) |
| 133 | |
| 134 | # [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side! |
| 135 | imgs_t = [] |
| 136 | for idx in range(len(self.input_sizes)): |
| 137 | img_idx = TVF.resize(img, **self.tvf_resize_params[idx]) |
| 138 | img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx]) |
| 139 | img_idx_t = TVF.to_tensor(img_idx) |
| 140 | img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx]) |
| 141 | imgs_t.append(img_idx_t) |
| 142 | |
| 143 | # [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0 |
| 144 | img_t = torch.vstack(imgs_t) |
| 145 | |
| 146 | return img_t |
| 147 | |
| 148 | def preprocess( |
| 149 | self, |
| 150 | images: Union[Image.Image, List[Image.Image]], |
| 151 | return_tensors: Optional[Union[str, TensorType]] = None, |
| 152 | **_: str, |
| 153 | ) -> BatchFeature: |
| 154 | """ |
| 155 | Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we |
| 156 | explicitly only handle PIL.Image.Image instances for simplicity. |
| 157 | |
| 158 | @param images: A (batch of) PIL.Image.Image instance(s) to preprocess. |
| 159 | @param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray |
| 160 | |
| 161 | @return: Instance of `transformers :: BatchFeature` with a single key "pixel_values" |
| 162 | """ |
| 163 | if not isinstance(images, list): |
| 164 | images = [images] |
| 165 | |
| 166 | # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor |
| 167 | pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images]) |
| 168 | |
| 169 | # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert |
| 170 | return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors) |
| 171 | |
| 172 | def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature: |
| 173 | return self.preprocess(images, **kwargs) |
| 174 | |
| 175 | |
| 176 | # === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer === |
| 177 | # =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py |
| 178 | class PrismaticProcessor(ProcessorMixin): |
| 179 | attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"] |
| 180 | image_processor_class: str = "AutoImageProcessor" |
| 181 | tokenizer_class: str = "AutoTokenizer" |
| 182 | |
| 183 | def __init__( |
| 184 | self, |
| 185 | image_processor: Optional[ImageProcessingMixin] = None, |
| 186 | tokenizer: Optional[PreTrainedTokenizerBase] = None, |
| 187 | ) -> None: |
| 188 | super().__init__(image_processor, tokenizer) |
| 189 | |
| 190 | def __call__( |
| 191 | self, |
| 192 | text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], |
| 193 | images: Union[Image.Image, List[Image.Image]], |
| 194 | padding: Union[bool, str, PaddingStrategy] = False, |
| 195 | truncation: Optional[Union[bool, str, TruncationStrategy]] = None, |
| 196 | max_length: Optional[int] = None, |
| 197 | return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, |
| 198 | ) -> BatchFeature: |
| 199 | """ |
| 200 | Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer, |
| 201 | forwards images to PrismaticImageProcessor. |
| 202 | |
| 203 | @param text: The (batch) of text to encode; must be a string or list of strings. |
| 204 | @param images: A (batch of) PIL.Image.Image instance(s) to preprocess. |
| 205 | @param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False > |
| 206 | @param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified |
| 207 | @param max_length: Maximum length (in tokens) to truncate |
| 208 | @param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH) |
| 209 | |
| 210 | @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`. |
| 211 | """ |
| 212 | pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] |
| 213 | text_inputs = self.tokenizer( |
| 214 | text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length |
| 215 | ) |
| 216 | |
| 217 | # [Validate] Need same number of images and text inputs! |
| 218 | if pixel_values.shape[0] != text_inputs.input_ids.shape[0]: |
| 219 | raise ValueError("Batch is malformed; expected same number of images and text inputs!") |
| 220 | |
| 221 | return BatchFeature(data={**text_inputs, "pixel_values": pixel_values}) |
| 222 | |
| 223 | # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation === |
| 224 | def batch_decode( |
| 225 | self, |
| 226 | sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor |
| 227 | skip_special_tokens: bool = False, |
| 228 | clean_up_tokenization_spaces: Optional[bool] = None, |
| 229 | **kwargs: str, |
| 230 | ) -> List[str]: |
| 231 | return self.tokenizer.batch_decode( |
| 232 | sequences=sequences, |
| 233 | skip_special_tokens=skip_special_tokens, |
| 234 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
| 235 | **kwargs, |
| 236 | ) |
| 237 | |
| 238 | def decode( |
| 239 | self, |
| 240 | token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor |
| 241 | skip_special_tokens: bool = False, |
| 242 | clean_up_tokenization_spaces: Optional[bool] = None, |
| 243 | **kwargs: str, |
| 244 | ) -> str: |
| 245 | return self.tokenizer.decode( |
| 246 | token_ids=token_ids, |
| 247 | skip_special_tokens=skip_special_tokens, |
| 248 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
| 249 | **kwargs, |
| 250 | ) |
| 251 | |
| 252 | @property |
| 253 | def model_input_names(self) -> List[str]: |
| 254 | tokenizer_input_names = self.tokenizer.model_input_names |
| 255 | image_processor_input_names = self.image_processor.model_input_names |
| 256 | |
| 257 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |
| 258 | |