BEN2.py
52.3 KB · 1402 lines · python Raw
1
2 import math
3 import torch
4 import torch.nn as nn
5 import torch.nn.functional as F
6 from einops import rearrange
7 import torch.utils.checkpoint as checkpoint
8 import numpy as np
9 from timm.models.layers import DropPath, to_2tuple, trunc_normal_
10 from PIL import Image, ImageOps
11 from torchvision import transforms
12 import numpy as np
13 import random
14 import cv2
15 import os
16 import subprocess
17 import time
18 import tempfile
19
20
21
22
23 def set_random_seed(seed):
24 random.seed(seed)
25 np.random.seed(seed)
26 torch.manual_seed(seed)
27 torch.cuda.manual_seed(seed)
28 torch.cuda.manual_seed_all(seed)
29 torch.backends.cudnn.deterministic = True
30 torch.backends.cudnn.benchmark = False
31 set_random_seed(9)
32
33
34 torch.set_float32_matmul_precision('highest')
35
36
37
38 class Mlp(nn.Module):
39 """ Multilayer perceptron."""
40
41 def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
42 super().__init__()
43 out_features = out_features or in_features
44 hidden_features = hidden_features or in_features
45 self.fc1 = nn.Linear(in_features, hidden_features)
46 self.act = act_layer()
47 self.fc2 = nn.Linear(hidden_features, out_features)
48 self.drop = nn.Dropout(drop)
49
50 def forward(self, x):
51 x = self.fc1(x)
52 x = self.act(x)
53 x = self.drop(x)
54 x = self.fc2(x)
55 x = self.drop(x)
56 return x
57
58
59 def window_partition(x, window_size):
60 """
61 Args:
62 x: (B, H, W, C)
63 window_size (int): window size
64 Returns:
65 windows: (num_windows*B, window_size, window_size, C)
66 """
67 B, H, W, C = x.shape
68 x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
69 windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
70 return windows
71
72
73 def window_reverse(windows, window_size, H, W):
74 """
75 Args:
76 windows: (num_windows*B, window_size, window_size, C)
77 window_size (int): Window size
78 H (int): Height of image
79 W (int): Width of image
80 Returns:
81 x: (B, H, W, C)
82 """
83 B = int(windows.shape[0] / (H * W / window_size / window_size))
84 x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
85 x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
86 return x
87
88
89 class WindowAttention(nn.Module):
90 """ Window based multi-head self attention (W-MSA) module with relative position bias.
91 It supports both of shifted and non-shifted window.
92 Args:
93 dim (int): Number of input channels.
94 window_size (tuple[int]): The height and width of the window.
95 num_heads (int): Number of attention heads.
96 qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
97 qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
98 attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
99 proj_drop (float, optional): Dropout ratio of output. Default: 0.0
100 """
101
102 def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
103
104 super().__init__()
105 self.dim = dim
106 self.window_size = window_size # Wh, Ww
107 self.num_heads = num_heads
108 head_dim = dim // num_heads
109 self.scale = qk_scale or head_dim ** -0.5
110
111 # define a parameter table of relative position bias
112 self.relative_position_bias_table = nn.Parameter(
113 torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
114
115 # get pair-wise relative position index for each token inside the window
116 coords_h = torch.arange(self.window_size[0])
117 coords_w = torch.arange(self.window_size[1])
118 coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
119 coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
120 relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
121 relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
122 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
123 relative_coords[:, :, 1] += self.window_size[1] - 1
124 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
125 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
126 self.register_buffer("relative_position_index", relative_position_index)
127
128 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
129 self.attn_drop = nn.Dropout(attn_drop)
130 self.proj = nn.Linear(dim, dim)
131 self.proj_drop = nn.Dropout(proj_drop)
132
133 trunc_normal_(self.relative_position_bias_table, std=.02)
134 self.softmax = nn.Softmax(dim=-1)
135
136 def forward(self, x, mask=None):
137 """ Forward function.
138 Args:
139 x: input features with shape of (num_windows*B, N, C)
140 mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
141 """
142 B_, N, C = x.shape
143 qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
144 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
145
146 q = q * self.scale
147 attn = (q @ k.transpose(-2, -1))
148
149 relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
150 self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
151 relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
152 attn = attn + relative_position_bias.unsqueeze(0)
153
154 if mask is not None:
155 nW = mask.shape[0]
156 attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
157 attn = attn.view(-1, self.num_heads, N, N)
158 attn = self.softmax(attn)
159 else:
160 attn = self.softmax(attn)
161
162 attn = self.attn_drop(attn)
163
164 x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
165 x = self.proj(x)
166 x = self.proj_drop(x)
167 return x
168
169
170 class SwinTransformerBlock(nn.Module):
171 """ Swin Transformer Block.
172 Args:
173 dim (int): Number of input channels.
174 num_heads (int): Number of attention heads.
175 window_size (int): Window size.
176 shift_size (int): Shift size for SW-MSA.
177 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
178 qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
179 qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
180 drop (float, optional): Dropout rate. Default: 0.0
181 attn_drop (float, optional): Attention dropout rate. Default: 0.0
182 drop_path (float, optional): Stochastic depth rate. Default: 0.0
183 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
184 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
185 """
186
187 def __init__(self, dim, num_heads, window_size=7, shift_size=0,
188 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
189 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
190 super().__init__()
191 self.dim = dim
192 self.num_heads = num_heads
193 self.window_size = window_size
194 self.shift_size = shift_size
195 self.mlp_ratio = mlp_ratio
196 assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
197
198 self.norm1 = norm_layer(dim)
199 self.attn = WindowAttention(
200 dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
201 qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
202
203 self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
204 self.norm2 = norm_layer(dim)
205 mlp_hidden_dim = int(dim * mlp_ratio)
206 self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
207
208 self.H = None
209 self.W = None
210
211 def forward(self, x, mask_matrix):
212 """ Forward function.
213 Args:
214 x: Input feature, tensor size (B, H*W, C).
215 H, W: Spatial resolution of the input feature.
216 mask_matrix: Attention mask for cyclic shift.
217 """
218 B, L, C = x.shape
219 H, W = self.H, self.W
220 assert L == H * W, "input feature has wrong size"
221
222 shortcut = x
223 x = self.norm1(x)
224 x = x.view(B, H, W, C)
225
226 # pad feature maps to multiples of window size
227 pad_l = pad_t = 0
228 pad_r = (self.window_size - W % self.window_size) % self.window_size
229 pad_b = (self.window_size - H % self.window_size) % self.window_size
230 x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
231 _, Hp, Wp, _ = x.shape
232
233 # cyclic shift
234 if self.shift_size > 0:
235 shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
236 attn_mask = mask_matrix
237 else:
238 shifted_x = x
239 attn_mask = None
240
241 # partition windows
242 x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
243 x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
244
245 # W-MSA/SW-MSA
246 attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
247
248 # merge windows
249 attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
250 shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
251
252 # reverse cyclic shift
253 if self.shift_size > 0:
254 x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
255 else:
256 x = shifted_x
257
258 if pad_r > 0 or pad_b > 0:
259 x = x[:, :H, :W, :].contiguous()
260
261 x = x.view(B, H * W, C)
262
263 # FFN
264 x = shortcut + self.drop_path(x)
265 x = x + self.drop_path(self.mlp(self.norm2(x)))
266
267 return x
268
269
270 class PatchMerging(nn.Module):
271 """ Patch Merging Layer
272 Args:
273 dim (int): Number of input channels.
274 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
275 """
276 def __init__(self, dim, norm_layer=nn.LayerNorm):
277 super().__init__()
278 self.dim = dim
279 self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
280 self.norm = norm_layer(4 * dim)
281
282 def forward(self, x, H, W):
283 """ Forward function.
284 Args:
285 x: Input feature, tensor size (B, H*W, C).
286 H, W: Spatial resolution of the input feature.
287 """
288 B, L, C = x.shape
289 assert L == H * W, "input feature has wrong size"
290
291 x = x.view(B, H, W, C)
292
293 # padding
294 pad_input = (H % 2 == 1) or (W % 2 == 1)
295 if pad_input:
296 x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
297
298 x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
299 x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
300 x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
301 x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
302 x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
303 x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
304
305 x = self.norm(x)
306 x = self.reduction(x)
307
308 return x
309
310
311 class BasicLayer(nn.Module):
312 """ A basic Swin Transformer layer for one stage.
313 Args:
314 dim (int): Number of feature channels
315 depth (int): Depths of this stage.
316 num_heads (int): Number of attention head.
317 window_size (int): Local window size. Default: 7.
318 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
319 qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
320 qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
321 drop (float, optional): Dropout rate. Default: 0.0
322 attn_drop (float, optional): Attention dropout rate. Default: 0.0
323 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
324 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
325 downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
326 use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
327 """
328
329 def __init__(self,
330 dim,
331 depth,
332 num_heads,
333 window_size=7,
334 mlp_ratio=4.,
335 qkv_bias=True,
336 qk_scale=None,
337 drop=0.,
338 attn_drop=0.,
339 drop_path=0.,
340 norm_layer=nn.LayerNorm,
341 downsample=None,
342 use_checkpoint=False):
343 super().__init__()
344 self.window_size = window_size
345 self.shift_size = window_size // 2
346 self.depth = depth
347 self.use_checkpoint = use_checkpoint
348
349 # build blocks
350 self.blocks = nn.ModuleList([
351 SwinTransformerBlock(
352 dim=dim,
353 num_heads=num_heads,
354 window_size=window_size,
355 shift_size=0 if (i % 2 == 0) else window_size // 2,
356 mlp_ratio=mlp_ratio,
357 qkv_bias=qkv_bias,
358 qk_scale=qk_scale,
359 drop=drop,
360 attn_drop=attn_drop,
361 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
362 norm_layer=norm_layer)
363 for i in range(depth)])
364
365 # patch merging layer
366 if downsample is not None:
367 self.downsample = downsample(dim=dim, norm_layer=norm_layer)
368 else:
369 self.downsample = None
370
371 def forward(self, x, H, W):
372 """ Forward function.
373 Args:
374 x: Input feature, tensor size (B, H*W, C).
375 H, W: Spatial resolution of the input feature.
376 """
377
378 # calculate attention mask for SW-MSA
379 Hp = int(np.ceil(H / self.window_size)) * self.window_size
380 Wp = int(np.ceil(W / self.window_size)) * self.window_size
381 img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
382 h_slices = (slice(0, -self.window_size),
383 slice(-self.window_size, -self.shift_size),
384 slice(-self.shift_size, None))
385 w_slices = (slice(0, -self.window_size),
386 slice(-self.window_size, -self.shift_size),
387 slice(-self.shift_size, None))
388 cnt = 0
389 for h in h_slices:
390 for w in w_slices:
391 img_mask[:, h, w, :] = cnt
392 cnt += 1
393
394 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
395 mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
396 attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
397 attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
398
399 for blk in self.blocks:
400 blk.H, blk.W = H, W
401 if self.use_checkpoint:
402 x = checkpoint.checkpoint(blk, x, attn_mask)
403 else:
404 x = blk(x, attn_mask)
405 if self.downsample is not None:
406 x_down = self.downsample(x, H, W)
407 Wh, Ww = (H + 1) // 2, (W + 1) // 2
408 return x, H, W, x_down, Wh, Ww
409 else:
410 return x, H, W, x, H, W
411
412
413 class PatchEmbed(nn.Module):
414 """ Image to Patch Embedding
415 Args:
416 patch_size (int): Patch token size. Default: 4.
417 in_chans (int): Number of input image channels. Default: 3.
418 embed_dim (int): Number of linear projection output channels. Default: 96.
419 norm_layer (nn.Module, optional): Normalization layer. Default: None
420 """
421
422 def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
423 super().__init__()
424 patch_size = to_2tuple(patch_size)
425 self.patch_size = patch_size
426
427 self.in_chans = in_chans
428 self.embed_dim = embed_dim
429
430 self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
431 if norm_layer is not None:
432 self.norm = norm_layer(embed_dim)
433 else:
434 self.norm = None
435
436 def forward(self, x):
437 """Forward function."""
438 # padding
439 _, _, H, W = x.size()
440 if W % self.patch_size[1] != 0:
441 x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
442 if H % self.patch_size[0] != 0:
443 x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
444
445 x = self.proj(x) # B C Wh Ww
446 if self.norm is not None:
447 Wh, Ww = x.size(2), x.size(3)
448 x = x.flatten(2).transpose(1, 2)
449 x = self.norm(x)
450 x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
451
452 return x
453
454
455 class SwinTransformer(nn.Module):
456 """ Swin Transformer backbone.
457 A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
458 https://arxiv.org/pdf/2103.14030
459 Args:
460 pretrain_img_size (int): Input image size for training the pretrained model,
461 used in absolute postion embedding. Default 224.
462 patch_size (int | tuple(int)): Patch size. Default: 4.
463 in_chans (int): Number of input image channels. Default: 3.
464 embed_dim (int): Number of linear projection output channels. Default: 96.
465 depths (tuple[int]): Depths of each Swin Transformer stage.
466 num_heads (tuple[int]): Number of attention head of each stage.
467 window_size (int): Window size. Default: 7.
468 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
469 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
470 qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
471 drop_rate (float): Dropout rate.
472 attn_drop_rate (float): Attention dropout rate. Default: 0.
473 drop_path_rate (float): Stochastic depth rate. Default: 0.2.
474 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
475 ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
476 patch_norm (bool): If True, add normalization after patch embedding. Default: True.
477 out_indices (Sequence[int]): Output from which stages.
478 frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
479 -1 means not freezing any parameters.
480 use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
481 """
482
483 def __init__(self,
484 pretrain_img_size=224,
485 patch_size=4,
486 in_chans=3,
487 embed_dim=96,
488 depths=[2, 2, 6, 2],
489 num_heads=[3, 6, 12, 24],
490 window_size=7,
491 mlp_ratio=4.,
492 qkv_bias=True,
493 qk_scale=None,
494 drop_rate=0.,
495 attn_drop_rate=0.,
496 drop_path_rate=0.2,
497 norm_layer=nn.LayerNorm,
498 ape=False,
499 patch_norm=True,
500 out_indices=(0, 1, 2, 3),
501 frozen_stages=-1,
502 use_checkpoint=False):
503 super().__init__()
504
505 self.pretrain_img_size = pretrain_img_size
506 self.num_layers = len(depths)
507 self.embed_dim = embed_dim
508 self.ape = ape
509 self.patch_norm = patch_norm
510 self.out_indices = out_indices
511 self.frozen_stages = frozen_stages
512
513 # split image into non-overlapping patches
514 self.patch_embed = PatchEmbed(
515 patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
516 norm_layer=norm_layer if self.patch_norm else None)
517
518 # absolute position embedding
519 if self.ape:
520 pretrain_img_size = to_2tuple(pretrain_img_size)
521 patch_size = to_2tuple(patch_size)
522 patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
523
524 self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
525 trunc_normal_(self.absolute_pos_embed, std=.02)
526
527 self.pos_drop = nn.Dropout(p=drop_rate)
528
529 # stochastic depth
530 dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
531
532 # build layers
533 self.layers = nn.ModuleList()
534 for i_layer in range(self.num_layers):
535 layer = BasicLayer(
536 dim=int(embed_dim * 2 ** i_layer),
537 depth=depths[i_layer],
538 num_heads=num_heads[i_layer],
539 window_size=window_size,
540 mlp_ratio=mlp_ratio,
541 qkv_bias=qkv_bias,
542 qk_scale=qk_scale,
543 drop=drop_rate,
544 attn_drop=attn_drop_rate,
545 drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
546 norm_layer=norm_layer,
547 downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
548 use_checkpoint=use_checkpoint)
549 self.layers.append(layer)
550
551 num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
552 self.num_features = num_features
553
554 # add a norm layer for each output
555 for i_layer in out_indices:
556 layer = norm_layer(num_features[i_layer])
557 layer_name = f'norm{i_layer}'
558 self.add_module(layer_name, layer)
559
560 self._freeze_stages()
561
562 def _freeze_stages(self):
563 if self.frozen_stages >= 0:
564 self.patch_embed.eval()
565 for param in self.patch_embed.parameters():
566 param.requires_grad = False
567
568 if self.frozen_stages >= 1 and self.ape:
569 self.absolute_pos_embed.requires_grad = False
570
571 if self.frozen_stages >= 2:
572 self.pos_drop.eval()
573 for i in range(0, self.frozen_stages - 1):
574 m = self.layers[i]
575 m.eval()
576 for param in m.parameters():
577 param.requires_grad = False
578
579
580 def forward(self, x):
581
582 x = self.patch_embed(x)
583
584 Wh, Ww = x.size(2), x.size(3)
585 if self.ape:
586 # interpolate the position embedding to the corresponding size
587 absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
588 x = (x + absolute_pos_embed) # B Wh*Ww C
589
590 outs = [x.contiguous()]
591 x = x.flatten(2).transpose(1, 2)
592 x = self.pos_drop(x)
593
594
595 for i in range(self.num_layers):
596 layer = self.layers[i]
597 x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
598
599
600 if i in self.out_indices:
601 norm_layer = getattr(self, f'norm{i}')
602 x_out = norm_layer(x_out)
603
604 out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
605 outs.append(out)
606
607
608
609 return tuple(outs)
610
611
612
613
614
615
616
617
618 def get_activation_fn(activation):
619 """Return an activation function given a string"""
620 if activation == "gelu":
621 return F.gelu
622
623 raise RuntimeError(F"activation should be gelu, not {activation}.")
624
625
626 def make_cbr(in_dim, out_dim):
627 return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
628
629
630 def make_cbg(in_dim, out_dim):
631 return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU())
632
633
634 def rescale_to(x, scale_factor: float = 2, interpolation='nearest'):
635 return F.interpolate(x, scale_factor=scale_factor, mode=interpolation)
636
637
638 def resize_as(x, y, interpolation='bilinear'):
639 return F.interpolate(x, size=y.shape[-2:], mode=interpolation)
640
641
642 def image2patches(x):
643 """b c (hg h) (wg w) -> (hg wg b) c h w"""
644 x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2 )
645 return x
646
647
648 def patches2image(x):
649 """(hg wg b) c h w -> b c (hg h) (wg w)"""
650 x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
651 return x
652
653
654
655 class PositionEmbeddingSine:
656 def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
657 super().__init__()
658 self.num_pos_feats = num_pos_feats
659 self.temperature = temperature
660 self.normalize = normalize
661 if scale is not None and normalize is False:
662 raise ValueError("normalize should be True if scale is passed")
663 if scale is None:
664 scale = 2 * math.pi
665 self.scale = scale
666 self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
667
668 def __call__(self, b, h, w):
669 device = self.dim_t.device
670 mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
671 assert mask is not None
672 not_mask = ~mask
673 y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
674 x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
675 if self.normalize:
676 eps = 1e-6
677 y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
678 x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
679
680 dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
681 pos_x = x_embed[:, :, :, None] / dim_t
682 pos_y = y_embed[:, :, :, None] / dim_t
683
684 pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
685 pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
686
687 return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
688
689
690
691 class PositionEmbeddingSine:
692 def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
693 super().__init__()
694 self.num_pos_feats = num_pos_feats
695 self.temperature = temperature
696 self.normalize = normalize
697 if scale is not None and normalize is False:
698 raise ValueError("normalize should be True if scale is passed")
699 if scale is None:
700 scale = 2 * math.pi
701 self.scale = scale
702 self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
703
704 def __call__(self, b, h, w):
705 device = self.dim_t.device
706 mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
707 assert mask is not None
708 not_mask = ~mask
709 y_embed = not_mask.cumsum(dim=1, dtype=torch.float32)
710 x_embed = not_mask.cumsum(dim=2, dtype=torch.float32)
711 if self.normalize:
712 eps = 1e-6
713 y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
714 x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
715
716 dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
717 pos_x = x_embed[:, :, :, None] / dim_t
718 pos_y = y_embed[:, :, :, None] / dim_t
719
720 pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
721 pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
722
723 return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
724
725
726 class MCLM(nn.Module):
727 def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
728 super(MCLM, self).__init__()
729 self.attention = nn.ModuleList([
730 nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
731 nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
732 nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
733 nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
734 nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
735 ])
736
737 self.linear1 = nn.Linear(d_model, d_model * 2)
738 self.linear2 = nn.Linear(d_model * 2, d_model)
739 self.linear3 = nn.Linear(d_model, d_model * 2)
740 self.linear4 = nn.Linear(d_model * 2, d_model)
741 self.norm1 = nn.LayerNorm(d_model)
742 self.norm2 = nn.LayerNorm(d_model)
743 self.dropout = nn.Dropout(0.1)
744 self.dropout1 = nn.Dropout(0.1)
745 self.dropout2 = nn.Dropout(0.1)
746 self.activation = get_activation_fn('gelu')
747 self.pool_ratios = pool_ratios
748 self.p_poses = []
749 self.g_pos = None
750 self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True)
751
752 def forward(self, l, g):
753 """
754 l: 4,c,h,w
755 g: 1,c,h,w
756 """
757 self.p_poses = []
758 self.g_pos = None
759 b, c, h, w = l.size()
760 # 4,c,h,w -> 1,c,2h,2w
761 concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
762
763 pools = []
764 for pool_ratio in self.pool_ratios:
765 # b,c,h,w
766 tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
767 pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw)
768 pools.append(rearrange(pool, 'b c h w -> (h w) b c'))
769 if self.g_pos is None:
770 pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3])
771 pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c')
772 self.p_poses.append(pos_emb)
773 pools = torch.cat(pools, 0)
774 if self.g_pos is None:
775 self.p_poses = torch.cat(self.p_poses, dim=0)
776 pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3])
777 self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c')
778
779 device = pools.device
780 self.p_poses = self.p_poses.to(device)
781 self.g_pos = self.g_pos.to(device)
782
783
784 # attention between glb (q) & multisensory concated-locs (k,v)
785 g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c')
786
787
788 g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0])
789 g_hw_b_c = self.norm1(g_hw_b_c)
790 g_hw_b_c = g_hw_b_c + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(g_hw_b_c)).clone())))
791 g_hw_b_c = self.norm2(g_hw_b_c)
792
793 # attention between origin locs (q) & freashed glb (k,v)
794 l_hw_b_c = rearrange(l, "b c h w -> (h w) b c")
795 _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w)
796 _g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2)
797 outputs_re = []
798 for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))):
799 outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c
800 outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c
801
802 l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
803 l_hw_b_c = self.norm1(l_hw_b_c)
804 l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
805 l_hw_b_c = self.norm2(l_hw_b_c)
806
807 l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
808 return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
809
810
811
812
813
814
815
816
817
818 class MCRM(nn.Module):
819 def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None):
820 super(MCRM, self).__init__()
821 self.attention = nn.ModuleList([
822 nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
823 nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
824 nn.MultiheadAttention(d_model, num_heads, dropout=0.1),
825 nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
826 ])
827 self.linear3 = nn.Linear(d_model, d_model * 2)
828 self.linear4 = nn.Linear(d_model * 2, d_model)
829 self.norm1 = nn.LayerNorm(d_model)
830 self.norm2 = nn.LayerNorm(d_model)
831 self.dropout = nn.Dropout(0.1)
832 self.dropout1 = nn.Dropout(0.1)
833 self.dropout2 = nn.Dropout(0.1)
834 self.sigmoid = nn.Sigmoid()
835 self.activation = get_activation_fn('gelu')
836 self.sal_conv = nn.Conv2d(d_model, 1, 1)
837 self.pool_ratios = pool_ratios
838
839 def forward(self, x):
840 device = x.device
841 b, c, h, w = x.size()
842 loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
843
844 patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
845
846 token_attention_map = self.sigmoid(self.sal_conv(glb))
847 token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
848 loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
849
850 pools = []
851 for pool_ratio in self.pool_ratios:
852 tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
853 pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
854 pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
855
856 pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
857 loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
858
859 outputs = []
860 for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
861 v = pools[i]
862 k = v
863 outputs.append(self.attention[i](q, k, v)[0])
864
865 outputs = torch.cat(outputs, 1)
866 src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
867 src = self.norm1(src)
868 src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
869 src = self.norm2(src)
870 src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
871 glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
872
873 return torch.cat((src, glb), 0), token_attention_map
874
875
876
877 class BEN_Base(nn.Module):
878 def __init__(self):
879 super().__init__()
880
881 self.backbone = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
882 emb_dim = 128
883 self.sideout5 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
884 self.sideout4 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
885 self.sideout3 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
886 self.sideout2 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
887 self.sideout1 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
888
889 self.output5 = make_cbr(1024, emb_dim)
890 self.output4 = make_cbr(512, emb_dim)
891 self.output3 = make_cbr(256, emb_dim)
892 self.output2 = make_cbr(128, emb_dim)
893 self.output1 = make_cbr(128, emb_dim)
894
895 self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8])
896 self.conv1 = make_cbr(emb_dim, emb_dim)
897 self.conv2 = make_cbr(emb_dim, emb_dim)
898 self.conv3 = make_cbr(emb_dim, emb_dim)
899 self.conv4 = make_cbr(emb_dim, emb_dim)
900 self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8])
901 self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8])
902 self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8])
903 self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8])
904
905 self.insmask_head = nn.Sequential(
906 nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1),
907 nn.InstanceNorm2d(384),
908 nn.GELU(),
909 nn.Conv2d(384, 384, kernel_size=3, padding=1),
910 nn.InstanceNorm2d(384),
911 nn.GELU(),
912 nn.Conv2d(384, emb_dim, kernel_size=3, padding=1)
913 )
914
915 self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1))
916 self.upsample1 = make_cbg(emb_dim, emb_dim)
917 self.upsample2 = make_cbg(emb_dim, emb_dim)
918 self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1))
919
920 for m in self.modules():
921 if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
922 m.inplace = True
923
924
925
926 @torch.inference_mode()
927 @torch.autocast(device_type="cuda",dtype=torch.float16)
928 def forward(self, x):
929 real_batch = x.size(0)
930
931 shallow_batch = self.shallow(x)
932 glb_batch = rescale_to(x, scale_factor=0.5, interpolation='bilinear')
933
934
935
936 final_input = None
937 for i in range(real_batch):
938 start = i * 4
939 end = (i + 1) * 4
940 loc_batch = image2patches(x[i,:,:,:].unsqueeze(dim=0))
941 input_ = torch.cat((loc_batch, glb_batch[i,:,:,:].unsqueeze(dim=0)), dim=0)
942
943
944 if final_input == None:
945 final_input= input_
946 else: final_input = torch.cat((final_input, input_), dim=0)
947
948 features = self.backbone(final_input)
949 outputs = []
950
951 for i in range(real_batch):
952
953 start = i * 5
954 end = (i + 1) * 5
955
956 f4 = features[4][start:end, :, :, :] # shape: [5, C, H, W]
957 f3 = features[3][start:end, :, :, :]
958 f2 = features[2][start:end, :, :, :]
959 f1 = features[1][start:end, :, :, :]
960 f0 = features[0][start:end, :, :, :]
961 e5 = self.output5(f4)
962 e4 = self.output4(f3)
963 e3 = self.output3(f2)
964 e2 = self.output2(f1)
965 e1 = self.output1(f0)
966 loc_e5, glb_e5 = e5.split([4, 1], dim=0)
967 e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
968
969
970 e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
971 e4 = self.conv4(e4)
972 e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
973 e3 = self.conv3(e3)
974 e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
975 e2 = self.conv2(e2)
976 e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1))
977 e1 = self.conv1(e1)
978
979 loc_e1, glb_e1 = e1.split([4, 1], dim=0)
980
981 output1_cat = patches2image(loc_e1) # (1,128,256,256)
982
983 # add glb feat in
984 output1_cat = output1_cat + resize_as(glb_e1, output1_cat)
985 # merge
986 final_output = self.insmask_head(output1_cat) # (1,128,256,256)
987 # shallow feature merge
988 shallow = shallow_batch[i,:,:,:].unsqueeze(dim=0)
989 final_output = final_output + resize_as(shallow, final_output)
990 final_output = self.upsample1(rescale_to(final_output))
991 final_output = rescale_to(final_output + resize_as(shallow, final_output))
992 final_output = self.upsample2(final_output)
993 final_output = self.output(final_output)
994 mask = final_output.sigmoid()
995 outputs.append(mask)
996
997 return torch.cat(outputs, dim=0)
998
999
1000
1001
1002 def loadcheckpoints(self,model_path):
1003 model_dict = torch.load(model_path, map_location="cpu", weights_only=True)
1004 self.load_state_dict(model_dict['model_state_dict'], strict=True)
1005 del model_path
1006
1007 def inference(self,image,refine_foreground=False):
1008
1009 set_random_seed(9)
1010 # image = ImageOps.exif_transpose(image)
1011 if isinstance(image, Image.Image):
1012 image, h, w,original_image = rgb_loader_refiner(image)
1013 if torch.cuda.is_available():
1014
1015 img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
1016 else:
1017 img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
1018
1019
1020 with torch.no_grad():
1021 res = self.forward(img_tensor)
1022
1023 # Show Results
1024 if refine_foreground == True:
1025
1026 pred_pil = transforms.ToPILImage()(res.squeeze())
1027 image_masked = refine_foreground_process(original_image, pred_pil)
1028
1029 image_masked.putalpha(pred_pil.resize(original_image.size))
1030 return image_masked
1031
1032 else:
1033 alpha = postprocess_image(res, im_size=[w,h])
1034 pred_pil = transforms.ToPILImage()(alpha)
1035 mask = pred_pil.resize(original_image.size)
1036 original_image.putalpha(mask)
1037 # mask = Image.fromarray(alpha)
1038
1039 return original_image
1040
1041
1042 else:
1043 foregrounds = []
1044 for batch in image:
1045 image, h, w,original_image = rgb_loader_refiner(batch)
1046 if torch.cuda.is_available():
1047
1048 img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
1049 else:
1050 img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
1051
1052 with torch.no_grad():
1053 res = self.forward(img_tensor)
1054
1055 if refine_foreground == True:
1056
1057 pred_pil = transforms.ToPILImage()(res.squeeze())
1058 image_masked = refine_foreground_process(original_image, pred_pil)
1059
1060 image_masked.putalpha(pred_pil.resize(original_image.size))
1061
1062 foregrounds.append(image_masked)
1063 else:
1064 alpha = postprocess_image(res, im_size=[w,h])
1065 pred_pil = transforms.ToPILImage()(alpha)
1066 mask = pred_pil.resize(original_image.size)
1067 original_image.putalpha(mask)
1068 # mask = Image.fromarray(alpha)
1069 foregrounds.append(original_image)
1070
1071 return foregrounds
1072
1073
1074
1075
1076 def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1, print_frames_processed=True, webm = False, rgb_value= (0, 255, 0)):
1077
1078 """
1079 Segments the given video to extract the foreground (with alpha) from each frame
1080 and saves the result as either a WebM video (with alpha channel) or MP4 (with a
1081 color background).
1082
1083 Args:
1084 video_path (str):
1085 Path to the input video file.
1086
1087 output_path (str, optional):
1088 Directory (or full path) where the output video and/or files will be saved.
1089 Defaults to "./".
1090
1091 fps (int, optional):
1092 The frames per second (FPS) to use for the output video. If 0 (default), the
1093 original FPS of the input video is used. Otherwise, overrides it.
1094
1095 refine_foreground (bool, optional):
1096 Whether to run an additional “refine foreground” process on each frame.
1097 Defaults to False.
1098
1099 batch (int, optional):
1100 Number of frames to process at once (inference batch size). Large batch sizes
1101 may require more GPU memory. Defaults to 1.
1102
1103 print_frames_processed (bool, optional):
1104 If True (default), prints progress (how many frames have been processed) to
1105 the console.
1106
1107 webm (bool, optional):
1108 If True (default), exports a WebM video with alpha channel (VP9 / yuva420p).
1109 If False, exports an MP4 video composited over a solid color background.
1110
1111 rgb_value (tuple, optional):
1112 The RGB background color (e.g., green screen) used to composite frames when
1113 saving to MP4. Defaults to (0, 255, 0).
1114
1115 Returns:
1116 None. Writes the output video(s) to disk in the specified format.
1117 """
1118
1119
1120 cap = cv2.VideoCapture(video_path)
1121 if not cap.isOpened():
1122 raise IOError(f"Cannot open video: {video_path}")
1123
1124 original_fps = cap.get(cv2.CAP_PROP_FPS)
1125 original_fps = 30 if original_fps == 0 else original_fps
1126 fps = original_fps if fps == 0 else fps
1127
1128 ret, first_frame = cap.read()
1129 if not ret:
1130 raise ValueError("No frames found in the video.")
1131 height, width = first_frame.shape[:2]
1132 cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
1133
1134 foregrounds = []
1135 frame_idx = 0
1136 processed_count = 0
1137 batch_frames = []
1138 total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
1139
1140 while True:
1141 ret, frame = cap.read()
1142 if not ret:
1143 if batch_frames:
1144 batch_results = self.inference(batch_frames, refine_foreground)
1145 if isinstance(batch_results, Image.Image):
1146 foregrounds.append(batch_results)
1147 else:
1148 foregrounds.extend(batch_results)
1149 if print_frames_processed:
1150 print(f"Processed frames {frame_idx-len(batch_frames)+1} to {frame_idx} of {total_frames}")
1151 break
1152
1153 # Process every frame instead of using intervals
1154 frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
1155 pil_frame = Image.fromarray(frame_rgb)
1156 batch_frames.append(pil_frame)
1157
1158 if len(batch_frames) == batch:
1159 batch_results = self.inference(batch_frames, refine_foreground)
1160 if isinstance(batch_results, Image.Image):
1161 foregrounds.append(batch_results)
1162 else:
1163 foregrounds.extend(batch_results)
1164 if print_frames_processed:
1165 print(f"Processed frames {frame_idx-batch+1} to {frame_idx} of {total_frames}")
1166 batch_frames = []
1167 processed_count += batch
1168
1169 frame_idx += 1
1170
1171
1172 if webm:
1173 alpha_webm_path = os.path.join(output_path, "foreground.webm")
1174 pil_images_to_webm_alpha(foregrounds, alpha_webm_path, fps=original_fps)
1175
1176 else:
1177 cap.release()
1178 fg_output = os.path.join(output_path, 'foreground.mp4')
1179
1180 pil_images_to_mp4(foregrounds, fg_output, fps=original_fps,rgb_value=rgb_value)
1181 cv2.destroyAllWindows()
1182
1183 try:
1184 fg_audio_output = os.path.join(output_path, 'foreground_output_with_audio.mp4')
1185 add_audio_to_video(fg_output, video_path, fg_audio_output)
1186 except Exception as e:
1187 print("No audio found in the original video")
1188 print(e)
1189
1190
1191
1192
1193
1194 def rgb_loader_refiner( original_image):
1195 h, w = original_image.size
1196
1197 image = original_image
1198 # Convert to RGB if necessary
1199 if image.mode != 'RGB':
1200 image = image.convert('RGB')
1201
1202 # Resize the image
1203 image = image.resize((1024, 1024), resample=Image.LANCZOS)
1204
1205 return image.convert('RGB'), h, w,original_image
1206
1207 # Define the image transformation
1208 img_transform = transforms.Compose([
1209 transforms.ToTensor(),
1210 transforms.ConvertImageDtype(torch.float16),
1211 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1212 ])
1213
1214 img_transform32 = transforms.Compose([
1215 transforms.ToTensor(),
1216 transforms.ConvertImageDtype(torch.float32),
1217 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1218 ])
1219
1220
1221
1222
1223
1224 def pil_images_to_mp4(images, output_path, fps=24, rgb_value=(0, 255, 0)):
1225 """
1226 Converts an array of PIL images to an MP4 video.
1227
1228 Args:
1229 images: List of PIL images
1230 output_path: Path to save the MP4 file
1231 fps: Frames per second (default: 24)
1232 rgb_value: Background RGB color tuple (default: green (0, 255, 0))
1233 """
1234 if not images:
1235 raise ValueError("No images provided to convert to MP4.")
1236
1237 width, height = images[0].size
1238 fourcc = cv2.VideoWriter_fourcc(*'mp4v')
1239 video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
1240
1241 for image in images:
1242 # If image has alpha channel, composite onto the specified background color
1243 if image.mode == 'RGBA':
1244 # Create background image with specified RGB color
1245 background = Image.new('RGB', image.size, rgb_value)
1246 background = background.convert('RGBA')
1247 # Composite the image onto the background
1248 image = Image.alpha_composite(background, image)
1249 image = image.convert('RGB')
1250 else:
1251 # Ensure RGB format for non-alpha images
1252 image = image.convert('RGB')
1253
1254 # Convert to OpenCV format and write
1255 open_cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
1256 video_writer.write(open_cv_image)
1257
1258 video_writer.release()
1259
1260 def pil_images_to_webm_alpha(images, output_path, fps=30):
1261 """
1262 Converts a list of PIL RGBA images to a VP9 .webm video with alpha channel.
1263
1264 NOTE: Not all players will display alpha in WebM.
1265 Browsers like Chrome/Firefox typically do support VP9 alpha.
1266 """
1267 if not images:
1268 raise ValueError("No images provided for WebM with alpha.")
1269
1270 # Ensure output directory exists
1271 os.makedirs(os.path.dirname(output_path), exist_ok=True)
1272
1273 with tempfile.TemporaryDirectory() as tmpdir:
1274 # Save frames as PNG (with alpha)
1275 for idx, img in enumerate(images):
1276 if img.mode != "RGBA":
1277 img = img.convert("RGBA")
1278 out_path = os.path.join(tmpdir, f"{idx:06d}.png")
1279 img.save(out_path, "PNG")
1280
1281 # Construct ffmpeg command
1282 # -c:v libvpx-vp9 => VP9 encoder
1283 # -pix_fmt yuva420p => alpha-enabled pixel format
1284 # -auto-alt-ref 0 => helps preserve alpha frames (libvpx quirk)
1285 ffmpeg_cmd = [
1286 "ffmpeg", "-y",
1287 "-framerate", str(fps),
1288 "-i", os.path.join(tmpdir, "%06d.png"),
1289 "-c:v", "libvpx-vp9",
1290 "-pix_fmt", "yuva420p",
1291 "-auto-alt-ref", "0",
1292 output_path
1293 ]
1294
1295 subprocess.run(ffmpeg_cmd, check=True)
1296
1297 print(f"WebM with alpha saved to {output_path}")
1298
1299 def add_audio_to_video(video_without_audio_path, original_video_path, output_path):
1300 """
1301 Check if the original video has an audio stream. If yes, add it. If not, skip.
1302 """
1303 # 1) Probe original video for audio streams
1304 probe_command = [
1305 'ffprobe', '-v', 'error',
1306 '-select_streams', 'a:0',
1307 '-show_entries', 'stream=index',
1308 '-of', 'csv=p=0',
1309 original_video_path
1310 ]
1311 result = subprocess.run(probe_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
1312
1313 # result.stdout is empty if no audio stream found
1314 if not result.stdout.strip():
1315 print("No audio track found in original video, skipping audio addition.")
1316 return
1317
1318 print("Audio track detected; proceeding to mux audio.")
1319 # 2) If audio found, run ffmpeg to add it
1320 command = [
1321 'ffmpeg', '-y',
1322 '-i', video_without_audio_path,
1323 '-i', original_video_path,
1324 '-c', 'copy',
1325 '-map', '0:v:0',
1326 '-map', '1:a:0', # we know there's an audio track now
1327 output_path
1328 ]
1329 subprocess.run(command, check=True)
1330 print(f"Audio added successfully => {output_path}")
1331
1332
1333
1334
1335
1336 ### Thanks to the source: https://huggingface.co/ZhengPeng7/BiRefNet/blob/main/handler.py
1337 def refine_foreground_process(image, mask, r=90):
1338 if mask.size != image.size:
1339 mask = mask.resize(image.size)
1340 image = np.array(image) / 255.0
1341 mask = np.array(mask) / 255.0
1342 estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
1343 image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
1344 return image_masked
1345
1346
1347 def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
1348 # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
1349 alpha = alpha[:, :, None]
1350 F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
1351 return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
1352
1353
1354 def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
1355 if isinstance(image, Image.Image):
1356 image = np.array(image) / 255.0
1357 blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
1358
1359 blurred_FA = cv2.blur(F * alpha, (r, r))
1360 blurred_F = blurred_FA / (blurred_alpha + 1e-5)
1361
1362 blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
1363 blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
1364 F = blurred_F + alpha * \
1365 (image - alpha * blurred_F - (1 - alpha) * blurred_B)
1366 F = np.clip(F, 0, 1)
1367 return F, blurred_B
1368
1369
1370
1371 def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
1372 result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
1373 ma = torch.max(result)
1374 mi = torch.min(result)
1375 result = (result - mi) / (ma - mi)
1376 im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
1377 im_array = np.squeeze(im_array)
1378 return im_array
1379
1380
1381
1382
1383 def rgb_loader_refiner( original_image):
1384 h, w = original_image.size
1385 # # Apply EXIF orientation
1386
1387 image = ImageOps.exif_transpose(original_image)
1388
1389 if original_image.mode != 'RGB':
1390 original_image = original_image.convert('RGB')
1391
1392 image = original_image
1393 # Convert to RGB if necessary
1394
1395 # Resize the image
1396 image = image.resize((1024, 1024), resample=Image.LANCZOS)
1397
1398 return image, h, w,original_image
1399
1400
1401
1402