media_utils.py
12.9 KB · 369 lines · python Raw
1 import base64
2 import io
3 import math
4 import os
5 from datetime import datetime, timezone
6 from typing import List, Literal, Optional, TypedDict
7
8 import numpy as np
9 from PIL import Image
10 from pydantic import BaseModel, Field
11
12 try:
13 from mecord import VideoReader
14 except ImportError:
15 VideoReader = None
16
17
18 class VideoSpec(BaseModel):
19 media_type: str = Literal['video']
20 height: int = Field(..., gt=0, description="video frame height")
21 width: int = Field(..., gt=0, description="video frame width")
22 num_frames: int = Field(..., gt=0, description="num frames")
23 fps: float = Field(..., gt=0, description="average fps")
24
25 # optional, help to accelerate video reading
26 key_indices: list[int] = Field(None, description="key indices")
27 frame_time_info: dict = Field(None, description="frame time info")
28
29
30 class ImageInput(TypedDict):
31 type: Literal['image']
32 image: Image.Image
33
34
35 class VideoChunkInput(TypedDict):
36 type: Literal['video_chunk']
37 video_chunk: List[Image.Image]
38 prompt: Optional[str] = None
39
40
41 MediaInput = ImageInput | VideoChunkInput
42
43
44 def get_video_meta(video_src: bytes | str | os.PathLike,
45 accurate: bool = True) -> dict:
46 """Get the dimensions of a video."""
47 if isinstance(video_src, os.PathLike):
48 video_src = str(video_src)
49 # if b64 string, decode to bytes
50 if isinstance(video_src,
51 str) and video_src.startswith('data:video/mp4;base64,'):
52 video_src = base64.b64decode(video_src.split(',')[1])
53 video = VideoReader(video_src, auto_init=accurate, num_threads=1)
54 assert video.num_frames > 0, "Invalid video format."
55 assert video.original_width > 0 and video.original_height > 0, (
56 "Invalid video format.")
57 assert video.avg_fps > 0, "Invalid video format."
58 return VideoSpec(media_type='video',
59 height=video.original_height,
60 width=video.original_width,
61 num_frames=video.num_frames,
62 fps=video.avg_fps,
63 key_indices=video.key_indices,
64 frame_time_info=video.frame_time_info)
65
66
67 def timestamp_as_str(timestamp: float,
68 timestamp_mode: str = "hh:mm:ss.fff") -> str:
69 """Convert a timestamp to a string in the format of HH:MM:SS.mmm."""
70 if timestamp_mode == "hh:mm:ss.fff":
71 return (datetime.fromtimestamp(timestamp,
72 tz=timezone.utc).strftime("%H:%M:%S") +
73 f".{int((timestamp % 1) * 1000):03d}")
74 elif timestamp_mode == "mm:ss.fff":
75 return (datetime.fromtimestamp(timestamp,
76 tz=timezone.utc).strftime("%M:%S") +
77 f".{int((timestamp % 1) * 1000):03d}")
78 elif timestamp_mode == "mm:ss":
79 return datetime.fromtimestamp(timestamp,
80 tz=timezone.utc).strftime("%M:%S")
81 else:
82 raise ValueError(f"Invalid timestamp mode: {timestamp_mode}")
83
84
85 def navit_resize_image(
86 width: int,
87 height: int,
88 patch_size: int,
89 merge_kernel_size: int,
90 in_patch_limit: int,
91 patch_limit_on_one_side: int,
92 fixed_output_tokens: int | None,
93 ):
94 # Apply the patch limits.
95 s1 = math.sqrt(
96 in_patch_limit /
97 (max(1.0, width // patch_size) * max(1.0, height // patch_size)))
98 s2 = patch_limit_on_one_side * patch_size / width
99 s3 = patch_limit_on_one_side * patch_size / height
100 scale = min(1.0, s1, s2, s3)
101 new_w, new_h = max(1, int(width * scale)), max(1, int(height * scale))
102 new_w = min(new_w, patch_limit_on_one_side * patch_size)
103 new_h = min(new_h, patch_limit_on_one_side * patch_size)
104
105 # Calculate the padding to make the height and width divisible by the merge kernel size and patch size.
106 factor = merge_kernel_size * patch_size
107
108 pad_height = (factor - new_h % factor) % factor
109 pad_width = (factor - new_w % factor) % factor
110
111 if fixed_output_tokens is not None:
112 num_tokens = fixed_output_tokens
113 else:
114 # Calculate new dimensions after padding and patching
115 token_height = (new_h + pad_height) // factor
116 token_width = (new_w + pad_width) // factor
117
118 assert token_height * merge_kernel_size <= patch_limit_on_one_side, (
119 f"token_height {token_height} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
120 )
121 assert token_width * merge_kernel_size <= patch_limit_on_one_side, (
122 f"token_width {token_width} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
123 )
124
125 num_tokens = token_height * token_width
126 return {
127 "num_tokens": num_tokens,
128 "new_width": new_w,
129 "new_height": new_h,
130 "pad_width": pad_width,
131 "pad_height": pad_height,
132 "sampled_nframes": 1,
133 }
134
135
136 def navit_resize_video(
137 width: int,
138 height: int,
139 nframes: int,
140 avg_fps: float,
141 sample_fps: float,
142 patch_size: int,
143 merge_kernel_size: int,
144 in_patch_limit_each_frame: int,
145 patch_limit_on_one_side: int,
146 in_patch_limit_total: int | None,
147 max_num_frames_each_video: int | None,
148 fixed_output_tokens_each_frame: int | None,
149 ):
150 sample_fps = min(sample_fps, avg_fps)
151 # Calculate the number of frames to sample based on target FPS
152 sampled_nframes = max(round(nframes * sample_fps / avg_fps), 1)
153 if max_num_frames_each_video is not None:
154 sampled_nframes = min(sampled_nframes, max_num_frames_each_video)
155
156 if in_patch_limit_total is not None:
157 in_patch_limit_each_frame = min(
158 round(in_patch_limit_total / sampled_nframes),
159 in_patch_limit_each_frame)
160
161 ret = navit_resize_image(
162 width,
163 height,
164 patch_size,
165 merge_kernel_size,
166 in_patch_limit_each_frame,
167 patch_limit_on_one_side,
168 fixed_output_tokens_each_frame,
169 )
170 ret["sampled_nframes"] = sampled_nframes
171 return ret
172
173
174 def real_sample_fps_and_max_num_frames(
175 type_name: Literal["video", "video_chunk"],
176 sample_fps: float,
177 max_num_frames_each_video: int | None,
178 ) -> tuple[int, int | None]:
179 if type_name == "video":
180 return sample_fps, max_num_frames_each_video
181 elif type_name == "video_chunk":
182 max_num_frames_each_video = None
183 sample_fps = math.inf
184 return sample_fps, max_num_frames_each_video
185 else:
186 return math.inf, None
187
188
189 def _to_pil(data: str | bytes):
190 if isinstance(data, Image.Image):
191
192 return data.convert("RGB")
193 elif isinstance(data, str):
194 if data.startswith("data:"):
195 raw_base64 = data.split(",")[1]
196 return Image.open(io.BytesIO(
197 base64.b64decode(raw_base64))).convert("RGB")
198 else:
199 return Image.open(data).convert("RGB")
200 elif isinstance(data, bytes):
201 return Image.open(io.BytesIO(data)).convert("RGB")
202 else:
203 raise ValueError(f"Unsupported data type: {type(data)}")
204
205
206 def ensure_media_type(media: MediaInput) -> MediaInput:
207 if media['type'] == 'image':
208 media['image'] = _to_pil(media['image'])
209 return media
210 elif media['type'] == 'video_chunk':
211 media['video_chunk'] = [
212 _to_pil(frame) for frame in media['video_chunk']
213 ]
214 return media
215 else:
216 raise ValueError(f"Unsupported media type: {media['type']}")
217
218
219 def image_to_np(
220 image: Image.Image,
221 resize_to: tuple[int, int] | None = None,
222 mode: str = "resize",
223 raise_error_for_ill_resize: bool = True,
224 ) -> np.ndarray:
225 """Convert an image to a numpy array.
226
227 Args:
228 content: The image to convert.
229 resize_to: The size to resize the image to.
230 mode: The mode to resize the image to.
231 raise_error_for_ill_resize: Whether to raise an error for ill-sized resize.
232
233 Returns:
234 A numpy array.
235 """
236 assert isinstance(image, Image.Image), "image must be a PIL Image"
237 if resize_to is not None:
238 if mode == "resize":
239 image = image.resize(resize_to, resample=Image.Resampling.BICUBIC)
240
241 elif mode == "rescale_and_pad_to_center":
242 scale = min(resize_to[0] / image.width,
243 resize_to[1] / image.height, 1.0)
244 new_width = round(image.width * scale)
245 new_height = round(image.height * scale)
246 if new_width == 0 or new_height == 0:
247 if raise_error_for_ill_resize:
248 raise ValueError(
249 f"Invalid resize to: {resize_to}, from image size: {image.size}"
250 )
251 else:
252 return np.zeros((resize_to[1], resize_to[0], 3),
253 dtype=np.uint8)
254
255 image = image.resize((new_width, new_height),
256 resample=Image.Resampling.BICUBIC)
257 padding_left = (resize_to[0] - new_width) // 2
258 padding_right = resize_to[0] - new_width - padding_left
259 padding_top = (resize_to[1] - new_height) // 2
260 padding_bottom = resize_to[1] - new_height - padding_top
261 image = np.asarray(image)
262 image = np.pad(
263 image,
264 ((padding_top, padding_bottom), (padding_left, padding_right),
265 (0, 0)),
266 mode="constant",
267 constant_values=0,
268 )
269 assert image.shape == (resize_to[1], resize_to[0], 3)
270
271 elif mode == "rescale_and_pad_to_rightbottom":
272 scale = min(resize_to[0] / image.width,
273 resize_to[1] / image.height, 1.0)
274 new_width = round(image.width * scale)
275 new_height = round(image.height * scale)
276 if new_width == 0 or new_height == 0:
277 if raise_error_for_ill_resize:
278 raise ValueError(
279 f"Invalid resize to: {resize_to}, from image size: {image.size}"
280 )
281 else:
282 return np.zeros((resize_to[1], resize_to[0], 3),
283 dtype=np.uint8)
284
285 image = image.resize((new_width, new_height),
286 resample=Image.Resampling.BICUBIC)
287 padding_right = resize_to[0] - new_width
288 padding_bottom = resize_to[1] - new_height
289 image = np.asarray(image)
290 image = np.pad(
291 image,
292 ((0, padding_bottom), (0, padding_right), (0, 0)),
293 mode="constant",
294 constant_values=0,
295 )
296 assert image.shape == (resize_to[1], resize_to[0], 3)
297
298 else:
299 raise ValueError(f"Invalid mode: {mode}")
300
301 if isinstance(image, Image.Image):
302 return np.asarray(image)
303 else:
304 return image
305
306
307 def navit_patchify(pixel_values: np.ndarray,
308 patch_size: int) -> dict[str, np.ndarray]:
309 """Reshape the pixel values to a navit shape.
310
311 Args:
312 pixel_values: np.ndarray, shape (t, h, w, c)
313 patch_size: int
314
315 Returns:
316 dict[str, np.ndarray]
317 - patches: np.ndarray, shape (t * h//patch_size * w//patch_size, c, patch_size, patch_size)
318 - grid_thw: np.ndarray, (t, h//patch_size, w//patch_size)
319 """
320 T, H, W, C = pixel_values.shape
321 assert C == 3, "pixel_values must have 3 channels"
322
323 patches = pixel_values.reshape(T, H // patch_size, patch_size,
324 W // patch_size, patch_size, C)
325 # (T, H//patch_size, W//patch_size, C, patch_size, patch_size)
326 patches = patches.transpose(0, 1, 3, 5, 2, 4)
327 patches = patches.reshape(-1, C, patch_size, patch_size)
328 grid_thw = np.array([T, H // patch_size, W // patch_size])
329 return {"pixel_values": patches, "grid_thw": grid_thw}
330
331
332 def normalize(x: np.ndarray,
333 mean,
334 std_inv,
335 pixels_dtype: np.dtype = np.float32) -> np.ndarray:
336 """Normalize the image.
337
338 Args:
339 x: The image to normalize. The shape is (..., 3). The dtype is uint8. The range is [0, 255].
340 mean: The mean of the image.
341 std_inv: The inverse of the std of the image.
342 pixels_dtype: The dtype of the image.
343 Returns:
344 The normalized image. The shape is (..., 3). The dtype is determined by the pixels_dtype.
345 """
346 x = (x / 255.0).astype(pixels_dtype)
347 x -= mean
348 x *= std_inv
349 return x
350
351
352 def _to_tensor(data, **kwargs):
353 import torch
354
355 if isinstance(data, np.ndarray):
356 return torch.from_numpy(data).to(**kwargs)
357 elif isinstance(data, torch.Tensor):
358 return data.to(**kwargs)
359 elif isinstance(data, list):
360 return [_to_tensor(item, **kwargs) for item in data]
361 elif isinstance(data, tuple):
362 return tuple(_to_tensor(item, **kwargs) for item in data)
363 elif isinstance(data, dict):
364 return {k: _to_tensor(v, **kwargs) for k, v in data.items()}
365 elif data is None:
366 return None
367 else:
368 raise ValueError(f"Unsupported data type: {type(data)}")
369