modeling_videomaev2.py
| 1 | # -------------------------------------------------------- |
| 2 | # Based on BEiT, timm, DINO and DeiT code bases |
| 3 | # https://github.com/microsoft/unilm/tree/master/beit |
| 4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm |
| 5 | # https://github.com/facebookresearch/deit |
| 6 | # https://github.com/facebookresearch/dino |
| 7 | # --------------------------------------------------------' |
| 8 | from functools import partial |
| 9 | import logging |
| 10 | logger = logging.getLogger(__name__) |
| 11 | |
| 12 | import numpy as np |
| 13 | import torch |
| 14 | import torch.nn as nn |
| 15 | import torch.nn.functional as F |
| 16 | import torch.utils.checkpoint as cp |
| 17 | |
| 18 | from transformers import AutoConfig, PreTrainedModel |
| 19 | |
| 20 | from timm.layers import drop_path, to_2tuple, trunc_normal_ |
| 21 | from .modeling_config import VideoMAEv2Config |
| 22 | |
| 23 | def _cfg(url='', **kwargs): |
| 24 | return { |
| 25 | 'url': url, |
| 26 | 'num_classes': 400, |
| 27 | 'input_size': (3, 224, 224), |
| 28 | 'pool_size': None, |
| 29 | 'crop_pct': .9, |
| 30 | 'interpolation': 'bicubic', |
| 31 | 'mean': (0.5, 0.5, 0.5), |
| 32 | 'std': (0.5, 0.5, 0.5), |
| 33 | **kwargs |
| 34 | } |
| 35 | |
| 36 | |
| 37 | class DropPath(nn.Module): |
| 38 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| 39 | """ |
| 40 | |
| 41 | def __init__(self, drop_prob=None): |
| 42 | super(DropPath, self).__init__() |
| 43 | self.drop_prob = drop_prob |
| 44 | |
| 45 | def forward(self, x): |
| 46 | return drop_path(x, self.drop_prob, self.training) |
| 47 | |
| 48 | def extra_repr(self) -> str: |
| 49 | return 'p={}'.format(self.drop_prob) |
| 50 | |
| 51 | |
| 52 | class Mlp(nn.Module): |
| 53 | |
| 54 | def __init__(self, |
| 55 | in_features, |
| 56 | hidden_features=None, |
| 57 | out_features=None, |
| 58 | act_layer=nn.GELU, |
| 59 | drop=0.): |
| 60 | super().__init__() |
| 61 | out_features = out_features or in_features |
| 62 | hidden_features = hidden_features or in_features |
| 63 | self.fc1 = nn.Linear(in_features, hidden_features) |
| 64 | self.act = act_layer() |
| 65 | self.fc2 = nn.Linear(hidden_features, out_features) |
| 66 | self.drop = nn.Dropout(drop) |
| 67 | |
| 68 | def forward(self, x): |
| 69 | x = self.fc1(x) |
| 70 | x = self.act(x) |
| 71 | # x = self.drop(x) |
| 72 | # commit this for the orignal BERT implement |
| 73 | x = self.fc2(x) |
| 74 | x = self.drop(x) |
| 75 | return x |
| 76 | |
| 77 | |
| 78 | class CosAttention(nn.Module): |
| 79 | |
| 80 | def __init__(self, |
| 81 | dim, |
| 82 | num_heads=8, |
| 83 | qkv_bias=False, |
| 84 | qk_scale=None, |
| 85 | attn_drop=0., |
| 86 | proj_drop=0., |
| 87 | attn_head_dim=None): |
| 88 | super().__init__() |
| 89 | self.num_heads = num_heads |
| 90 | head_dim = dim // num_heads |
| 91 | if attn_head_dim is not None: |
| 92 | head_dim = attn_head_dim |
| 93 | all_head_dim = head_dim * self.num_heads |
| 94 | # self.scale = qk_scale or head_dim**-0.5 |
| 95 | # DO NOT RENAME [self.scale] (for no weight decay) |
| 96 | if qk_scale is None: |
| 97 | self.scale = nn.Parameter( |
| 98 | torch.log(10 * torch.ones((num_heads, 1, 1))), |
| 99 | requires_grad=True) |
| 100 | else: |
| 101 | self.scale = qk_scale |
| 102 | |
| 103 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) |
| 104 | if qkv_bias: |
| 105 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) |
| 106 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) |
| 107 | else: |
| 108 | self.q_bias = None |
| 109 | self.v_bias = None |
| 110 | |
| 111 | self.attn_drop = nn.Dropout(attn_drop) |
| 112 | self.proj = nn.Linear(all_head_dim, dim) |
| 113 | self.proj_drop = nn.Dropout(proj_drop) |
| 114 | |
| 115 | def forward(self, x): |
| 116 | B, N, C = x.shape |
| 117 | qkv_bias = None |
| 118 | if self.q_bias is not None: |
| 119 | qkv_bias = torch.cat( |
| 120 | (self.q_bias, |
| 121 | torch.zeros_like(self.v_bias, |
| 122 | requires_grad=False), self.v_bias)) |
| 123 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) |
| 124 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) |
| 125 | q, k, v = qkv[0], qkv[1], qkv[ |
| 126 | 2] # make torchscript happy (cannot use tensor as tuple) |
| 127 | |
| 128 | attn = ( |
| 129 | F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) |
| 130 | |
| 131 | # torch.log(torch.tensor(1. / 0.01)) = 4.6052 |
| 132 | logit_scale = torch.clamp(self.scale, max=4.6052).exp() |
| 133 | |
| 134 | attn = attn * logit_scale |
| 135 | |
| 136 | attn = attn.softmax(dim=-1) |
| 137 | attn = self.attn_drop(attn) |
| 138 | |
| 139 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) |
| 140 | |
| 141 | x = self.proj(x) |
| 142 | x = self.proj_drop(x) |
| 143 | return x |
| 144 | |
| 145 | |
| 146 | class Attention(nn.Module): |
| 147 | |
| 148 | def __init__(self, |
| 149 | dim, |
| 150 | num_heads=8, |
| 151 | qkv_bias=False, |
| 152 | qk_scale=None, |
| 153 | attn_drop=0., |
| 154 | proj_drop=0., |
| 155 | attn_head_dim=None): |
| 156 | super().__init__() |
| 157 | self.num_heads = num_heads |
| 158 | head_dim = dim // num_heads |
| 159 | if attn_head_dim is not None: |
| 160 | head_dim = attn_head_dim |
| 161 | all_head_dim = head_dim * self.num_heads |
| 162 | self.scale = qk_scale or head_dim**-0.5 |
| 163 | |
| 164 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) |
| 165 | if qkv_bias: |
| 166 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) |
| 167 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) |
| 168 | else: |
| 169 | self.q_bias = None |
| 170 | self.v_bias = None |
| 171 | |
| 172 | self.attn_drop = nn.Dropout(attn_drop) |
| 173 | self.proj = nn.Linear(all_head_dim, dim) |
| 174 | self.proj_drop = nn.Dropout(proj_drop) |
| 175 | |
| 176 | def forward(self, x): |
| 177 | B, N, C = x.shape |
| 178 | qkv_bias = None |
| 179 | if self.q_bias is not None: |
| 180 | qkv_bias = torch.cat( |
| 181 | (self.q_bias, |
| 182 | torch.zeros_like(self.v_bias, |
| 183 | requires_grad=False), self.v_bias)) |
| 184 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) |
| 185 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) |
| 186 | q, k, v = qkv[0], qkv[1], qkv[ |
| 187 | 2] # make torchscript happy (cannot use tensor as tuple) |
| 188 | |
| 189 | q = q * self.scale |
| 190 | attn = (q @ k.transpose(-2, -1)) |
| 191 | |
| 192 | attn = attn.softmax(dim=-1) |
| 193 | attn = self.attn_drop(attn) |
| 194 | |
| 195 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) |
| 196 | |
| 197 | x = self.proj(x) |
| 198 | x = self.proj_drop(x) |
| 199 | return x |
| 200 | |
| 201 | |
| 202 | class Block(nn.Module): |
| 203 | |
| 204 | def __init__(self, |
| 205 | dim, |
| 206 | num_heads, |
| 207 | mlp_ratio=4., |
| 208 | qkv_bias=False, |
| 209 | qk_scale=None, |
| 210 | drop=0., |
| 211 | attn_drop=0., |
| 212 | drop_path=0., |
| 213 | init_values=None, |
| 214 | act_layer=nn.GELU, |
| 215 | norm_layer=nn.LayerNorm, |
| 216 | attn_head_dim=None, |
| 217 | cos_attn=False): |
| 218 | super().__init__() |
| 219 | self.norm1 = norm_layer(dim) |
| 220 | if cos_attn: |
| 221 | self.attn = CosAttention( |
| 222 | dim, |
| 223 | num_heads=num_heads, |
| 224 | qkv_bias=qkv_bias, |
| 225 | qk_scale=qk_scale, |
| 226 | attn_drop=attn_drop, |
| 227 | proj_drop=drop, |
| 228 | attn_head_dim=attn_head_dim) |
| 229 | else: |
| 230 | self.attn = Attention( |
| 231 | dim, |
| 232 | num_heads=num_heads, |
| 233 | qkv_bias=qkv_bias, |
| 234 | qk_scale=qk_scale, |
| 235 | attn_drop=attn_drop, |
| 236 | proj_drop=drop, |
| 237 | attn_head_dim=attn_head_dim) |
| 238 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here |
| 239 | self.drop_path = DropPath( |
| 240 | drop_path) if drop_path > 0. else nn.Identity() |
| 241 | self.norm2 = norm_layer(dim) |
| 242 | mlp_hidden_dim = int(dim * mlp_ratio) |
| 243 | self.mlp = Mlp( |
| 244 | in_features=dim, |
| 245 | hidden_features=mlp_hidden_dim, |
| 246 | act_layer=act_layer, |
| 247 | drop=drop) |
| 248 | |
| 249 | if init_values > 0: |
| 250 | self.gamma_1 = nn.Parameter( |
| 251 | init_values * torch.ones((dim)), requires_grad=True) |
| 252 | self.gamma_2 = nn.Parameter( |
| 253 | init_values * torch.ones((dim)), requires_grad=True) |
| 254 | else: |
| 255 | self.gamma_1, self.gamma_2 = None, None |
| 256 | |
| 257 | def forward(self, x): |
| 258 | if self.gamma_1 is None: |
| 259 | x = x + self.drop_path(self.attn(self.norm1(x))) |
| 260 | x = x + self.drop_path(self.mlp(self.norm2(x))) |
| 261 | else: |
| 262 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) |
| 263 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) |
| 264 | return x |
| 265 | |
| 266 | |
| 267 | class PatchEmbed(nn.Module): |
| 268 | """ Image to Patch Embedding |
| 269 | """ |
| 270 | |
| 271 | def __init__(self, |
| 272 | img_size=224, |
| 273 | patch_size=16, |
| 274 | in_chans=3, |
| 275 | embed_dim=768, |
| 276 | num_frames=16, |
| 277 | tubelet_size=2): |
| 278 | super().__init__() |
| 279 | img_size = to_2tuple(img_size) |
| 280 | patch_size = to_2tuple(patch_size) |
| 281 | num_spatial_patches = (img_size[0] // patch_size[0]) * ( |
| 282 | img_size[1] // patch_size[1]) |
| 283 | num_patches = num_spatial_patches * (num_frames // tubelet_size) |
| 284 | |
| 285 | self.img_size = img_size |
| 286 | self.tubelet_size = tubelet_size |
| 287 | self.patch_size = patch_size |
| 288 | self.num_patches = num_patches |
| 289 | self.proj = nn.Conv3d( |
| 290 | in_channels=in_chans, |
| 291 | out_channels=embed_dim, |
| 292 | kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]), |
| 293 | stride=(self.tubelet_size, patch_size[0], patch_size[1])) |
| 294 | |
| 295 | def forward(self, x, **kwargs): |
| 296 | B, C, T, H, W = x.shape |
| 297 | assert H == self.img_size[0] and W == self.img_size[ |
| 298 | 1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." |
| 299 | # b, c, l -> b, l, c |
| 300 | x = self.proj(x).flatten(2).transpose(1, 2) |
| 301 | return x |
| 302 | |
| 303 | |
| 304 | # sin-cos position encoding |
| 305 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31 |
| 306 | def get_sinusoid_encoding_table(n_position, d_hid): |
| 307 | ''' Sinusoid position encoding table ''' |
| 308 | |
| 309 | # TODO: make it with torch instead of numpy |
| 310 | def get_position_angle_vec(position): |
| 311 | return [ |
| 312 | position / np.power(10000, 2 * (hid_j // 2) / d_hid) |
| 313 | for hid_j in range(d_hid) |
| 314 | ] |
| 315 | |
| 316 | sinusoid_table = np.array( |
| 317 | [get_position_angle_vec(pos_i) for pos_i in range(n_position)]) |
| 318 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i |
| 319 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 |
| 320 | |
| 321 | return torch.tensor( |
| 322 | sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0) |
| 323 | |
| 324 | |
| 325 | class VisionTransformer(nn.Module): |
| 326 | """ Vision Transformer with support for patch or hybrid CNN input stage |
| 327 | """ |
| 328 | |
| 329 | def __init__(self, |
| 330 | img_size=224, |
| 331 | patch_size=16, |
| 332 | in_chans=3, |
| 333 | num_classes=1000, |
| 334 | embed_dim=768, |
| 335 | depth=12, |
| 336 | num_heads=12, |
| 337 | mlp_ratio=4., |
| 338 | qkv_bias=False, |
| 339 | qk_scale=None, |
| 340 | drop_rate=0., |
| 341 | attn_drop_rate=0., |
| 342 | drop_path_rate=0., |
| 343 | head_drop_rate=0., |
| 344 | norm_layer=nn.LayerNorm, |
| 345 | layer_norm_eps=1e-12, |
| 346 | init_values=0., |
| 347 | use_learnable_pos_emb=False, |
| 348 | init_scale=0., |
| 349 | num_frames=16, |
| 350 | tubelet_size=2, |
| 351 | use_mean_pooling=True, |
| 352 | with_cp=False, |
| 353 | cos_attn=False): |
| 354 | super().__init__() |
| 355 | self.num_classes = num_classes |
| 356 | # num_features for consistency with other models |
| 357 | self.num_features = self.embed_dim = embed_dim |
| 358 | self.tubelet_size = tubelet_size |
| 359 | self.patch_embed = PatchEmbed( |
| 360 | img_size=img_size, |
| 361 | patch_size=patch_size, |
| 362 | in_chans=in_chans, |
| 363 | embed_dim=embed_dim, |
| 364 | num_frames=num_frames, |
| 365 | tubelet_size=tubelet_size) |
| 366 | num_patches = self.patch_embed.num_patches |
| 367 | self.with_cp = with_cp |
| 368 | |
| 369 | norm_layer = partial(eval(norm_layer), eps=layer_norm_eps) |
| 370 | |
| 371 | if use_learnable_pos_emb: |
| 372 | self.pos_embed = nn.Parameter( |
| 373 | torch.zeros(1, num_patches, embed_dim)) |
| 374 | else: |
| 375 | # sine-cosine positional embeddings is on the way |
| 376 | self.pos_embed = get_sinusoid_encoding_table( |
| 377 | num_patches, embed_dim) |
| 378 | |
| 379 | self.pos_drop = nn.Dropout(p=drop_rate) |
| 380 | |
| 381 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) |
| 382 | ] # stochastic depth decay rule |
| 383 | self.blocks = nn.ModuleList([ |
| 384 | Block( |
| 385 | dim=embed_dim, |
| 386 | num_heads=num_heads, |
| 387 | mlp_ratio=mlp_ratio, |
| 388 | qkv_bias=qkv_bias, |
| 389 | qk_scale=qk_scale, |
| 390 | drop=drop_rate, |
| 391 | attn_drop=attn_drop_rate, |
| 392 | drop_path=dpr[i], |
| 393 | norm_layer=norm_layer, |
| 394 | init_values=init_values, |
| 395 | cos_attn=cos_attn) for i in range(depth) |
| 396 | ]) |
| 397 | self.norm = nn.Identity() if use_mean_pooling else norm_layer( |
| 398 | embed_dim) |
| 399 | self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None |
| 400 | self.head_dropout = nn.Dropout(head_drop_rate) |
| 401 | self.head = nn.Linear( |
| 402 | embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
| 403 | |
| 404 | if use_learnable_pos_emb: |
| 405 | trunc_normal_(self.pos_embed, std=.02) |
| 406 | |
| 407 | self.apply(self._init_weights) |
| 408 | if num_classes > 0: |
| 409 | self.head.weight.data.mul_(init_scale) |
| 410 | self.head.bias.data.mul_(init_scale) |
| 411 | |
| 412 | def _init_weights(self, m): |
| 413 | if isinstance(m, nn.Linear): |
| 414 | trunc_normal_(m.weight, std=.02) |
| 415 | if isinstance(m, nn.Linear) and m.bias is not None: |
| 416 | nn.init.constant_(m.bias, 0) |
| 417 | elif isinstance(m, nn.LayerNorm): |
| 418 | nn.init.constant_(m.bias, 0) |
| 419 | nn.init.constant_(m.weight, 1.0) |
| 420 | |
| 421 | def get_num_layers(self): |
| 422 | return len(self.blocks) |
| 423 | |
| 424 | @torch.jit.ignore |
| 425 | def no_weight_decay(self): |
| 426 | return {'pos_embed', 'cls_token'} |
| 427 | |
| 428 | def get_classifier(self): |
| 429 | return self.head |
| 430 | |
| 431 | def reset_classifier(self, num_classes, global_pool=''): |
| 432 | self.num_classes = num_classes |
| 433 | self.head = nn.Linear( |
| 434 | self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
| 435 | |
| 436 | def forward_features(self, x): |
| 437 | B = x.size(0) |
| 438 | |
| 439 | x = self.patch_embed(x) |
| 440 | |
| 441 | if self.pos_embed is not None: |
| 442 | x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to( |
| 443 | x.device).clone().detach() |
| 444 | x = self.pos_drop(x) |
| 445 | |
| 446 | for blk in self.blocks: |
| 447 | if self.with_cp: |
| 448 | x = cp.checkpoint(blk, x) |
| 449 | else: |
| 450 | x = blk(x) |
| 451 | |
| 452 | if self.fc_norm is not None: |
| 453 | return self.fc_norm(x.mean(1)) |
| 454 | else: |
| 455 | return self.norm(x[:, 0]) |
| 456 | |
| 457 | def forward(self, x): |
| 458 | x = self.forward_features(x) |
| 459 | x = self.head_dropout(x) |
| 460 | x = self.head(x) |
| 461 | return x |
| 462 | |
| 463 | |
| 464 | |
| 465 | |
| 466 | class VideoMAEv2(PreTrainedModel): |
| 467 | config_class = VideoMAEv2Config |
| 468 | def __init__(self, config=None): |
| 469 | super().__init__(config=config) |
| 470 | self.model_config = config.model_config |
| 471 | logger.info("Model config: {}".format(self.model_config)) |
| 472 | self.model = VisionTransformer(**self.model_config) |
| 473 | |
| 474 | def forward(self, pixel_values): |
| 475 | return self.model(pixel_values) |
| 476 | |
| 477 | def extract_features(self, pixel_values): |
| 478 | return self.model.forward_features(pixel_values) |
| 479 | def vit_small_patch16_224(pretrained=False, **kwargs): |
| 480 | model = VisionTransformer( |
| 481 | patch_size=16, |
| 482 | embed_dim=384, |
| 483 | depth=12, |
| 484 | num_heads=6, |
| 485 | mlp_ratio=4, |
| 486 | qkv_bias=True, |
| 487 | norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| 488 | **kwargs) |
| 489 | model.default_cfg = _cfg() |
| 490 | return model |
| 491 | |
| 492 | |
| 493 | |
| 494 | def vit_base_patch16_224(pretrained=False, **kwargs): |
| 495 | model = VisionTransformer( |
| 496 | patch_size=16, |
| 497 | embed_dim=768, |
| 498 | depth=12, |
| 499 | num_heads=12, |
| 500 | mlp_ratio=4, |
| 501 | qkv_bias=True, |
| 502 | norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| 503 | **kwargs) |
| 504 | model.default_cfg = _cfg() |
| 505 | return model |
| 506 | |
| 507 | |
| 508 | # @register_model |
| 509 | def vit_huge_patch16_224(pretrained=False, **kwargs): |
| 510 | model = VisionTransformer( |
| 511 | patch_size=16, |
| 512 | embed_dim=1280, |
| 513 | depth=32, |
| 514 | num_heads=16, |
| 515 | mlp_ratio=4, |
| 516 | qkv_bias=True, |
| 517 | norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| 518 | **kwargs) |
| 519 | model.default_cfg = _cfg() |
| 520 | return model |
| 521 | |
| 522 | |
| 523 | # @register_model |
| 524 | def vit_giant_patch14_224(pretrained=False, **kwargs): |
| 525 | model = VisionTransformer( |
| 526 | patch_size=14, |
| 527 | embed_dim=1408, |
| 528 | depth=40, |
| 529 | num_heads=16, |
| 530 | mlp_ratio=48 / 11, |
| 531 | qkv_bias=True, |
| 532 | norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| 533 | **kwargs) |
| 534 | model.default_cfg = _cfg() |
| 535 | return model |
| 536 | |