modeling_deepseekocr2.py
38.3 KB · 1030 lines · python Raw
1 from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM
2 from .configuration_deepseek_v2 import DeepseekV2Config
3 from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
4 from typing import List, Optional, Tuple, Union
5 from transformers.cache_utils import Cache
6 import requests
7 from PIL import Image, ImageOps, ImageDraw, ImageFont
8 from io import BytesIO
9 import torch
10 import torch.nn as nn
11 from torch.nn import CrossEntropyLoss
12 from torchvision import transforms
13 # from torchvision.transforms.functional import InterpolationMode
14 import os
15 from .deepencoderv2 import build_sam_vit_b, build_qwen2_decoder_as_encoder, MlpProjector
16 from addict import Dict
17 from transformers import TextStreamer
18 from .conversation import get_conv_template
19 from abc import ABC
20 import math
21 import re
22 from tqdm import tqdm
23 import numpy as np
24 # import time
25
26
27
28 def load_image(image_path):
29
30 try:
31 image = Image.open(image_path)
32
33 corrected_image = ImageOps.exif_transpose(image)
34
35 return corrected_image
36
37 except Exception as e:
38 print(f"error: {e}")
39 try:
40 return Image.open(image_path)
41 except:
42 return None
43
44
45 def re_match(text):
46 pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
47 matches = re.findall(pattern, text, re.DOTALL)
48
49 # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n'
50 # new_text1 = re.sub(pattern1, '', text, flags=re.DOTALL)
51
52 mathes_image = []
53 mathes_other = []
54 for a_match in matches:
55 if '<|ref|>image<|/ref|>' in a_match[0]:
56 mathes_image.append(a_match[0])
57 else:
58 mathes_other.append(a_match[0])
59 return matches, mathes_image, mathes_other
60
61
62 def extract_coordinates_and_label(ref_text, image_width, image_height):
63
64 try:
65 label_type = ref_text[1]
66 cor_list = eval(ref_text[2])
67 except Exception as e:
68 print(e)
69 return None
70
71 return (label_type, cor_list)
72
73
74 def draw_bounding_boxes(image, refs, ouput_path):
75
76 image_width, image_height = image.size
77
78 img_draw = image.copy()
79 draw = ImageDraw.Draw(img_draw)
80
81 overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
82 draw2 = ImageDraw.Draw(overlay)
83
84 # try:
85 # except IOError:
86 # try:
87 # font = ImageFont.truetype("DejaVuSans.ttf", 20)
88 # except IOError:
89 font = ImageFont.load_default()
90
91 img_idx = 0
92
93 for i, ref in enumerate(refs):
94 try:
95 result = extract_coordinates_and_label(ref, image_width, image_height)
96 if result:
97 label_type, points_list = result
98
99 color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
100
101 color_a = color + (20, )
102 for points in points_list:
103 x1, y1, x2, y2 = points
104
105 x1 = int(x1 / 999 * image_width)
106 y1 = int(y1 / 999 * image_height)
107
108 x2 = int(x2 / 999 * image_width)
109 y2 = int(y2 / 999 * image_height)
110
111 if label_type == 'image':
112 try:
113 cropped = image.crop((x1, y1, x2, y2))
114 cropped.save(f"{ouput_path}/images/{img_idx}.jpg")
115 except Exception as e:
116 print(e)
117 pass
118 img_idx += 1
119
120 try:
121 if label_type == 'title':
122 draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
123 draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
124 else:
125 draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
126 draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
127 text_x = x1
128 text_y = max(0, y1 - 15)
129
130
131 text_bbox = draw.textbbox((0, 0), label_type, font=font)
132 text_width = text_bbox[2] - text_bbox[0]
133 text_height = text_bbox[3] - text_bbox[1]
134 draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
135 fill=(255, 255, 255, 30))
136
137 draw.text((text_x, text_y), label_type, font=font, fill=color)
138 except:
139 pass
140 except:
141 continue
142 img_draw.paste(overlay, (0, 0), overlay)
143 return img_draw
144
145
146 def process_image_with_refs(image, ref_texts, output_path):
147
148 result_image = draw_bounding_boxes(image, ref_texts, output_path)
149
150 return result_image
151
152
153
154
155
156 def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
157 best_ratio_diff = float('inf')
158 best_ratio = (1, 1)
159 area = width * height
160 for ratio in target_ratios:
161 target_aspect_ratio = ratio[0] / ratio[1]
162 ratio_diff = abs(aspect_ratio - target_aspect_ratio)
163 if ratio_diff < best_ratio_diff:
164 best_ratio_diff = ratio_diff
165 best_ratio = ratio
166 elif ratio_diff == best_ratio_diff:
167 if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
168 best_ratio = ratio
169 # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
170 return best_ratio
171
172
173 def dynamic_preprocess(image, min_num=2, max_num=6, image_size=768, use_thumbnail=False):
174 orig_width, orig_height = image.size
175 aspect_ratio = orig_width / orig_height
176
177 # calculate the existing image aspect ratio
178 target_ratios = set(
179 (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
180 i * j <= max_num and i * j >= min_num)
181 # print(target_ratios)
182 target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
183
184 # find the closest aspect ratio to the target
185 target_aspect_ratio = find_closest_aspect_ratio(
186 aspect_ratio, target_ratios, orig_width, orig_height, image_size)
187
188 # print(target_aspect_ratio)
189 # calculate the target width and height
190 target_width = image_size * target_aspect_ratio[0]
191 target_height = image_size * target_aspect_ratio[1]
192 blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
193
194 # resize the image
195 resized_img = image.resize((target_width, target_height))
196 processed_images = []
197 for i in range(blocks):
198 box = (
199 (i % (target_width // image_size)) * image_size,
200 (i // (target_width // image_size)) * image_size,
201 ((i % (target_width // image_size)) + 1) * image_size,
202 ((i // (target_width // image_size)) + 1) * image_size
203 )
204 # split the image
205 split_img = resized_img.crop(box)
206 processed_images.append(split_img)
207 assert len(processed_images) == blocks
208 if use_thumbnail and len(processed_images) != 1:
209 thumbnail_img = image.resize((image_size, image_size))
210 processed_images.append(thumbnail_img)
211 return processed_images, target_aspect_ratio
212
213
214
215 def normalize_transform(mean, std):
216 if mean is None and std is None:
217 transform = None
218 elif mean is None and std is not None:
219 mean = [0.] * len(std)
220 transform = transforms.Normalize(mean=mean, std=std)
221 elif mean is not None and std is None:
222 std = [1.] * len(mean)
223 transform = transforms.Normalize(mean=mean, std=std)
224 else:
225 transform = transforms.Normalize(mean=mean, std=std)
226
227 return transform
228
229
230
231 def format_messages(
232 conversations: List[Dict[str, str]],
233 sft_format: str = "deepseek",
234 system_prompt: str = "",
235 ):
236 """
237 Applies the SFT template to conversation.
238
239 Args:
240 conversations (List[Dict]): A List of messages.
241 sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
242 system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
243
244 Returns:
245 sft_prompt (str): The formatted text.
246 """
247
248 conv = get_conv_template(sft_format)
249 conv.set_system_message(system_prompt)
250 for message in conversations:
251 conv.append_message(message["role"], message["content"].strip())
252 sft_prompt = conv.get_prompt().strip()
253
254 return sft_prompt
255
256
257 def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False):
258 t = tokenizer.encode(text, add_special_tokens=False)
259 bos_id = 0
260 eos_id = 1
261 if bos:
262 t = [bos_id] + t
263 if eos:
264 t = t + [eos_id]
265
266 return t
267
268 def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
269 """
270
271 Args:
272 conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
273 [
274 {
275 "role": "User",
276 "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
277 "images": ["./examples/table_datasets.png"]
278 },
279 {"role": "Assistant", "content": ""},
280 ]
281
282 Returns:
283 pil_images (List[PIL.Image.Image]): the list of PIL images.
284
285 """
286
287 pil_images = []
288
289 for message in conversations:
290 if "images" not in message:
291 continue
292
293 for image_path in message["images"]:
294 # print('----------------')
295 # print(image_path)
296 # print('----------------')
297 # exit()
298
299 # pil_img = Image.open(image_path)
300 pil_img = load_image(image_path)
301 pil_img = pil_img.convert("RGB")
302 pil_images.append(pil_img)
303
304 return pil_images
305
306
307 class BaseTransform(ABC):
308
309 def set_rng(self, *args, **kwargs):
310 pass
311
312 def __call__(self, *args, **kwargs) -> torch.Tensor:
313 pass
314
315 @property
316 def default_shape(self):
317 raise NotImplementedError
318
319
320 class BasicImageTransform(BaseTransform):
321 def __init__(
322 self,
323 mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
324 std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
325 normalize: bool = True
326 ):
327 self.mean = mean
328 self.std = std
329
330 transform_pipelines = [
331 transforms.ToTensor()
332 ]
333
334 normalize = normalize_transform(mean, std) if normalize else nn.Identity()
335 if normalize is not None:
336 transform_pipelines.append(normalize)
337
338 self.transform = transforms.Compose(transform_pipelines)
339
340 def __call__(self, x):
341 x = self.transform(x)
342 return x
343
344 class NoEOSTextStreamer(TextStreamer):
345 def on_finalized_text(self, text: str, stream_end: bool = False):
346
347 eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False)
348 text = text.replace(eos_text, "\n")
349 print(text, flush=True, end="")
350
351
352 class DeepseekOCR2Config(DeepseekV2Config):
353 model_type = "DeepseekOCR2"
354
355 class DeepseekOCR2Model(DeepseekV2Model):
356 config_class = DeepseekOCR2Config
357
358 def __init__(self, config: DeepseekV2Config):
359 super(DeepseekOCR2Model, self).__init__(config)
360
361 self.sam_model = build_sam_vit_b()
362 self.qwen2_model = build_qwen2_decoder_as_encoder()
363 # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2)
364 n_embed = 1280
365 self.projector = MlpProjector(Dict(projector_type="linear", input_dim=896, n_embed=n_embed))
366 embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
367 # self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
368 self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
369
370
371
372
373 def forward(
374 self,
375 input_ids: torch.LongTensor = None,
376 attention_mask: Optional[torch.Tensor] = None,
377 position_ids: Optional[torch.LongTensor] = None,
378 past_key_values: Optional[List[torch.FloatTensor]] = None,
379 inputs_embeds: Optional[torch.FloatTensor] = None,
380 use_cache: Optional[bool] = None,
381 output_attentions: Optional[bool] = None,
382 output_hidden_states: Optional[bool] = None,
383 images: Optional[torch.FloatTensor] = None,
384 images_seq_mask: Optional[torch.FloatTensor] = None,
385 images_spatial_crop: Optional[torch.FloatTensor] = None,
386 return_dict: Optional[bool] = None,
387 ) -> Union[Tuple, BaseModelOutputWithPast]:
388
389
390
391
392 if inputs_embeds is None:
393 # inputs_embeds = self.embed_tokens(input_ids)
394 inputs_embeds = self.get_input_embeddings()(input_ids)
395
396
397
398 sam_model = getattr(self, 'sam_model', None)
399 # sam_model = self.sam_model
400 qwen2_model = getattr(self, 'qwen2_model', None)
401
402
403
404 if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0:
405
406 idx = 0
407
408 # sam_model = torch.jit.script(sam_model)
409
410 # start_time = time.time()
411 for image, crop_shape in zip(images, images_spatial_crop):
412 images_in_this_batch = []
413
414 patches = image[0]
415 image_ori = image[1]
416
417 with torch.no_grad():
418 # with torch.inference_mode():
419
420 if torch.sum(patches).item() != 0:
421 # P, C, H, W = patches.shape
422 crop_flag = 1
423 local_features_1 = sam_model(patches)
424
425 local_features_2 = qwen2_model(local_features_1)
426 # vit_time = time.time()
427 local_features = local_features_2
428 local_features = self.projector(local_features)
429
430
431 global_features_1 = sam_model(image_ori)
432 global_features_2 = qwen2_model(global_features_1)
433 global_features = global_features_2
434 global_features = self.projector(global_features)
435
436 print('=====================')
437 print('BASE: ', global_features.shape)
438 print('PATCHES: ', local_features.shape)
439 print('=====================')
440
441 _, hw, n_dim = global_features.shape
442 # h = w = int(hw ** 0.5)
443
444 _2, hw2, n_dim2 = local_features.shape
445 # h2 = w2 = int(hw2 ** 0.5)
446
447
448 global_features = global_features.view(-1, n_dim)
449
450
451 local_features = local_features.view(-1, n_dim2)
452
453 global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0)
454
455 # end_time = time.time()
456
457 # print('sam: ', sam_time - start_time)
458 # print('vit: ', vit_time - sam_time)
459 # print('all: ', end_time - start_time)
460
461 # exit()
462
463 else:
464 global_features_1 = sam_model(image_ori)
465 global_features_2 = qwen2_model(global_features_1)
466 global_features = global_features_2
467 global_features = self.projector(global_features)
468 print('=====================')
469 print('BASE: ', global_features.shape)
470 print('NO PATCHES')
471 print('=====================')
472 _, hw, n_dim = global_features.shape
473 # h = w = int(hw ** 0.5)
474
475
476 # global_features = global_features.view(h, w, n_dim)
477
478 # global_features = torch.cat(
479 # [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
480 # )
481
482 global_features = global_features.view(-1, n_dim)
483
484 global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0)
485
486 images_in_this_batch.append(global_local_features)
487
488
489 # print(inputs_embeds.shape)
490
491 if images_in_this_batch:
492 images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
493 # exit()
494
495 inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
496
497 idx += 1
498
499
500 return super(DeepseekOCR2Model, self).forward(
501 input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
502 inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
503 output_attentions=output_attentions, output_hidden_states=output_hidden_states,
504 return_dict=return_dict
505 )
506
507
508 class DeepseekOCR2ForCausalLM(DeepseekV2ForCausalLM):
509
510 config_class = DeepseekOCR2Config
511 # supports_gradient_checkpointing = True
512
513 def __init__(self, config):
514 super(DeepseekV2ForCausalLM, self).__init__(config)
515 self.model = DeepseekOCR2Model(config)
516
517 self.vocab_size = config.vocab_size
518
519 self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
520
521 # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
522
523 # Initialize weights and apply final processing
524 self.post_init()
525
526 def get_model(self):
527 return self.model
528
529
530 def forward(
531 self,
532 input_ids: torch.LongTensor = None,
533 attention_mask: Optional[torch.Tensor] = None,
534 position_ids: Optional[torch.LongTensor] = None,
535 past_key_values: Optional[List[torch.FloatTensor]] = None,
536 inputs_embeds: Optional[torch.FloatTensor] = None,
537 labels: Optional[torch.LongTensor] = None,
538 use_cache: Optional[bool] = None,
539 output_attentions: Optional[bool] = None,
540 output_hidden_states: Optional[bool] = None,
541 images: Optional[torch.FloatTensor] = None,
542 images_seq_mask: Optional[torch.FloatTensor] = None,
543 images_spatial_crop: Optional[torch.FloatTensor] = None,
544 return_dict: Optional[bool] = None,
545
546 ) -> Union[Tuple, CausalLMOutputWithPast]:
547 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
548 output_hidden_states = (
549 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
550 )
551 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
552
553
554
555 outputs = self.model(
556 input_ids=input_ids,
557 past_key_values=past_key_values,
558 attention_mask=attention_mask,
559 position_ids=position_ids,
560 inputs_embeds=inputs_embeds,
561 use_cache=use_cache,
562 output_attentions=output_attentions,
563 output_hidden_states=output_hidden_states,
564 images=images,
565 images_seq_mask = images_seq_mask,
566 images_spatial_crop = images_spatial_crop,
567 return_dict=return_dict
568
569 )
570
571
572
573 # print(transformer_outputs)
574
575 hidden_states = outputs[0]
576 logits = self.lm_head(hidden_states)
577 logits = logits.float()
578
579 # logits
580
581 loss = None
582 if labels is not None:
583 # Shift so that tokens < n predict n
584 shift_logits = logits[..., :-1, :].contiguous()
585 shift_labels = labels[..., 1:].contiguous()
586 # Flatten the tokens
587 loss_fct = CrossEntropyLoss()
588 shift_logits = shift_logits.view(-1, self.config.vocab_size)
589 shift_labels = shift_labels.view(-1)
590 # Enable model parallelism
591 shift_labels = shift_labels.to(shift_logits.device)
592 loss = loss_fct(shift_logits, shift_labels)
593
594 if not return_dict:
595 output = (logits,) + outputs[1:]
596 return (loss,) + output if loss is not None else output
597
598 return CausalLMOutputWithPast(
599 loss=loss,
600 logits=logits,
601 past_key_values=outputs.past_key_values,
602 hidden_states=outputs.hidden_states,
603 attentions=outputs.attentions,
604 )
605
606
607 def prepare_inputs_for_generation(
608 self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
609 ):
610 # Omit tokens covered by past_key_values
611 past_length = 0
612 if past_key_values is not None:
613 if isinstance(past_key_values, Cache):
614 cache_length = past_key_values.get_seq_length()
615 past_length = past_key_values.seen_tokens
616 max_cache_length = past_key_values.get_max_length()
617 else:
618 cache_length = past_length = past_key_values[0][0].shape[2]
619 max_cache_length = None
620
621 # Keep only the unprocessed tokens:
622 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
623 # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
624 # input)
625 if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
626 input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
627 # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
628 # input_ids based on the past_length.
629 elif past_length < input_ids.shape[1]:
630 input_ids = input_ids[:, past_length:]
631 # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
632
633 # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
634 if (
635 max_cache_length is not None
636 and attention_mask is not None
637 and cache_length + input_ids.shape[1] > max_cache_length
638 ):
639 attention_mask = attention_mask[:, -max_cache_length:]
640
641 position_ids = kwargs.get("position_ids", None)
642 if attention_mask is not None and position_ids is None:
643 # create position_ids on the fly for batch generation
644 position_ids = attention_mask.long().cumsum(-1) - 1
645 position_ids.masked_fill_(attention_mask == 0, 1)
646 if past_key_values:
647 position_ids = position_ids[:, -input_ids.shape[1] :]
648
649 # if self.generation_config.cache_implementation == "static":
650 # # generation with static cache
651 # cache_position = kwargs.get("cache_position", None)
652 # if cache_position is None:
653 # past_length = 0
654 # else:
655 # past_length = cache_position[-1] + 1
656 # input_ids = input_ids[:, past_length:]
657 # position_ids = position_ids[:, past_length:]
658
659 # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
660 # same goes for position ids. Could also help with continued generation.
661 cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
662
663 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
664 if inputs_embeds is not None and past_key_values is None:
665 model_inputs = {"inputs_embeds": inputs_embeds}
666 else:
667 model_inputs = {"input_ids": input_ids}
668
669 model_inputs.update(
670 {
671 "position_ids": position_ids,
672 "past_key_values": past_key_values,
673 "use_cache": kwargs.get("use_cache"),
674 "attention_mask": attention_mask,
675 "images": kwargs.get("images", None),
676 "images_seq_mask": kwargs.get("images_seq_mask", None),
677 "images_spatial_crop": kwargs.get("images_spatial_crop", None),
678 }
679 )
680 return model_inputs
681
682
683 def disable_torch_init(self):
684 """
685 Disable the redundant torch default initialization to accelerate model creation.
686 """
687 import torch
688 setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
689 setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
690
691
692
693 def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
694 self.disable_torch_init()
695
696 os.makedirs(output_path, exist_ok=True)
697 os.makedirs(f'{output_path}/images', exist_ok=True)
698
699 if prompt and image_file:
700 conversation = [
701 {
702 "role": "<|User|>",
703 # "content": "<image>\n<|grounding|>Given the layout of the image. ",
704 "content": f'{prompt}',
705 # "content": "君不见黄河之水天上来的下一句是什么?",
706 # "content": "<image>\nFree OCR. ",
707 # "content": "<image>\nParse the figure. ",
708 # "content": "<image>\nExtract the text in the image. ",
709 "images": [f'{image_file}'],
710 },
711 {"role": "<|Assistant|>", "content": ""},
712 ]
713
714 elif prompt:
715 conversation = [
716 {
717 "role": "<|User|>",
718 # "content": "<image>\n<|grounding|>Given the layout of the image. ",
719 "content": f'{prompt}',
720 # "content": "君不见黄河之水天上来的下一句是什么?",
721 # "content": "<image>\nFree OCR. ",
722 # "content": "<image>\nParse the figure. ",
723 # "content": "<image>\nExtract the text in the image. ",
724 # "images": [f'{image_file}'],
725 },
726 {"role": "<|Assistant|>", "content": ""},
727 ]
728 else:
729 assert False, f'prompt is none!'
730
731 prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='')
732
733 patch_size = 16
734 downsample_ratio = 4
735 images = load_pil_images(conversation)
736
737 valid_img_tokens = 0
738 ratio = 1
739
740 image_draw = images[0].copy()
741
742 w,h = image_draw.size
743 # print(w, h)
744 ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h)))
745
746
747 image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True)
748 images_seq_mask = []
749
750 image_token = '<image>'
751 image_token_id = 128815
752 text_splits = prompt.split(image_token)
753
754 images_list, images_crop_list, images_seq_mask = [], [], []
755 tokenized_str = []
756 images_spatial_crop = []
757 for text_sep, image in zip(text_splits, images):
758
759 tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)
760 tokenized_str += tokenized_sep
761 images_seq_mask += [False] * len(tokenized_sep)
762
763 if crop_mode:
764
765 if image.size[0] <= 768 and image.size[1] <= 768:
766 crop_ratio = [1, 1]
767
768 else:
769 if crop_mode:
770 # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions)
771 images_crop_raw, crop_ratio = dynamic_preprocess(image)
772 else:
773 # best_width, best_height = self.image_size, self.image_size
774 crop_ratio = [1, 1]
775
776 """process the global view"""
777 # image = image.resize((base_size, base_size))
778 global_view = ImageOps.pad(image, (base_size, base_size),
779 color=tuple(int(x * 255) for x in image_transform.mean))
780
781 if base_size == 1024:
782 valid_img_tokens += int(256 * ratio)
783 elif base_size == 1280:
784 valid_img_tokens += int(400 * ratio)
785 # elif base_size == 640:
786 # valid_img_tokens += int(100 * ratio)
787
788
789
790
791
792 images_list.append(image_transform(global_view).to(torch.bfloat16))
793
794 # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
795
796 width_crop_num, height_crop_num = crop_ratio
797
798 images_spatial_crop.append([width_crop_num, height_crop_num])
799
800
801 if width_crop_num > 1 or height_crop_num > 1:
802 """process the local views"""
803
804 for i in range(len(images_crop_raw)):
805 images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16))
806
807 if image_size == 768:
808 valid_img_tokens += len(images_crop_list) * 144
809
810 num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
811 num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio)
812
813
814
815 """add image tokens"""
816
817
818
819 tokenized_image = ([image_token_id] * num_queries_base) * num_queries_base
820 tokenized_image += [image_token_id]
821 if width_crop_num > 1 or height_crop_num > 1:
822 tokenized_image += ([image_token_id] * (num_queries * width_crop_num)) * (
823 num_queries * height_crop_num)
824 tokenized_str += tokenized_image
825 images_seq_mask += [True] * len(tokenized_image)
826 # num_image_tokens.append(len(tokenized_image))
827
828 else:
829 # best_width, best_height = self.image_size, self.image_size
830 # print(image.size, (best_width, best_height)) # check the select_best_resolutions func
831
832 """process the global view"""
833 if image_size <= 768:
834 print('directly resize')
835 image = image.resize((image_size, image_size))
836 # else:
837 global_view = ImageOps.pad(image, (image_size, image_size),
838 color=tuple(int(x * 255) for x in image_transform.mean))
839 images_list.append(image_transform(global_view).to(torch.bfloat16))
840
841 if base_size == 1024:
842 valid_img_tokens += int(256 * ratio)
843 elif base_size == 1280:
844 valid_img_tokens += int(400 * ratio)
845 elif base_size == 640:
846 valid_img_tokens += int(100 * 1)
847 elif base_size == 512:
848 valid_img_tokens += int(64 * 1)
849 elif base_size == 768:
850 valid_img_tokens += int(144 * 1)
851
852 width_crop_num, height_crop_num = 1, 1
853
854 images_spatial_crop.append([width_crop_num, height_crop_num])
855
856
857 """add image tokens"""
858 num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
859
860 tokenized_image = ([image_token_id] * num_queries) * num_queries
861 tokenized_image += [image_token_id]
862 # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
863 # num_queries * height_crop_num)
864 tokenized_str += tokenized_image
865 images_seq_mask += [True] * len(tokenized_image)
866 # num_image_tokens.append(len(tokenized_image))
867
868
869 """process the last text split"""
870 tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)
871 tokenized_str += tokenized_sep
872 images_seq_mask += [False] * len(tokenized_sep)
873
874 """add the bos tokens"""
875 bos_id = 0
876 tokenized_str = [bos_id] + tokenized_str
877 images_seq_mask = [False] + images_seq_mask
878
879
880
881 input_ids = torch.LongTensor(tokenized_str)
882
883
884
885
886 images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
887
888
889 if len(images_list) == 0:
890 images_ori = torch.zeros((1, 3, image_size, image_size))
891 images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
892 images_crop = torch.zeros((1, 3, base_size, base_size))
893
894 else:
895 images_ori = torch.stack(images_list, dim=0)
896 images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
897 if images_crop_list:
898 images_crop = torch.stack(images_crop_list, dim=0)
899 else:
900 images_crop = torch.zeros((1, 3, base_size, base_size))
901
902
903
904 if not eval_mode:
905 streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
906 with torch.autocast("cuda", dtype=torch.bfloat16):
907 with torch.no_grad():
908 output_ids = self.generate(
909 input_ids.unsqueeze(0).cuda(),
910 images=[(images_crop.cuda(), images_ori.cuda())],
911 images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
912 images_spatial_crop = images_spatial_crop,
913 # do_sample=False,
914 # num_beams = 1,
915 temperature=0.0,
916 eos_token_id=tokenizer.eos_token_id,
917 streamer=streamer,
918 max_new_tokens=8192,
919 no_repeat_ngram_size = 20,
920 use_cache = True
921 )
922
923 else:
924 with torch.autocast("cuda", dtype=torch.bfloat16):
925 with torch.no_grad():
926 output_ids = self.generate(
927 input_ids.unsqueeze(0).cuda(),
928 images=[(images_crop.cuda(), images_ori.cuda())],
929 images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
930 images_spatial_crop = images_spatial_crop,
931 # do_sample=False,
932 # num_beams = 1,
933 temperature=0.0,
934 eos_token_id=tokenizer.eos_token_id,
935 max_new_tokens=8192,
936 no_repeat_ngram_size = 35,
937 use_cache = True
938 )
939
940
941 if '<image>' in conversation[0]['content'] and eval_mode:
942 outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
943 stop_str = '<|end▁of▁sentence|>'
944 if outputs.endswith(stop_str):
945 outputs = outputs[:-len(stop_str)]
946 # re_match
947 outputs = outputs.strip()
948
949 return outputs
950
951 if '<image>' in conversation[0]['content'] and test_compress:
952 outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
953 pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
954 print('='*50)
955 print('image size: ', (w, h))
956 print('valid image tokens: ', int(valid_img_tokens))
957 print('output texts tokens (valid): ', pure_texts_outputs_token_length)
958 print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2))
959 print('='*50)
960
961
962 if '<image>' in conversation[0]['content'] and save_results:
963 outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
964 stop_str = '<|end▁of▁sentence|>'
965
966 print('='*15 + 'save results:' + '='*15)
967
968 # # # # conv.messages[-1][-1] = outputs
969 if outputs.endswith(stop_str):
970 outputs = outputs[:-len(stop_str)]
971 outputs = outputs.strip()
972
973 matches_ref, matches_images, mathes_other = re_match(outputs)
974 # print(matches_ref)
975 result = process_image_with_refs(image_draw, matches_ref, output_path)
976
977
978 for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
979 outputs = outputs.replace(a_match_image, '![](images/' + str(idx) + '.jpg)\n')
980
981 for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
982 outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
983
984
985 # if 'structural formula' in conversation[0]['content']:
986 # outputs = '<smiles>' + outputs + '</smiles>'
987 with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile:
988 afile.write(outputs)
989
990 if 'line_type' in outputs:
991 import matplotlib.pyplot as plt
992 lines = eval(outputs)['Line']['line']
993
994 line_type = eval(outputs)['Line']['line_type']
995 # print(lines)
996
997 endpoints = eval(outputs)['Line']['line_endpoint']
998
999 fig, ax = plt.subplots(figsize=(3,3), dpi=200)
1000 ax.set_xlim(-15, 15)
1001 ax.set_ylim(-15, 15)
1002
1003 for idx, line in enumerate(lines):
1004 try:
1005 p0 = eval(line.split(' -- ')[0])
1006 p1 = eval(line.split(' -- ')[-1])
1007
1008 if line_type[idx] == '--':
1009 ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k')
1010 else:
1011 ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k')
1012
1013 ax.scatter(p0[0], p0[1], s=5, color = 'k')
1014 ax.scatter(p1[0], p1[1], s=5, color = 'k')
1015 except:
1016 pass
1017
1018 for endpoint in endpoints:
1019
1020 label = endpoint.split(': ')[0]
1021 (x, y) = eval(endpoint.split(': ')[1])
1022 ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',
1023 fontsize=5, fontweight='light')
1024
1025
1026 plt.savefig(f'{output_path}/geo.jpg')
1027 plt.close()
1028
1029 result.save(f"{output_path}/result_with_boxes.jpg")
1030