modeling_deepseekocr2.py
| 1 | from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM |
| 2 | from .configuration_deepseek_v2 import DeepseekV2Config |
| 3 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| 4 | from typing import List, Optional, Tuple, Union |
| 5 | from transformers.cache_utils import Cache |
| 6 | import requests |
| 7 | from PIL import Image, ImageOps, ImageDraw, ImageFont |
| 8 | from io import BytesIO |
| 9 | import torch |
| 10 | import torch.nn as nn |
| 11 | from torch.nn import CrossEntropyLoss |
| 12 | from torchvision import transforms |
| 13 | # from torchvision.transforms.functional import InterpolationMode |
| 14 | import os |
| 15 | from .deepencoderv2 import build_sam_vit_b, build_qwen2_decoder_as_encoder, MlpProjector |
| 16 | from addict import Dict |
| 17 | from transformers import TextStreamer |
| 18 | from .conversation import get_conv_template |
| 19 | from abc import ABC |
| 20 | import math |
| 21 | import re |
| 22 | from tqdm import tqdm |
| 23 | import numpy as np |
| 24 | # import time |
| 25 | |
| 26 | |
| 27 | |
| 28 | def load_image(image_path): |
| 29 | |
| 30 | try: |
| 31 | image = Image.open(image_path) |
| 32 | |
| 33 | corrected_image = ImageOps.exif_transpose(image) |
| 34 | |
| 35 | return corrected_image |
| 36 | |
| 37 | except Exception as e: |
| 38 | print(f"error: {e}") |
| 39 | try: |
| 40 | return Image.open(image_path) |
| 41 | except: |
| 42 | return None |
| 43 | |
| 44 | |
| 45 | def re_match(text): |
| 46 | pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' |
| 47 | matches = re.findall(pattern, text, re.DOTALL) |
| 48 | |
| 49 | # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n' |
| 50 | # new_text1 = re.sub(pattern1, '', text, flags=re.DOTALL) |
| 51 | |
| 52 | mathes_image = [] |
| 53 | mathes_other = [] |
| 54 | for a_match in matches: |
| 55 | if '<|ref|>image<|/ref|>' in a_match[0]: |
| 56 | mathes_image.append(a_match[0]) |
| 57 | else: |
| 58 | mathes_other.append(a_match[0]) |
| 59 | return matches, mathes_image, mathes_other |
| 60 | |
| 61 | |
| 62 | def extract_coordinates_and_label(ref_text, image_width, image_height): |
| 63 | |
| 64 | try: |
| 65 | label_type = ref_text[1] |
| 66 | cor_list = eval(ref_text[2]) |
| 67 | except Exception as e: |
| 68 | print(e) |
| 69 | return None |
| 70 | |
| 71 | return (label_type, cor_list) |
| 72 | |
| 73 | |
| 74 | def draw_bounding_boxes(image, refs, ouput_path): |
| 75 | |
| 76 | image_width, image_height = image.size |
| 77 | |
| 78 | img_draw = image.copy() |
| 79 | draw = ImageDraw.Draw(img_draw) |
| 80 | |
| 81 | overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) |
| 82 | draw2 = ImageDraw.Draw(overlay) |
| 83 | |
| 84 | # try: |
| 85 | # except IOError: |
| 86 | # try: |
| 87 | # font = ImageFont.truetype("DejaVuSans.ttf", 20) |
| 88 | # except IOError: |
| 89 | font = ImageFont.load_default() |
| 90 | |
| 91 | img_idx = 0 |
| 92 | |
| 93 | for i, ref in enumerate(refs): |
| 94 | try: |
| 95 | result = extract_coordinates_and_label(ref, image_width, image_height) |
| 96 | if result: |
| 97 | label_type, points_list = result |
| 98 | |
| 99 | color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255)) |
| 100 | |
| 101 | color_a = color + (20, ) |
| 102 | for points in points_list: |
| 103 | x1, y1, x2, y2 = points |
| 104 | |
| 105 | x1 = int(x1 / 999 * image_width) |
| 106 | y1 = int(y1 / 999 * image_height) |
| 107 | |
| 108 | x2 = int(x2 / 999 * image_width) |
| 109 | y2 = int(y2 / 999 * image_height) |
| 110 | |
| 111 | if label_type == 'image': |
| 112 | try: |
| 113 | cropped = image.crop((x1, y1, x2, y2)) |
| 114 | cropped.save(f"{ouput_path}/images/{img_idx}.jpg") |
| 115 | except Exception as e: |
| 116 | print(e) |
| 117 | pass |
| 118 | img_idx += 1 |
| 119 | |
| 120 | try: |
| 121 | if label_type == 'title': |
| 122 | draw.rectangle([x1, y1, x2, y2], outline=color, width=4) |
| 123 | draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) |
| 124 | else: |
| 125 | draw.rectangle([x1, y1, x2, y2], outline=color, width=2) |
| 126 | draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) |
| 127 | text_x = x1 |
| 128 | text_y = max(0, y1 - 15) |
| 129 | |
| 130 | |
| 131 | text_bbox = draw.textbbox((0, 0), label_type, font=font) |
| 132 | text_width = text_bbox[2] - text_bbox[0] |
| 133 | text_height = text_bbox[3] - text_bbox[1] |
| 134 | draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height], |
| 135 | fill=(255, 255, 255, 30)) |
| 136 | |
| 137 | draw.text((text_x, text_y), label_type, font=font, fill=color) |
| 138 | except: |
| 139 | pass |
| 140 | except: |
| 141 | continue |
| 142 | img_draw.paste(overlay, (0, 0), overlay) |
| 143 | return img_draw |
| 144 | |
| 145 | |
| 146 | def process_image_with_refs(image, ref_texts, output_path): |
| 147 | |
| 148 | result_image = draw_bounding_boxes(image, ref_texts, output_path) |
| 149 | |
| 150 | return result_image |
| 151 | |
| 152 | |
| 153 | |
| 154 | |
| 155 | |
| 156 | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
| 157 | best_ratio_diff = float('inf') |
| 158 | best_ratio = (1, 1) |
| 159 | area = width * height |
| 160 | for ratio in target_ratios: |
| 161 | target_aspect_ratio = ratio[0] / ratio[1] |
| 162 | ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
| 163 | if ratio_diff < best_ratio_diff: |
| 164 | best_ratio_diff = ratio_diff |
| 165 | best_ratio = ratio |
| 166 | elif ratio_diff == best_ratio_diff: |
| 167 | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
| 168 | best_ratio = ratio |
| 169 | # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') |
| 170 | return best_ratio |
| 171 | |
| 172 | |
| 173 | def dynamic_preprocess(image, min_num=2, max_num=6, image_size=768, use_thumbnail=False): |
| 174 | orig_width, orig_height = image.size |
| 175 | aspect_ratio = orig_width / orig_height |
| 176 | |
| 177 | # calculate the existing image aspect ratio |
| 178 | target_ratios = set( |
| 179 | (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if |
| 180 | i * j <= max_num and i * j >= min_num) |
| 181 | # print(target_ratios) |
| 182 | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
| 183 | |
| 184 | # find the closest aspect ratio to the target |
| 185 | target_aspect_ratio = find_closest_aspect_ratio( |
| 186 | aspect_ratio, target_ratios, orig_width, orig_height, image_size) |
| 187 | |
| 188 | # print(target_aspect_ratio) |
| 189 | # calculate the target width and height |
| 190 | target_width = image_size * target_aspect_ratio[0] |
| 191 | target_height = image_size * target_aspect_ratio[1] |
| 192 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
| 193 | |
| 194 | # resize the image |
| 195 | resized_img = image.resize((target_width, target_height)) |
| 196 | processed_images = [] |
| 197 | for i in range(blocks): |
| 198 | box = ( |
| 199 | (i % (target_width // image_size)) * image_size, |
| 200 | (i // (target_width // image_size)) * image_size, |
| 201 | ((i % (target_width // image_size)) + 1) * image_size, |
| 202 | ((i // (target_width // image_size)) + 1) * image_size |
| 203 | ) |
| 204 | # split the image |
| 205 | split_img = resized_img.crop(box) |
| 206 | processed_images.append(split_img) |
| 207 | assert len(processed_images) == blocks |
| 208 | if use_thumbnail and len(processed_images) != 1: |
| 209 | thumbnail_img = image.resize((image_size, image_size)) |
| 210 | processed_images.append(thumbnail_img) |
| 211 | return processed_images, target_aspect_ratio |
| 212 | |
| 213 | |
| 214 | |
| 215 | def normalize_transform(mean, std): |
| 216 | if mean is None and std is None: |
| 217 | transform = None |
| 218 | elif mean is None and std is not None: |
| 219 | mean = [0.] * len(std) |
| 220 | transform = transforms.Normalize(mean=mean, std=std) |
| 221 | elif mean is not None and std is None: |
| 222 | std = [1.] * len(mean) |
| 223 | transform = transforms.Normalize(mean=mean, std=std) |
| 224 | else: |
| 225 | transform = transforms.Normalize(mean=mean, std=std) |
| 226 | |
| 227 | return transform |
| 228 | |
| 229 | |
| 230 | |
| 231 | def format_messages( |
| 232 | conversations: List[Dict[str, str]], |
| 233 | sft_format: str = "deepseek", |
| 234 | system_prompt: str = "", |
| 235 | ): |
| 236 | """ |
| 237 | Applies the SFT template to conversation. |
| 238 | |
| 239 | Args: |
| 240 | conversations (List[Dict]): A List of messages. |
| 241 | sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". |
| 242 | system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". |
| 243 | |
| 244 | Returns: |
| 245 | sft_prompt (str): The formatted text. |
| 246 | """ |
| 247 | |
| 248 | conv = get_conv_template(sft_format) |
| 249 | conv.set_system_message(system_prompt) |
| 250 | for message in conversations: |
| 251 | conv.append_message(message["role"], message["content"].strip()) |
| 252 | sft_prompt = conv.get_prompt().strip() |
| 253 | |
| 254 | return sft_prompt |
| 255 | |
| 256 | |
| 257 | def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False): |
| 258 | t = tokenizer.encode(text, add_special_tokens=False) |
| 259 | bos_id = 0 |
| 260 | eos_id = 1 |
| 261 | if bos: |
| 262 | t = [bos_id] + t |
| 263 | if eos: |
| 264 | t = t + [eos_id] |
| 265 | |
| 266 | return t |
| 267 | |
| 268 | def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]: |
| 269 | """ |
| 270 | |
| 271 | Args: |
| 272 | conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : |
| 273 | [ |
| 274 | { |
| 275 | "role": "User", |
| 276 | "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.", |
| 277 | "images": ["./examples/table_datasets.png"] |
| 278 | }, |
| 279 | {"role": "Assistant", "content": ""}, |
| 280 | ] |
| 281 | |
| 282 | Returns: |
| 283 | pil_images (List[PIL.Image.Image]): the list of PIL images. |
| 284 | |
| 285 | """ |
| 286 | |
| 287 | pil_images = [] |
| 288 | |
| 289 | for message in conversations: |
| 290 | if "images" not in message: |
| 291 | continue |
| 292 | |
| 293 | for image_path in message["images"]: |
| 294 | # print('----------------') |
| 295 | # print(image_path) |
| 296 | # print('----------------') |
| 297 | # exit() |
| 298 | |
| 299 | # pil_img = Image.open(image_path) |
| 300 | pil_img = load_image(image_path) |
| 301 | pil_img = pil_img.convert("RGB") |
| 302 | pil_images.append(pil_img) |
| 303 | |
| 304 | return pil_images |
| 305 | |
| 306 | |
| 307 | class BaseTransform(ABC): |
| 308 | |
| 309 | def set_rng(self, *args, **kwargs): |
| 310 | pass |
| 311 | |
| 312 | def __call__(self, *args, **kwargs) -> torch.Tensor: |
| 313 | pass |
| 314 | |
| 315 | @property |
| 316 | def default_shape(self): |
| 317 | raise NotImplementedError |
| 318 | |
| 319 | |
| 320 | class BasicImageTransform(BaseTransform): |
| 321 | def __init__( |
| 322 | self, |
| 323 | mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), |
| 324 | std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), |
| 325 | normalize: bool = True |
| 326 | ): |
| 327 | self.mean = mean |
| 328 | self.std = std |
| 329 | |
| 330 | transform_pipelines = [ |
| 331 | transforms.ToTensor() |
| 332 | ] |
| 333 | |
| 334 | normalize = normalize_transform(mean, std) if normalize else nn.Identity() |
| 335 | if normalize is not None: |
| 336 | transform_pipelines.append(normalize) |
| 337 | |
| 338 | self.transform = transforms.Compose(transform_pipelines) |
| 339 | |
| 340 | def __call__(self, x): |
| 341 | x = self.transform(x) |
| 342 | return x |
| 343 | |
| 344 | class NoEOSTextStreamer(TextStreamer): |
| 345 | def on_finalized_text(self, text: str, stream_end: bool = False): |
| 346 | |
| 347 | eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False) |
| 348 | text = text.replace(eos_text, "\n") |
| 349 | print(text, flush=True, end="") |
| 350 | |
| 351 | |
| 352 | class DeepseekOCR2Config(DeepseekV2Config): |
| 353 | model_type = "DeepseekOCR2" |
| 354 | |
| 355 | class DeepseekOCR2Model(DeepseekV2Model): |
| 356 | config_class = DeepseekOCR2Config |
| 357 | |
| 358 | def __init__(self, config: DeepseekV2Config): |
| 359 | super(DeepseekOCR2Model, self).__init__(config) |
| 360 | |
| 361 | self.sam_model = build_sam_vit_b() |
| 362 | self.qwen2_model = build_qwen2_decoder_as_encoder() |
| 363 | # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2) |
| 364 | n_embed = 1280 |
| 365 | self.projector = MlpProjector(Dict(projector_type="linear", input_dim=896, n_embed=n_embed)) |
| 366 | embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) |
| 367 | # self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std) |
| 368 | self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) |
| 369 | |
| 370 | |
| 371 | |
| 372 | |
| 373 | def forward( |
| 374 | self, |
| 375 | input_ids: torch.LongTensor = None, |
| 376 | attention_mask: Optional[torch.Tensor] = None, |
| 377 | position_ids: Optional[torch.LongTensor] = None, |
| 378 | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 379 | inputs_embeds: Optional[torch.FloatTensor] = None, |
| 380 | use_cache: Optional[bool] = None, |
| 381 | output_attentions: Optional[bool] = None, |
| 382 | output_hidden_states: Optional[bool] = None, |
| 383 | images: Optional[torch.FloatTensor] = None, |
| 384 | images_seq_mask: Optional[torch.FloatTensor] = None, |
| 385 | images_spatial_crop: Optional[torch.FloatTensor] = None, |
| 386 | return_dict: Optional[bool] = None, |
| 387 | ) -> Union[Tuple, BaseModelOutputWithPast]: |
| 388 | |
| 389 | |
| 390 | |
| 391 | |
| 392 | if inputs_embeds is None: |
| 393 | # inputs_embeds = self.embed_tokens(input_ids) |
| 394 | inputs_embeds = self.get_input_embeddings()(input_ids) |
| 395 | |
| 396 | |
| 397 | |
| 398 | sam_model = getattr(self, 'sam_model', None) |
| 399 | # sam_model = self.sam_model |
| 400 | qwen2_model = getattr(self, 'qwen2_model', None) |
| 401 | |
| 402 | |
| 403 | |
| 404 | if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0: |
| 405 | |
| 406 | idx = 0 |
| 407 | |
| 408 | # sam_model = torch.jit.script(sam_model) |
| 409 | |
| 410 | # start_time = time.time() |
| 411 | for image, crop_shape in zip(images, images_spatial_crop): |
| 412 | images_in_this_batch = [] |
| 413 | |
| 414 | patches = image[0] |
| 415 | image_ori = image[1] |
| 416 | |
| 417 | with torch.no_grad(): |
| 418 | # with torch.inference_mode(): |
| 419 | |
| 420 | if torch.sum(patches).item() != 0: |
| 421 | # P, C, H, W = patches.shape |
| 422 | crop_flag = 1 |
| 423 | local_features_1 = sam_model(patches) |
| 424 | |
| 425 | local_features_2 = qwen2_model(local_features_1) |
| 426 | # vit_time = time.time() |
| 427 | local_features = local_features_2 |
| 428 | local_features = self.projector(local_features) |
| 429 | |
| 430 | |
| 431 | global_features_1 = sam_model(image_ori) |
| 432 | global_features_2 = qwen2_model(global_features_1) |
| 433 | global_features = global_features_2 |
| 434 | global_features = self.projector(global_features) |
| 435 | |
| 436 | print('=====================') |
| 437 | print('BASE: ', global_features.shape) |
| 438 | print('PATCHES: ', local_features.shape) |
| 439 | print('=====================') |
| 440 | |
| 441 | _, hw, n_dim = global_features.shape |
| 442 | # h = w = int(hw ** 0.5) |
| 443 | |
| 444 | _2, hw2, n_dim2 = local_features.shape |
| 445 | # h2 = w2 = int(hw2 ** 0.5) |
| 446 | |
| 447 | |
| 448 | global_features = global_features.view(-1, n_dim) |
| 449 | |
| 450 | |
| 451 | local_features = local_features.view(-1, n_dim2) |
| 452 | |
| 453 | global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0) |
| 454 | |
| 455 | # end_time = time.time() |
| 456 | |
| 457 | # print('sam: ', sam_time - start_time) |
| 458 | # print('vit: ', vit_time - sam_time) |
| 459 | # print('all: ', end_time - start_time) |
| 460 | |
| 461 | # exit() |
| 462 | |
| 463 | else: |
| 464 | global_features_1 = sam_model(image_ori) |
| 465 | global_features_2 = qwen2_model(global_features_1) |
| 466 | global_features = global_features_2 |
| 467 | global_features = self.projector(global_features) |
| 468 | print('=====================') |
| 469 | print('BASE: ', global_features.shape) |
| 470 | print('NO PATCHES') |
| 471 | print('=====================') |
| 472 | _, hw, n_dim = global_features.shape |
| 473 | # h = w = int(hw ** 0.5) |
| 474 | |
| 475 | |
| 476 | # global_features = global_features.view(h, w, n_dim) |
| 477 | |
| 478 | # global_features = torch.cat( |
| 479 | # [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 |
| 480 | # ) |
| 481 | |
| 482 | global_features = global_features.view(-1, n_dim) |
| 483 | |
| 484 | global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0) |
| 485 | |
| 486 | images_in_this_batch.append(global_local_features) |
| 487 | |
| 488 | |
| 489 | # print(inputs_embeds.shape) |
| 490 | |
| 491 | if images_in_this_batch: |
| 492 | images_in_this_batch = torch.cat(images_in_this_batch, dim=0) |
| 493 | # exit() |
| 494 | |
| 495 | inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch) |
| 496 | |
| 497 | idx += 1 |
| 498 | |
| 499 | |
| 500 | return super(DeepseekOCR2Model, self).forward( |
| 501 | input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, |
| 502 | inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids, |
| 503 | output_attentions=output_attentions, output_hidden_states=output_hidden_states, |
| 504 | return_dict=return_dict |
| 505 | ) |
| 506 | |
| 507 | |
| 508 | class DeepseekOCR2ForCausalLM(DeepseekV2ForCausalLM): |
| 509 | |
| 510 | config_class = DeepseekOCR2Config |
| 511 | # supports_gradient_checkpointing = True |
| 512 | |
| 513 | def __init__(self, config): |
| 514 | super(DeepseekV2ForCausalLM, self).__init__(config) |
| 515 | self.model = DeepseekOCR2Model(config) |
| 516 | |
| 517 | self.vocab_size = config.vocab_size |
| 518 | |
| 519 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| 520 | |
| 521 | # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| 522 | |
| 523 | # Initialize weights and apply final processing |
| 524 | self.post_init() |
| 525 | |
| 526 | def get_model(self): |
| 527 | return self.model |
| 528 | |
| 529 | |
| 530 | def forward( |
| 531 | self, |
| 532 | input_ids: torch.LongTensor = None, |
| 533 | attention_mask: Optional[torch.Tensor] = None, |
| 534 | position_ids: Optional[torch.LongTensor] = None, |
| 535 | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 536 | inputs_embeds: Optional[torch.FloatTensor] = None, |
| 537 | labels: Optional[torch.LongTensor] = None, |
| 538 | use_cache: Optional[bool] = None, |
| 539 | output_attentions: Optional[bool] = None, |
| 540 | output_hidden_states: Optional[bool] = None, |
| 541 | images: Optional[torch.FloatTensor] = None, |
| 542 | images_seq_mask: Optional[torch.FloatTensor] = None, |
| 543 | images_spatial_crop: Optional[torch.FloatTensor] = None, |
| 544 | return_dict: Optional[bool] = None, |
| 545 | |
| 546 | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| 547 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 548 | output_hidden_states = ( |
| 549 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 550 | ) |
| 551 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 552 | |
| 553 | |
| 554 | |
| 555 | outputs = self.model( |
| 556 | input_ids=input_ids, |
| 557 | past_key_values=past_key_values, |
| 558 | attention_mask=attention_mask, |
| 559 | position_ids=position_ids, |
| 560 | inputs_embeds=inputs_embeds, |
| 561 | use_cache=use_cache, |
| 562 | output_attentions=output_attentions, |
| 563 | output_hidden_states=output_hidden_states, |
| 564 | images=images, |
| 565 | images_seq_mask = images_seq_mask, |
| 566 | images_spatial_crop = images_spatial_crop, |
| 567 | return_dict=return_dict |
| 568 | |
| 569 | ) |
| 570 | |
| 571 | |
| 572 | |
| 573 | # print(transformer_outputs) |
| 574 | |
| 575 | hidden_states = outputs[0] |
| 576 | logits = self.lm_head(hidden_states) |
| 577 | logits = logits.float() |
| 578 | |
| 579 | # logits |
| 580 | |
| 581 | loss = None |
| 582 | if labels is not None: |
| 583 | # Shift so that tokens < n predict n |
| 584 | shift_logits = logits[..., :-1, :].contiguous() |
| 585 | shift_labels = labels[..., 1:].contiguous() |
| 586 | # Flatten the tokens |
| 587 | loss_fct = CrossEntropyLoss() |
| 588 | shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| 589 | shift_labels = shift_labels.view(-1) |
| 590 | # Enable model parallelism |
| 591 | shift_labels = shift_labels.to(shift_logits.device) |
| 592 | loss = loss_fct(shift_logits, shift_labels) |
| 593 | |
| 594 | if not return_dict: |
| 595 | output = (logits,) + outputs[1:] |
| 596 | return (loss,) + output if loss is not None else output |
| 597 | |
| 598 | return CausalLMOutputWithPast( |
| 599 | loss=loss, |
| 600 | logits=logits, |
| 601 | past_key_values=outputs.past_key_values, |
| 602 | hidden_states=outputs.hidden_states, |
| 603 | attentions=outputs.attentions, |
| 604 | ) |
| 605 | |
| 606 | |
| 607 | def prepare_inputs_for_generation( |
| 608 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
| 609 | ): |
| 610 | # Omit tokens covered by past_key_values |
| 611 | past_length = 0 |
| 612 | if past_key_values is not None: |
| 613 | if isinstance(past_key_values, Cache): |
| 614 | cache_length = past_key_values.get_seq_length() |
| 615 | past_length = past_key_values.seen_tokens |
| 616 | max_cache_length = past_key_values.get_max_length() |
| 617 | else: |
| 618 | cache_length = past_length = past_key_values[0][0].shape[2] |
| 619 | max_cache_length = None |
| 620 | |
| 621 | # Keep only the unprocessed tokens: |
| 622 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where |
| 623 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as |
| 624 | # input) |
| 625 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: |
| 626 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] |
| 627 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard |
| 628 | # input_ids based on the past_length. |
| 629 | elif past_length < input_ids.shape[1]: |
| 630 | input_ids = input_ids[:, past_length:] |
| 631 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. |
| 632 | |
| 633 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. |
| 634 | if ( |
| 635 | max_cache_length is not None |
| 636 | and attention_mask is not None |
| 637 | and cache_length + input_ids.shape[1] > max_cache_length |
| 638 | ): |
| 639 | attention_mask = attention_mask[:, -max_cache_length:] |
| 640 | |
| 641 | position_ids = kwargs.get("position_ids", None) |
| 642 | if attention_mask is not None and position_ids is None: |
| 643 | # create position_ids on the fly for batch generation |
| 644 | position_ids = attention_mask.long().cumsum(-1) - 1 |
| 645 | position_ids.masked_fill_(attention_mask == 0, 1) |
| 646 | if past_key_values: |
| 647 | position_ids = position_ids[:, -input_ids.shape[1] :] |
| 648 | |
| 649 | # if self.generation_config.cache_implementation == "static": |
| 650 | # # generation with static cache |
| 651 | # cache_position = kwargs.get("cache_position", None) |
| 652 | # if cache_position is None: |
| 653 | # past_length = 0 |
| 654 | # else: |
| 655 | # past_length = cache_position[-1] + 1 |
| 656 | # input_ids = input_ids[:, past_length:] |
| 657 | # position_ids = position_ids[:, past_length:] |
| 658 | |
| 659 | # TODO @gante we should only keep a `cache_position` in generate, and do +=1. |
| 660 | # same goes for position ids. Could also help with continued generation. |
| 661 | cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) |
| 662 | |
| 663 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step |
| 664 | if inputs_embeds is not None and past_key_values is None: |
| 665 | model_inputs = {"inputs_embeds": inputs_embeds} |
| 666 | else: |
| 667 | model_inputs = {"input_ids": input_ids} |
| 668 | |
| 669 | model_inputs.update( |
| 670 | { |
| 671 | "position_ids": position_ids, |
| 672 | "past_key_values": past_key_values, |
| 673 | "use_cache": kwargs.get("use_cache"), |
| 674 | "attention_mask": attention_mask, |
| 675 | "images": kwargs.get("images", None), |
| 676 | "images_seq_mask": kwargs.get("images_seq_mask", None), |
| 677 | "images_spatial_crop": kwargs.get("images_spatial_crop", None), |
| 678 | } |
| 679 | ) |
| 680 | return model_inputs |
| 681 | |
| 682 | |
| 683 | def disable_torch_init(self): |
| 684 | """ |
| 685 | Disable the redundant torch default initialization to accelerate model creation. |
| 686 | """ |
| 687 | import torch |
| 688 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) |
| 689 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) |
| 690 | |
| 691 | |
| 692 | |
| 693 | def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False): |
| 694 | self.disable_torch_init() |
| 695 | |
| 696 | os.makedirs(output_path, exist_ok=True) |
| 697 | os.makedirs(f'{output_path}/images', exist_ok=True) |
| 698 | |
| 699 | if prompt and image_file: |
| 700 | conversation = [ |
| 701 | { |
| 702 | "role": "<|User|>", |
| 703 | # "content": "<image>\n<|grounding|>Given the layout of the image. ", |
| 704 | "content": f'{prompt}', |
| 705 | # "content": "君不见黄河之水天上来的下一句是什么?", |
| 706 | # "content": "<image>\nFree OCR. ", |
| 707 | # "content": "<image>\nParse the figure. ", |
| 708 | # "content": "<image>\nExtract the text in the image. ", |
| 709 | "images": [f'{image_file}'], |
| 710 | }, |
| 711 | {"role": "<|Assistant|>", "content": ""}, |
| 712 | ] |
| 713 | |
| 714 | elif prompt: |
| 715 | conversation = [ |
| 716 | { |
| 717 | "role": "<|User|>", |
| 718 | # "content": "<image>\n<|grounding|>Given the layout of the image. ", |
| 719 | "content": f'{prompt}', |
| 720 | # "content": "君不见黄河之水天上来的下一句是什么?", |
| 721 | # "content": "<image>\nFree OCR. ", |
| 722 | # "content": "<image>\nParse the figure. ", |
| 723 | # "content": "<image>\nExtract the text in the image. ", |
| 724 | # "images": [f'{image_file}'], |
| 725 | }, |
| 726 | {"role": "<|Assistant|>", "content": ""}, |
| 727 | ] |
| 728 | else: |
| 729 | assert False, f'prompt is none!' |
| 730 | |
| 731 | prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='') |
| 732 | |
| 733 | patch_size = 16 |
| 734 | downsample_ratio = 4 |
| 735 | images = load_pil_images(conversation) |
| 736 | |
| 737 | valid_img_tokens = 0 |
| 738 | ratio = 1 |
| 739 | |
| 740 | image_draw = images[0].copy() |
| 741 | |
| 742 | w,h = image_draw.size |
| 743 | # print(w, h) |
| 744 | ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h))) |
| 745 | |
| 746 | |
| 747 | image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True) |
| 748 | images_seq_mask = [] |
| 749 | |
| 750 | image_token = '<image>' |
| 751 | image_token_id = 128815 |
| 752 | text_splits = prompt.split(image_token) |
| 753 | |
| 754 | images_list, images_crop_list, images_seq_mask = [], [], [] |
| 755 | tokenized_str = [] |
| 756 | images_spatial_crop = [] |
| 757 | for text_sep, image in zip(text_splits, images): |
| 758 | |
| 759 | tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False) |
| 760 | tokenized_str += tokenized_sep |
| 761 | images_seq_mask += [False] * len(tokenized_sep) |
| 762 | |
| 763 | if crop_mode: |
| 764 | |
| 765 | if image.size[0] <= 768 and image.size[1] <= 768: |
| 766 | crop_ratio = [1, 1] |
| 767 | |
| 768 | else: |
| 769 | if crop_mode: |
| 770 | # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions) |
| 771 | images_crop_raw, crop_ratio = dynamic_preprocess(image) |
| 772 | else: |
| 773 | # best_width, best_height = self.image_size, self.image_size |
| 774 | crop_ratio = [1, 1] |
| 775 | |
| 776 | """process the global view""" |
| 777 | # image = image.resize((base_size, base_size)) |
| 778 | global_view = ImageOps.pad(image, (base_size, base_size), |
| 779 | color=tuple(int(x * 255) for x in image_transform.mean)) |
| 780 | |
| 781 | if base_size == 1024: |
| 782 | valid_img_tokens += int(256 * ratio) |
| 783 | elif base_size == 1280: |
| 784 | valid_img_tokens += int(400 * ratio) |
| 785 | # elif base_size == 640: |
| 786 | # valid_img_tokens += int(100 * ratio) |
| 787 | |
| 788 | |
| 789 | |
| 790 | |
| 791 | |
| 792 | images_list.append(image_transform(global_view).to(torch.bfloat16)) |
| 793 | |
| 794 | # global_view_tensor = image_transform(global_view).to(torch.bfloat16) |
| 795 | |
| 796 | width_crop_num, height_crop_num = crop_ratio |
| 797 | |
| 798 | images_spatial_crop.append([width_crop_num, height_crop_num]) |
| 799 | |
| 800 | |
| 801 | if width_crop_num > 1 or height_crop_num > 1: |
| 802 | """process the local views""" |
| 803 | |
| 804 | for i in range(len(images_crop_raw)): |
| 805 | images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16)) |
| 806 | |
| 807 | if image_size == 768: |
| 808 | valid_img_tokens += len(images_crop_list) * 144 |
| 809 | |
| 810 | num_queries = math.ceil((image_size // patch_size) / downsample_ratio) |
| 811 | num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio) |
| 812 | |
| 813 | |
| 814 | |
| 815 | """add image tokens""" |
| 816 | |
| 817 | |
| 818 | |
| 819 | tokenized_image = ([image_token_id] * num_queries_base) * num_queries_base |
| 820 | tokenized_image += [image_token_id] |
| 821 | if width_crop_num > 1 or height_crop_num > 1: |
| 822 | tokenized_image += ([image_token_id] * (num_queries * width_crop_num)) * ( |
| 823 | num_queries * height_crop_num) |
| 824 | tokenized_str += tokenized_image |
| 825 | images_seq_mask += [True] * len(tokenized_image) |
| 826 | # num_image_tokens.append(len(tokenized_image)) |
| 827 | |
| 828 | else: |
| 829 | # best_width, best_height = self.image_size, self.image_size |
| 830 | # print(image.size, (best_width, best_height)) # check the select_best_resolutions func |
| 831 | |
| 832 | """process the global view""" |
| 833 | if image_size <= 768: |
| 834 | print('directly resize') |
| 835 | image = image.resize((image_size, image_size)) |
| 836 | # else: |
| 837 | global_view = ImageOps.pad(image, (image_size, image_size), |
| 838 | color=tuple(int(x * 255) for x in image_transform.mean)) |
| 839 | images_list.append(image_transform(global_view).to(torch.bfloat16)) |
| 840 | |
| 841 | if base_size == 1024: |
| 842 | valid_img_tokens += int(256 * ratio) |
| 843 | elif base_size == 1280: |
| 844 | valid_img_tokens += int(400 * ratio) |
| 845 | elif base_size == 640: |
| 846 | valid_img_tokens += int(100 * 1) |
| 847 | elif base_size == 512: |
| 848 | valid_img_tokens += int(64 * 1) |
| 849 | elif base_size == 768: |
| 850 | valid_img_tokens += int(144 * 1) |
| 851 | |
| 852 | width_crop_num, height_crop_num = 1, 1 |
| 853 | |
| 854 | images_spatial_crop.append([width_crop_num, height_crop_num]) |
| 855 | |
| 856 | |
| 857 | """add image tokens""" |
| 858 | num_queries = math.ceil((image_size // patch_size) / downsample_ratio) |
| 859 | |
| 860 | tokenized_image = ([image_token_id] * num_queries) * num_queries |
| 861 | tokenized_image += [image_token_id] |
| 862 | # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * ( |
| 863 | # num_queries * height_crop_num) |
| 864 | tokenized_str += tokenized_image |
| 865 | images_seq_mask += [True] * len(tokenized_image) |
| 866 | # num_image_tokens.append(len(tokenized_image)) |
| 867 | |
| 868 | |
| 869 | """process the last text split""" |
| 870 | tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False) |
| 871 | tokenized_str += tokenized_sep |
| 872 | images_seq_mask += [False] * len(tokenized_sep) |
| 873 | |
| 874 | """add the bos tokens""" |
| 875 | bos_id = 0 |
| 876 | tokenized_str = [bos_id] + tokenized_str |
| 877 | images_seq_mask = [False] + images_seq_mask |
| 878 | |
| 879 | |
| 880 | |
| 881 | input_ids = torch.LongTensor(tokenized_str) |
| 882 | |
| 883 | |
| 884 | |
| 885 | |
| 886 | images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) |
| 887 | |
| 888 | |
| 889 | if len(images_list) == 0: |
| 890 | images_ori = torch.zeros((1, 3, image_size, image_size)) |
| 891 | images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) |
| 892 | images_crop = torch.zeros((1, 3, base_size, base_size)) |
| 893 | |
| 894 | else: |
| 895 | images_ori = torch.stack(images_list, dim=0) |
| 896 | images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) |
| 897 | if images_crop_list: |
| 898 | images_crop = torch.stack(images_crop_list, dim=0) |
| 899 | else: |
| 900 | images_crop = torch.zeros((1, 3, base_size, base_size)) |
| 901 | |
| 902 | |
| 903 | |
| 904 | if not eval_mode: |
| 905 | streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) |
| 906 | with torch.autocast("cuda", dtype=torch.bfloat16): |
| 907 | with torch.no_grad(): |
| 908 | output_ids = self.generate( |
| 909 | input_ids.unsqueeze(0).cuda(), |
| 910 | images=[(images_crop.cuda(), images_ori.cuda())], |
| 911 | images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), |
| 912 | images_spatial_crop = images_spatial_crop, |
| 913 | # do_sample=False, |
| 914 | # num_beams = 1, |
| 915 | temperature=0.0, |
| 916 | eos_token_id=tokenizer.eos_token_id, |
| 917 | streamer=streamer, |
| 918 | max_new_tokens=8192, |
| 919 | no_repeat_ngram_size = 20, |
| 920 | use_cache = True |
| 921 | ) |
| 922 | |
| 923 | else: |
| 924 | with torch.autocast("cuda", dtype=torch.bfloat16): |
| 925 | with torch.no_grad(): |
| 926 | output_ids = self.generate( |
| 927 | input_ids.unsqueeze(0).cuda(), |
| 928 | images=[(images_crop.cuda(), images_ori.cuda())], |
| 929 | images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), |
| 930 | images_spatial_crop = images_spatial_crop, |
| 931 | # do_sample=False, |
| 932 | # num_beams = 1, |
| 933 | temperature=0.0, |
| 934 | eos_token_id=tokenizer.eos_token_id, |
| 935 | max_new_tokens=8192, |
| 936 | no_repeat_ngram_size = 35, |
| 937 | use_cache = True |
| 938 | ) |
| 939 | |
| 940 | |
| 941 | if '<image>' in conversation[0]['content'] and eval_mode: |
| 942 | outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) |
| 943 | stop_str = '<|end▁of▁sentence|>' |
| 944 | if outputs.endswith(stop_str): |
| 945 | outputs = outputs[:-len(stop_str)] |
| 946 | # re_match |
| 947 | outputs = outputs.strip() |
| 948 | |
| 949 | return outputs |
| 950 | |
| 951 | if '<image>' in conversation[0]['content'] and test_compress: |
| 952 | outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) |
| 953 | pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False)) |
| 954 | print('='*50) |
| 955 | print('image size: ', (w, h)) |
| 956 | print('valid image tokens: ', int(valid_img_tokens)) |
| 957 | print('output texts tokens (valid): ', pure_texts_outputs_token_length) |
| 958 | print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2)) |
| 959 | print('='*50) |
| 960 | |
| 961 | |
| 962 | if '<image>' in conversation[0]['content'] and save_results: |
| 963 | outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) |
| 964 | stop_str = '<|end▁of▁sentence|>' |
| 965 | |
| 966 | print('='*15 + 'save results:' + '='*15) |
| 967 | |
| 968 | # # # # conv.messages[-1][-1] = outputs |
| 969 | if outputs.endswith(stop_str): |
| 970 | outputs = outputs[:-len(stop_str)] |
| 971 | outputs = outputs.strip() |
| 972 | |
| 973 | matches_ref, matches_images, mathes_other = re_match(outputs) |
| 974 | # print(matches_ref) |
| 975 | result = process_image_with_refs(image_draw, matches_ref, output_path) |
| 976 | |
| 977 | |
| 978 | for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): |
| 979 | outputs = outputs.replace(a_match_image, ' + '.jpg)\n') |
| 980 | |
| 981 | for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): |
| 982 | outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') |
| 983 | |
| 984 | |
| 985 | # if 'structural formula' in conversation[0]['content']: |
| 986 | # outputs = '<smiles>' + outputs + '</smiles>' |
| 987 | with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile: |
| 988 | afile.write(outputs) |
| 989 | |
| 990 | if 'line_type' in outputs: |
| 991 | import matplotlib.pyplot as plt |
| 992 | lines = eval(outputs)['Line']['line'] |
| 993 | |
| 994 | line_type = eval(outputs)['Line']['line_type'] |
| 995 | # print(lines) |
| 996 | |
| 997 | endpoints = eval(outputs)['Line']['line_endpoint'] |
| 998 | |
| 999 | fig, ax = plt.subplots(figsize=(3,3), dpi=200) |
| 1000 | ax.set_xlim(-15, 15) |
| 1001 | ax.set_ylim(-15, 15) |
| 1002 | |
| 1003 | for idx, line in enumerate(lines): |
| 1004 | try: |
| 1005 | p0 = eval(line.split(' -- ')[0]) |
| 1006 | p1 = eval(line.split(' -- ')[-1]) |
| 1007 | |
| 1008 | if line_type[idx] == '--': |
| 1009 | ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k') |
| 1010 | else: |
| 1011 | ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k') |
| 1012 | |
| 1013 | ax.scatter(p0[0], p0[1], s=5, color = 'k') |
| 1014 | ax.scatter(p1[0], p1[1], s=5, color = 'k') |
| 1015 | except: |
| 1016 | pass |
| 1017 | |
| 1018 | for endpoint in endpoints: |
| 1019 | |
| 1020 | label = endpoint.split(': ')[0] |
| 1021 | (x, y) = eval(endpoint.split(': ')[1]) |
| 1022 | ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points', |
| 1023 | fontsize=5, fontweight='light') |
| 1024 | |
| 1025 | |
| 1026 | plt.savefig(f'{output_path}/geo.jpg') |
| 1027 | plt.close() |
| 1028 | |
| 1029 | result.save(f"{output_path}/result_with_boxes.jpg") |
| 1030 | |