image_processing_locateanything.py
4.8 KB · 128 lines · python Raw
1 # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2 #
3 # NVIDIA CORPORATION and its licensors retain all intellectual property
4 # and proprietary rights in and to this software, related documentation
5 # and any modifications thereto. Any use, reproduction, disclosure or
6 # distribution of this software and related documentation without an express
7 # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
9 """Image processor class for KimiVL."""
10
11 import math
12 import numpy as np
13 from PIL import Image
14 from typing import Optional, Union
15
16 import torch
17 from torchvision.transforms import functional as TF
18 from transformers.image_utils import ImageInput, make_list_of_images, valid_images
19 from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
20 from transformers.utils import TensorType
21 from transformers import AutoImageProcessor
22
23 MEAN = (0.5, 0.5, 0.5)
24 STD = (0.5, 0.5, 0.5)
25
26
27 class LocateAnythingImageProcessor(BaseImageProcessor):
28 model_type = "locateanything"
29
30 def __init__(
31 self,
32 patch_size: int = 14,
33 image_mean: tuple[float, float, float] = MEAN,
34 image_std: tuple[float, float, float] = STD,
35 in_token_limit: int = 4096,
36 merge_kernel_size: list[int, int] = [2, 2],
37 **kwargs,
38 ):
39 super().__init__(**kwargs)
40 self.in_token_limit = in_token_limit
41 self.patch_size = patch_size
42 self.image_mean = image_mean
43 self.image_std = image_std
44 self.merge_kernel_size = merge_kernel_size
45
46 def rescale(
47 self, image: Image.Image, merge_kernel_size: list[int, int] = [2, 2]
48 ) -> Image.Image:
49 w, h = image.size
50 patch_size = self.patch_size
51
52 if (w // patch_size) * (h // patch_size) > self.in_token_limit:
53 scale = math.sqrt(self.in_token_limit / ((w // patch_size) * (h // patch_size)))
54 new_w, new_h = int(w * scale), int(h * scale)
55 image = image.resize((new_w, new_h), Image.Resampling.BICUBIC)
56
57 new_w, new_h = image.size
58 pad_size_h = merge_kernel_size[0] * patch_size
59 pad_size_w = merge_kernel_size[1] * patch_size
60
61 target_w = math.ceil(new_w / pad_size_w) * pad_size_w
62 target_h = math.ceil(new_h / pad_size_h) * pad_size_h
63
64 if target_w != new_w or target_h != new_h:
65 image = image.resize((target_w, target_h), Image.Resampling.BICUBIC)
66
67 w, h = image.size
68 if w // patch_size >= 512 or h // patch_size >= 512:
69 raise ValueError("Exceed pos emb")
70
71 return image
72
73 def to_tensor(self, image: Image.Image) -> torch.Tensor:
74 return TF.to_tensor(image.convert("RGB"))
75
76 def normalize(self, image: torch.Tensor) -> torch.Tensor:
77 return TF.normalize(image, self.image_mean, self.image_std)
78
79 def patchify(self, image: torch.Tensor) -> tuple[torch.Tensor, list[int, int]]:
80 patch_size = self.patch_size
81 C, H, W = image.shape
82 patches = image.reshape(C, H // patch_size, patch_size, W // patch_size, patch_size)
83 patches = patches.permute(1, 3, 0, 2, 4)
84 patches = patches.contiguous().view(-1, C, patch_size, patch_size)
85 grid_hw = (H // patch_size, W // patch_size)
86 return patches, grid_hw
87
88 def _preprocess(self, image: ImageInput) -> tuple[torch.Tensor, list[int, int]]:
89 """
90 Preprocess image and patchify it.
91 Args:
92 image (`ImageInput`):
93 Image to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
94 Returns:
95 patches: torch.Tensor
96 grid_hw: list[int, int]
97 """
98 image = self.rescale(image, self.merge_kernel_size)
99 image = self.to_tensor(image)
100 image = self.normalize(image)
101 patches, grid_hw = self.patchify(image)
102 return patches, grid_hw
103
104 def preprocess(
105 self,
106 images: ImageInput,
107 return_tensors: Optional[Union[str, TensorType]] = None,
108 ) -> BatchFeature:
109 images = make_list_of_images(images)
110
111 if not valid_images(images):
112 raise ValueError(
113 "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
114 "torch.Tensor, tf.Tensor or jax.ndarray."
115 )
116
117 pixel_values, image_grid_hws = [], []
118 for image in images:
119 patches, image_grid_hw = self._preprocess(image)
120 pixel_values.append(patches)
121 image_grid_hws.append(image_grid_hw)
122 pixel_values = torch.concat(pixel_values, dim=0)
123 image_grid_hws = np.array(image_grid_hws)
124 data = {"pixel_values": pixel_values, "image_grid_hws": image_grid_hws}
125
126 return BatchFeature(data=data, tensor_type=return_tensors)
127
128 AutoImageProcessor.register("LocateAnythingImageProcessor", LocateAnythingImageProcessor)