deepencoderv2.py
35.4 KB · 1015 lines · python Raw
1 import torch.nn as nn
2 import torch
3 import torch.nn.functional as F
4 import copy
5
6
7 from typing import Optional, Tuple
8
9 # from megatron.model import LayerNorm
10
11 import transformers
12
13
14 from typing import Optional, Tuple, Type
15 from functools import partial
16
17
18
19 class MlpProjector(nn.Module):
20
21 def __init__(self, cfg):
22
23 super().__init__()
24
25 self.cfg = cfg
26
27 if cfg.projector_type == "identity":
28 modules = nn.Identity()
29
30 elif cfg.projector_type == "linear":
31 modules = nn.Linear(cfg.input_dim, cfg.n_embed)
32
33 elif cfg.projector_type == "mlp_gelu":
34 mlp_depth = cfg.get("depth", 1)
35 modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
36 for _ in range(1, mlp_depth):
37 modules.append(nn.GELU())
38 modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
39 modules = nn.Sequential(*modules)
40
41 elif cfg.projector_type == "normlayer_downsample_mlp_gelu":
42 mlp_depth = cfg.get("depth", 1)
43 mlp_ratio = cfg.get("mlp_ratio", 1)
44 modules = [
45 nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio),
46 nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
47 ]
48 for _ in range(1, mlp_depth - 1):
49 modules.append(nn.GELU())
50 modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
51 modules.append(nn.GELU())
52 modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
53 modules = nn.Sequential(*modules)
54
55 elif cfg.projector_type == "downsample_mlp_gelu":
56 mlp_depth = cfg.get("depth", 1)
57 mlp_ratio = cfg.get("mlp_ratio", 1)
58 modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
59 for _ in range(1, mlp_depth - 1):
60 modules.append(nn.GELU())
61 modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
62 modules.append(nn.GELU())
63 modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
64 modules = nn.Sequential(*modules)
65
66 elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
67 mlp_depth = cfg.get("depth", 1)
68 self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
69 self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
70
71 modules = []
72 for _ in range(1, mlp_depth):
73 modules.append(nn.GELU())
74 modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
75 modules = nn.Sequential(*modules)
76
77 elif cfg.projector_type == "hybrid_split_feature_mlp_gelu":
78 mlp_depth = cfg.get("depth", 1)
79 channel_div = cfg.get("channel_div", 0.5)
80 self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div))
81 self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div))
82
83 modules = []
84 for _ in range(1, mlp_depth):
85 modules.append(nn.GELU())
86 modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
87 modules = nn.Sequential(*modules)
88
89 elif cfg.projector_type == "low_high_split_mlp_gelu":
90 mlp_depth = cfg.get("depth", 1)
91 modules = []
92 for _ in range(1, mlp_depth):
93 modules.append(nn.GELU())
94 modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2))
95 modules = nn.Sequential(*modules)
96 self.high_layers = nn.Sequential(*modules)
97 self.low_layers = copy.deepcopy(modules)
98
99 else:
100 raise ValueError(f"Unknown projector type: {cfg.projector_type}")
101
102 if cfg.get("token_pooling", False):
103 self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
104
105 if cfg.get("conv_fusion_high_low_features", False):
106 self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim)
107 self.layers = modules
108
109 def forward(self, x):
110 if self.cfg.get("token_pooling", False):
111 batch_size, wxh, channels = x.shape
112 w = h = int(wxh**0.5)
113 x = x.view(batch_size, w, h, channels)
114 x = x.permute(0, 3, 1, 2)
115 # import ipdb; ipdb.set_trace()
116 patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
117 batch_size, channels, h_patches, w_patches, _, _ = patches.size()
118 # 在通道维度上拼接
119 patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
120
121 # 通过线性层
122 patches = patches.permute(0, 2, 1, 3).contiguous()
123 patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
124
125 x = self.token_pooling_layer(patches)
126
127 if self.cfg.get("conv_fusion_high_low_features", False):
128 x = self.fusion_layer(x[:, 0]) + x[:, 1]
129
130 if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu':
131 high_x, low_x = x[0], x[1]
132 high_x = self.high_up_proj(high_x)
133 low_x = self.low_up_proj(low_x)
134 x = torch.concat([high_x, low_x], dim=-1)
135
136 if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu':
137 high_x = x[...,:self.cfg.input_dim[0]]
138 low_x = x[...,self.cfg.input_dim[0]:]
139 high_x = self.high_up_proj(high_x)
140 low_x = self.low_up_proj(low_x)
141 x = torch.concat([high_x, low_x], dim=-1)
142
143 if self.cfg.projector_type == 'low_high_split_mlp_gelu':
144 high_x, low_x = x[0], x[1]
145 high_x = self.high_layers(high_x)
146 low_x = self.low_layers(low_x)
147 x = torch.concat([high_x, low_x], dim=-1)
148 return x
149
150 if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu':
151 bs, hw, input_dim = x.shape
152 h = w = int((hw) ** 0.5)
153
154 """compute padding"""
155 if h % self.cfg.downsample_ratio:
156 pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
157 else:
158 pad = 0
159 x = x.reshape(bs, h, w, input_dim)
160 if pad > 0:
161 x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
162
163 """4 to 1 concat"""
164 x = x.permute(0, 3, 1, 2) # B, C, H, W
165 x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4
166 x = x.permute(0, 2, 1)
167
168 return self.layers(x)
169
170 @staticmethod
171 def get_flops_per_sample(cfg):
172 if cfg.projector_type == "linear":
173 fwd = 2 * cfg.input_dim * cfg.n_embed
174
175 elif "mlp_gelu" in cfg.projector_type :
176 mlp_depth = cfg.get("depth", 1)
177 downsample_ratio = cfg.get("downsample_ratio", 1)
178 input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim
179 input_dim = input_dim * downsample_ratio * downsample_ratio
180 fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed
181 else:
182 fwd = 0
183
184 return fwd * 3
185
186
187 #===================qwen2================================
188
189 class CustomQwen2Decoder(nn.Module):
190 """
191 Qwen2 visual encoder
192 non-causal attention + causal attention
193 token_type_ids :0=non-causal, 1=causal
194 """
195
196 def __init__(
197 self,
198 decoder_layer: int = 24,
199 max_position_embeddings: int = 131072,
200 hidden_dimension: int = 896,
201 num_attention_heads: int = 14,
202 num_key_value_heads: int = 2,
203 intermediate_size: int = 4864,
204 vocab_size: int = 151936,
205 attn_implementation: str = "sdpa", # ⭐
206 rms_norm_eps: float = 1e-06,
207 rope_theta: float = 1000000.0,
208 attention_dropout: float = 0.0,
209 hidden_act: str = "silu",
210 initializer_range: float = 0.02,
211 ):
212 super().__init__()
213
214 # attn_implementation check
215 if attn_implementation == "flash_attention_2":
216 raise ValueError(
217 "CustomQwen2Decoder do not support flash_attention_2,"
218 "new attention mask needs 'sdpa' or 'eager'"
219 )
220
221 # load
222 Qwen2Model = getattr(transformers.models.qwen2.modeling_qwen2, 'Qwen2Model')
223 Qwen2Config = getattr(transformers, 'Qwen2Config')
224
225 # config
226 config = Qwen2Config(
227 hidden_size=hidden_dimension,
228 num_hidden_layers=decoder_layer,
229 num_attention_heads=num_attention_heads,
230 num_key_value_heads=num_key_value_heads,
231 intermediate_size=intermediate_size,
232 max_position_embeddings=max_position_embeddings,
233 vocab_size=vocab_size,
234 rms_norm_eps=rms_norm_eps,
235 rope_theta=rope_theta,
236 attention_dropout=attention_dropout,
237 hidden_act=hidden_act,
238 initializer_range=initializer_range,
239 _attn_implementation=attn_implementation, # ⭐
240 )
241
242 #
243 self.model = self._create_custom_model(Qwen2Model, config)
244
245 del self.model.embed_tokens
246
247 def _create_custom_model(self, Qwen2Model, config):
248 """ Qwen2Model """
249
250 class CustomQwen2ModelInner(Qwen2Model):
251
252
253 def forward(
254 self,
255 input_ids=None,
256 attention_mask=None,
257 position_ids=None,
258 past_key_values=None,
259 inputs_embeds=None,
260 token_type_ids=None, # ⭐
261 use_cache=None,
262 output_attentions=None,
263 output_hidden_states=None,
264 return_dict=None,
265 cache_position=None,
266 ):
267 # token_type_ids
268 self._current_token_type_ids = token_type_ids
269
270 outputs = super().forward(
271 input_ids=input_ids,
272 attention_mask=attention_mask,
273 position_ids=position_ids,
274 past_key_values=past_key_values,
275 inputs_embeds=inputs_embeds,
276 use_cache=use_cache,
277 output_attentions=output_attentions,
278 output_hidden_states=output_hidden_states,
279 return_dict=return_dict,
280 cache_position=cache_position,
281 )
282
283 return outputs
284
285 def _update_causal_mask(
286 self,
287 attention_mask,
288 input_tensor,
289 cache_position,
290 past_key_values,
291 output_attentions,
292 ):
293 dtype, device = input_tensor.dtype, input_tensor.device
294 min_dtype = torch.finfo(dtype).min
295 batch_size, sequence_length = input_tensor.shape[0], input_tensor.shape[1]
296
297 token_type_ids = self._current_token_type_ids
298
299 # attention mask
300 causal_mask = self._create_custom_4d_mask(
301 sequence_length=sequence_length,
302 dtype=dtype,
303 device=device,
304 batch_size=batch_size,
305 token_type_ids=token_type_ids,
306 )
307
308 # padding mask
309 if attention_mask is not None and attention_mask.dim() == 2:
310 padding_mask = attention_mask[:, None, None, :].to(dtype=dtype)
311 padding_mask = (1.0 - padding_mask) * min_dtype
312 causal_mask = causal_mask + padding_mask
313
314 return causal_mask
315
316 def _create_custom_4d_mask(
317 self,
318 sequence_length,
319 dtype,
320 device,
321 batch_size,
322 token_type_ids,
323 ):
324 min_dtype = torch.finfo(dtype).min
325
326 masks = []
327 for b in range(batch_size):
328 mask = torch.full(
329 (sequence_length, sequence_length),
330 fill_value=min_dtype,
331 dtype=dtype,
332 device=device
333 )
334
335 type_ids = token_type_ids[b]
336
337 image_positions = (type_ids == 0).nonzero(as_tuple=True)[0]
338 text_positions = (type_ids == 1).nonzero(as_tuple=True)[0]
339
340 # non-casual
341 if len(image_positions) > 0:
342 mask[image_positions[:, None], image_positions] = 0.0
343
344 # causal
345 for i, text_pos in enumerate(text_positions):
346 if len(image_positions) > 0:
347 mask[text_pos, image_positions] = 0.0
348 mask[text_pos, text_positions[:i+1]] = 0.0
349
350 masks.append(mask)
351
352 mask = torch.stack(masks, dim=0).unsqueeze(1)
353 return mask
354
355 return CustomQwen2ModelInner(config)
356
357 def forward(
358 self,
359 inputs_embeds,
360 token_type_ids,
361 attention_mask=None,
362 **kwargs
363 ):
364 """
365 Args:
366 inputs_embeds: [batch_size, seq_len, hidden_dim]
367 token_type_ids: [batch_size, seq_len], 0=non-causal, 1=causal
368 attention_mask: [batch_size, seq_len], optional
369 """
370 return self.model(
371 inputs_embeds=inputs_embeds,
372 token_type_ids=token_type_ids,
373 attention_mask=attention_mask,
374 **kwargs
375 )
376
377
378
379
380
381 # batch_size = 2
382 # inputs_embeds = torch.randn(batch_size, 512, 896).cuda()
383
384 # inputs_embeds = torch.randn(batch_size, 512, 896).cuda()
385 # token_type_ids = torch.cat([
386 # torch.zeros(batch_size, 256, dtype=torch.long),
387 # torch.ones(batch_size, 256, dtype=torch.long),
388 # ], dim=1).cuda()
389
390 # # start = time.time()
391 # with torch.no_grad():
392 # outputs_sdpa = decoder_sdpa(inputs_embeds, token_type_ids)
393 # print(outputs_sdpa[0].shape)
394 # print(f"SDPA time: {time.time() - start:.4f}s")
395
396
397
398 class Qwen2Decoder2Encoder(nn.Module):
399 """
400 Decoder based on Multilingual BART
401 Set the initial weights and configuration with a pretrained multilingual BART model,
402 and modify the detailed configurations as a Nougat decoder
403 """
404
405 def __init__(
406 self,
407 decoder_layer: int,
408 hidden_dimension: int,
409 num_attention_heads: int,
410 num_key_value_heads: int,
411 intermediate_size: int,
412 max_query: int,
413 ):
414 super().__init__()
415
416 self.model = CustomQwen2Decoder(
417 decoder_layer=decoder_layer,
418 hidden_dimension=hidden_dimension,
419 num_attention_heads=num_attention_heads,
420 num_key_value_heads=num_key_value_heads,
421 intermediate_size=intermediate_size,
422 attn_implementation="sdpa",
423 )
424
425
426
427
428 self.query_768 = nn.Embedding(144, hidden_dimension)
429 self.query_1024 = nn.Embedding(256, hidden_dimension)
430
431
432 # self.query_refixation = nn.Embedding(int(math.sqrt(max_query)), hidden_dimension)
433
434
435 def forward(self, x: torch.Tensor) -> torch.Tensor:
436 x = x.flatten(2).transpose(1, 2)
437
438 bs, n_query, _ = x.shape
439
440 if n_query == 144:
441 param_img = self.query_768.weight
442 elif n_query == 256:
443 param_img = self.query_1024.weight
444
445 batch_query_imgs = param_img.unsqueeze(0).expand(
446 bs, -1, -1
447 ) # (batch_size, num_queries, hidden_size)
448
449
450
451 x_combined = torch.cat([x, batch_query_imgs], dim=1)
452
453 token_type_ids = torch.cat([
454 torch.zeros(bs, n_query, dtype=torch.long),
455 torch.ones(bs, n_query, dtype=torch.long),
456 ], dim=1)
457
458
459 y = self.model(x_combined, token_type_ids)[0]
460
461
462 y = y[:, n_query:, :] # causal flow query
463
464
465 return y
466
467
468 def build_qwen2_decoder_as_encoder(
469 decoder_layer=24,
470 hidden_dimension=896,
471 num_attention_heads=14,
472 num_key_value_heads=2,
473 intermediate_size=4864,
474 max_query = 400,
475 checkpoint=None,
476 ):
477
478 decoder_as_encoder = Qwen2Decoder2Encoder(
479 decoder_layer=decoder_layer,
480 hidden_dimension = hidden_dimension,
481 num_attention_heads = num_attention_heads,
482 num_key_value_heads = num_key_value_heads,
483 intermediate_size = intermediate_size,
484 max_query = max_query
485 )
486
487
488
489
490 if checkpoint is not None:
491 # with open(checkpoint, "rb") as f:
492 state_dict = torch.load(checkpoint)
493
494 decoder_as_encoder.load_state_dict(state_dict, strict=True)
495 # tob
496 print(checkpoint)
497 return decoder_as_encoder
498
499
500
501
502 #=========================Sam-Vary=================================
503
504
505 def get_abs_pos_sam(abs_pos, tgt_size):
506
507 dtype = abs_pos.dtype
508
509 src_size = abs_pos.size(1)
510
511 if src_size != tgt_size:
512 old_pos_embed = abs_pos.permute(0, 3, 1, 2)
513 old_pos_embed = old_pos_embed.to(torch.float32)
514 new_pos_embed = F.interpolate(
515 old_pos_embed,
516 size=(tgt_size, tgt_size),
517 mode='bicubic',
518 antialias=True,
519 align_corners=False,
520 ).to(dtype)
521 new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
522 return new_pos_embed
523 else:
524 return abs_pos
525
526
527
528
529 class MLPBlock(nn.Module):
530 def __init__(
531 self,
532 embedding_dim: int,
533 mlp_dim: int,
534 act: Type[nn.Module] = nn.GELU,
535 ) -> None:
536 super().__init__()
537 self.lin1 = nn.Linear(embedding_dim, mlp_dim)
538 self.lin2 = nn.Linear(mlp_dim, embedding_dim)
539 self.act = act()
540
541 def forward(self, x: torch.Tensor) -> torch.Tensor:
542 return self.lin2(self.act(self.lin1(x)))
543
544
545 # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
546 # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
547 class LayerNorm2d(nn.Module):
548 def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
549 super().__init__()
550 self.weight = nn.Parameter(torch.ones(num_channels))
551 self.bias = nn.Parameter(torch.zeros(num_channels))
552 self.eps = eps
553
554 def forward(self, x: torch.Tensor) -> torch.Tensor:
555 u = x.mean(1, keepdim=True)
556 s = (x - u).pow(2).mean(1, keepdim=True)
557 x = (x - u) / torch.sqrt(s + self.eps)
558 x = self.weight[:, None, None] * x + self.bias[:, None, None]
559 return x
560
561
562 # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
563 class ImageEncoderViT(nn.Module):
564 def __init__(
565 self,
566 img_size: int = 1024,
567 patch_size: int = 16,
568 in_chans: int = 3,
569 embed_dim: int = 768,
570 depth: int = 12,
571 num_heads: int = 12,
572 mlp_ratio: float = 4.0,
573 out_chans: int = 256,
574 qkv_bias: bool = True,
575 norm_layer: Type[nn.Module] = nn.LayerNorm,
576 act_layer: Type[nn.Module] = nn.GELU,
577 use_abs_pos: bool = True,
578 use_rel_pos: bool = False,
579 rel_pos_zero_init: bool = True,
580 window_size: int = 0,
581 global_attn_indexes: Tuple[int, ...] = (),
582 ) -> None:
583 """
584 Args:
585 img_size (int): Input image size.
586 patch_size (int): Patch size.
587 in_chans (int): Number of input image channels.
588 embed_dim (int): Patch embedding dimension.
589 depth (int): Depth of ViT.
590 num_heads (int): Number of attention heads in each ViT block.
591 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
592 qkv_bias (bool): If True, add a learnable bias to query, key, value.
593 norm_layer (nn.Module): Normalization layer.
594 act_layer (nn.Module): Activation layer.
595 use_abs_pos (bool): If True, use absolute positional embeddings.
596 use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
597 rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
598 window_size (int): Window size for window attention blocks.
599 global_attn_indexes (list): Indexes for blocks using global attention.
600 """
601 super().__init__()
602 self.img_size = img_size
603
604 self.patch_embed = PatchEmbed(
605 kernel_size=(patch_size, patch_size),
606 stride=(patch_size, patch_size),
607 in_chans=in_chans,
608 embed_dim=embed_dim,
609 )
610
611 self.pos_embed: Optional[nn.Parameter] = None
612 if use_abs_pos:
613 # Initialize absolute positional embedding with pretrain image size.
614 self.pos_embed = nn.Parameter(
615 torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
616 )
617
618 self.blocks = nn.ModuleList()
619 for i in range(depth):
620 block = Block(
621 dim=embed_dim,
622 num_heads=num_heads,
623 mlp_ratio=mlp_ratio,
624 qkv_bias=qkv_bias,
625 norm_layer=norm_layer,
626 act_layer=act_layer,
627 use_rel_pos=use_rel_pos,
628 rel_pos_zero_init=rel_pos_zero_init,
629 window_size=window_size if i not in global_attn_indexes else 0,
630 input_size=(img_size // patch_size, img_size // patch_size),
631 )
632 self.blocks.append(block)
633
634 self.neck = nn.Sequential(
635 nn.Conv2d(
636 embed_dim,
637 out_chans,
638 kernel_size=1,
639 bias=False,
640 ),
641 LayerNorm2d(out_chans),
642 nn.Conv2d(
643 out_chans,
644 out_chans,
645 kernel_size=3,
646 padding=1,
647 bias=False,
648 ),
649 LayerNorm2d(out_chans),
650 )
651
652 self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
653 self.net_3 = nn.Conv2d(512, 896, kernel_size=3, stride=2, padding=1, bias=False)
654
655 def forward(self, x: torch.Tensor) -> torch.Tensor:
656 x = self.patch_embed(x)
657 if self.pos_embed is not None:
658 # x = x + self.pos_embed
659 x = x + get_abs_pos_sam(self.pos_embed, x.size(1))
660
661 for blk in self.blocks:
662 x = blk(x)
663
664 x = self.neck(x.permute(0, 3, 1, 2))
665 x2 = self.net_2(x)
666 x3 = self.net_3(x2.clone())
667
668 return x3
669
670
671 class Block(nn.Module):
672 """Transformer blocks with support of window attention and residual propagation blocks"""
673
674 def __init__(
675 self,
676 dim: int,
677 num_heads: int,
678 mlp_ratio: float = 4.0,
679 qkv_bias: bool = True,
680 norm_layer: Type[nn.Module] = nn.LayerNorm,
681 act_layer: Type[nn.Module] = nn.GELU,
682 use_rel_pos: bool = False,
683 rel_pos_zero_init: bool = True,
684 window_size: int = 0,
685 input_size: Optional[Tuple[int, int]] = None,
686 ) -> None:
687 """
688 Args:
689 dim (int): Number of input channels.
690 num_heads (int): Number of attention heads in each ViT block.
691 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
692 qkv_bias (bool): If True, add a learnable bias to query, key, value.
693 norm_layer (nn.Module): Normalization layer.
694 act_layer (nn.Module): Activation layer.
695 use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
696 rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
697 window_size (int): Window size for window attention blocks. If it equals 0, then
698 use global attention.
699 input_size (tuple(int, int) or None): Input resolution for calculating the relative
700 positional parameter size.
701 """
702 super().__init__()
703 self.norm1 = norm_layer(dim)
704 self.attn = Attention(
705 dim,
706 num_heads=num_heads,
707 qkv_bias=qkv_bias,
708 use_rel_pos=use_rel_pos,
709 rel_pos_zero_init=rel_pos_zero_init,
710 input_size=input_size if window_size == 0 else (window_size, window_size),
711 )
712
713 self.norm2 = norm_layer(dim)
714 self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
715
716 self.window_size = window_size
717
718 def forward(self, x: torch.Tensor) -> torch.Tensor:
719 shortcut = x
720 x = self.norm1(x)
721 # Window partition
722 if self.window_size > 0:
723 H, W = x.shape[1], x.shape[2]
724 x, pad_hw = window_partition(x, self.window_size)
725
726 x = self.attn(x)
727 # Reverse window partition
728 if self.window_size > 0:
729 x = window_unpartition(x, self.window_size, pad_hw, (H, W))
730
731 x = shortcut + x
732 x = x + self.mlp(self.norm2(x))
733
734 return x
735
736
737 class Attention(nn.Module):
738 """Multi-head Attention block with relative position embeddings."""
739
740 def __init__(
741 self,
742 dim: int,
743 num_heads: int = 8,
744 qkv_bias: bool = True,
745 use_rel_pos: bool = False,
746 rel_pos_zero_init: bool = True,
747 input_size: Optional[Tuple[int, int]] = None,
748 ) -> None:
749 """
750 Args:
751 dim (int): Number of input channels.
752 num_heads (int): Number of attention heads.
753 qkv_bias (bool): If True, add a learnable bias to query, key, value.
754 rel_pos (bool): If True, add relative positional embeddings to the attention map.
755 rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
756 input_size (tuple(int, int) or None): Input resolution for calculating the relative
757 positional parameter size.
758 """
759 super().__init__()
760 self.num_heads = num_heads
761 head_dim = dim // num_heads
762 self.scale = head_dim**-0.5
763
764 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
765 self.proj = nn.Linear(dim, dim)
766
767 self.use_rel_pos = use_rel_pos
768 if self.use_rel_pos:
769 assert (
770 input_size is not None
771 ), "Input size must be provided if using relative positional encoding."
772 # initialize relative positional embeddings
773 self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
774 self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
775
776 def forward(self, x: torch.Tensor) -> torch.Tensor:
777 B, H, W, _ = x.shape
778 # qkv with shape (3, B, nHead, H * W, C)
779 qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
780 # q, k, v with shape (B * nHead, H * W, C)
781 q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
782
783 rel_h, rel_w = None, None
784 if self.use_rel_pos:
785 rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
786
787 q = q.view(B, self.num_heads, H * W, -1)
788 k = k.view(B, self.num_heads, H * W, -1)
789 v = v.view(B, self.num_heads, H * W, -1)
790
791 if self.use_rel_pos:
792 rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
793 rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
794 attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
795 x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
796 # x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
797 else:
798 x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
799
800 x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
801
802 x = self.proj(x)
803
804 return x
805
806
807 def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
808 """
809 Partition into non-overlapping windows with padding if needed.
810 Args:
811 x (tensor): input tokens with [B, H, W, C].
812 window_size (int): window size.
813
814 Returns:
815 windows: windows after partition with [B * num_windows, window_size, window_size, C].
816 (Hp, Wp): padded height and width before partition
817 """
818 B, H, W, C = x.shape
819
820 pad_h = (window_size - H % window_size) % window_size
821 pad_w = (window_size - W % window_size) % window_size
822 if pad_h > 0 or pad_w > 0:
823 x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
824 Hp, Wp = H + pad_h, W + pad_w
825
826 x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
827 windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
828 return windows, (Hp, Wp)
829
830
831 def window_unpartition(
832 windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
833 ) -> torch.Tensor:
834 """
835 Window unpartition into original sequences and removing padding.
836 Args:
837 windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
838 window_size (int): window size.
839 pad_hw (Tuple): padded height and width (Hp, Wp).
840 hw (Tuple): original height and width (H, W) before padding.
841
842 Returns:
843 x: unpartitioned sequences with [B, H, W, C].
844 """
845 Hp, Wp = pad_hw
846 H, W = hw
847 B = windows.shape[0] // (Hp * Wp // window_size // window_size)
848 x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
849 x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
850
851 if Hp > H or Wp > W:
852 x = x[:, :H, :W, :].contiguous()
853 return x
854
855
856 def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
857 """
858 Get relative positional embeddings according to the relative positions of
859 query and key sizes.
860 Args:
861 q_size (int): size of query q.
862 k_size (int): size of key k.
863 rel_pos (Tensor): relative position embeddings (L, C).
864
865 Returns:
866 Extracted positional embeddings according to relative positions.
867 """
868 max_rel_dist = int(2 * max(q_size, k_size) - 1)
869 # Interpolate rel pos if needed.
870 if rel_pos.shape[0] != max_rel_dist:
871 # Interpolate rel pos.
872 dtype = rel_pos.dtype
873 rel_pos = rel_pos.to(torch.float32)
874 rel_pos_resized = F.interpolate(
875 rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
876 size=max_rel_dist,
877 mode="linear",
878 ).to(dtype)
879 rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
880 else:
881 rel_pos_resized = rel_pos
882
883 # Scale the coords with short length if shapes for q and k are different.
884 q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0)
885 k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0)
886 relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
887
888 return rel_pos_resized[relative_coords.long()]
889
890
891 def add_decomposed_rel_pos(
892 q: torch.Tensor,
893 rel_pos_h: torch.Tensor,
894 rel_pos_w: torch.Tensor,
895 q_size: Tuple[int, int],
896 k_size: Tuple[int, int],
897 ) -> torch.Tensor:
898 """
899 Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
900 https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
901 Args:
902 q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
903 rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
904 rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
905 q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
906 k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
907
908 Returns:
909 attn (Tensor): attention map with added relative positional embeddings.
910 """
911 q_h, q_w = q_size
912 k_h, k_w = k_size
913 Rh = get_rel_pos(q_h, k_h, rel_pos_h)
914 Rw = get_rel_pos(q_w, k_w, rel_pos_w)
915
916 B, _, dim = q.shape
917 r_q = q.reshape(B, q_h, q_w, dim)
918 rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
919 rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
920 rel_h = rel_h.unsqueeze(-1)
921 rel_w = rel_w.unsqueeze(-2)
922 rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
923 rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
924
925 return rel_h, rel_w
926
927
928 class PatchEmbed(nn.Module):
929 """
930 Image to Patch Embedding.
931 """
932
933 def __init__(
934 self,
935 kernel_size: Tuple[int, int] = (16, 16),
936 stride: Tuple[int, int] = (16, 16),
937 padding: Tuple[int, int] = (0, 0),
938 in_chans: int = 3,
939 embed_dim: int = 768,
940 ) -> None:
941 """
942 Args:
943 kernel_size (Tuple): kernel size of the projection layer.
944 stride (Tuple): stride of the projection layer.
945 padding (Tuple): padding size of the projection layer.
946 in_chans (int): Number of input image channels.
947 embed_dim (int): Patch embedding dimension.
948 """
949 super().__init__()
950
951 self.proj = nn.Conv2d(
952 in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
953 )
954
955 def forward(self, x: torch.Tensor) -> torch.Tensor:
956 x = self.proj(x)
957 # B C H W -> B H W C
958 x = x.permute(0, 2, 3, 1)
959 return x
960
961
962 def build_sam_vit_b(checkpoint=None):
963 return _build_sam(
964 encoder_embed_dim=768,
965 encoder_depth=12,
966 encoder_num_heads=12,
967 encoder_global_attn_indexes=[2, 5, 8, 11],
968 checkpoint=checkpoint,
969 )
970
971 def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
972 image_encoder = build_sam_vit_b(checkpoint).eval().to(dtype)
973 # sam = _apply_eval_dtype_sam(sam, dtype)
974 image_encoder = torch.compile(image_encoder, mode=compile_mode)
975 return image_encoder
976
977
978 def _build_sam(
979 encoder_embed_dim,
980 encoder_depth,
981 encoder_num_heads,
982 encoder_global_attn_indexes,
983 checkpoint=None,
984 ):
985 prompt_embed_dim = 256
986 image_size = 1024
987 vit_patch_size = 16
988 image_embedding_size = image_size // vit_patch_size
989 image_encoder=ImageEncoderViT(
990 depth=encoder_depth,
991 embed_dim=encoder_embed_dim,
992 img_size=image_size,
993 mlp_ratio=4,
994 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
995 num_heads=encoder_num_heads,
996 patch_size=vit_patch_size,
997 qkv_bias=True,
998 use_rel_pos=True,
999 global_attn_indexes=encoder_global_attn_indexes,
1000 window_size=14,
1001 out_chans=prompt_embed_dim,
1002 )
1003 image_encoder.eval()
1004 if checkpoint is not None:
1005 # with open(checkpoint, "rb") as f:
1006 state_dict = torch.load(checkpoint)
1007 # print(state_dict.keys())
1008 # for key in state_dict:
1009 # image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False)
1010 # ocr-anyting
1011 # image_encoder.load_state_dict(state_dict, strict=True)
1012 # tob
1013 image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True)
1014 print(checkpoint)
1015 return image_encoder