processing_locateanything.py
27.8 KB · 678 lines · python Raw
1 # coding=utf-8
2 # Copyright 2024 The HuggingFace Inc. team.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 # http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15 """
16 Processor class for LocateAnything.
17 """
18
19 import math
20 import os
21 from typing import Iterable, List, Union, Literal
22 import base64
23 import sys
24 import time
25 import warnings
26 from functools import lru_cache
27 from io import BytesIO
28 import re
29 import requests
30 import torch
31 import torchvision
32 from packaging import version
33 from PIL import Image
34 from torchvision import io
35 from torchvision import transforms
36 from torchvision.transforms import InterpolationMode
37 from typing import Optional, Any
38 import numpy as np
39
40 from transformers.feature_extraction_utils import BatchFeature
41 from transformers.image_utils import ImageInput
42 try:
43 from transformers.image_utils import VideoInput
44 except ImportError:
45 VideoInput = None
46 from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
47 from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
48 from transformers.utils import logging
49 import lmdb
50 import cv2
51 import pickle
52 import decord
53
54 logger = logging.get_logger(__name__)
55
56 FPS = 2.0
57 MAX_FRAMES = 64
58 VIDEO_TOTAL_PIXELS = int(float(os.environ.get('VIDEO_MAX_PIXELS', 32000 * 28 * 28 * 0.9)))
59 logger.info(f"set VIDEO_TOTAL_PIXELS: {VIDEO_TOTAL_PIXELS}")
60
61
62 def to_rgb(pil_image: Image.Image) -> Image.Image:
63 if pil_image.mode == 'RGBA':
64 white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
65 white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
66 return white_background
67 else:
68 return pil_image.convert("RGB")
69
70 def read_img_from_lmdb_v2(image_data):
71 # special case for AgiBotWorld
72 lmdb_file, lmdb_key = image_data['lmdb_file'], image_data['lmdb_key']
73 key = lmdb_key.encode('ascii')
74 env = lmdb.open(lmdb_file, max_readers=10240, readonly=True, lock=False, readahead=False, meminit=False)
75 txn = env.begin()
76 value = txn.get(key)
77 if value is None:
78 print(f"Warning: Key {key} not found.")
79 return None
80 record = pickle.loads(value)
81 image_bgr = cv2.imdecode(np.frombuffer(record['image'], dtype=np.uint8), cv2.IMREAD_COLOR)
82 image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
83 image = Image.fromarray(image_rgb)
84
85 return image
86
87 def parse_lmdb_image_data(image_data):
88 lmdb_file = image_data['lmdb_file']
89 if not os.path.exists(lmdb_file):
90 if "/home/zhidingy/workspace/libs/eagle/Eagle2/" in lmdb_file:
91 image_data['lmdb_file'] = lmdb_file.replace("/home/zhidingy/workspace/libs/eagle/Eagle2/", "")
92 else:
93 raise ValueError(f"LMDB file {lmdb_file} does not exist")
94 # special case for AgiBotWorld
95 if 'AgiBotWorld' in image_data['lmdb_file']:
96 return read_img_from_lmdb_v2(image_data)
97
98 try:
99 env = lmdb.open(image_data['lmdb_file'], readonly=True, lock=False, max_readers=10240)
100 except Exception as e:
101 print(f"Failed to open lmdb file {image_data['lmdb_file']}. Error message: {e}", flush=True)
102 raise e
103
104 with env.begin(write=False) as txn:
105 try:
106 image_bin = txn.get(image_data['lmdb_key'].encode('ascii'))
107 buf = BytesIO(image_bin)
108 except Exception as e:
109 print(f"Failed to get image from lmdb file {image_data['lmdb_file']}. Error message: {e}", flush=True)
110 raise e
111 try:
112 image = Image.open(buf)
113 except Exception as e:
114 image_np = np.frombuffer(image_bin, dtype=np.uint8)
115 image_bgr = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
116 image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
117 image = Image.fromarray(image_rgb)
118 return image
119
120 def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image:
121 if "image" in ele:
122 image = ele["image"]
123 else:
124 image = ele["image_url"]
125 image_obj = None
126 if isinstance(image, Image.Image):
127 image_obj = image
128 elif isinstance(image, dict) and 'lmdb_file' in image:
129 image_obj = parse_lmdb_image_data(image)
130 elif image.startswith("http://") or image.startswith("https://"):
131 response = requests.get(image, stream=True)
132 image_obj = Image.open(BytesIO(response.content))
133 elif image.startswith("file://"):
134 image_obj = Image.open(image[7:])
135 elif image.startswith("data:image"):
136 if "base64," in image:
137 _, base64_data = image.split("base64,", 1)
138 data = base64.b64decode(base64_data)
139 image_obj = Image.open(BytesIO(data))
140 else:
141 image_obj = Image.open(image)
142 if image_obj is None:
143 raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
144 image = to_rgb(image_obj)
145
146 return image
147
148
149 def get_video_frame_indices(
150 ele: dict,
151 total_frames: int,
152 video_fps: int | float,
153 ) -> tuple[torch.Tensor, float]:
154 target_fps = ele.get("fps", FPS)
155 max_frames = ele.get("max_frames", MAX_FRAMES)
156
157 nframes = (total_frames / video_fps) * target_fps
158 nframes = int(round(nframes))
159 nframes = max(1, nframes)
160
161 if nframes > max_frames:
162 nframes = max_frames
163
164 nframes = min(nframes, total_frames)
165
166 if nframes == total_frames:
167 idx = torch.arange(total_frames).long()
168 else:
169 idx = torch.linspace(0, total_frames - 1, nframes).round().long()
170
171 sample_fps = nframes / max(total_frames, 1e-6) * video_fps
172
173 return idx, sample_fps
174
175 def _read_video_torchvision(
176 ele: dict,
177 ) -> (torch.Tensor, float, list):
178 """read video using torchvision.io.read_video and return also per-frame timestamps"""
179 video_path = ele["video"]
180 if version.parse(torchvision.__version__) < version.parse("0.19.0"):
181 if "http://" in video_path or "https://" in video_path:
182 warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
183 if "file://" in video_path:
184 video_path = video_path[7:]
185 st = time.time()
186
187 video, audio, info = io.read_video(
188 video_path,
189 start_pts=ele.get("video_start", 0.0),
190 end_pts=ele.get("video_end", None),
191 pts_unit="sec",
192 output_format="TCHW",
193 )
194 total_frames, video_fps = video.size(0), info["video_fps"]
195 logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
196
197 idx, sample_fps = get_video_frame_indices(ele, total_frames, video_fps)
198
199 start_time = ele.get("video_start", 0.0)
200 timestamps = (start_time + idx.to(torch.float32) / video_fps).tolist()
201
202 video = video[idx]
203 return video, sample_fps, timestamps
204
205
206 def is_decord_available() -> bool:
207 import importlib.util
208 return importlib.util.find_spec("decord") is not None
209
210 def _read_video_decord(
211 ele: dict,
212 ) -> (torch.Tensor, float, list):
213 """read video using decord.VideoReader and return also per-frame timestamps"""
214 video_path = ele["video"]
215 st = time.time()
216 vr = decord.VideoReader(video_path)
217
218 total_frames, video_fps = len(vr), vr.get_avg_fps()
219 logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
220
221 idx_tensor, sample_fps = get_video_frame_indices(ele, total_frames, video_fps)
222 idx = idx_tensor.tolist()
223
224 start_time = ele.get("video_start", 0.0)
225 timestamps = [start_time + i / video_fps for i in idx]
226
227 video = vr.get_batch(idx).asnumpy()
228 video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
229
230 return video, sample_fps, timestamps
231
232
233 VIDEO_READER_BACKENDS = {
234 "decord": _read_video_decord,
235 "torchvision": _read_video_torchvision,
236 }
237
238
239 @lru_cache(maxsize=1)
240 def get_video_reader_backend() -> str:
241 if is_decord_available():
242 video_reader_backend = "decord"
243 else:
244 video_reader_backend = "torchvision"
245 return video_reader_backend
246
247
248 def fetch_video(ele: dict, return_video_sample_fps: bool = False, video_reader_backend: str = "torchvision") -> torch.Tensor | list[Image.Image]:
249 """
250 Fetches video, samples frames, resizes based on video_total_pixels, and returns as Tensor (TCHW).
251 """
252 if isinstance(ele["video"], str):
253 video_reader_backend = video_reader_backend if video_reader_backend is not None else get_video_reader_backend()
254 try:
255 video, sample_fps, timestamps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
256 except Exception as e:
257 logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
258 video, sample_fps, timestamps = VIDEO_READER_BACKENDS["torchvision"](ele)
259
260 nframes, _, height, width = video.shape
261
262 video_total_pixels = ele.get("video_total_pixels", VIDEO_TOTAL_PIXELS)
263 current_pixels = nframes * height * width
264
265 if current_pixels > video_total_pixels:
266 scale_factor = math.sqrt(video_total_pixels / current_pixels)
267 new_height = int(height * scale_factor)
268 new_width = int(width * scale_factor)
269
270 video = transforms.functional.resize(
271 video,
272 [new_height, new_width],
273 interpolation=InterpolationMode.BICUBIC,
274 antialias=True,
275 ).float()
276 else:
277 video = video.float()
278
279 if return_video_sample_fps:
280 return video, sample_fps, timestamps
281 return video
282
283 else:
284 assert isinstance(ele["video"], (list, tuple))
285 process_info = ele.copy()
286 process_info.pop("type", None)
287 process_info.pop("video", None)
288
289 images = [
290 fetch_image({"image": video_element, **process_info})
291 for video_element in ele["video"]
292 ]
293
294 nframes = len(images)
295 timestamps = [-1 for i in range(nframes)]
296
297 # For list of images, we return list of PIL images directly,
298 # the processor will handle conversion to tensor later.
299 if return_video_sample_fps:
300 return images, process_info.get("fps", 2.0), timestamps
301 return images
302
303 class LocateAnythingProcessorKwargs(ProcessingKwargs, total=False):
304 _defaults = {
305 "text_kwargs": {
306 "padding": False,
307 },
308 "images_kwargs": {},
309 "videos_kwargs": {},
310 }
311
312
313 class LocateAnythingProcessor(ProcessorMixin):
314 attributes = ["image_processor", "tokenizer"]
315 valid_kwargs = [
316 "chat_template",
317 "num_image_tokens",
318 "image_token",
319 "video_token",
320 "images_kwargs",
321 "videos_kwargs",
322 "text_kwargs",
323 ]
324 image_processor_class = "AutoImageProcessor"
325 tokenizer_class = "AutoTokenizer"
326
327 def __init__(
328 self,
329 image_processor=None,
330 tokenizer=None,
331 chat_template=None,
332 image_token='<IMG_CONTEXT>',
333 video_token='<IMG_CONTEXT>',
334 merge_kernel_size=[2, 2], # Note: This might need adjustment based on your patch_size (14*14)
335 image_placeholder='image',
336 video_placeholder='video',
337 image_start_token='<img>',
338 image_end_token='</img>',
339 **kwargs,
340 ):
341 self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
342 self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
343 self.image_token_id = (
344 tokenizer.image_token_id
345 if getattr(tokenizer, "image_token_id", None)
346 else tokenizer.convert_tokens_to_ids(self.image_token)
347 )
348 self.video_token_id = (
349 tokenizer.video_token_id
350 if getattr(tokenizer, "video_token_id", None)
351 else tokenizer.convert_tokens_to_ids(self.video_token)
352 )
353 self.image_placeholder = image_placeholder
354 self.video_placeholder = video_placeholder
355 self.merge_kernel_size = merge_kernel_size
356 self.image_start_token = image_start_token
357 self.image_end_token = image_end_token
358 if 'auto_map' in kwargs:
359 self.auto_map = kwargs['auto_map']
360 super().__init__(image_processor, tokenizer, chat_template=chat_template)
361
362
363 def replace_media_placeholder(self, text, image_list, video_list, timestamps_list, fps_list, **output_kwargs):
364
365 num_of_images_in_this_sample = 0
366 num_of_videos_in_this_sample = 0
367 pattern = re.compile(rf"<({self.image_placeholder}|{self.video_placeholder})-(\d+)>")
368 unified_frame_list = []
369
370 def replace_in_text(text):
371 def repl(match):
372 nonlocal unified_frame_list
373 nonlocal num_of_images_in_this_sample
374 nonlocal num_of_videos_in_this_sample
375 media_type = match.group(1)
376 idx_in_list = int(match.group(2)) - 1
377 idx_mapper = {0: "first", 1: "second", 2: "third", 3: "fourth", 4: "fifth", 5: "sixth", 6: "seventh", 7: "eighth", 8: "ninth", 9: "tenth"}
378
379 if media_type == 'image':
380 # Call LocateAnythingImageProcessor with a single image in a list
381 image_inputs = self.image_processor(images=[image_list[idx_in_list]], **output_kwargs["images_kwargs"])
382
383 num_of_tokens_list = [int(h * w) // (self.image_processor.merge_kernel_size[0] * self.image_processor.merge_kernel_size[1]) for h, w in image_inputs['image_grid_hws']]
384
385 special_placeholder = f"<image {idx_in_list+1}>{self.image_start_token}{self.image_token * num_of_tokens_list[0]}{self.image_end_token}"
386 unified_frame_list.append(image_inputs)
387 num_of_images_in_this_sample += 1
388
389 elif media_type == 'video':
390 video_obj = video_list[idx_in_list]
391
392 # Convert Tensor TCHW to list of PIL Images for the ImageProcessor
393 if isinstance(video_obj, torch.Tensor):
394 # video_obj is [T, C, H, W], float, likely 0-255 or standardized
395 # LocateAnythingImageProcessor expects PIL or 0-255 inputs usually.
396 # We need to convert back to PIL or List[Tensor] compatible with make_list_of_images
397 video_frames = []
398 for i in range(video_obj.shape[0]):
399 frame = video_obj[i] # [C, H, W]
400 # Assuming fetch_video returns float tensors.
401 # If they are 0-255, convert to uint8.
402 if frame.dtype.is_floating_point and frame.max() > 1.0:
403 frame = frame.byte()
404 elif frame.dtype.is_floating_point:
405 frame = (frame * 255).byte()
406
407 img = transforms.ToPILImage()(frame)
408 video_frames.append(img)
409 elif isinstance(video_obj, list):
410 # Already list of PIL images
411 video_frames = video_obj
412 else:
413 raise ValueError("Unsupported video format")
414
415 # Call ImageProcessor with list of frames
416 video_inputs = self.image_processor(images=video_frames, **output_kwargs["videos_kwargs"])
417
418 # Calculate tokens per frame
419 num_of_tokens_list = [int(h * w) // (self.image_processor.merge_kernel_size[0] * self.image_processor.merge_kernel_size[1]) for h, w in video_inputs['image_grid_hws']]
420
421 if timestamps_list is not None and -1 not in timestamps_list:
422 frame_timestamps = timestamps_list[idx_in_list]
423 else:
424 frame_timestamps = None
425 sampled_fps = fps_list[idx_in_list] if fps_list is not None else None
426
427 if frame_timestamps is not None:
428 # Ensure lengths match (sometimes rounding might cause off-by-one if not careful, but usually safe here)
429 if len(frame_timestamps) != len(num_of_tokens_list):
430 logger.warning(f"Timestamp mismatch: {len(frame_timestamps)} vs {len(num_of_tokens_list)}")
431 min_len = min(len(frame_timestamps), len(num_of_tokens_list))
432 frame_timestamps = frame_timestamps[:min_len]
433 num_of_tokens_list = num_of_tokens_list[:min_len]
434
435 special_placeholder = [f"Frame-{i+1}-{frame_timestamps[i]:.2f}s: {self.image_start_token}{self.image_token * num_of_tokens}{self.image_end_token}" for i, num_of_tokens in enumerate(num_of_tokens_list)]
436 else:
437 special_placeholder = [f"Frame-{i+1}: {self.image_start_token}{self.image_token * num_of_tokens}{self.image_end_token}" for i, num_of_tokens in enumerate(num_of_tokens_list)]
438
439 if sampled_fps is not None:
440 special_placeholder = f"The {idx_mapper[idx_in_list]} video sampled with {sampled_fps:.2f} fps: " + "".join(special_placeholder)
441 else:
442 special_placeholder = f"The {idx_mapper[idx_in_list]} video: " + "".join(special_placeholder)
443
444 unified_frame_list.append(video_inputs)
445 num_of_videos_in_this_sample += 1
446 else:
447 raise ValueError(f'Unknown media type: {media_type}')
448 return special_placeholder
449 return pattern.sub(repl, text)
450
451 text = replace_in_text(text)
452
453 if len(unified_frame_list) > 0:
454 # Concatenate all pixel values from all images/videos in this sample
455 pixel_values = torch.cat([frame['pixel_values'] for frame in unified_frame_list], dim=0)
456 # Concatenate grid hws
457 image_grid_hws = np.concatenate([frame['image_grid_hws'] for frame in unified_frame_list], axis=0)
458 else:
459 pixel_values = torch.empty(0)
460 image_grid_hws = np.empty(0)
461
462 return text, pixel_values, image_grid_hws, num_of_images_in_this_sample, num_of_videos_in_this_sample
463
464 def __call__(
465 self,
466 images: ImageInput = None,
467 text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
468 audio=None,
469 videos: VideoInput = None,
470 **kwargs: Unpack[LocateAnythingProcessorKwargs],
471 ) -> BatchFeature:
472 output_kwargs = self._merge_kwargs(
473 LocateAnythingProcessorKwargs,
474 tokenizer_init_kwargs=self.tokenizer.init_kwargs,
475 **kwargs,
476 )
477
478 if isinstance(text, str):
479 text_list = [text]
480 elif not isinstance(text, list) and not isinstance(text[0], str):
481 raise ValueError("Invalid input text. Please provide a string, or a list of strings")
482 elif isinstance(text, list) and isinstance(text[0], str):
483 text_list = text
484
485 if images is None: images = []
486 if videos is None: videos = []
487
488 pixel_values_list = []
489 image_grid_hws_list = []
490 new_sample_list = []
491 image_start_idx = 0
492 video_start_idx = 0
493 timestamps_batch = output_kwargs['videos_kwargs'].pop("timestamps", None)
494 fps_batch = output_kwargs['videos_kwargs'].pop("fps", None)
495
496 for sample in text_list:
497 timestamps_list = timestamps_batch[video_start_idx:] if timestamps_batch is not None else None
498 fps_list = fps_batch[video_start_idx:] if fps_batch is not None else None
499
500 sample, pixel_values, image_grid_hws, num_of_images_in_this_sample, num_of_videos_in_this_sample = self.replace_media_placeholder(
501 sample, images[image_start_idx:], videos[video_start_idx:], timestamps_list, fps_list, **output_kwargs
502 )
503 new_sample_list.append(sample)
504
505 if pixel_values.numel() > 0:
506 pixel_values_list.append(pixel_values)
507 image_grid_hws_list.append(image_grid_hws)
508
509 image_start_idx += num_of_images_in_this_sample
510 video_start_idx += num_of_videos_in_this_sample
511
512 image_inputs = {}
513 if len(pixel_values_list) > 0:
514 # Concatenate across the batch
515 image_inputs['pixel_values'] = torch.cat(pixel_values_list, dim=0)
516 image_inputs['image_grid_hws'] = np.concatenate(image_grid_hws_list, axis=0)
517
518 video_inputs = {} # Video data is merged into image_inputs now
519 text_inputs = self.tokenizer(new_sample_list, **output_kwargs["text_kwargs"])
520
521 return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs})
522
523 def batch_decode(self, *args, **kwargs):
524 return self.tokenizer.batch_decode(*args, **kwargs)
525
526 def decode(self, *args, **kwargs):
527 return self.tokenizer.decode(*args, **kwargs)
528
529 @property
530 def model_input_names(self):
531 tokenizer_input_names = self.tokenizer.model_input_names
532 image_processor_input_names = self.image_processor.model_input_names
533 return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
534
535 def save_pretrained(self, save_directory, **kwargs):
536 if os.path.isfile(save_directory):
537 raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
538 os.makedirs(save_directory, exist_ok=True)
539 outputs = super().save_pretrained(save_directory, **kwargs)
540 return outputs
541
542 @classmethod
543 def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
544 processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
545 if isinstance(processor, tuple):
546 processor = processor[0]
547 return processor
548
549 def process_vision_info(
550 self,
551 conversations: list[dict] | list[list[dict]],
552 return_video_kwargs: bool = False,
553 video_reader_backend: str = "torchvision",
554 ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]:
555
556 vision_infos = self.extract_vision_info(conversations)
557 image_inputs = []
558 video_inputs = []
559 video_sample_fps_list = []
560 video_timestamps_list = []
561
562 for vision_info in vision_infos:
563 if "image" in vision_info or "image_url" in vision_info:
564 image_inputs.append(fetch_image(vision_info))
565 elif "video" in vision_info:
566 video_input, video_sample_fps, video_timestamps = fetch_video(vision_info, return_video_sample_fps=True, video_reader_backend=video_reader_backend)
567 video_sample_fps_list.append(video_sample_fps)
568 video_inputs.append(video_input)
569 video_timestamps_list.append(video_timestamps)
570 else:
571 raise ValueError("image, image_url or video should in content.")
572
573 if len(image_inputs) == 0:
574 image_inputs = None
575 if len(video_inputs) == 0:
576 video_inputs = None
577
578 if return_video_kwargs:
579 return image_inputs, video_inputs, {'fps': video_sample_fps_list, 'timestamps': video_timestamps_list}
580 return image_inputs, video_inputs
581
582 def extract_vision_info(self, conversations: list[dict] | list[list[dict]]) -> list[dict]:
583 vision_infos = []
584 if isinstance(conversations[0], dict):
585 conversations = [conversations]
586 for conversation in conversations:
587 for message in conversation:
588 if isinstance(message["content"], list):
589 for ele in message["content"]:
590 if (
591 "image" in ele
592 or "image_url" in ele
593 or "video" in ele
594 or ele["type"] in ("image", "image_url", "video")
595 ):
596 vision_infos.append(ele)
597 return vision_infos
598
599 def py_apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False):
600 assert tokenize == False, "tokenize is not supported yet"
601 result = ""
602 image_count = 0
603 video_count = 0
604
605 message_text = ""
606 for idx, message in enumerate(messages):
607 if message.get('role') != 'user': continue
608 content = message.get('content')
609 if isinstance(content, str):
610 message_text += content
611 elif isinstance(content, list):
612 for item in content:
613 if isinstance(item, dict) and "text" in item:
614 message_text += item["text"]
615 elif isinstance(item, str):
616 message_text += item
617
618 for idx, message in enumerate(messages):
619 if idx == 0 and message.get('role') != 'system':
620 result += "<|im_start|>system\n"
621 result += "You are a helpful assistant.\n"
622 result += "<|im_end|>\n"
623
624 result += f"<|im_start|>{message.get('role', '')}\n"
625 content = message.get('content')
626
627 if isinstance(content, str):
628 result += content
629 result += "<|im_end|>\n"
630 else:
631 for item in content:
632 if (isinstance(item, dict) and (item.get('type') == 'image' or 'image' in item or 'image_url' in item)):
633 image_count += 1
634 candidate_token = f"<image-{image_count}>"
635 if candidate_token not in message_text:
636 result += candidate_token
637 elif (isinstance(item, dict) and (item.get('type') == 'video' or 'video' in item)):
638 video_count += 1
639 candidate_token = f"<video-{video_count}>"
640 if candidate_token not in message_text:
641 result += candidate_token
642 elif isinstance(item, dict) and 'text' in item:
643 result += item['text']
644 elif isinstance(item, str):
645 result += item
646 result += "<|im_end|>\n"
647
648 if add_generation_prompt:
649 result += "<|im_start|>assistant\n"
650
651 return result
652
653
654 @classmethod
655 def from_args_and_dict(cls, args, processor_dict: dict[str, Any], **kwargs):
656 processor_dict = processor_dict.copy()
657 return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
658
659 if "processor_class" in processor_dict:
660 del processor_dict["processor_class"]
661
662 unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs)
663 processor = cls(*args, **processor_dict)
664
665 for key in set(kwargs.keys()):
666 if hasattr(processor, key):
667 setattr(processor, key, kwargs.pop(key))
668
669 if isinstance(unused_kwargs, dict):
670 kwargs.update(unused_kwargs)
671 logger.info(f"Processor {processor}")
672 if return_unused_kwargs:
673 return processor, kwargs
674 else:
675 return processor
676
677
678 __all__ = ["LocateAnythingProcessor"]