kimi_k25_vision_processing.py
| 1 | """Image processor class for Kimi-K2.5. |
| 2 | """ |
| 3 | |
| 4 | import json |
| 5 | from typing import Any, Dict, Optional, Union |
| 6 | |
| 7 | import numpy as np |
| 8 | import torch |
| 9 | from PIL import Image |
| 10 | from transformers.image_processing_utils import (BaseImageProcessor, |
| 11 | BatchFeature) |
| 12 | from transformers.utils import TensorType |
| 13 | |
| 14 | from .media_utils import (MediaInput, VideoChunkInput, _to_tensor, |
| 15 | ensure_media_type, get_video_meta, image_to_np, |
| 16 | navit_patchify, navit_resize_image, |
| 17 | navit_resize_video, normalize, |
| 18 | real_sample_fps_and_max_num_frames, timestamp_as_str) |
| 19 | |
| 20 | try: |
| 21 | from mecord import VideoReader |
| 22 | except ImportError: |
| 23 | VideoReader = None |
| 24 | |
| 25 | |
| 26 | def resampling(video_bytes: bytes, |
| 27 | sample_indices: list[int], |
| 28 | key_indices=None, |
| 29 | frame_time_info=None, |
| 30 | num_threads=4) -> str: |
| 31 | video = VideoReader(video_bytes, |
| 32 | num_threads=num_threads, |
| 33 | frame_time_info=frame_time_info, |
| 34 | key_indices=key_indices) |
| 35 | # extract target frames |
| 36 | frames = video[sample_indices] |
| 37 | frames = [Image.fromarray(frame) for frame in frames] |
| 38 | return frames |
| 39 | |
| 40 | |
| 41 | class KimiK25VisionProcessor(BaseImageProcessor): |
| 42 | model_type = "kimi_k25" |
| 43 | |
| 44 | def __init__( |
| 45 | self, |
| 46 | media_proc_cfg: dict, |
| 47 | **kwargs, |
| 48 | ): |
| 49 | super().__init__(**kwargs) |
| 50 | self.media_proc_cfg = media_proc_cfg |
| 51 | self.num_frames_per_chunk = media_proc_cfg[ |
| 52 | 'temporal_merge_kernel_size'] |
| 53 | |
| 54 | def media_tokens_calculator(self, media: MediaInput): |
| 55 | media = ensure_media_type(media) |
| 56 | ret = self.get_resize_config(media) |
| 57 | return ret['num_tokens'] |
| 58 | |
| 59 | @classmethod |
| 60 | def make_chunk_prompt(cls, timestamp_text: str) -> str: |
| 61 | return f"{timestamp_text}<|media_begin|>video<|media_content|><|media_pad|><|media_end|>" |
| 62 | |
| 63 | def split_video_chunks(self, |
| 64 | video_url: str | bytes) -> list[list[Image.Image]]: |
| 65 | # video_url should be base64 str or bytes |
| 66 | video_spec = get_video_meta(video_url) |
| 67 | sample_fps = min(self.media_proc_cfg['sample_fps'], video_spec.fps) |
| 68 | sampled_nframes = max( |
| 69 | round(video_spec.num_frames * sample_fps / video_spec.fps), 1) |
| 70 | frame_inds = np.linspace(0, video_spec.num_frames - 1, |
| 71 | sampled_nframes).round().astype(int) |
| 72 | frame_inds = frame_inds.tolist() |
| 73 | sampled_frame_ids = [] |
| 74 | temporal_merge_kernel_size = self.media_proc_cfg[ |
| 75 | "temporal_merge_kernel_size"] |
| 76 | num_chunks = 0 |
| 77 | chunk_timestamp = [] |
| 78 | for i in range(0, len(frame_inds), temporal_merge_kernel_size): |
| 79 | sampled_frame_ids.extend(frame_inds[i:i + |
| 80 | temporal_merge_kernel_size]) |
| 81 | start_time = frame_inds[i] / float(video_spec.fps) |
| 82 | timestamp_text = timestamp_as_str( |
| 83 | start_time, self.media_proc_cfg["timestamp_mode"]) |
| 84 | chunk_timestamp.append(timestamp_text) |
| 85 | num_chunks += 1 |
| 86 | |
| 87 | sampled_frames = resampling(video_url, sampled_frame_ids) |
| 88 | chunks = [] |
| 89 | for chunk_id in range(num_chunks): |
| 90 | chunk = sampled_frames[chunk_id * |
| 91 | temporal_merge_kernel_size:(chunk_id + 1) * |
| 92 | temporal_merge_kernel_size] |
| 93 | chunks.append( |
| 94 | VideoChunkInput(type="video_chunk", |
| 95 | video_chunk=chunk, |
| 96 | prompt=self.make_chunk_prompt( |
| 97 | chunk_timestamp[chunk_id]))) |
| 98 | return chunks |
| 99 | |
| 100 | def get_resize_config(self, media_input: MediaInput) -> dict: |
| 101 | if media_input['type'] == 'image': |
| 102 | w, h = media_input['image'].size |
| 103 | ret = navit_resize_image( |
| 104 | w, h, self.media_proc_cfg['patch_size'], |
| 105 | self.media_proc_cfg['merge_kernel_size'], |
| 106 | self.media_proc_cfg['in_patch_limit'], |
| 107 | self.media_proc_cfg['patch_limit_on_one_side'], |
| 108 | self.media_proc_cfg['fixed_output_tokens']) |
| 109 | return ret |
| 110 | elif media_input['type'] == 'video_chunk': |
| 111 | frame = media_input['video_chunk'][0] |
| 112 | width, height = frame.size |
| 113 | num_frames = len(media_input["video_chunk"]) |
| 114 | fps = 1.0 |
| 115 | |
| 116 | sample_fps, max_num_frames_each_video = real_sample_fps_and_max_num_frames( |
| 117 | media_input["type"], |
| 118 | self.media_proc_cfg['sample_fps'], |
| 119 | self.media_proc_cfg['max_num_frames_each_video'], |
| 120 | ) |
| 121 | |
| 122 | in_patch_limit_each_frame = self.media_proc_cfg[ |
| 123 | 'in_patch_limit_each_frame'] |
| 124 | if in_patch_limit_each_frame is None: |
| 125 | in_patch_limit_each_frame = self.media_proc_cfg[ |
| 126 | 'in_patch_limit'] |
| 127 | |
| 128 | ret = navit_resize_video( |
| 129 | width, |
| 130 | height, |
| 131 | num_frames, |
| 132 | fps, |
| 133 | sample_fps, |
| 134 | self.media_proc_cfg['patch_size'], |
| 135 | self.media_proc_cfg['merge_kernel_size'], |
| 136 | in_patch_limit_each_frame, |
| 137 | self.media_proc_cfg['patch_limit_on_one_side'], |
| 138 | self.media_proc_cfg['in_patch_limit_video'], |
| 139 | max_num_frames_each_video, |
| 140 | self.media_proc_cfg['fixed_output_tokens'], |
| 141 | ) |
| 142 | return ret |
| 143 | else: |
| 144 | raise ValueError("Unsupported type: {}".format( |
| 145 | media_input['type'])) |
| 146 | |
| 147 | def resize_image(self, image: Image.Image, new_width: int, new_height: int, |
| 148 | pad_width: int, pad_height: int) -> np.ndarray: |
| 149 | image_np = image_to_np(image, (new_width, new_height), "resize") |
| 150 | image_np = np.pad( |
| 151 | image_np, |
| 152 | ((0, pad_height), (0, pad_width), (0, 0)), |
| 153 | mode="constant", |
| 154 | constant_values=0, |
| 155 | ) |
| 156 | return image_np |
| 157 | |
| 158 | def preprocess( |
| 159 | self, |
| 160 | medias: list[MediaInput], |
| 161 | return_tensors: Optional[Union[str, TensorType]] = None, |
| 162 | ) -> BatchFeature: |
| 163 | """ |
| 164 | Preprocess a atom vision input (images/video_chunk) into model-ready tensors. |
| 165 | |
| 166 | Args: |
| 167 | medias: List of MediaInput. |
| 168 | return_tensors: Desired output format ('pt', 'np', 'tf', or None). |
| 169 | |
| 170 | Returns: |
| 171 | BatchFeature containing 'pixel_values' and 'grid_thws' tensors. |
| 172 | """ |
| 173 | if not isinstance(medias, list): |
| 174 | medias = [medias] |
| 175 | if medias: |
| 176 | pixel_values = [] |
| 177 | for item in medias: |
| 178 | item = ensure_media_type(item) |
| 179 | resize_config = self.get_resize_config(item) |
| 180 | new_width, new_height, pad_width, pad_height = resize_config[ |
| 181 | 'new_width'], resize_config['new_height'], resize_config[ |
| 182 | 'pad_width'], resize_config['pad_height'] |
| 183 | if item['type'] == 'image': |
| 184 | image = item['image'] |
| 185 | image_np = self.resize_image(image, new_width, new_height, |
| 186 | pad_width, pad_height) |
| 187 | pixel_values.append(np.expand_dims(image_np, axis=0)) |
| 188 | elif item['type'] == 'video_chunk': |
| 189 | pixels = [] |
| 190 | for frame in item['video_chunk']: |
| 191 | frame_np = self.resize_image(frame, new_width, |
| 192 | new_height, pad_width, |
| 193 | pad_height) |
| 194 | pixels.append(frame_np) |
| 195 | pixel_values.append(np.stack(pixels, axis=0)) |
| 196 | else: |
| 197 | raise ValueError("Unsupported type: {}".format( |
| 198 | item['type'])) |
| 199 | normalized_pixel_values = [] |
| 200 | image_std_inv = 1.0 / np.array(self.media_proc_cfg['image_std']) |
| 201 | image_mean = np.array(self.media_proc_cfg['image_mean']) |
| 202 | for pixels in pixel_values: |
| 203 | pixels = normalize(pixels, image_mean, image_std_inv) |
| 204 | pixels_and_thw = navit_patchify( |
| 205 | pixels, |
| 206 | self.media_proc_cfg['patch_size'], |
| 207 | ) |
| 208 | normalized_pixel_values.append(pixels_and_thw) |
| 209 | |
| 210 | pixel_values = torch.cat([ |
| 211 | _to_tensor(pixel_value['pixel_values']) |
| 212 | for pixel_value in normalized_pixel_values |
| 213 | ]) |
| 214 | grid_thws = torch.cat([ |
| 215 | _to_tensor(pixel_value['grid_thw'], |
| 216 | dtype=torch.int64).unsqueeze(0) |
| 217 | for pixel_value in normalized_pixel_values |
| 218 | ]) |
| 219 | |
| 220 | data = { |
| 221 | 'pixel_values': pixel_values, |
| 222 | 'grid_thws': grid_thws, |
| 223 | } |
| 224 | |
| 225 | else: |
| 226 | data = {} |
| 227 | |
| 228 | return BatchFeature(data=data, tensor_type=return_tensors) |
| 229 | |
| 230 | def __repr__(self): |
| 231 | return f"KimiK25VisionProcessor(media_proc_cfg={self.media_proc_cfg})" |
| 232 | |
| 233 | def to_dict(self) -> Dict[str, Any]: |
| 234 | output = super().to_dict() |
| 235 | output["media_proc_cfg"] = self.media_proc_cfg |
| 236 | if "media_processor" in output: |
| 237 | del output["media_processor"] |
| 238 | return output |
| 239 | |
| 240 | @classmethod |
| 241 | def from_dict(cls, config_dict: Dict[str, Any], **kwargs): |
| 242 | config = config_dict.copy() |
| 243 | media_proc_cfg = config.pop("media_proc_cfg", {}) |
| 244 | return cls(media_proc_cfg=media_proc_cfg, **config, **kwargs) |
| 245 | |
| 246 | def to_json_string(self): |
| 247 | dictionary = self.to_dict() |
| 248 | for key, value in dictionary.items(): |
| 249 | if hasattr(value, 'tolist'): |
| 250 | dictionary[key] = value.tolist() |
| 251 | return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" |
| 252 | |