modeling_videomaev2.py
16.7 KB · 536 lines · python Raw
1 # --------------------------------------------------------
2 # Based on BEiT, timm, DINO and DeiT code bases
3 # https://github.com/microsoft/unilm/tree/master/beit
4 # https://github.com/rwightman/pytorch-image-models/tree/master/timm
5 # https://github.com/facebookresearch/deit
6 # https://github.com/facebookresearch/dino
7 # --------------------------------------------------------'
8 from functools import partial
9 import logging
10 logger = logging.getLogger(__name__)
11
12 import numpy as np
13 import torch
14 import torch.nn as nn
15 import torch.nn.functional as F
16 import torch.utils.checkpoint as cp
17
18 from transformers import AutoConfig, PreTrainedModel
19
20 from timm.layers import drop_path, to_2tuple, trunc_normal_
21 from .modeling_config import VideoMAEv2Config
22
23 def _cfg(url='', **kwargs):
24 return {
25 'url': url,
26 'num_classes': 400,
27 'input_size': (3, 224, 224),
28 'pool_size': None,
29 'crop_pct': .9,
30 'interpolation': 'bicubic',
31 'mean': (0.5, 0.5, 0.5),
32 'std': (0.5, 0.5, 0.5),
33 **kwargs
34 }
35
36
37 class DropPath(nn.Module):
38 """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
39 """
40
41 def __init__(self, drop_prob=None):
42 super(DropPath, self).__init__()
43 self.drop_prob = drop_prob
44
45 def forward(self, x):
46 return drop_path(x, self.drop_prob, self.training)
47
48 def extra_repr(self) -> str:
49 return 'p={}'.format(self.drop_prob)
50
51
52 class Mlp(nn.Module):
53
54 def __init__(self,
55 in_features,
56 hidden_features=None,
57 out_features=None,
58 act_layer=nn.GELU,
59 drop=0.):
60 super().__init__()
61 out_features = out_features or in_features
62 hidden_features = hidden_features or in_features
63 self.fc1 = nn.Linear(in_features, hidden_features)
64 self.act = act_layer()
65 self.fc2 = nn.Linear(hidden_features, out_features)
66 self.drop = nn.Dropout(drop)
67
68 def forward(self, x):
69 x = self.fc1(x)
70 x = self.act(x)
71 # x = self.drop(x)
72 # commit this for the orignal BERT implement
73 x = self.fc2(x)
74 x = self.drop(x)
75 return x
76
77
78 class CosAttention(nn.Module):
79
80 def __init__(self,
81 dim,
82 num_heads=8,
83 qkv_bias=False,
84 qk_scale=None,
85 attn_drop=0.,
86 proj_drop=0.,
87 attn_head_dim=None):
88 super().__init__()
89 self.num_heads = num_heads
90 head_dim = dim // num_heads
91 if attn_head_dim is not None:
92 head_dim = attn_head_dim
93 all_head_dim = head_dim * self.num_heads
94 # self.scale = qk_scale or head_dim**-0.5
95 # DO NOT RENAME [self.scale] (for no weight decay)
96 if qk_scale is None:
97 self.scale = nn.Parameter(
98 torch.log(10 * torch.ones((num_heads, 1, 1))),
99 requires_grad=True)
100 else:
101 self.scale = qk_scale
102
103 self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
104 if qkv_bias:
105 self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
106 self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
107 else:
108 self.q_bias = None
109 self.v_bias = None
110
111 self.attn_drop = nn.Dropout(attn_drop)
112 self.proj = nn.Linear(all_head_dim, dim)
113 self.proj_drop = nn.Dropout(proj_drop)
114
115 def forward(self, x):
116 B, N, C = x.shape
117 qkv_bias = None
118 if self.q_bias is not None:
119 qkv_bias = torch.cat(
120 (self.q_bias,
121 torch.zeros_like(self.v_bias,
122 requires_grad=False), self.v_bias))
123 qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
124 qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
125 q, k, v = qkv[0], qkv[1], qkv[
126 2] # make torchscript happy (cannot use tensor as tuple)
127
128 attn = (
129 F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
130
131 # torch.log(torch.tensor(1. / 0.01)) = 4.6052
132 logit_scale = torch.clamp(self.scale, max=4.6052).exp()
133
134 attn = attn * logit_scale
135
136 attn = attn.softmax(dim=-1)
137 attn = self.attn_drop(attn)
138
139 x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
140
141 x = self.proj(x)
142 x = self.proj_drop(x)
143 return x
144
145
146 class Attention(nn.Module):
147
148 def __init__(self,
149 dim,
150 num_heads=8,
151 qkv_bias=False,
152 qk_scale=None,
153 attn_drop=0.,
154 proj_drop=0.,
155 attn_head_dim=None):
156 super().__init__()
157 self.num_heads = num_heads
158 head_dim = dim // num_heads
159 if attn_head_dim is not None:
160 head_dim = attn_head_dim
161 all_head_dim = head_dim * self.num_heads
162 self.scale = qk_scale or head_dim**-0.5
163
164 self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
165 if qkv_bias:
166 self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
167 self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
168 else:
169 self.q_bias = None
170 self.v_bias = None
171
172 self.attn_drop = nn.Dropout(attn_drop)
173 self.proj = nn.Linear(all_head_dim, dim)
174 self.proj_drop = nn.Dropout(proj_drop)
175
176 def forward(self, x):
177 B, N, C = x.shape
178 qkv_bias = None
179 if self.q_bias is not None:
180 qkv_bias = torch.cat(
181 (self.q_bias,
182 torch.zeros_like(self.v_bias,
183 requires_grad=False), self.v_bias))
184 qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
185 qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
186 q, k, v = qkv[0], qkv[1], qkv[
187 2] # make torchscript happy (cannot use tensor as tuple)
188
189 q = q * self.scale
190 attn = (q @ k.transpose(-2, -1))
191
192 attn = attn.softmax(dim=-1)
193 attn = self.attn_drop(attn)
194
195 x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
196
197 x = self.proj(x)
198 x = self.proj_drop(x)
199 return x
200
201
202 class Block(nn.Module):
203
204 def __init__(self,
205 dim,
206 num_heads,
207 mlp_ratio=4.,
208 qkv_bias=False,
209 qk_scale=None,
210 drop=0.,
211 attn_drop=0.,
212 drop_path=0.,
213 init_values=None,
214 act_layer=nn.GELU,
215 norm_layer=nn.LayerNorm,
216 attn_head_dim=None,
217 cos_attn=False):
218 super().__init__()
219 self.norm1 = norm_layer(dim)
220 if cos_attn:
221 self.attn = CosAttention(
222 dim,
223 num_heads=num_heads,
224 qkv_bias=qkv_bias,
225 qk_scale=qk_scale,
226 attn_drop=attn_drop,
227 proj_drop=drop,
228 attn_head_dim=attn_head_dim)
229 else:
230 self.attn = Attention(
231 dim,
232 num_heads=num_heads,
233 qkv_bias=qkv_bias,
234 qk_scale=qk_scale,
235 attn_drop=attn_drop,
236 proj_drop=drop,
237 attn_head_dim=attn_head_dim)
238 # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
239 self.drop_path = DropPath(
240 drop_path) if drop_path > 0. else nn.Identity()
241 self.norm2 = norm_layer(dim)
242 mlp_hidden_dim = int(dim * mlp_ratio)
243 self.mlp = Mlp(
244 in_features=dim,
245 hidden_features=mlp_hidden_dim,
246 act_layer=act_layer,
247 drop=drop)
248
249 if init_values > 0:
250 self.gamma_1 = nn.Parameter(
251 init_values * torch.ones((dim)), requires_grad=True)
252 self.gamma_2 = nn.Parameter(
253 init_values * torch.ones((dim)), requires_grad=True)
254 else:
255 self.gamma_1, self.gamma_2 = None, None
256
257 def forward(self, x):
258 if self.gamma_1 is None:
259 x = x + self.drop_path(self.attn(self.norm1(x)))
260 x = x + self.drop_path(self.mlp(self.norm2(x)))
261 else:
262 x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
263 x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
264 return x
265
266
267 class PatchEmbed(nn.Module):
268 """ Image to Patch Embedding
269 """
270
271 def __init__(self,
272 img_size=224,
273 patch_size=16,
274 in_chans=3,
275 embed_dim=768,
276 num_frames=16,
277 tubelet_size=2):
278 super().__init__()
279 img_size = to_2tuple(img_size)
280 patch_size = to_2tuple(patch_size)
281 num_spatial_patches = (img_size[0] // patch_size[0]) * (
282 img_size[1] // patch_size[1])
283 num_patches = num_spatial_patches * (num_frames // tubelet_size)
284
285 self.img_size = img_size
286 self.tubelet_size = tubelet_size
287 self.patch_size = patch_size
288 self.num_patches = num_patches
289 self.proj = nn.Conv3d(
290 in_channels=in_chans,
291 out_channels=embed_dim,
292 kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
293 stride=(self.tubelet_size, patch_size[0], patch_size[1]))
294
295 def forward(self, x, **kwargs):
296 B, C, T, H, W = x.shape
297 assert H == self.img_size[0] and W == self.img_size[
298 1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
299 # b, c, l -> b, l, c
300 x = self.proj(x).flatten(2).transpose(1, 2)
301 return x
302
303
304 # sin-cos position encoding
305 # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
306 def get_sinusoid_encoding_table(n_position, d_hid):
307 ''' Sinusoid position encoding table '''
308
309 # TODO: make it with torch instead of numpy
310 def get_position_angle_vec(position):
311 return [
312 position / np.power(10000, 2 * (hid_j // 2) / d_hid)
313 for hid_j in range(d_hid)
314 ]
315
316 sinusoid_table = np.array(
317 [get_position_angle_vec(pos_i) for pos_i in range(n_position)])
318 sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
319 sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
320
321 return torch.tensor(
322 sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
323
324
325 class VisionTransformer(nn.Module):
326 """ Vision Transformer with support for patch or hybrid CNN input stage
327 """
328
329 def __init__(self,
330 img_size=224,
331 patch_size=16,
332 in_chans=3,
333 num_classes=1000,
334 embed_dim=768,
335 depth=12,
336 num_heads=12,
337 mlp_ratio=4.,
338 qkv_bias=False,
339 qk_scale=None,
340 drop_rate=0.,
341 attn_drop_rate=0.,
342 drop_path_rate=0.,
343 head_drop_rate=0.,
344 norm_layer=nn.LayerNorm,
345 layer_norm_eps=1e-12,
346 init_values=0.,
347 use_learnable_pos_emb=False,
348 init_scale=0.,
349 num_frames=16,
350 tubelet_size=2,
351 use_mean_pooling=True,
352 with_cp=False,
353 cos_attn=False):
354 super().__init__()
355 self.num_classes = num_classes
356 # num_features for consistency with other models
357 self.num_features = self.embed_dim = embed_dim
358 self.tubelet_size = tubelet_size
359 self.patch_embed = PatchEmbed(
360 img_size=img_size,
361 patch_size=patch_size,
362 in_chans=in_chans,
363 embed_dim=embed_dim,
364 num_frames=num_frames,
365 tubelet_size=tubelet_size)
366 num_patches = self.patch_embed.num_patches
367 self.with_cp = with_cp
368
369 norm_layer = partial(eval(norm_layer), eps=layer_norm_eps)
370
371 if use_learnable_pos_emb:
372 self.pos_embed = nn.Parameter(
373 torch.zeros(1, num_patches, embed_dim))
374 else:
375 # sine-cosine positional embeddings is on the way
376 self.pos_embed = get_sinusoid_encoding_table(
377 num_patches, embed_dim)
378
379 self.pos_drop = nn.Dropout(p=drop_rate)
380
381 dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
382 ] # stochastic depth decay rule
383 self.blocks = nn.ModuleList([
384 Block(
385 dim=embed_dim,
386 num_heads=num_heads,
387 mlp_ratio=mlp_ratio,
388 qkv_bias=qkv_bias,
389 qk_scale=qk_scale,
390 drop=drop_rate,
391 attn_drop=attn_drop_rate,
392 drop_path=dpr[i],
393 norm_layer=norm_layer,
394 init_values=init_values,
395 cos_attn=cos_attn) for i in range(depth)
396 ])
397 self.norm = nn.Identity() if use_mean_pooling else norm_layer(
398 embed_dim)
399 self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
400 self.head_dropout = nn.Dropout(head_drop_rate)
401 self.head = nn.Linear(
402 embed_dim, num_classes) if num_classes > 0 else nn.Identity()
403
404 if use_learnable_pos_emb:
405 trunc_normal_(self.pos_embed, std=.02)
406
407 self.apply(self._init_weights)
408 if num_classes > 0:
409 self.head.weight.data.mul_(init_scale)
410 self.head.bias.data.mul_(init_scale)
411
412 def _init_weights(self, m):
413 if isinstance(m, nn.Linear):
414 trunc_normal_(m.weight, std=.02)
415 if isinstance(m, nn.Linear) and m.bias is not None:
416 nn.init.constant_(m.bias, 0)
417 elif isinstance(m, nn.LayerNorm):
418 nn.init.constant_(m.bias, 0)
419 nn.init.constant_(m.weight, 1.0)
420
421 def get_num_layers(self):
422 return len(self.blocks)
423
424 @torch.jit.ignore
425 def no_weight_decay(self):
426 return {'pos_embed', 'cls_token'}
427
428 def get_classifier(self):
429 return self.head
430
431 def reset_classifier(self, num_classes, global_pool=''):
432 self.num_classes = num_classes
433 self.head = nn.Linear(
434 self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
435
436 def forward_features(self, x):
437 B = x.size(0)
438
439 x = self.patch_embed(x)
440
441 if self.pos_embed is not None:
442 x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(
443 x.device).clone().detach()
444 x = self.pos_drop(x)
445
446 for blk in self.blocks:
447 if self.with_cp:
448 x = cp.checkpoint(blk, x)
449 else:
450 x = blk(x)
451
452 if self.fc_norm is not None:
453 return self.fc_norm(x.mean(1))
454 else:
455 return self.norm(x[:, 0])
456
457 def forward(self, x):
458 x = self.forward_features(x)
459 x = self.head_dropout(x)
460 x = self.head(x)
461 return x
462
463
464
465
466 class VideoMAEv2(PreTrainedModel):
467 config_class = VideoMAEv2Config
468 def __init__(self, config=None):
469 super().__init__(config=config)
470 self.model_config = config.model_config
471 logger.info("Model config: {}".format(self.model_config))
472 self.model = VisionTransformer(**self.model_config)
473
474 def forward(self, pixel_values):
475 return self.model(pixel_values)
476
477 def extract_features(self, pixel_values):
478 return self.model.forward_features(pixel_values)
479 def vit_small_patch16_224(pretrained=False, **kwargs):
480 model = VisionTransformer(
481 patch_size=16,
482 embed_dim=384,
483 depth=12,
484 num_heads=6,
485 mlp_ratio=4,
486 qkv_bias=True,
487 norm_layer=partial(nn.LayerNorm, eps=1e-6),
488 **kwargs)
489 model.default_cfg = _cfg()
490 return model
491
492
493
494 def vit_base_patch16_224(pretrained=False, **kwargs):
495 model = VisionTransformer(
496 patch_size=16,
497 embed_dim=768,
498 depth=12,
499 num_heads=12,
500 mlp_ratio=4,
501 qkv_bias=True,
502 norm_layer=partial(nn.LayerNorm, eps=1e-6),
503 **kwargs)
504 model.default_cfg = _cfg()
505 return model
506
507
508 # @register_model
509 def vit_huge_patch16_224(pretrained=False, **kwargs):
510 model = VisionTransformer(
511 patch_size=16,
512 embed_dim=1280,
513 depth=32,
514 num_heads=16,
515 mlp_ratio=4,
516 qkv_bias=True,
517 norm_layer=partial(nn.LayerNorm, eps=1e-6),
518 **kwargs)
519 model.default_cfg = _cfg()
520 return model
521
522
523 # @register_model
524 def vit_giant_patch14_224(pretrained=False, **kwargs):
525 model = VisionTransformer(
526 patch_size=14,
527 embed_dim=1408,
528 depth=40,
529 num_heads=16,
530 mlp_ratio=48 / 11,
531 qkv_bias=True,
532 norm_layer=partial(nn.LayerNorm, eps=1e-6),
533 **kwargs)
534 model.default_cfg = _cfg()
535 return model
536