media_utils.py
| 1 | import base64 |
| 2 | import io |
| 3 | import math |
| 4 | import os |
| 5 | from datetime import datetime, timezone |
| 6 | from typing import List, Literal, Optional, TypedDict |
| 7 | |
| 8 | import numpy as np |
| 9 | from PIL import Image |
| 10 | from pydantic import BaseModel, Field |
| 11 | |
| 12 | try: |
| 13 | from mecord import VideoReader |
| 14 | except ImportError: |
| 15 | VideoReader = None |
| 16 | |
| 17 | |
| 18 | class VideoSpec(BaseModel): |
| 19 | media_type: str = Literal['video'] |
| 20 | height: int = Field(..., gt=0, description="video frame height") |
| 21 | width: int = Field(..., gt=0, description="video frame width") |
| 22 | num_frames: int = Field(..., gt=0, description="num frames") |
| 23 | fps: float = Field(..., gt=0, description="average fps") |
| 24 | |
| 25 | # optional, help to accelerate video reading |
| 26 | key_indices: list[int] = Field(None, description="key indices") |
| 27 | frame_time_info: dict = Field(None, description="frame time info") |
| 28 | |
| 29 | |
| 30 | class ImageInput(TypedDict): |
| 31 | type: Literal['image'] |
| 32 | image: Image.Image |
| 33 | |
| 34 | |
| 35 | class VideoChunkInput(TypedDict): |
| 36 | type: Literal['video_chunk'] |
| 37 | video_chunk: List[Image.Image] |
| 38 | prompt: Optional[str] = None |
| 39 | |
| 40 | |
| 41 | MediaInput = ImageInput | VideoChunkInput |
| 42 | |
| 43 | |
| 44 | def get_video_meta(video_src: bytes | str | os.PathLike, |
| 45 | accurate: bool = True) -> dict: |
| 46 | """Get the dimensions of a video.""" |
| 47 | if isinstance(video_src, os.PathLike): |
| 48 | video_src = str(video_src) |
| 49 | # if b64 string, decode to bytes |
| 50 | if isinstance(video_src, |
| 51 | str) and video_src.startswith('data:video/mp4;base64,'): |
| 52 | video_src = base64.b64decode(video_src.split(',')[1]) |
| 53 | video = VideoReader(video_src, auto_init=accurate, num_threads=1) |
| 54 | assert video.num_frames > 0, "Invalid video format." |
| 55 | assert video.original_width > 0 and video.original_height > 0, ( |
| 56 | "Invalid video format.") |
| 57 | assert video.avg_fps > 0, "Invalid video format." |
| 58 | return VideoSpec(media_type='video', |
| 59 | height=video.original_height, |
| 60 | width=video.original_width, |
| 61 | num_frames=video.num_frames, |
| 62 | fps=video.avg_fps, |
| 63 | key_indices=video.key_indices, |
| 64 | frame_time_info=video.frame_time_info) |
| 65 | |
| 66 | |
| 67 | def timestamp_as_str(timestamp: float, |
| 68 | timestamp_mode: str = "hh:mm:ss.fff") -> str: |
| 69 | """Convert a timestamp to a string in the format of HH:MM:SS.mmm.""" |
| 70 | if timestamp_mode == "hh:mm:ss.fff": |
| 71 | return (datetime.fromtimestamp(timestamp, |
| 72 | tz=timezone.utc).strftime("%H:%M:%S") + |
| 73 | f".{int((timestamp % 1) * 1000):03d}") |
| 74 | elif timestamp_mode == "mm:ss.fff": |
| 75 | return (datetime.fromtimestamp(timestamp, |
| 76 | tz=timezone.utc).strftime("%M:%S") + |
| 77 | f".{int((timestamp % 1) * 1000):03d}") |
| 78 | elif timestamp_mode == "mm:ss": |
| 79 | return datetime.fromtimestamp(timestamp, |
| 80 | tz=timezone.utc).strftime("%M:%S") |
| 81 | else: |
| 82 | raise ValueError(f"Invalid timestamp mode: {timestamp_mode}") |
| 83 | |
| 84 | |
| 85 | def navit_resize_image( |
| 86 | width: int, |
| 87 | height: int, |
| 88 | patch_size: int, |
| 89 | merge_kernel_size: int, |
| 90 | in_patch_limit: int, |
| 91 | patch_limit_on_one_side: int, |
| 92 | fixed_output_tokens: int | None, |
| 93 | ): |
| 94 | # Apply the patch limits. |
| 95 | s1 = math.sqrt( |
| 96 | in_patch_limit / |
| 97 | (max(1.0, width // patch_size) * max(1.0, height // patch_size))) |
| 98 | s2 = patch_limit_on_one_side * patch_size / width |
| 99 | s3 = patch_limit_on_one_side * patch_size / height |
| 100 | scale = min(1.0, s1, s2, s3) |
| 101 | new_w, new_h = max(1, int(width * scale)), max(1, int(height * scale)) |
| 102 | new_w = min(new_w, patch_limit_on_one_side * patch_size) |
| 103 | new_h = min(new_h, patch_limit_on_one_side * patch_size) |
| 104 | |
| 105 | # Calculate the padding to make the height and width divisible by the merge kernel size and patch size. |
| 106 | factor = merge_kernel_size * patch_size |
| 107 | |
| 108 | pad_height = (factor - new_h % factor) % factor |
| 109 | pad_width = (factor - new_w % factor) % factor |
| 110 | |
| 111 | if fixed_output_tokens is not None: |
| 112 | num_tokens = fixed_output_tokens |
| 113 | else: |
| 114 | # Calculate new dimensions after padding and patching |
| 115 | token_height = (new_h + pad_height) // factor |
| 116 | token_width = (new_w + pad_width) // factor |
| 117 | |
| 118 | assert token_height * merge_kernel_size <= patch_limit_on_one_side, ( |
| 119 | f"token_height {token_height} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}" |
| 120 | ) |
| 121 | assert token_width * merge_kernel_size <= patch_limit_on_one_side, ( |
| 122 | f"token_width {token_width} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}" |
| 123 | ) |
| 124 | |
| 125 | num_tokens = token_height * token_width |
| 126 | return { |
| 127 | "num_tokens": num_tokens, |
| 128 | "new_width": new_w, |
| 129 | "new_height": new_h, |
| 130 | "pad_width": pad_width, |
| 131 | "pad_height": pad_height, |
| 132 | "sampled_nframes": 1, |
| 133 | } |
| 134 | |
| 135 | |
| 136 | def navit_resize_video( |
| 137 | width: int, |
| 138 | height: int, |
| 139 | nframes: int, |
| 140 | avg_fps: float, |
| 141 | sample_fps: float, |
| 142 | patch_size: int, |
| 143 | merge_kernel_size: int, |
| 144 | in_patch_limit_each_frame: int, |
| 145 | patch_limit_on_one_side: int, |
| 146 | in_patch_limit_total: int | None, |
| 147 | max_num_frames_each_video: int | None, |
| 148 | fixed_output_tokens_each_frame: int | None, |
| 149 | ): |
| 150 | sample_fps = min(sample_fps, avg_fps) |
| 151 | # Calculate the number of frames to sample based on target FPS |
| 152 | sampled_nframes = max(round(nframes * sample_fps / avg_fps), 1) |
| 153 | if max_num_frames_each_video is not None: |
| 154 | sampled_nframes = min(sampled_nframes, max_num_frames_each_video) |
| 155 | |
| 156 | if in_patch_limit_total is not None: |
| 157 | in_patch_limit_each_frame = min( |
| 158 | round(in_patch_limit_total / sampled_nframes), |
| 159 | in_patch_limit_each_frame) |
| 160 | |
| 161 | ret = navit_resize_image( |
| 162 | width, |
| 163 | height, |
| 164 | patch_size, |
| 165 | merge_kernel_size, |
| 166 | in_patch_limit_each_frame, |
| 167 | patch_limit_on_one_side, |
| 168 | fixed_output_tokens_each_frame, |
| 169 | ) |
| 170 | ret["sampled_nframes"] = sampled_nframes |
| 171 | return ret |
| 172 | |
| 173 | |
| 174 | def real_sample_fps_and_max_num_frames( |
| 175 | type_name: Literal["video", "video_chunk"], |
| 176 | sample_fps: float, |
| 177 | max_num_frames_each_video: int | None, |
| 178 | ) -> tuple[int, int | None]: |
| 179 | if type_name == "video": |
| 180 | return sample_fps, max_num_frames_each_video |
| 181 | elif type_name == "video_chunk": |
| 182 | max_num_frames_each_video = None |
| 183 | sample_fps = math.inf |
| 184 | return sample_fps, max_num_frames_each_video |
| 185 | else: |
| 186 | return math.inf, None |
| 187 | |
| 188 | |
| 189 | def _to_pil(data: str | bytes): |
| 190 | if isinstance(data, Image.Image): |
| 191 | |
| 192 | return data.convert("RGB") |
| 193 | elif isinstance(data, str): |
| 194 | if data.startswith("data:"): |
| 195 | raw_base64 = data.split(",")[1] |
| 196 | return Image.open(io.BytesIO( |
| 197 | base64.b64decode(raw_base64))).convert("RGB") |
| 198 | else: |
| 199 | return Image.open(data).convert("RGB") |
| 200 | elif isinstance(data, bytes): |
| 201 | return Image.open(io.BytesIO(data)).convert("RGB") |
| 202 | else: |
| 203 | raise ValueError(f"Unsupported data type: {type(data)}") |
| 204 | |
| 205 | |
| 206 | def ensure_media_type(media: MediaInput) -> MediaInput: |
| 207 | if media['type'] == 'image': |
| 208 | media['image'] = _to_pil(media['image']) |
| 209 | return media |
| 210 | elif media['type'] == 'video_chunk': |
| 211 | media['video_chunk'] = [ |
| 212 | _to_pil(frame) for frame in media['video_chunk'] |
| 213 | ] |
| 214 | return media |
| 215 | else: |
| 216 | raise ValueError(f"Unsupported media type: {media['type']}") |
| 217 | |
| 218 | |
| 219 | def image_to_np( |
| 220 | image: Image.Image, |
| 221 | resize_to: tuple[int, int] | None = None, |
| 222 | mode: str = "resize", |
| 223 | raise_error_for_ill_resize: bool = True, |
| 224 | ) -> np.ndarray: |
| 225 | """Convert an image to a numpy array. |
| 226 | |
| 227 | Args: |
| 228 | content: The image to convert. |
| 229 | resize_to: The size to resize the image to. |
| 230 | mode: The mode to resize the image to. |
| 231 | raise_error_for_ill_resize: Whether to raise an error for ill-sized resize. |
| 232 | |
| 233 | Returns: |
| 234 | A numpy array. |
| 235 | """ |
| 236 | assert isinstance(image, Image.Image), "image must be a PIL Image" |
| 237 | if resize_to is not None: |
| 238 | if mode == "resize": |
| 239 | image = image.resize(resize_to, resample=Image.Resampling.BICUBIC) |
| 240 | |
| 241 | elif mode == "rescale_and_pad_to_center": |
| 242 | scale = min(resize_to[0] / image.width, |
| 243 | resize_to[1] / image.height, 1.0) |
| 244 | new_width = round(image.width * scale) |
| 245 | new_height = round(image.height * scale) |
| 246 | if new_width == 0 or new_height == 0: |
| 247 | if raise_error_for_ill_resize: |
| 248 | raise ValueError( |
| 249 | f"Invalid resize to: {resize_to}, from image size: {image.size}" |
| 250 | ) |
| 251 | else: |
| 252 | return np.zeros((resize_to[1], resize_to[0], 3), |
| 253 | dtype=np.uint8) |
| 254 | |
| 255 | image = image.resize((new_width, new_height), |
| 256 | resample=Image.Resampling.BICUBIC) |
| 257 | padding_left = (resize_to[0] - new_width) // 2 |
| 258 | padding_right = resize_to[0] - new_width - padding_left |
| 259 | padding_top = (resize_to[1] - new_height) // 2 |
| 260 | padding_bottom = resize_to[1] - new_height - padding_top |
| 261 | image = np.asarray(image) |
| 262 | image = np.pad( |
| 263 | image, |
| 264 | ((padding_top, padding_bottom), (padding_left, padding_right), |
| 265 | (0, 0)), |
| 266 | mode="constant", |
| 267 | constant_values=0, |
| 268 | ) |
| 269 | assert image.shape == (resize_to[1], resize_to[0], 3) |
| 270 | |
| 271 | elif mode == "rescale_and_pad_to_rightbottom": |
| 272 | scale = min(resize_to[0] / image.width, |
| 273 | resize_to[1] / image.height, 1.0) |
| 274 | new_width = round(image.width * scale) |
| 275 | new_height = round(image.height * scale) |
| 276 | if new_width == 0 or new_height == 0: |
| 277 | if raise_error_for_ill_resize: |
| 278 | raise ValueError( |
| 279 | f"Invalid resize to: {resize_to}, from image size: {image.size}" |
| 280 | ) |
| 281 | else: |
| 282 | return np.zeros((resize_to[1], resize_to[0], 3), |
| 283 | dtype=np.uint8) |
| 284 | |
| 285 | image = image.resize((new_width, new_height), |
| 286 | resample=Image.Resampling.BICUBIC) |
| 287 | padding_right = resize_to[0] - new_width |
| 288 | padding_bottom = resize_to[1] - new_height |
| 289 | image = np.asarray(image) |
| 290 | image = np.pad( |
| 291 | image, |
| 292 | ((0, padding_bottom), (0, padding_right), (0, 0)), |
| 293 | mode="constant", |
| 294 | constant_values=0, |
| 295 | ) |
| 296 | assert image.shape == (resize_to[1], resize_to[0], 3) |
| 297 | |
| 298 | else: |
| 299 | raise ValueError(f"Invalid mode: {mode}") |
| 300 | |
| 301 | if isinstance(image, Image.Image): |
| 302 | return np.asarray(image) |
| 303 | else: |
| 304 | return image |
| 305 | |
| 306 | |
| 307 | def navit_patchify(pixel_values: np.ndarray, |
| 308 | patch_size: int) -> dict[str, np.ndarray]: |
| 309 | """Reshape the pixel values to a navit shape. |
| 310 | |
| 311 | Args: |
| 312 | pixel_values: np.ndarray, shape (t, h, w, c) |
| 313 | patch_size: int |
| 314 | |
| 315 | Returns: |
| 316 | dict[str, np.ndarray] |
| 317 | - patches: np.ndarray, shape (t * h//patch_size * w//patch_size, c, patch_size, patch_size) |
| 318 | - grid_thw: np.ndarray, (t, h//patch_size, w//patch_size) |
| 319 | """ |
| 320 | T, H, W, C = pixel_values.shape |
| 321 | assert C == 3, "pixel_values must have 3 channels" |
| 322 | |
| 323 | patches = pixel_values.reshape(T, H // patch_size, patch_size, |
| 324 | W // patch_size, patch_size, C) |
| 325 | # (T, H//patch_size, W//patch_size, C, patch_size, patch_size) |
| 326 | patches = patches.transpose(0, 1, 3, 5, 2, 4) |
| 327 | patches = patches.reshape(-1, C, patch_size, patch_size) |
| 328 | grid_thw = np.array([T, H // patch_size, W // patch_size]) |
| 329 | return {"pixel_values": patches, "grid_thw": grid_thw} |
| 330 | |
| 331 | |
| 332 | def normalize(x: np.ndarray, |
| 333 | mean, |
| 334 | std_inv, |
| 335 | pixels_dtype: np.dtype = np.float32) -> np.ndarray: |
| 336 | """Normalize the image. |
| 337 | |
| 338 | Args: |
| 339 | x: The image to normalize. The shape is (..., 3). The dtype is uint8. The range is [0, 255]. |
| 340 | mean: The mean of the image. |
| 341 | std_inv: The inverse of the std of the image. |
| 342 | pixels_dtype: The dtype of the image. |
| 343 | Returns: |
| 344 | The normalized image. The shape is (..., 3). The dtype is determined by the pixels_dtype. |
| 345 | """ |
| 346 | x = (x / 255.0).astype(pixels_dtype) |
| 347 | x -= mean |
| 348 | x *= std_inv |
| 349 | return x |
| 350 | |
| 351 | |
| 352 | def _to_tensor(data, **kwargs): |
| 353 | import torch |
| 354 | |
| 355 | if isinstance(data, np.ndarray): |
| 356 | return torch.from_numpy(data).to(**kwargs) |
| 357 | elif isinstance(data, torch.Tensor): |
| 358 | return data.to(**kwargs) |
| 359 | elif isinstance(data, list): |
| 360 | return [_to_tensor(item, **kwargs) for item in data] |
| 361 | elif isinstance(data, tuple): |
| 362 | return tuple(_to_tensor(item, **kwargs) for item in data) |
| 363 | elif isinstance(data, dict): |
| 364 | return {k: _to_tensor(v, **kwargs) for k, v in data.items()} |
| 365 | elif data is None: |
| 366 | return None |
| 367 | else: |
| 368 | raise ValueError(f"Unsupported data type: {type(data)}") |
| 369 | |