kimi_k25_vision_processing.py
9.8 KB · 252 lines · python Raw
1 """Image processor class for Kimi-K2.5.
2 """
3
4 import json
5 from typing import Any, Dict, Optional, Union
6
7 import numpy as np
8 import torch
9 from PIL import Image
10 from transformers.image_processing_utils import (BaseImageProcessor,
11 BatchFeature)
12 from transformers.utils import TensorType
13
14 from .media_utils import (MediaInput, VideoChunkInput, _to_tensor,
15 ensure_media_type, get_video_meta, image_to_np,
16 navit_patchify, navit_resize_image,
17 navit_resize_video, normalize,
18 real_sample_fps_and_max_num_frames, timestamp_as_str)
19
20 try:
21 from mecord import VideoReader
22 except ImportError:
23 VideoReader = None
24
25
26 def resampling(video_bytes: bytes,
27 sample_indices: list[int],
28 key_indices=None,
29 frame_time_info=None,
30 num_threads=4) -> str:
31 video = VideoReader(video_bytes,
32 num_threads=num_threads,
33 frame_time_info=frame_time_info,
34 key_indices=key_indices)
35 # extract target frames
36 frames = video[sample_indices]
37 frames = [Image.fromarray(frame) for frame in frames]
38 return frames
39
40
41 class KimiK25VisionProcessor(BaseImageProcessor):
42 model_type = "kimi_k25"
43
44 def __init__(
45 self,
46 media_proc_cfg: dict,
47 **kwargs,
48 ):
49 super().__init__(**kwargs)
50 self.media_proc_cfg = media_proc_cfg
51 self.num_frames_per_chunk = media_proc_cfg[
52 'temporal_merge_kernel_size']
53
54 def media_tokens_calculator(self, media: MediaInput):
55 media = ensure_media_type(media)
56 ret = self.get_resize_config(media)
57 return ret['num_tokens']
58
59 @classmethod
60 def make_chunk_prompt(cls, timestamp_text: str) -> str:
61 return f"{timestamp_text}<|media_begin|>video<|media_content|><|media_pad|><|media_end|>"
62
63 def split_video_chunks(self,
64 video_url: str | bytes) -> list[list[Image.Image]]:
65 # video_url should be base64 str or bytes
66 video_spec = get_video_meta(video_url)
67 sample_fps = min(self.media_proc_cfg['sample_fps'], video_spec.fps)
68 sampled_nframes = max(
69 round(video_spec.num_frames * sample_fps / video_spec.fps), 1)
70 frame_inds = np.linspace(0, video_spec.num_frames - 1,
71 sampled_nframes).round().astype(int)
72 frame_inds = frame_inds.tolist()
73 sampled_frame_ids = []
74 temporal_merge_kernel_size = self.media_proc_cfg[
75 "temporal_merge_kernel_size"]
76 num_chunks = 0
77 chunk_timestamp = []
78 for i in range(0, len(frame_inds), temporal_merge_kernel_size):
79 sampled_frame_ids.extend(frame_inds[i:i +
80 temporal_merge_kernel_size])
81 start_time = frame_inds[i] / float(video_spec.fps)
82 timestamp_text = timestamp_as_str(
83 start_time, self.media_proc_cfg["timestamp_mode"])
84 chunk_timestamp.append(timestamp_text)
85 num_chunks += 1
86
87 sampled_frames = resampling(video_url, sampled_frame_ids)
88 chunks = []
89 for chunk_id in range(num_chunks):
90 chunk = sampled_frames[chunk_id *
91 temporal_merge_kernel_size:(chunk_id + 1) *
92 temporal_merge_kernel_size]
93 chunks.append(
94 VideoChunkInput(type="video_chunk",
95 video_chunk=chunk,
96 prompt=self.make_chunk_prompt(
97 chunk_timestamp[chunk_id])))
98 return chunks
99
100 def get_resize_config(self, media_input: MediaInput) -> dict:
101 if media_input['type'] == 'image':
102 w, h = media_input['image'].size
103 ret = navit_resize_image(
104 w, h, self.media_proc_cfg['patch_size'],
105 self.media_proc_cfg['merge_kernel_size'],
106 self.media_proc_cfg['in_patch_limit'],
107 self.media_proc_cfg['patch_limit_on_one_side'],
108 self.media_proc_cfg['fixed_output_tokens'])
109 return ret
110 elif media_input['type'] == 'video_chunk':
111 frame = media_input['video_chunk'][0]
112 width, height = frame.size
113 num_frames = len(media_input["video_chunk"])
114 fps = 1.0
115
116 sample_fps, max_num_frames_each_video = real_sample_fps_and_max_num_frames(
117 media_input["type"],
118 self.media_proc_cfg['sample_fps'],
119 self.media_proc_cfg['max_num_frames_each_video'],
120 )
121
122 in_patch_limit_each_frame = self.media_proc_cfg[
123 'in_patch_limit_each_frame']
124 if in_patch_limit_each_frame is None:
125 in_patch_limit_each_frame = self.media_proc_cfg[
126 'in_patch_limit']
127
128 ret = navit_resize_video(
129 width,
130 height,
131 num_frames,
132 fps,
133 sample_fps,
134 self.media_proc_cfg['patch_size'],
135 self.media_proc_cfg['merge_kernel_size'],
136 in_patch_limit_each_frame,
137 self.media_proc_cfg['patch_limit_on_one_side'],
138 self.media_proc_cfg['in_patch_limit_video'],
139 max_num_frames_each_video,
140 self.media_proc_cfg['fixed_output_tokens'],
141 )
142 return ret
143 else:
144 raise ValueError("Unsupported type: {}".format(
145 media_input['type']))
146
147 def resize_image(self, image: Image.Image, new_width: int, new_height: int,
148 pad_width: int, pad_height: int) -> np.ndarray:
149 image_np = image_to_np(image, (new_width, new_height), "resize")
150 image_np = np.pad(
151 image_np,
152 ((0, pad_height), (0, pad_width), (0, 0)),
153 mode="constant",
154 constant_values=0,
155 )
156 return image_np
157
158 def preprocess(
159 self,
160 medias: list[MediaInput],
161 return_tensors: Optional[Union[str, TensorType]] = None,
162 ) -> BatchFeature:
163 """
164 Preprocess a atom vision input (images/video_chunk) into model-ready tensors.
165
166 Args:
167 medias: List of MediaInput.
168 return_tensors: Desired output format ('pt', 'np', 'tf', or None).
169
170 Returns:
171 BatchFeature containing 'pixel_values' and 'grid_thws' tensors.
172 """
173 if not isinstance(medias, list):
174 medias = [medias]
175 if medias:
176 pixel_values = []
177 for item in medias:
178 item = ensure_media_type(item)
179 resize_config = self.get_resize_config(item)
180 new_width, new_height, pad_width, pad_height = resize_config[
181 'new_width'], resize_config['new_height'], resize_config[
182 'pad_width'], resize_config['pad_height']
183 if item['type'] == 'image':
184 image = item['image']
185 image_np = self.resize_image(image, new_width, new_height,
186 pad_width, pad_height)
187 pixel_values.append(np.expand_dims(image_np, axis=0))
188 elif item['type'] == 'video_chunk':
189 pixels = []
190 for frame in item['video_chunk']:
191 frame_np = self.resize_image(frame, new_width,
192 new_height, pad_width,
193 pad_height)
194 pixels.append(frame_np)
195 pixel_values.append(np.stack(pixels, axis=0))
196 else:
197 raise ValueError("Unsupported type: {}".format(
198 item['type']))
199 normalized_pixel_values = []
200 image_std_inv = 1.0 / np.array(self.media_proc_cfg['image_std'])
201 image_mean = np.array(self.media_proc_cfg['image_mean'])
202 for pixels in pixel_values:
203 pixels = normalize(pixels, image_mean, image_std_inv)
204 pixels_and_thw = navit_patchify(
205 pixels,
206 self.media_proc_cfg['patch_size'],
207 )
208 normalized_pixel_values.append(pixels_and_thw)
209
210 pixel_values = torch.cat([
211 _to_tensor(pixel_value['pixel_values'])
212 for pixel_value in normalized_pixel_values
213 ])
214 grid_thws = torch.cat([
215 _to_tensor(pixel_value['grid_thw'],
216 dtype=torch.int64).unsqueeze(0)
217 for pixel_value in normalized_pixel_values
218 ])
219
220 data = {
221 'pixel_values': pixel_values,
222 'grid_thws': grid_thws,
223 }
224
225 else:
226 data = {}
227
228 return BatchFeature(data=data, tensor_type=return_tensors)
229
230 def __repr__(self):
231 return f"KimiK25VisionProcessor(media_proc_cfg={self.media_proc_cfg})"
232
233 def to_dict(self) -> Dict[str, Any]:
234 output = super().to_dict()
235 output["media_proc_cfg"] = self.media_proc_cfg
236 if "media_processor" in output:
237 del output["media_processor"]
238 return output
239
240 @classmethod
241 def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
242 config = config_dict.copy()
243 media_proc_cfg = config.pop("media_proc_cfg", {})
244 return cls(media_proc_cfg=media_proc_cfg, **config, **kwargs)
245
246 def to_json_string(self):
247 dictionary = self.to_dict()
248 for key, value in dictionary.items():
249 if hasattr(value, 'tolist'):
250 dictionary[key] = value.tolist()
251 return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
252