deepencoderv2.py
| 1 | import torch.nn as nn |
| 2 | import torch |
| 3 | import torch.nn.functional as F |
| 4 | import copy |
| 5 | |
| 6 | |
| 7 | from typing import Optional, Tuple |
| 8 | |
| 9 | # from megatron.model import LayerNorm |
| 10 | |
| 11 | import transformers |
| 12 | |
| 13 | |
| 14 | from typing import Optional, Tuple, Type |
| 15 | from functools import partial |
| 16 | |
| 17 | |
| 18 | |
| 19 | class MlpProjector(nn.Module): |
| 20 | |
| 21 | def __init__(self, cfg): |
| 22 | |
| 23 | super().__init__() |
| 24 | |
| 25 | self.cfg = cfg |
| 26 | |
| 27 | if cfg.projector_type == "identity": |
| 28 | modules = nn.Identity() |
| 29 | |
| 30 | elif cfg.projector_type == "linear": |
| 31 | modules = nn.Linear(cfg.input_dim, cfg.n_embed) |
| 32 | |
| 33 | elif cfg.projector_type == "mlp_gelu": |
| 34 | mlp_depth = cfg.get("depth", 1) |
| 35 | modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] |
| 36 | for _ in range(1, mlp_depth): |
| 37 | modules.append(nn.GELU()) |
| 38 | modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) |
| 39 | modules = nn.Sequential(*modules) |
| 40 | |
| 41 | elif cfg.projector_type == "normlayer_downsample_mlp_gelu": |
| 42 | mlp_depth = cfg.get("depth", 1) |
| 43 | mlp_ratio = cfg.get("mlp_ratio", 1) |
| 44 | modules = [ |
| 45 | nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio), |
| 46 | nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio) |
| 47 | ] |
| 48 | for _ in range(1, mlp_depth - 1): |
| 49 | modules.append(nn.GELU()) |
| 50 | modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio)) |
| 51 | modules.append(nn.GELU()) |
| 52 | modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) |
| 53 | modules = nn.Sequential(*modules) |
| 54 | |
| 55 | elif cfg.projector_type == "downsample_mlp_gelu": |
| 56 | mlp_depth = cfg.get("depth", 1) |
| 57 | mlp_ratio = cfg.get("mlp_ratio", 1) |
| 58 | modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)] |
| 59 | for _ in range(1, mlp_depth - 1): |
| 60 | modules.append(nn.GELU()) |
| 61 | modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio)) |
| 62 | modules.append(nn.GELU()) |
| 63 | modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) |
| 64 | modules = nn.Sequential(*modules) |
| 65 | |
| 66 | elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": |
| 67 | mlp_depth = cfg.get("depth", 1) |
| 68 | self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) |
| 69 | self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) |
| 70 | |
| 71 | modules = [] |
| 72 | for _ in range(1, mlp_depth): |
| 73 | modules.append(nn.GELU()) |
| 74 | modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) |
| 75 | modules = nn.Sequential(*modules) |
| 76 | |
| 77 | elif cfg.projector_type == "hybrid_split_feature_mlp_gelu": |
| 78 | mlp_depth = cfg.get("depth", 1) |
| 79 | channel_div = cfg.get("channel_div", 0.5) |
| 80 | self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div)) |
| 81 | self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div)) |
| 82 | |
| 83 | modules = [] |
| 84 | for _ in range(1, mlp_depth): |
| 85 | modules.append(nn.GELU()) |
| 86 | modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) |
| 87 | modules = nn.Sequential(*modules) |
| 88 | |
| 89 | elif cfg.projector_type == "low_high_split_mlp_gelu": |
| 90 | mlp_depth = cfg.get("depth", 1) |
| 91 | modules = [] |
| 92 | for _ in range(1, mlp_depth): |
| 93 | modules.append(nn.GELU()) |
| 94 | modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2)) |
| 95 | modules = nn.Sequential(*modules) |
| 96 | self.high_layers = nn.Sequential(*modules) |
| 97 | self.low_layers = copy.deepcopy(modules) |
| 98 | |
| 99 | else: |
| 100 | raise ValueError(f"Unknown projector type: {cfg.projector_type}") |
| 101 | |
| 102 | if cfg.get("token_pooling", False): |
| 103 | self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim) |
| 104 | |
| 105 | if cfg.get("conv_fusion_high_low_features", False): |
| 106 | self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim) |
| 107 | self.layers = modules |
| 108 | |
| 109 | def forward(self, x): |
| 110 | if self.cfg.get("token_pooling", False): |
| 111 | batch_size, wxh, channels = x.shape |
| 112 | w = h = int(wxh**0.5) |
| 113 | x = x.view(batch_size, w, h, channels) |
| 114 | x = x.permute(0, 3, 1, 2) |
| 115 | # import ipdb; ipdb.set_trace() |
| 116 | patches = x.unfold(2, 2, 2).unfold(3, 2, 2) |
| 117 | batch_size, channels, h_patches, w_patches, _, _ = patches.size() |
| 118 | # 在通道维度上拼接 |
| 119 | patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1) |
| 120 | |
| 121 | # 通过线性层 |
| 122 | patches = patches.permute(0, 2, 1, 3).contiguous() |
| 123 | patches = patches.view(batch_size, h_patches * w_patches, channels * 4) |
| 124 | |
| 125 | x = self.token_pooling_layer(patches) |
| 126 | |
| 127 | if self.cfg.get("conv_fusion_high_low_features", False): |
| 128 | x = self.fusion_layer(x[:, 0]) + x[:, 1] |
| 129 | |
| 130 | if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu': |
| 131 | high_x, low_x = x[0], x[1] |
| 132 | high_x = self.high_up_proj(high_x) |
| 133 | low_x = self.low_up_proj(low_x) |
| 134 | x = torch.concat([high_x, low_x], dim=-1) |
| 135 | |
| 136 | if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu': |
| 137 | high_x = x[...,:self.cfg.input_dim[0]] |
| 138 | low_x = x[...,self.cfg.input_dim[0]:] |
| 139 | high_x = self.high_up_proj(high_x) |
| 140 | low_x = self.low_up_proj(low_x) |
| 141 | x = torch.concat([high_x, low_x], dim=-1) |
| 142 | |
| 143 | if self.cfg.projector_type == 'low_high_split_mlp_gelu': |
| 144 | high_x, low_x = x[0], x[1] |
| 145 | high_x = self.high_layers(high_x) |
| 146 | low_x = self.low_layers(low_x) |
| 147 | x = torch.concat([high_x, low_x], dim=-1) |
| 148 | return x |
| 149 | |
| 150 | if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu': |
| 151 | bs, hw, input_dim = x.shape |
| 152 | h = w = int((hw) ** 0.5) |
| 153 | |
| 154 | """compute padding""" |
| 155 | if h % self.cfg.downsample_ratio: |
| 156 | pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio |
| 157 | else: |
| 158 | pad = 0 |
| 159 | x = x.reshape(bs, h, w, input_dim) |
| 160 | if pad > 0: |
| 161 | x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) |
| 162 | |
| 163 | """4 to 1 concat""" |
| 164 | x = x.permute(0, 3, 1, 2) # B, C, H, W |
| 165 | x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4 |
| 166 | x = x.permute(0, 2, 1) |
| 167 | |
| 168 | return self.layers(x) |
| 169 | |
| 170 | @staticmethod |
| 171 | def get_flops_per_sample(cfg): |
| 172 | if cfg.projector_type == "linear": |
| 173 | fwd = 2 * cfg.input_dim * cfg.n_embed |
| 174 | |
| 175 | elif "mlp_gelu" in cfg.projector_type : |
| 176 | mlp_depth = cfg.get("depth", 1) |
| 177 | downsample_ratio = cfg.get("downsample_ratio", 1) |
| 178 | input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim |
| 179 | input_dim = input_dim * downsample_ratio * downsample_ratio |
| 180 | fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed |
| 181 | else: |
| 182 | fwd = 0 |
| 183 | |
| 184 | return fwd * 3 |
| 185 | |
| 186 | |
| 187 | #===================qwen2================================ |
| 188 | |
| 189 | class CustomQwen2Decoder(nn.Module): |
| 190 | """ |
| 191 | Qwen2 visual encoder |
| 192 | non-causal attention + causal attention |
| 193 | token_type_ids :0=non-causal, 1=causal |
| 194 | """ |
| 195 | |
| 196 | def __init__( |
| 197 | self, |
| 198 | decoder_layer: int = 24, |
| 199 | max_position_embeddings: int = 131072, |
| 200 | hidden_dimension: int = 896, |
| 201 | num_attention_heads: int = 14, |
| 202 | num_key_value_heads: int = 2, |
| 203 | intermediate_size: int = 4864, |
| 204 | vocab_size: int = 151936, |
| 205 | attn_implementation: str = "sdpa", # ⭐ |
| 206 | rms_norm_eps: float = 1e-06, |
| 207 | rope_theta: float = 1000000.0, |
| 208 | attention_dropout: float = 0.0, |
| 209 | hidden_act: str = "silu", |
| 210 | initializer_range: float = 0.02, |
| 211 | ): |
| 212 | super().__init__() |
| 213 | |
| 214 | # attn_implementation check |
| 215 | if attn_implementation == "flash_attention_2": |
| 216 | raise ValueError( |
| 217 | "CustomQwen2Decoder do not support flash_attention_2," |
| 218 | "new attention mask needs 'sdpa' or 'eager'" |
| 219 | ) |
| 220 | |
| 221 | # load |
| 222 | Qwen2Model = getattr(transformers.models.qwen2.modeling_qwen2, 'Qwen2Model') |
| 223 | Qwen2Config = getattr(transformers, 'Qwen2Config') |
| 224 | |
| 225 | # config |
| 226 | config = Qwen2Config( |
| 227 | hidden_size=hidden_dimension, |
| 228 | num_hidden_layers=decoder_layer, |
| 229 | num_attention_heads=num_attention_heads, |
| 230 | num_key_value_heads=num_key_value_heads, |
| 231 | intermediate_size=intermediate_size, |
| 232 | max_position_embeddings=max_position_embeddings, |
| 233 | vocab_size=vocab_size, |
| 234 | rms_norm_eps=rms_norm_eps, |
| 235 | rope_theta=rope_theta, |
| 236 | attention_dropout=attention_dropout, |
| 237 | hidden_act=hidden_act, |
| 238 | initializer_range=initializer_range, |
| 239 | _attn_implementation=attn_implementation, # ⭐ |
| 240 | ) |
| 241 | |
| 242 | # |
| 243 | self.model = self._create_custom_model(Qwen2Model, config) |
| 244 | |
| 245 | del self.model.embed_tokens |
| 246 | |
| 247 | def _create_custom_model(self, Qwen2Model, config): |
| 248 | """ Qwen2Model """ |
| 249 | |
| 250 | class CustomQwen2ModelInner(Qwen2Model): |
| 251 | |
| 252 | |
| 253 | def forward( |
| 254 | self, |
| 255 | input_ids=None, |
| 256 | attention_mask=None, |
| 257 | position_ids=None, |
| 258 | past_key_values=None, |
| 259 | inputs_embeds=None, |
| 260 | token_type_ids=None, # ⭐ |
| 261 | use_cache=None, |
| 262 | output_attentions=None, |
| 263 | output_hidden_states=None, |
| 264 | return_dict=None, |
| 265 | cache_position=None, |
| 266 | ): |
| 267 | # token_type_ids |
| 268 | self._current_token_type_ids = token_type_ids |
| 269 | |
| 270 | outputs = super().forward( |
| 271 | input_ids=input_ids, |
| 272 | attention_mask=attention_mask, |
| 273 | position_ids=position_ids, |
| 274 | past_key_values=past_key_values, |
| 275 | inputs_embeds=inputs_embeds, |
| 276 | use_cache=use_cache, |
| 277 | output_attentions=output_attentions, |
| 278 | output_hidden_states=output_hidden_states, |
| 279 | return_dict=return_dict, |
| 280 | cache_position=cache_position, |
| 281 | ) |
| 282 | |
| 283 | return outputs |
| 284 | |
| 285 | def _update_causal_mask( |
| 286 | self, |
| 287 | attention_mask, |
| 288 | input_tensor, |
| 289 | cache_position, |
| 290 | past_key_values, |
| 291 | output_attentions, |
| 292 | ): |
| 293 | dtype, device = input_tensor.dtype, input_tensor.device |
| 294 | min_dtype = torch.finfo(dtype).min |
| 295 | batch_size, sequence_length = input_tensor.shape[0], input_tensor.shape[1] |
| 296 | |
| 297 | token_type_ids = self._current_token_type_ids |
| 298 | |
| 299 | # attention mask |
| 300 | causal_mask = self._create_custom_4d_mask( |
| 301 | sequence_length=sequence_length, |
| 302 | dtype=dtype, |
| 303 | device=device, |
| 304 | batch_size=batch_size, |
| 305 | token_type_ids=token_type_ids, |
| 306 | ) |
| 307 | |
| 308 | # padding mask |
| 309 | if attention_mask is not None and attention_mask.dim() == 2: |
| 310 | padding_mask = attention_mask[:, None, None, :].to(dtype=dtype) |
| 311 | padding_mask = (1.0 - padding_mask) * min_dtype |
| 312 | causal_mask = causal_mask + padding_mask |
| 313 | |
| 314 | return causal_mask |
| 315 | |
| 316 | def _create_custom_4d_mask( |
| 317 | self, |
| 318 | sequence_length, |
| 319 | dtype, |
| 320 | device, |
| 321 | batch_size, |
| 322 | token_type_ids, |
| 323 | ): |
| 324 | min_dtype = torch.finfo(dtype).min |
| 325 | |
| 326 | masks = [] |
| 327 | for b in range(batch_size): |
| 328 | mask = torch.full( |
| 329 | (sequence_length, sequence_length), |
| 330 | fill_value=min_dtype, |
| 331 | dtype=dtype, |
| 332 | device=device |
| 333 | ) |
| 334 | |
| 335 | type_ids = token_type_ids[b] |
| 336 | |
| 337 | image_positions = (type_ids == 0).nonzero(as_tuple=True)[0] |
| 338 | text_positions = (type_ids == 1).nonzero(as_tuple=True)[0] |
| 339 | |
| 340 | # non-casual |
| 341 | if len(image_positions) > 0: |
| 342 | mask[image_positions[:, None], image_positions] = 0.0 |
| 343 | |
| 344 | # causal |
| 345 | for i, text_pos in enumerate(text_positions): |
| 346 | if len(image_positions) > 0: |
| 347 | mask[text_pos, image_positions] = 0.0 |
| 348 | mask[text_pos, text_positions[:i+1]] = 0.0 |
| 349 | |
| 350 | masks.append(mask) |
| 351 | |
| 352 | mask = torch.stack(masks, dim=0).unsqueeze(1) |
| 353 | return mask |
| 354 | |
| 355 | return CustomQwen2ModelInner(config) |
| 356 | |
| 357 | def forward( |
| 358 | self, |
| 359 | inputs_embeds, |
| 360 | token_type_ids, |
| 361 | attention_mask=None, |
| 362 | **kwargs |
| 363 | ): |
| 364 | """ |
| 365 | Args: |
| 366 | inputs_embeds: [batch_size, seq_len, hidden_dim] |
| 367 | token_type_ids: [batch_size, seq_len], 0=non-causal, 1=causal |
| 368 | attention_mask: [batch_size, seq_len], optional |
| 369 | """ |
| 370 | return self.model( |
| 371 | inputs_embeds=inputs_embeds, |
| 372 | token_type_ids=token_type_ids, |
| 373 | attention_mask=attention_mask, |
| 374 | **kwargs |
| 375 | ) |
| 376 | |
| 377 | |
| 378 | |
| 379 | |
| 380 | |
| 381 | # batch_size = 2 |
| 382 | # inputs_embeds = torch.randn(batch_size, 512, 896).cuda() |
| 383 | |
| 384 | # inputs_embeds = torch.randn(batch_size, 512, 896).cuda() |
| 385 | # token_type_ids = torch.cat([ |
| 386 | # torch.zeros(batch_size, 256, dtype=torch.long), |
| 387 | # torch.ones(batch_size, 256, dtype=torch.long), |
| 388 | # ], dim=1).cuda() |
| 389 | |
| 390 | # # start = time.time() |
| 391 | # with torch.no_grad(): |
| 392 | # outputs_sdpa = decoder_sdpa(inputs_embeds, token_type_ids) |
| 393 | # print(outputs_sdpa[0].shape) |
| 394 | # print(f"SDPA time: {time.time() - start:.4f}s") |
| 395 | |
| 396 | |
| 397 | |
| 398 | class Qwen2Decoder2Encoder(nn.Module): |
| 399 | """ |
| 400 | Decoder based on Multilingual BART |
| 401 | Set the initial weights and configuration with a pretrained multilingual BART model, |
| 402 | and modify the detailed configurations as a Nougat decoder |
| 403 | """ |
| 404 | |
| 405 | def __init__( |
| 406 | self, |
| 407 | decoder_layer: int, |
| 408 | hidden_dimension: int, |
| 409 | num_attention_heads: int, |
| 410 | num_key_value_heads: int, |
| 411 | intermediate_size: int, |
| 412 | max_query: int, |
| 413 | ): |
| 414 | super().__init__() |
| 415 | |
| 416 | self.model = CustomQwen2Decoder( |
| 417 | decoder_layer=decoder_layer, |
| 418 | hidden_dimension=hidden_dimension, |
| 419 | num_attention_heads=num_attention_heads, |
| 420 | num_key_value_heads=num_key_value_heads, |
| 421 | intermediate_size=intermediate_size, |
| 422 | attn_implementation="sdpa", |
| 423 | ) |
| 424 | |
| 425 | |
| 426 | |
| 427 | |
| 428 | self.query_768 = nn.Embedding(144, hidden_dimension) |
| 429 | self.query_1024 = nn.Embedding(256, hidden_dimension) |
| 430 | |
| 431 | |
| 432 | # self.query_refixation = nn.Embedding(int(math.sqrt(max_query)), hidden_dimension) |
| 433 | |
| 434 | |
| 435 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 436 | x = x.flatten(2).transpose(1, 2) |
| 437 | |
| 438 | bs, n_query, _ = x.shape |
| 439 | |
| 440 | if n_query == 144: |
| 441 | param_img = self.query_768.weight |
| 442 | elif n_query == 256: |
| 443 | param_img = self.query_1024.weight |
| 444 | |
| 445 | batch_query_imgs = param_img.unsqueeze(0).expand( |
| 446 | bs, -1, -1 |
| 447 | ) # (batch_size, num_queries, hidden_size) |
| 448 | |
| 449 | |
| 450 | |
| 451 | x_combined = torch.cat([x, batch_query_imgs], dim=1) |
| 452 | |
| 453 | token_type_ids = torch.cat([ |
| 454 | torch.zeros(bs, n_query, dtype=torch.long), |
| 455 | torch.ones(bs, n_query, dtype=torch.long), |
| 456 | ], dim=1) |
| 457 | |
| 458 | |
| 459 | y = self.model(x_combined, token_type_ids)[0] |
| 460 | |
| 461 | |
| 462 | y = y[:, n_query:, :] # causal flow query |
| 463 | |
| 464 | |
| 465 | return y |
| 466 | |
| 467 | |
| 468 | def build_qwen2_decoder_as_encoder( |
| 469 | decoder_layer=24, |
| 470 | hidden_dimension=896, |
| 471 | num_attention_heads=14, |
| 472 | num_key_value_heads=2, |
| 473 | intermediate_size=4864, |
| 474 | max_query = 400, |
| 475 | checkpoint=None, |
| 476 | ): |
| 477 | |
| 478 | decoder_as_encoder = Qwen2Decoder2Encoder( |
| 479 | decoder_layer=decoder_layer, |
| 480 | hidden_dimension = hidden_dimension, |
| 481 | num_attention_heads = num_attention_heads, |
| 482 | num_key_value_heads = num_key_value_heads, |
| 483 | intermediate_size = intermediate_size, |
| 484 | max_query = max_query |
| 485 | ) |
| 486 | |
| 487 | |
| 488 | |
| 489 | |
| 490 | if checkpoint is not None: |
| 491 | # with open(checkpoint, "rb") as f: |
| 492 | state_dict = torch.load(checkpoint) |
| 493 | |
| 494 | decoder_as_encoder.load_state_dict(state_dict, strict=True) |
| 495 | # tob |
| 496 | print(checkpoint) |
| 497 | return decoder_as_encoder |
| 498 | |
| 499 | |
| 500 | |
| 501 | |
| 502 | #=========================Sam-Vary================================= |
| 503 | |
| 504 | |
| 505 | def get_abs_pos_sam(abs_pos, tgt_size): |
| 506 | |
| 507 | dtype = abs_pos.dtype |
| 508 | |
| 509 | src_size = abs_pos.size(1) |
| 510 | |
| 511 | if src_size != tgt_size: |
| 512 | old_pos_embed = abs_pos.permute(0, 3, 1, 2) |
| 513 | old_pos_embed = old_pos_embed.to(torch.float32) |
| 514 | new_pos_embed = F.interpolate( |
| 515 | old_pos_embed, |
| 516 | size=(tgt_size, tgt_size), |
| 517 | mode='bicubic', |
| 518 | antialias=True, |
| 519 | align_corners=False, |
| 520 | ).to(dtype) |
| 521 | new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) |
| 522 | return new_pos_embed |
| 523 | else: |
| 524 | return abs_pos |
| 525 | |
| 526 | |
| 527 | |
| 528 | |
| 529 | class MLPBlock(nn.Module): |
| 530 | def __init__( |
| 531 | self, |
| 532 | embedding_dim: int, |
| 533 | mlp_dim: int, |
| 534 | act: Type[nn.Module] = nn.GELU, |
| 535 | ) -> None: |
| 536 | super().__init__() |
| 537 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) |
| 538 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) |
| 539 | self.act = act() |
| 540 | |
| 541 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 542 | return self.lin2(self.act(self.lin1(x))) |
| 543 | |
| 544 | |
| 545 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa |
| 546 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa |
| 547 | class LayerNorm2d(nn.Module): |
| 548 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: |
| 549 | super().__init__() |
| 550 | self.weight = nn.Parameter(torch.ones(num_channels)) |
| 551 | self.bias = nn.Parameter(torch.zeros(num_channels)) |
| 552 | self.eps = eps |
| 553 | |
| 554 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 555 | u = x.mean(1, keepdim=True) |
| 556 | s = (x - u).pow(2).mean(1, keepdim=True) |
| 557 | x = (x - u) / torch.sqrt(s + self.eps) |
| 558 | x = self.weight[:, None, None] * x + self.bias[:, None, None] |
| 559 | return x |
| 560 | |
| 561 | |
| 562 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa |
| 563 | class ImageEncoderViT(nn.Module): |
| 564 | def __init__( |
| 565 | self, |
| 566 | img_size: int = 1024, |
| 567 | patch_size: int = 16, |
| 568 | in_chans: int = 3, |
| 569 | embed_dim: int = 768, |
| 570 | depth: int = 12, |
| 571 | num_heads: int = 12, |
| 572 | mlp_ratio: float = 4.0, |
| 573 | out_chans: int = 256, |
| 574 | qkv_bias: bool = True, |
| 575 | norm_layer: Type[nn.Module] = nn.LayerNorm, |
| 576 | act_layer: Type[nn.Module] = nn.GELU, |
| 577 | use_abs_pos: bool = True, |
| 578 | use_rel_pos: bool = False, |
| 579 | rel_pos_zero_init: bool = True, |
| 580 | window_size: int = 0, |
| 581 | global_attn_indexes: Tuple[int, ...] = (), |
| 582 | ) -> None: |
| 583 | """ |
| 584 | Args: |
| 585 | img_size (int): Input image size. |
| 586 | patch_size (int): Patch size. |
| 587 | in_chans (int): Number of input image channels. |
| 588 | embed_dim (int): Patch embedding dimension. |
| 589 | depth (int): Depth of ViT. |
| 590 | num_heads (int): Number of attention heads in each ViT block. |
| 591 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
| 592 | qkv_bias (bool): If True, add a learnable bias to query, key, value. |
| 593 | norm_layer (nn.Module): Normalization layer. |
| 594 | act_layer (nn.Module): Activation layer. |
| 595 | use_abs_pos (bool): If True, use absolute positional embeddings. |
| 596 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. |
| 597 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. |
| 598 | window_size (int): Window size for window attention blocks. |
| 599 | global_attn_indexes (list): Indexes for blocks using global attention. |
| 600 | """ |
| 601 | super().__init__() |
| 602 | self.img_size = img_size |
| 603 | |
| 604 | self.patch_embed = PatchEmbed( |
| 605 | kernel_size=(patch_size, patch_size), |
| 606 | stride=(patch_size, patch_size), |
| 607 | in_chans=in_chans, |
| 608 | embed_dim=embed_dim, |
| 609 | ) |
| 610 | |
| 611 | self.pos_embed: Optional[nn.Parameter] = None |
| 612 | if use_abs_pos: |
| 613 | # Initialize absolute positional embedding with pretrain image size. |
| 614 | self.pos_embed = nn.Parameter( |
| 615 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) |
| 616 | ) |
| 617 | |
| 618 | self.blocks = nn.ModuleList() |
| 619 | for i in range(depth): |
| 620 | block = Block( |
| 621 | dim=embed_dim, |
| 622 | num_heads=num_heads, |
| 623 | mlp_ratio=mlp_ratio, |
| 624 | qkv_bias=qkv_bias, |
| 625 | norm_layer=norm_layer, |
| 626 | act_layer=act_layer, |
| 627 | use_rel_pos=use_rel_pos, |
| 628 | rel_pos_zero_init=rel_pos_zero_init, |
| 629 | window_size=window_size if i not in global_attn_indexes else 0, |
| 630 | input_size=(img_size // patch_size, img_size // patch_size), |
| 631 | ) |
| 632 | self.blocks.append(block) |
| 633 | |
| 634 | self.neck = nn.Sequential( |
| 635 | nn.Conv2d( |
| 636 | embed_dim, |
| 637 | out_chans, |
| 638 | kernel_size=1, |
| 639 | bias=False, |
| 640 | ), |
| 641 | LayerNorm2d(out_chans), |
| 642 | nn.Conv2d( |
| 643 | out_chans, |
| 644 | out_chans, |
| 645 | kernel_size=3, |
| 646 | padding=1, |
| 647 | bias=False, |
| 648 | ), |
| 649 | LayerNorm2d(out_chans), |
| 650 | ) |
| 651 | |
| 652 | self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) |
| 653 | self.net_3 = nn.Conv2d(512, 896, kernel_size=3, stride=2, padding=1, bias=False) |
| 654 | |
| 655 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 656 | x = self.patch_embed(x) |
| 657 | if self.pos_embed is not None: |
| 658 | # x = x + self.pos_embed |
| 659 | x = x + get_abs_pos_sam(self.pos_embed, x.size(1)) |
| 660 | |
| 661 | for blk in self.blocks: |
| 662 | x = blk(x) |
| 663 | |
| 664 | x = self.neck(x.permute(0, 3, 1, 2)) |
| 665 | x2 = self.net_2(x) |
| 666 | x3 = self.net_3(x2.clone()) |
| 667 | |
| 668 | return x3 |
| 669 | |
| 670 | |
| 671 | class Block(nn.Module): |
| 672 | """Transformer blocks with support of window attention and residual propagation blocks""" |
| 673 | |
| 674 | def __init__( |
| 675 | self, |
| 676 | dim: int, |
| 677 | num_heads: int, |
| 678 | mlp_ratio: float = 4.0, |
| 679 | qkv_bias: bool = True, |
| 680 | norm_layer: Type[nn.Module] = nn.LayerNorm, |
| 681 | act_layer: Type[nn.Module] = nn.GELU, |
| 682 | use_rel_pos: bool = False, |
| 683 | rel_pos_zero_init: bool = True, |
| 684 | window_size: int = 0, |
| 685 | input_size: Optional[Tuple[int, int]] = None, |
| 686 | ) -> None: |
| 687 | """ |
| 688 | Args: |
| 689 | dim (int): Number of input channels. |
| 690 | num_heads (int): Number of attention heads in each ViT block. |
| 691 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
| 692 | qkv_bias (bool): If True, add a learnable bias to query, key, value. |
| 693 | norm_layer (nn.Module): Normalization layer. |
| 694 | act_layer (nn.Module): Activation layer. |
| 695 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. |
| 696 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. |
| 697 | window_size (int): Window size for window attention blocks. If it equals 0, then |
| 698 | use global attention. |
| 699 | input_size (tuple(int, int) or None): Input resolution for calculating the relative |
| 700 | positional parameter size. |
| 701 | """ |
| 702 | super().__init__() |
| 703 | self.norm1 = norm_layer(dim) |
| 704 | self.attn = Attention( |
| 705 | dim, |
| 706 | num_heads=num_heads, |
| 707 | qkv_bias=qkv_bias, |
| 708 | use_rel_pos=use_rel_pos, |
| 709 | rel_pos_zero_init=rel_pos_zero_init, |
| 710 | input_size=input_size if window_size == 0 else (window_size, window_size), |
| 711 | ) |
| 712 | |
| 713 | self.norm2 = norm_layer(dim) |
| 714 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) |
| 715 | |
| 716 | self.window_size = window_size |
| 717 | |
| 718 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 719 | shortcut = x |
| 720 | x = self.norm1(x) |
| 721 | # Window partition |
| 722 | if self.window_size > 0: |
| 723 | H, W = x.shape[1], x.shape[2] |
| 724 | x, pad_hw = window_partition(x, self.window_size) |
| 725 | |
| 726 | x = self.attn(x) |
| 727 | # Reverse window partition |
| 728 | if self.window_size > 0: |
| 729 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) |
| 730 | |
| 731 | x = shortcut + x |
| 732 | x = x + self.mlp(self.norm2(x)) |
| 733 | |
| 734 | return x |
| 735 | |
| 736 | |
| 737 | class Attention(nn.Module): |
| 738 | """Multi-head Attention block with relative position embeddings.""" |
| 739 | |
| 740 | def __init__( |
| 741 | self, |
| 742 | dim: int, |
| 743 | num_heads: int = 8, |
| 744 | qkv_bias: bool = True, |
| 745 | use_rel_pos: bool = False, |
| 746 | rel_pos_zero_init: bool = True, |
| 747 | input_size: Optional[Tuple[int, int]] = None, |
| 748 | ) -> None: |
| 749 | """ |
| 750 | Args: |
| 751 | dim (int): Number of input channels. |
| 752 | num_heads (int): Number of attention heads. |
| 753 | qkv_bias (bool): If True, add a learnable bias to query, key, value. |
| 754 | rel_pos (bool): If True, add relative positional embeddings to the attention map. |
| 755 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. |
| 756 | input_size (tuple(int, int) or None): Input resolution for calculating the relative |
| 757 | positional parameter size. |
| 758 | """ |
| 759 | super().__init__() |
| 760 | self.num_heads = num_heads |
| 761 | head_dim = dim // num_heads |
| 762 | self.scale = head_dim**-0.5 |
| 763 | |
| 764 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| 765 | self.proj = nn.Linear(dim, dim) |
| 766 | |
| 767 | self.use_rel_pos = use_rel_pos |
| 768 | if self.use_rel_pos: |
| 769 | assert ( |
| 770 | input_size is not None |
| 771 | ), "Input size must be provided if using relative positional encoding." |
| 772 | # initialize relative positional embeddings |
| 773 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) |
| 774 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) |
| 775 | |
| 776 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 777 | B, H, W, _ = x.shape |
| 778 | # qkv with shape (3, B, nHead, H * W, C) |
| 779 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) |
| 780 | # q, k, v with shape (B * nHead, H * W, C) |
| 781 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) |
| 782 | |
| 783 | rel_h, rel_w = None, None |
| 784 | if self.use_rel_pos: |
| 785 | rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) |
| 786 | |
| 787 | q = q.view(B, self.num_heads, H * W, -1) |
| 788 | k = k.view(B, self.num_heads, H * W, -1) |
| 789 | v = v.view(B, self.num_heads, H * W, -1) |
| 790 | |
| 791 | if self.use_rel_pos: |
| 792 | rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3)) |
| 793 | rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3)) |
| 794 | attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)) |
| 795 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) |
| 796 | # x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w) |
| 797 | else: |
| 798 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) |
| 799 | |
| 800 | x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) |
| 801 | |
| 802 | x = self.proj(x) |
| 803 | |
| 804 | return x |
| 805 | |
| 806 | |
| 807 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: |
| 808 | """ |
| 809 | Partition into non-overlapping windows with padding if needed. |
| 810 | Args: |
| 811 | x (tensor): input tokens with [B, H, W, C]. |
| 812 | window_size (int): window size. |
| 813 | |
| 814 | Returns: |
| 815 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. |
| 816 | (Hp, Wp): padded height and width before partition |
| 817 | """ |
| 818 | B, H, W, C = x.shape |
| 819 | |
| 820 | pad_h = (window_size - H % window_size) % window_size |
| 821 | pad_w = (window_size - W % window_size) % window_size |
| 822 | if pad_h > 0 or pad_w > 0: |
| 823 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) |
| 824 | Hp, Wp = H + pad_h, W + pad_w |
| 825 | |
| 826 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) |
| 827 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
| 828 | return windows, (Hp, Wp) |
| 829 | |
| 830 | |
| 831 | def window_unpartition( |
| 832 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] |
| 833 | ) -> torch.Tensor: |
| 834 | """ |
| 835 | Window unpartition into original sequences and removing padding. |
| 836 | Args: |
| 837 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. |
| 838 | window_size (int): window size. |
| 839 | pad_hw (Tuple): padded height and width (Hp, Wp). |
| 840 | hw (Tuple): original height and width (H, W) before padding. |
| 841 | |
| 842 | Returns: |
| 843 | x: unpartitioned sequences with [B, H, W, C]. |
| 844 | """ |
| 845 | Hp, Wp = pad_hw |
| 846 | H, W = hw |
| 847 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) |
| 848 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) |
| 849 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) |
| 850 | |
| 851 | if Hp > H or Wp > W: |
| 852 | x = x[:, :H, :W, :].contiguous() |
| 853 | return x |
| 854 | |
| 855 | |
| 856 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: |
| 857 | """ |
| 858 | Get relative positional embeddings according to the relative positions of |
| 859 | query and key sizes. |
| 860 | Args: |
| 861 | q_size (int): size of query q. |
| 862 | k_size (int): size of key k. |
| 863 | rel_pos (Tensor): relative position embeddings (L, C). |
| 864 | |
| 865 | Returns: |
| 866 | Extracted positional embeddings according to relative positions. |
| 867 | """ |
| 868 | max_rel_dist = int(2 * max(q_size, k_size) - 1) |
| 869 | # Interpolate rel pos if needed. |
| 870 | if rel_pos.shape[0] != max_rel_dist: |
| 871 | # Interpolate rel pos. |
| 872 | dtype = rel_pos.dtype |
| 873 | rel_pos = rel_pos.to(torch.float32) |
| 874 | rel_pos_resized = F.interpolate( |
| 875 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), |
| 876 | size=max_rel_dist, |
| 877 | mode="linear", |
| 878 | ).to(dtype) |
| 879 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) |
| 880 | else: |
| 881 | rel_pos_resized = rel_pos |
| 882 | |
| 883 | # Scale the coords with short length if shapes for q and k are different. |
| 884 | q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0) |
| 885 | k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0) |
| 886 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) |
| 887 | |
| 888 | return rel_pos_resized[relative_coords.long()] |
| 889 | |
| 890 | |
| 891 | def add_decomposed_rel_pos( |
| 892 | q: torch.Tensor, |
| 893 | rel_pos_h: torch.Tensor, |
| 894 | rel_pos_w: torch.Tensor, |
| 895 | q_size: Tuple[int, int], |
| 896 | k_size: Tuple[int, int], |
| 897 | ) -> torch.Tensor: |
| 898 | """ |
| 899 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. |
| 900 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 |
| 901 | Args: |
| 902 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). |
| 903 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. |
| 904 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. |
| 905 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). |
| 906 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). |
| 907 | |
| 908 | Returns: |
| 909 | attn (Tensor): attention map with added relative positional embeddings. |
| 910 | """ |
| 911 | q_h, q_w = q_size |
| 912 | k_h, k_w = k_size |
| 913 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) |
| 914 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) |
| 915 | |
| 916 | B, _, dim = q.shape |
| 917 | r_q = q.reshape(B, q_h, q_w, dim) |
| 918 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) |
| 919 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) |
| 920 | rel_h = rel_h.unsqueeze(-1) |
| 921 | rel_w = rel_w.unsqueeze(-2) |
| 922 | rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1) |
| 923 | rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w) |
| 924 | |
| 925 | return rel_h, rel_w |
| 926 | |
| 927 | |
| 928 | class PatchEmbed(nn.Module): |
| 929 | """ |
| 930 | Image to Patch Embedding. |
| 931 | """ |
| 932 | |
| 933 | def __init__( |
| 934 | self, |
| 935 | kernel_size: Tuple[int, int] = (16, 16), |
| 936 | stride: Tuple[int, int] = (16, 16), |
| 937 | padding: Tuple[int, int] = (0, 0), |
| 938 | in_chans: int = 3, |
| 939 | embed_dim: int = 768, |
| 940 | ) -> None: |
| 941 | """ |
| 942 | Args: |
| 943 | kernel_size (Tuple): kernel size of the projection layer. |
| 944 | stride (Tuple): stride of the projection layer. |
| 945 | padding (Tuple): padding size of the projection layer. |
| 946 | in_chans (int): Number of input image channels. |
| 947 | embed_dim (int): Patch embedding dimension. |
| 948 | """ |
| 949 | super().__init__() |
| 950 | |
| 951 | self.proj = nn.Conv2d( |
| 952 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding |
| 953 | ) |
| 954 | |
| 955 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 956 | x = self.proj(x) |
| 957 | # B C H W -> B H W C |
| 958 | x = x.permute(0, 2, 3, 1) |
| 959 | return x |
| 960 | |
| 961 | |
| 962 | def build_sam_vit_b(checkpoint=None): |
| 963 | return _build_sam( |
| 964 | encoder_embed_dim=768, |
| 965 | encoder_depth=12, |
| 966 | encoder_num_heads=12, |
| 967 | encoder_global_attn_indexes=[2, 5, 8, 11], |
| 968 | checkpoint=checkpoint, |
| 969 | ) |
| 970 | |
| 971 | def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16): |
| 972 | image_encoder = build_sam_vit_b(checkpoint).eval().to(dtype) |
| 973 | # sam = _apply_eval_dtype_sam(sam, dtype) |
| 974 | image_encoder = torch.compile(image_encoder, mode=compile_mode) |
| 975 | return image_encoder |
| 976 | |
| 977 | |
| 978 | def _build_sam( |
| 979 | encoder_embed_dim, |
| 980 | encoder_depth, |
| 981 | encoder_num_heads, |
| 982 | encoder_global_attn_indexes, |
| 983 | checkpoint=None, |
| 984 | ): |
| 985 | prompt_embed_dim = 256 |
| 986 | image_size = 1024 |
| 987 | vit_patch_size = 16 |
| 988 | image_embedding_size = image_size // vit_patch_size |
| 989 | image_encoder=ImageEncoderViT( |
| 990 | depth=encoder_depth, |
| 991 | embed_dim=encoder_embed_dim, |
| 992 | img_size=image_size, |
| 993 | mlp_ratio=4, |
| 994 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), |
| 995 | num_heads=encoder_num_heads, |
| 996 | patch_size=vit_patch_size, |
| 997 | qkv_bias=True, |
| 998 | use_rel_pos=True, |
| 999 | global_attn_indexes=encoder_global_attn_indexes, |
| 1000 | window_size=14, |
| 1001 | out_chans=prompt_embed_dim, |
| 1002 | ) |
| 1003 | image_encoder.eval() |
| 1004 | if checkpoint is not None: |
| 1005 | # with open(checkpoint, "rb") as f: |
| 1006 | state_dict = torch.load(checkpoint) |
| 1007 | # print(state_dict.keys()) |
| 1008 | # for key in state_dict: |
| 1009 | # image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False) |
| 1010 | # ocr-anyting |
| 1011 | # image_encoder.load_state_dict(state_dict, strict=True) |
| 1012 | # tob |
| 1013 | image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True) |
| 1014 | print(checkpoint) |
| 1015 | return image_encoder |