birefnet.py
| 1 | ### config.py |
| 2 | |
| 3 | import os |
| 4 | import math |
| 5 | from transformers import PretrainedConfig |
| 6 | |
| 7 | |
| 8 | class Config(PretrainedConfig): |
| 9 | def __init__(self) -> None: |
| 10 | # Compatible with the latest version of transformers |
| 11 | super().__init__() |
| 12 | |
| 13 | # PATH settings |
| 14 | self.sys_home_dir = os.path.expanduser('~') # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx |
| 15 | |
| 16 | # TASK settings |
| 17 | self.task = ['DIS5K', 'COD', 'HRSOD', 'DIS5K+HRSOD+HRS10K', 'P3M-10k'][0] |
| 18 | self.training_set = { |
| 19 | 'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0], |
| 20 | 'COD': 'TR-COD10K+TR-CAMO', |
| 21 | 'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5], |
| 22 | 'DIS5K+HRSOD+HRS10K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TE-HRS10K+TE-HRSOD+TE-UHRSD+TR-HRS10K+TR-HRSOD+TR-UHRSD', # leave DIS-VD for evaluation. |
| 23 | 'P3M-10k': 'TR-P3M-10k', |
| 24 | }[self.task] |
| 25 | self.prompt4loc = ['dense', 'sparse'][0] |
| 26 | |
| 27 | # Faster-Training settings |
| 28 | self.load_all = True |
| 29 | self.compile = True # 1. Trigger CPU memory leak in some extend, which is an inherent problem of PyTorch. |
| 30 | # Machines with > 70GB CPU memory can run the whole training on DIS5K with default setting. |
| 31 | # 2. Higher PyTorch version may fix it: https://github.com/pytorch/pytorch/issues/119607. |
| 32 | # 3. But compile in Pytorch > 2.0.1 seems to bring no acceleration for training. |
| 33 | self.precisionHigh = True |
| 34 | |
| 35 | # MODEL settings |
| 36 | self.ms_supervision = True |
| 37 | self.out_ref = self.ms_supervision and True |
| 38 | self.dec_ipt = True |
| 39 | self.dec_ipt_split = True |
| 40 | self.cxt_num = [0, 3][1] # multi-scale skip connections from encoder |
| 41 | self.mul_scl_ipt = ['', 'add', 'cat'][2] |
| 42 | self.dec_att = ['', 'ASPP', 'ASPPDeformable'][2] |
| 43 | self.squeeze_block = ['', 'BasicDecBlk_x1', 'ResBlk_x4', 'ASPP_x3', 'ASPPDeformable_x3'][1] |
| 44 | self.dec_blk = ['BasicDecBlk', 'ResBlk', 'HierarAttDecBlk'][0] |
| 45 | |
| 46 | # TRAINING settings |
| 47 | self.batch_size = 4 |
| 48 | self.IoU_finetune_last_epochs = [ |
| 49 | 0, |
| 50 | { |
| 51 | 'DIS5K': -50, |
| 52 | 'COD': -20, |
| 53 | 'HRSOD': -20, |
| 54 | 'DIS5K+HRSOD+HRS10K': -20, |
| 55 | 'P3M-10k': -20, |
| 56 | }[self.task] |
| 57 | ][1] # choose 0 to skip |
| 58 | self.lr = (1e-4 if 'DIS5K' in self.task else 1e-5) * math.sqrt(self.batch_size / 4) # DIS needs high lr to converge faster. Adapt the lr linearly |
| 59 | self.size = 1024 |
| 60 | self.num_workers = max(4, self.batch_size) # will be decrease to min(it, batch_size) at the initialization of the data_loader |
| 61 | |
| 62 | # Backbone settings |
| 63 | self.bb = [ |
| 64 | 'vgg16', 'vgg16bn', 'resnet50', # 0, 1, 2 |
| 65 | 'swin_v1_t', 'swin_v1_s', # 3, 4 |
| 66 | 'swin_v1_b', 'swin_v1_l', # 5-bs9, 6-bs4 |
| 67 | 'pvt_v2_b0', 'pvt_v2_b1', # 7, 8 |
| 68 | 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5 |
| 69 | ][6] |
| 70 | self.lateral_channels_in_collection = { |
| 71 | 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64], |
| 72 | 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64], |
| 73 | 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192], |
| 74 | 'swin_v1_t': [768, 384, 192, 96], 'swin_v1_s': [768, 384, 192, 96], |
| 75 | 'pvt_v2_b0': [256, 160, 64, 32], 'pvt_v2_b1': [512, 320, 128, 64], |
| 76 | }[self.bb] |
| 77 | if self.mul_scl_ipt == 'cat': |
| 78 | self.lateral_channels_in_collection = [channel * 2 for channel in self.lateral_channels_in_collection] |
| 79 | self.cxt = self.lateral_channels_in_collection[1:][::-1][-self.cxt_num:] if self.cxt_num else [] |
| 80 | |
| 81 | # MODEL settings - inactive |
| 82 | self.lat_blk = ['BasicLatBlk'][0] |
| 83 | self.dec_channels_inter = ['fixed', 'adap'][0] |
| 84 | self.refine = ['', 'itself', 'RefUNet', 'Refiner', 'RefinerPVTInChannels4'][0] |
| 85 | self.progressive_ref = self.refine and True |
| 86 | self.ender = self.progressive_ref and False |
| 87 | self.scale = self.progressive_ref and 2 |
| 88 | self.auxiliary_classification = False # Only for DIS5K, where class labels are saved in `dataset.py`. |
| 89 | self.refine_iteration = 1 |
| 90 | self.freeze_bb = False |
| 91 | self.model = [ |
| 92 | 'BiRefNet', |
| 93 | ][0] |
| 94 | if self.dec_blk == 'HierarAttDecBlk': |
| 95 | self.batch_size = 2 ** [0, 1, 2, 3, 4][2] |
| 96 | |
| 97 | # TRAINING settings - inactive |
| 98 | self.preproc_methods = ['flip', 'enhance', 'rotate', 'pepper', 'crop'][:4] |
| 99 | self.optimizer = ['Adam', 'AdamW'][1] |
| 100 | self.lr_decay_epochs = [1e5] # Set to negative N to decay the lr in the last N-th epoch. |
| 101 | self.lr_decay_rate = 0.5 |
| 102 | # Loss |
| 103 | self.lambdas_pix_last = { |
| 104 | # not 0 means opening this loss |
| 105 | # original rate -- 1 : 30 : 1.5 : 0.2, bce x 30 |
| 106 | 'bce': 30 * 1, # high performance |
| 107 | 'iou': 0.5 * 1, # 0 / 255 |
| 108 | 'iou_patch': 0.5 * 0, # 0 / 255, win_size = (64, 64) |
| 109 | 'mse': 150 * 0, # can smooth the saliency map |
| 110 | 'triplet': 3 * 0, |
| 111 | 'reg': 100 * 0, |
| 112 | 'ssim': 10 * 1, # help contours, |
| 113 | 'cnt': 5 * 0, # help contours |
| 114 | 'structure': 5 * 0, # structure loss from codes of MVANet. A little improvement on DIS-TE[1,2,3], a bit more decrease on DIS-TE4. |
| 115 | } |
| 116 | self.lambdas_cls = { |
| 117 | 'ce': 5.0 |
| 118 | } |
| 119 | # Adv |
| 120 | self.lambda_adv_g = 10. * 0 # turn to 0 to avoid adv training |
| 121 | self.lambda_adv_d = 3. * (self.lambda_adv_g > 0) |
| 122 | |
| 123 | # PATH settings - inactive |
| 124 | self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis') |
| 125 | self.weights_root_dir = os.path.join(self.sys_home_dir, 'weights') |
| 126 | self.weights = { |
| 127 | 'pvt_v2_b2': os.path.join(self.weights_root_dir, 'pvt_v2_b2.pth'), |
| 128 | 'pvt_v2_b5': os.path.join(self.weights_root_dir, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]), |
| 129 | 'swin_v1_b': os.path.join(self.weights_root_dir, ['swin_base_patch4_window12_384_22kto1k.pth', 'swin_base_patch4_window12_384_22k.pth'][0]), |
| 130 | 'swin_v1_l': os.path.join(self.weights_root_dir, ['swin_large_patch4_window12_384_22kto1k.pth', 'swin_large_patch4_window12_384_22k.pth'][0]), |
| 131 | 'swin_v1_t': os.path.join(self.weights_root_dir, ['swin_tiny_patch4_window7_224_22kto1k_finetune.pth'][0]), |
| 132 | 'swin_v1_s': os.path.join(self.weights_root_dir, ['swin_small_patch4_window7_224_22kto1k_finetune.pth'][0]), |
| 133 | 'pvt_v2_b0': os.path.join(self.weights_root_dir, ['pvt_v2_b0.pth'][0]), |
| 134 | 'pvt_v2_b1': os.path.join(self.weights_root_dir, ['pvt_v2_b1.pth'][0]), |
| 135 | } |
| 136 | |
| 137 | # Callbacks - inactive |
| 138 | self.verbose_eval = True |
| 139 | self.only_S_MAE = False |
| 140 | self.use_fp16 = False # Bugs. It may cause nan in training. |
| 141 | self.SDPA_enabled = False # Bugs. Slower and errors occur in multi-GPUs |
| 142 | |
| 143 | # others |
| 144 | self.device = [0, 'cpu'][0] # .to(0) == .to('cuda:0') |
| 145 | |
| 146 | self.batch_size_valid = 1 |
| 147 | self.rand_seed = 7 |
| 148 | # run_sh_file = [f for f in os.listdir('.') if 'train.sh' == f] + [os.path.join('..', f) for f in os.listdir('..') if 'train.sh' == f] |
| 149 | # with open(run_sh_file[0], 'r') as f: |
| 150 | # lines = f.readlines() |
| 151 | # self.save_last = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0]) |
| 152 | # self.save_step = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'step=' in l][0].split('step=')[-1].split()[0]) |
| 153 | # self.val_step = [0, self.save_step][0] |
| 154 | |
| 155 | def print_task(self) -> None: |
| 156 | # Return task for choosing settings in shell scripts. |
| 157 | print(self.task) |
| 158 | |
| 159 | |
| 160 | |
| 161 | ### models/backbones/pvt_v2.py |
| 162 | |
| 163 | import torch |
| 164 | import torch.nn as nn |
| 165 | from functools import partial |
| 166 | |
| 167 | from timm.layers import DropPath, to_2tuple, trunc_normal_ |
| 168 | |
| 169 | |
| 170 | import math |
| 171 | |
| 172 | # from config import Config |
| 173 | |
| 174 | # config = Config() |
| 175 | |
| 176 | class Mlp(nn.Module): |
| 177 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
| 178 | super().__init__() |
| 179 | out_features = out_features or in_features |
| 180 | hidden_features = hidden_features or in_features |
| 181 | self.fc1 = nn.Linear(in_features, hidden_features) |
| 182 | self.dwconv = DWConv(hidden_features) |
| 183 | self.act = act_layer() |
| 184 | self.fc2 = nn.Linear(hidden_features, out_features) |
| 185 | self.drop = nn.Dropout(drop) |
| 186 | |
| 187 | self.apply(self._init_weights) |
| 188 | |
| 189 | def _init_weights(self, m): |
| 190 | if isinstance(m, nn.Linear): |
| 191 | trunc_normal_(m.weight, std=.02) |
| 192 | if isinstance(m, nn.Linear) and m.bias is not None: |
| 193 | nn.init.constant_(m.bias, 0) |
| 194 | elif isinstance(m, nn.LayerNorm): |
| 195 | nn.init.constant_(m.bias, 0) |
| 196 | nn.init.constant_(m.weight, 1.0) |
| 197 | elif isinstance(m, nn.Conv2d): |
| 198 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
| 199 | fan_out //= m.groups |
| 200 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) |
| 201 | if m.bias is not None: |
| 202 | m.bias.data.zero_() |
| 203 | |
| 204 | def forward(self, x, H, W): |
| 205 | x = self.fc1(x) |
| 206 | x = self.dwconv(x, H, W) |
| 207 | x = self.act(x) |
| 208 | x = self.drop(x) |
| 209 | x = self.fc2(x) |
| 210 | x = self.drop(x) |
| 211 | return x |
| 212 | |
| 213 | |
| 214 | class Attention(nn.Module): |
| 215 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): |
| 216 | super().__init__() |
| 217 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." |
| 218 | |
| 219 | self.dim = dim |
| 220 | self.num_heads = num_heads |
| 221 | head_dim = dim // num_heads |
| 222 | self.scale = qk_scale or head_dim ** -0.5 |
| 223 | |
| 224 | self.q = nn.Linear(dim, dim, bias=qkv_bias) |
| 225 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) |
| 226 | self.attn_drop_prob = attn_drop |
| 227 | self.attn_drop = nn.Dropout(attn_drop) |
| 228 | self.proj = nn.Linear(dim, dim) |
| 229 | self.proj_drop = nn.Dropout(proj_drop) |
| 230 | |
| 231 | self.sr_ratio = sr_ratio |
| 232 | if sr_ratio > 1: |
| 233 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) |
| 234 | self.norm = nn.LayerNorm(dim) |
| 235 | |
| 236 | self.apply(self._init_weights) |
| 237 | |
| 238 | def _init_weights(self, m): |
| 239 | if isinstance(m, nn.Linear): |
| 240 | trunc_normal_(m.weight, std=.02) |
| 241 | if isinstance(m, nn.Linear) and m.bias is not None: |
| 242 | nn.init.constant_(m.bias, 0) |
| 243 | elif isinstance(m, nn.LayerNorm): |
| 244 | nn.init.constant_(m.bias, 0) |
| 245 | nn.init.constant_(m.weight, 1.0) |
| 246 | elif isinstance(m, nn.Conv2d): |
| 247 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
| 248 | fan_out //= m.groups |
| 249 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) |
| 250 | if m.bias is not None: |
| 251 | m.bias.data.zero_() |
| 252 | |
| 253 | def forward(self, x, H, W): |
| 254 | B, N, C = x.shape |
| 255 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) |
| 256 | |
| 257 | if self.sr_ratio > 1: |
| 258 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) |
| 259 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) |
| 260 | x_ = self.norm(x_) |
| 261 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| 262 | else: |
| 263 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| 264 | k, v = kv[0], kv[1] |
| 265 | |
| 266 | if config.SDPA_enabled: |
| 267 | x = torch.nn.functional.scaled_dot_product_attention( |
| 268 | q, k, v, |
| 269 | attn_mask=None, dropout_p=self.attn_drop_prob, is_causal=False |
| 270 | ).transpose(1, 2).reshape(B, N, C) |
| 271 | else: |
| 272 | attn = (q @ k.transpose(-2, -1)) * self.scale |
| 273 | attn = attn.softmax(dim=-1) |
| 274 | attn = self.attn_drop(attn) |
| 275 | |
| 276 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| 277 | x = self.proj(x) |
| 278 | x = self.proj_drop(x) |
| 279 | |
| 280 | return x |
| 281 | |
| 282 | |
| 283 | class Block(nn.Module): |
| 284 | |
| 285 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., |
| 286 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): |
| 287 | super().__init__() |
| 288 | self.norm1 = norm_layer(dim) |
| 289 | self.attn = Attention( |
| 290 | dim, |
| 291 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, |
| 292 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) |
| 293 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here |
| 294 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| 295 | self.norm2 = norm_layer(dim) |
| 296 | mlp_hidden_dim = int(dim * mlp_ratio) |
| 297 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
| 298 | |
| 299 | self.apply(self._init_weights) |
| 300 | |
| 301 | def _init_weights(self, m): |
| 302 | if isinstance(m, nn.Linear): |
| 303 | trunc_normal_(m.weight, std=.02) |
| 304 | if isinstance(m, nn.Linear) and m.bias is not None: |
| 305 | nn.init.constant_(m.bias, 0) |
| 306 | elif isinstance(m, nn.LayerNorm): |
| 307 | nn.init.constant_(m.bias, 0) |
| 308 | nn.init.constant_(m.weight, 1.0) |
| 309 | elif isinstance(m, nn.Conv2d): |
| 310 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
| 311 | fan_out //= m.groups |
| 312 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) |
| 313 | if m.bias is not None: |
| 314 | m.bias.data.zero_() |
| 315 | |
| 316 | def forward(self, x, H, W): |
| 317 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) |
| 318 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) |
| 319 | |
| 320 | return x |
| 321 | |
| 322 | |
| 323 | class OverlapPatchEmbed(nn.Module): |
| 324 | """ Image to Patch Embedding |
| 325 | """ |
| 326 | |
| 327 | def __init__(self, img_size=224, patch_size=7, stride=4, in_channels=3, embed_dim=768): |
| 328 | super().__init__() |
| 329 | img_size = to_2tuple(img_size) |
| 330 | patch_size = to_2tuple(patch_size) |
| 331 | |
| 332 | self.img_size = img_size |
| 333 | self.patch_size = patch_size |
| 334 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] |
| 335 | self.num_patches = self.H * self.W |
| 336 | self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride, |
| 337 | padding=(patch_size[0] // 2, patch_size[1] // 2)) |
| 338 | self.norm = nn.LayerNorm(embed_dim) |
| 339 | |
| 340 | self.apply(self._init_weights) |
| 341 | |
| 342 | def _init_weights(self, m): |
| 343 | if isinstance(m, nn.Linear): |
| 344 | trunc_normal_(m.weight, std=.02) |
| 345 | if isinstance(m, nn.Linear) and m.bias is not None: |
| 346 | nn.init.constant_(m.bias, 0) |
| 347 | elif isinstance(m, nn.LayerNorm): |
| 348 | nn.init.constant_(m.bias, 0) |
| 349 | nn.init.constant_(m.weight, 1.0) |
| 350 | elif isinstance(m, nn.Conv2d): |
| 351 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
| 352 | fan_out //= m.groups |
| 353 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) |
| 354 | if m.bias is not None: |
| 355 | m.bias.data.zero_() |
| 356 | |
| 357 | def forward(self, x): |
| 358 | x = self.proj(x) |
| 359 | _, _, H, W = x.shape |
| 360 | x = x.flatten(2).transpose(1, 2) |
| 361 | x = self.norm(x) |
| 362 | |
| 363 | return x, H, W |
| 364 | |
| 365 | |
| 366 | class PyramidVisionTransformerImpr(nn.Module): |
| 367 | def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dims=[64, 128, 256, 512], |
| 368 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., |
| 369 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, |
| 370 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): |
| 371 | super().__init__() |
| 372 | self.num_classes = num_classes |
| 373 | self.depths = depths |
| 374 | |
| 375 | # patch_embed |
| 376 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_channels=in_channels, |
| 377 | embed_dim=embed_dims[0]) |
| 378 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_channels=embed_dims[0], |
| 379 | embed_dim=embed_dims[1]) |
| 380 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_channels=embed_dims[1], |
| 381 | embed_dim=embed_dims[2]) |
| 382 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_channels=embed_dims[2], |
| 383 | embed_dim=embed_dims[3]) |
| 384 | |
| 385 | # transformer encoder |
| 386 | dpr = np.linspace(0, drop_path_rate, sum(depths)).tolist() # stochastic depth decay rule |
| 387 | cur = 0 |
| 388 | self.block1 = nn.ModuleList([Block( |
| 389 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, |
| 390 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, |
| 391 | sr_ratio=sr_ratios[0]) |
| 392 | for i in range(depths[0])]) |
| 393 | self.norm1 = norm_layer(embed_dims[0]) |
| 394 | |
| 395 | cur += depths[0] |
| 396 | self.block2 = nn.ModuleList([Block( |
| 397 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, |
| 398 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, |
| 399 | sr_ratio=sr_ratios[1]) |
| 400 | for i in range(depths[1])]) |
| 401 | self.norm2 = norm_layer(embed_dims[1]) |
| 402 | |
| 403 | cur += depths[1] |
| 404 | self.block3 = nn.ModuleList([Block( |
| 405 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, |
| 406 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, |
| 407 | sr_ratio=sr_ratios[2]) |
| 408 | for i in range(depths[2])]) |
| 409 | self.norm3 = norm_layer(embed_dims[2]) |
| 410 | |
| 411 | cur += depths[2] |
| 412 | self.block4 = nn.ModuleList([Block( |
| 413 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, |
| 414 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, |
| 415 | sr_ratio=sr_ratios[3]) |
| 416 | for i in range(depths[3])]) |
| 417 | self.norm4 = norm_layer(embed_dims[3]) |
| 418 | |
| 419 | # classification head |
| 420 | # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() |
| 421 | |
| 422 | self.apply(self._init_weights) |
| 423 | |
| 424 | def _init_weights(self, m): |
| 425 | if isinstance(m, nn.Linear): |
| 426 | trunc_normal_(m.weight, std=.02) |
| 427 | if isinstance(m, nn.Linear) and m.bias is not None: |
| 428 | nn.init.constant_(m.bias, 0) |
| 429 | elif isinstance(m, nn.LayerNorm): |
| 430 | nn.init.constant_(m.bias, 0) |
| 431 | nn.init.constant_(m.weight, 1.0) |
| 432 | elif isinstance(m, nn.Conv2d): |
| 433 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
| 434 | fan_out //= m.groups |
| 435 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) |
| 436 | if m.bias is not None: |
| 437 | m.bias.data.zero_() |
| 438 | |
| 439 | def init_weights(self, pretrained=None): |
| 440 | if isinstance(pretrained, str): |
| 441 | logger = 1 |
| 442 | #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) |
| 443 | |
| 444 | def reset_drop_path(self, drop_path_rate): |
| 445 | dpr = np.linspace(0, drop_path_rate, sum(self.depths)).tolist() |
| 446 | cur = 0 |
| 447 | for i in range(self.depths[0]): |
| 448 | self.block1[i].drop_path.drop_prob = dpr[cur + i] |
| 449 | |
| 450 | cur += self.depths[0] |
| 451 | for i in range(self.depths[1]): |
| 452 | self.block2[i].drop_path.drop_prob = dpr[cur + i] |
| 453 | |
| 454 | cur += self.depths[1] |
| 455 | for i in range(self.depths[2]): |
| 456 | self.block3[i].drop_path.drop_prob = dpr[cur + i] |
| 457 | |
| 458 | cur += self.depths[2] |
| 459 | for i in range(self.depths[3]): |
| 460 | self.block4[i].drop_path.drop_prob = dpr[cur + i] |
| 461 | |
| 462 | def freeze_patch_emb(self): |
| 463 | self.patch_embed1.requires_grad = False |
| 464 | |
| 465 | @torch.jit.ignore |
| 466 | def no_weight_decay(self): |
| 467 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better |
| 468 | |
| 469 | def get_classifier(self): |
| 470 | return self.head |
| 471 | |
| 472 | def reset_classifier(self, num_classes, global_pool=''): |
| 473 | self.num_classes = num_classes |
| 474 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
| 475 | |
| 476 | def forward_features(self, x): |
| 477 | B = x.shape[0] |
| 478 | outs = [] |
| 479 | |
| 480 | # stage 1 |
| 481 | x, H, W = self.patch_embed1(x) |
| 482 | for i, blk in enumerate(self.block1): |
| 483 | x = blk(x, H, W) |
| 484 | x = self.norm1(x) |
| 485 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() |
| 486 | outs.append(x) |
| 487 | |
| 488 | # stage 2 |
| 489 | x, H, W = self.patch_embed2(x) |
| 490 | for i, blk in enumerate(self.block2): |
| 491 | x = blk(x, H, W) |
| 492 | x = self.norm2(x) |
| 493 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() |
| 494 | outs.append(x) |
| 495 | |
| 496 | # stage 3 |
| 497 | x, H, W = self.patch_embed3(x) |
| 498 | for i, blk in enumerate(self.block3): |
| 499 | x = blk(x, H, W) |
| 500 | x = self.norm3(x) |
| 501 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() |
| 502 | outs.append(x) |
| 503 | |
| 504 | # stage 4 |
| 505 | x, H, W = self.patch_embed4(x) |
| 506 | for i, blk in enumerate(self.block4): |
| 507 | x = blk(x, H, W) |
| 508 | x = self.norm4(x) |
| 509 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() |
| 510 | outs.append(x) |
| 511 | |
| 512 | return outs |
| 513 | |
| 514 | # return x.mean(dim=1) |
| 515 | |
| 516 | def forward(self, x): |
| 517 | x = self.forward_features(x) |
| 518 | # x = self.head(x) |
| 519 | |
| 520 | return x |
| 521 | |
| 522 | |
| 523 | class DWConv(nn.Module): |
| 524 | def __init__(self, dim=768): |
| 525 | super(DWConv, self).__init__() |
| 526 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) |
| 527 | |
| 528 | def forward(self, x, H, W): |
| 529 | B, N, C = x.shape |
| 530 | x = x.transpose(1, 2).view(B, C, H, W).contiguous() |
| 531 | x = self.dwconv(x) |
| 532 | x = x.flatten(2).transpose(1, 2) |
| 533 | |
| 534 | return x |
| 535 | |
| 536 | |
| 537 | def _conv_filter(state_dict, patch_size=16): |
| 538 | """ convert patch embedding weight from manual patchify + linear proj to conv""" |
| 539 | out_dict = {} |
| 540 | for k, v in state_dict.items(): |
| 541 | if 'patch_embed.proj.weight' in k: |
| 542 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) |
| 543 | out_dict[k] = v |
| 544 | |
| 545 | return out_dict |
| 546 | |
| 547 | |
| 548 | class pvt_v2_b0(PyramidVisionTransformerImpr): |
| 549 | def __init__(self, **kwargs): |
| 550 | super(pvt_v2_b0, self).__init__( |
| 551 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], |
| 552 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], |
| 553 | drop_rate=0.0, drop_path_rate=0.1) |
| 554 | |
| 555 | |
| 556 | |
| 557 | class pvt_v2_b1(PyramidVisionTransformerImpr): |
| 558 | def __init__(self, **kwargs): |
| 559 | super(pvt_v2_b1, self).__init__( |
| 560 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], |
| 561 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], |
| 562 | drop_rate=0.0, drop_path_rate=0.1) |
| 563 | |
| 564 | class pvt_v2_b2(PyramidVisionTransformerImpr): |
| 565 | def __init__(self, in_channels=3, **kwargs): |
| 566 | super(pvt_v2_b2, self).__init__( |
| 567 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], |
| 568 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], |
| 569 | drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels) |
| 570 | |
| 571 | class pvt_v2_b3(PyramidVisionTransformerImpr): |
| 572 | def __init__(self, **kwargs): |
| 573 | super(pvt_v2_b3, self).__init__( |
| 574 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], |
| 575 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], |
| 576 | drop_rate=0.0, drop_path_rate=0.1) |
| 577 | |
| 578 | class pvt_v2_b4(PyramidVisionTransformerImpr): |
| 579 | def __init__(self, **kwargs): |
| 580 | super(pvt_v2_b4, self).__init__( |
| 581 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], |
| 582 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], |
| 583 | drop_rate=0.0, drop_path_rate=0.1) |
| 584 | |
| 585 | |
| 586 | class pvt_v2_b5(PyramidVisionTransformerImpr): |
| 587 | def __init__(self, **kwargs): |
| 588 | super(pvt_v2_b5, self).__init__( |
| 589 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], |
| 590 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], |
| 591 | drop_rate=0.0, drop_path_rate=0.1) |
| 592 | |
| 593 | |
| 594 | |
| 595 | ### models/backbones/swin_v1.py |
| 596 | |
| 597 | # -------------------------------------------------------- |
| 598 | # Swin Transformer |
| 599 | # Copyright (c) 2021 Microsoft |
| 600 | # Licensed under The MIT License [see LICENSE for details] |
| 601 | # Written by Ze Liu, Yutong Lin, Yixuan Wei |
| 602 | # -------------------------------------------------------- |
| 603 | |
| 604 | import torch |
| 605 | import torch.nn as nn |
| 606 | import torch.nn.functional as F |
| 607 | import torch.utils.checkpoint as checkpoint |
| 608 | import numpy as np |
| 609 | from timm.layers import DropPath, to_2tuple, trunc_normal_ |
| 610 | |
| 611 | # from config import Config |
| 612 | |
| 613 | |
| 614 | # config = Config() |
| 615 | |
| 616 | |
| 617 | class Mlp(nn.Module): |
| 618 | """ Multilayer perceptron.""" |
| 619 | |
| 620 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
| 621 | super().__init__() |
| 622 | out_features = out_features or in_features |
| 623 | hidden_features = hidden_features or in_features |
| 624 | self.fc1 = nn.Linear(in_features, hidden_features) |
| 625 | self.act = act_layer() |
| 626 | self.fc2 = nn.Linear(hidden_features, out_features) |
| 627 | self.drop = nn.Dropout(drop) |
| 628 | |
| 629 | def forward(self, x): |
| 630 | x = self.fc1(x) |
| 631 | x = self.act(x) |
| 632 | x = self.drop(x) |
| 633 | x = self.fc2(x) |
| 634 | x = self.drop(x) |
| 635 | return x |
| 636 | |
| 637 | |
| 638 | def window_partition(x, window_size): |
| 639 | """ |
| 640 | Args: |
| 641 | x: (B, H, W, C) |
| 642 | window_size (int): window size |
| 643 | |
| 644 | Returns: |
| 645 | windows: (num_windows*B, window_size, window_size, C) |
| 646 | """ |
| 647 | B, H, W, C = x.shape |
| 648 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) |
| 649 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
| 650 | return windows |
| 651 | |
| 652 | |
| 653 | def window_reverse(windows, window_size, H, W): |
| 654 | """ |
| 655 | Args: |
| 656 | windows: (num_windows*B, window_size, window_size, C) |
| 657 | window_size (int): Window size |
| 658 | H (int): Height of image |
| 659 | W (int): Width of image |
| 660 | |
| 661 | Returns: |
| 662 | x: (B, H, W, C) |
| 663 | """ |
| 664 | B = int(windows.shape[0] / (H * W / window_size / window_size)) |
| 665 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) |
| 666 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) |
| 667 | return x |
| 668 | |
| 669 | |
| 670 | class WindowAttention(nn.Module): |
| 671 | """ Window based multi-head self attention (W-MSA) module with relative position bias. |
| 672 | It supports both of shifted and non-shifted window. |
| 673 | |
| 674 | Args: |
| 675 | dim (int): Number of input channels. |
| 676 | window_size (tuple[int]): The height and width of the window. |
| 677 | num_heads (int): Number of attention heads. |
| 678 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
| 679 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set |
| 680 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 |
| 681 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 |
| 682 | """ |
| 683 | |
| 684 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): |
| 685 | |
| 686 | super().__init__() |
| 687 | self.dim = dim |
| 688 | self.window_size = window_size # Wh, Ww |
| 689 | self.num_heads = num_heads |
| 690 | head_dim = dim // num_heads |
| 691 | self.scale = qk_scale or head_dim ** -0.5 |
| 692 | |
| 693 | # define a parameter table of relative position bias |
| 694 | self.relative_position_bias_table = nn.Parameter( |
| 695 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH |
| 696 | |
| 697 | # get pair-wise relative position index for each token inside the window |
| 698 | coords_h = torch.arange(self.window_size[0]) |
| 699 | coords_w = torch.arange(self.window_size[1]) |
| 700 | coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww |
| 701 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww |
| 702 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww |
| 703 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 |
| 704 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 |
| 705 | relative_coords[:, :, 1] += self.window_size[1] - 1 |
| 706 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 |
| 707 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww |
| 708 | self.register_buffer("relative_position_index", relative_position_index) |
| 709 | |
| 710 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| 711 | self.attn_drop_prob = attn_drop |
| 712 | self.attn_drop = nn.Dropout(attn_drop) |
| 713 | self.proj = nn.Linear(dim, dim) |
| 714 | self.proj_drop = nn.Dropout(proj_drop) |
| 715 | |
| 716 | trunc_normal_(self.relative_position_bias_table, std=.02) |
| 717 | self.softmax = nn.Softmax(dim=-1) |
| 718 | |
| 719 | def forward(self, x, mask=None): |
| 720 | """ Forward function. |
| 721 | |
| 722 | Args: |
| 723 | x: input features with shape of (num_windows*B, N, C) |
| 724 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None |
| 725 | """ |
| 726 | B_, N, C = x.shape |
| 727 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| 728 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) |
| 729 | |
| 730 | q = q * self.scale |
| 731 | |
| 732 | if config.SDPA_enabled: |
| 733 | x = torch.nn.functional.scaled_dot_product_attention( |
| 734 | q, k, v, |
| 735 | attn_mask=None, dropout_p=self.attn_drop_prob, is_causal=False |
| 736 | ).transpose(1, 2).reshape(B_, N, C) |
| 737 | else: |
| 738 | attn = (q @ k.transpose(-2, -1)) |
| 739 | |
| 740 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( |
| 741 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 |
| 742 | ) # Wh*Ww, Wh*Ww, nH |
| 743 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww |
| 744 | attn = attn + relative_position_bias.unsqueeze(0) |
| 745 | |
| 746 | if mask is not None: |
| 747 | nW = mask.shape[0] |
| 748 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) |
| 749 | attn = attn.view(-1, self.num_heads, N, N) |
| 750 | attn = self.softmax(attn) |
| 751 | else: |
| 752 | attn = self.softmax(attn) |
| 753 | |
| 754 | attn = self.attn_drop(attn) |
| 755 | |
| 756 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) |
| 757 | x = self.proj(x) |
| 758 | x = self.proj_drop(x) |
| 759 | return x |
| 760 | |
| 761 | |
| 762 | class SwinTransformerBlock(nn.Module): |
| 763 | """ Swin Transformer Block. |
| 764 | |
| 765 | Args: |
| 766 | dim (int): Number of input channels. |
| 767 | num_heads (int): Number of attention heads. |
| 768 | window_size (int): Window size. |
| 769 | shift_size (int): Shift size for SW-MSA. |
| 770 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
| 771 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
| 772 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
| 773 | drop (float, optional): Dropout rate. Default: 0.0 |
| 774 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
| 775 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 |
| 776 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU |
| 777 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
| 778 | """ |
| 779 | |
| 780 | def __init__(self, dim, num_heads, window_size=7, shift_size=0, |
| 781 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., |
| 782 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
| 783 | super().__init__() |
| 784 | self.dim = dim |
| 785 | self.num_heads = num_heads |
| 786 | self.window_size = window_size |
| 787 | self.shift_size = shift_size |
| 788 | self.mlp_ratio = mlp_ratio |
| 789 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" |
| 790 | |
| 791 | self.norm1 = norm_layer(dim) |
| 792 | self.attn = WindowAttention( |
| 793 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, |
| 794 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) |
| 795 | |
| 796 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| 797 | self.norm2 = norm_layer(dim) |
| 798 | mlp_hidden_dim = int(dim * mlp_ratio) |
| 799 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
| 800 | |
| 801 | self.H = None |
| 802 | self.W = None |
| 803 | |
| 804 | def forward(self, x, mask_matrix): |
| 805 | """ Forward function. |
| 806 | |
| 807 | Args: |
| 808 | x: Input feature, tensor size (B, H*W, C). |
| 809 | H, W: Spatial resolution of the input feature. |
| 810 | mask_matrix: Attention mask for cyclic shift. |
| 811 | """ |
| 812 | B, L, C = x.shape |
| 813 | H, W = self.H, self.W |
| 814 | assert L == H * W, "input feature has wrong size" |
| 815 | |
| 816 | shortcut = x |
| 817 | x = self.norm1(x) |
| 818 | x = x.view(B, H, W, C) |
| 819 | |
| 820 | # pad feature maps to multiples of window size |
| 821 | pad_l = pad_t = 0 |
| 822 | pad_r = (self.window_size - W % self.window_size) % self.window_size |
| 823 | pad_b = (self.window_size - H % self.window_size) % self.window_size |
| 824 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) |
| 825 | _, Hp, Wp, _ = x.shape |
| 826 | |
| 827 | # cyclic shift |
| 828 | if self.shift_size > 0: |
| 829 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) |
| 830 | attn_mask = mask_matrix |
| 831 | else: |
| 832 | shifted_x = x |
| 833 | attn_mask = None |
| 834 | |
| 835 | # partition windows |
| 836 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C |
| 837 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C |
| 838 | |
| 839 | # W-MSA/SW-MSA |
| 840 | attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C |
| 841 | |
| 842 | # merge windows |
| 843 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) |
| 844 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C |
| 845 | |
| 846 | # reverse cyclic shift |
| 847 | if self.shift_size > 0: |
| 848 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) |
| 849 | else: |
| 850 | x = shifted_x |
| 851 | |
| 852 | if pad_r > 0 or pad_b > 0: |
| 853 | x = x[:, :H, :W, :].contiguous() |
| 854 | |
| 855 | x = x.view(B, H * W, C) |
| 856 | |
| 857 | # FFN |
| 858 | x = shortcut + self.drop_path(x) |
| 859 | x = x + self.drop_path(self.mlp(self.norm2(x))) |
| 860 | |
| 861 | return x |
| 862 | |
| 863 | |
| 864 | class PatchMerging(nn.Module): |
| 865 | """ Patch Merging Layer |
| 866 | |
| 867 | Args: |
| 868 | dim (int): Number of input channels. |
| 869 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
| 870 | """ |
| 871 | def __init__(self, dim, norm_layer=nn.LayerNorm): |
| 872 | super().__init__() |
| 873 | self.dim = dim |
| 874 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) |
| 875 | self.norm = norm_layer(4 * dim) |
| 876 | |
| 877 | def forward(self, x, H, W): |
| 878 | """ Forward function. |
| 879 | |
| 880 | Args: |
| 881 | x: Input feature, tensor size (B, H*W, C). |
| 882 | H, W: Spatial resolution of the input feature. |
| 883 | """ |
| 884 | B, L, C = x.shape |
| 885 | assert L == H * W, "input feature has wrong size" |
| 886 | |
| 887 | x = x.view(B, H, W, C) |
| 888 | |
| 889 | # padding |
| 890 | pad_input = (H % 2 == 1) or (W % 2 == 1) |
| 891 | if pad_input: |
| 892 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) |
| 893 | |
| 894 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C |
| 895 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C |
| 896 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C |
| 897 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C |
| 898 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C |
| 899 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C |
| 900 | |
| 901 | x = self.norm(x) |
| 902 | x = self.reduction(x) |
| 903 | |
| 904 | return x |
| 905 | |
| 906 | |
| 907 | class BasicLayer(nn.Module): |
| 908 | """ A basic Swin Transformer layer for one stage. |
| 909 | |
| 910 | Args: |
| 911 | dim (int): Number of feature channels |
| 912 | depth (int): Depths of this stage. |
| 913 | num_heads (int): Number of attention head. |
| 914 | window_size (int): Local window size. Default: 7. |
| 915 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. |
| 916 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
| 917 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
| 918 | drop (float, optional): Dropout rate. Default: 0.0 |
| 919 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
| 920 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 |
| 921 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
| 922 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None |
| 923 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. |
| 924 | """ |
| 925 | |
| 926 | def __init__(self, |
| 927 | dim, |
| 928 | depth, |
| 929 | num_heads, |
| 930 | window_size=7, |
| 931 | mlp_ratio=4., |
| 932 | qkv_bias=True, |
| 933 | qk_scale=None, |
| 934 | drop=0., |
| 935 | attn_drop=0., |
| 936 | drop_path=0., |
| 937 | norm_layer=nn.LayerNorm, |
| 938 | downsample=None, |
| 939 | use_checkpoint=False): |
| 940 | super().__init__() |
| 941 | self.window_size = window_size |
| 942 | self.shift_size = window_size // 2 |
| 943 | self.depth = depth |
| 944 | self.use_checkpoint = use_checkpoint |
| 945 | |
| 946 | # build blocks |
| 947 | self.blocks = nn.ModuleList([ |
| 948 | SwinTransformerBlock( |
| 949 | dim=dim, |
| 950 | num_heads=num_heads, |
| 951 | window_size=window_size, |
| 952 | shift_size=0 if (i % 2 == 0) else window_size // 2, |
| 953 | mlp_ratio=mlp_ratio, |
| 954 | qkv_bias=qkv_bias, |
| 955 | qk_scale=qk_scale, |
| 956 | drop=drop, |
| 957 | attn_drop=attn_drop, |
| 958 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, |
| 959 | norm_layer=norm_layer) |
| 960 | for i in range(depth)]) |
| 961 | |
| 962 | # patch merging layer |
| 963 | if downsample is not None: |
| 964 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) |
| 965 | else: |
| 966 | self.downsample = None |
| 967 | |
| 968 | def forward(self, x, H, W): |
| 969 | """ Forward function. |
| 970 | |
| 971 | Args: |
| 972 | x: Input feature, tensor size (B, H*W, C). |
| 973 | H, W: Spatial resolution of the input feature. |
| 974 | """ |
| 975 | |
| 976 | # calculate attention mask for SW-MSA |
| 977 | # Turn int to torch.tensor for the compatiability with torch.compile in PyTorch 2.5. |
| 978 | Hp = torch.ceil(torch.tensor(H) / self.window_size).to(torch.int64) * self.window_size |
| 979 | Wp = torch.ceil(torch.tensor(W) / self.window_size).to(torch.int64) * self.window_size |
| 980 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 |
| 981 | h_slices = (slice(0, -self.window_size), |
| 982 | slice(-self.window_size, -self.shift_size), |
| 983 | slice(-self.shift_size, None)) |
| 984 | w_slices = (slice(0, -self.window_size), |
| 985 | slice(-self.window_size, -self.shift_size), |
| 986 | slice(-self.shift_size, None)) |
| 987 | cnt = 0 |
| 988 | for h in h_slices: |
| 989 | for w in w_slices: |
| 990 | img_mask[:, h, w, :] = cnt |
| 991 | cnt += 1 |
| 992 | |
| 993 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 |
| 994 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) |
| 995 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
| 996 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)).to(x.dtype) |
| 997 | |
| 998 | for blk in self.blocks: |
| 999 | blk.H, blk.W = H, W |
| 1000 | if self.use_checkpoint: |
| 1001 | x = checkpoint.checkpoint(blk, x, attn_mask) |
| 1002 | else: |
| 1003 | x = blk(x, attn_mask) |
| 1004 | if self.downsample is not None: |
| 1005 | x_down = self.downsample(x, H, W) |
| 1006 | Wh, Ww = (H + 1) // 2, (W + 1) // 2 |
| 1007 | return x, H, W, x_down, Wh, Ww |
| 1008 | else: |
| 1009 | return x, H, W, x, H, W |
| 1010 | |
| 1011 | |
| 1012 | class PatchEmbed(nn.Module): |
| 1013 | """ Image to Patch Embedding |
| 1014 | |
| 1015 | Args: |
| 1016 | patch_size (int): Patch token size. Default: 4. |
| 1017 | in_channels (int): Number of input image channels. Default: 3. |
| 1018 | embed_dim (int): Number of linear projection output channels. Default: 96. |
| 1019 | norm_layer (nn.Module, optional): Normalization layer. Default: None |
| 1020 | """ |
| 1021 | |
| 1022 | def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None): |
| 1023 | super().__init__() |
| 1024 | patch_size = to_2tuple(patch_size) |
| 1025 | self.patch_size = patch_size |
| 1026 | |
| 1027 | self.in_channels = in_channels |
| 1028 | self.embed_dim = embed_dim |
| 1029 | |
| 1030 | self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) |
| 1031 | if norm_layer is not None: |
| 1032 | self.norm = norm_layer(embed_dim) |
| 1033 | else: |
| 1034 | self.norm = None |
| 1035 | |
| 1036 | def forward(self, x): |
| 1037 | """Forward function.""" |
| 1038 | # padding |
| 1039 | _, _, H, W = x.size() |
| 1040 | if W % self.patch_size[1] != 0: |
| 1041 | x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) |
| 1042 | if H % self.patch_size[0] != 0: |
| 1043 | x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) |
| 1044 | |
| 1045 | x = self.proj(x) # B C Wh Ww |
| 1046 | if self.norm is not None: |
| 1047 | Wh, Ww = x.size(2), x.size(3) |
| 1048 | x = x.flatten(2).transpose(1, 2) |
| 1049 | x = self.norm(x) |
| 1050 | x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) |
| 1051 | |
| 1052 | return x |
| 1053 | |
| 1054 | |
| 1055 | class SwinTransformer(nn.Module): |
| 1056 | """ Swin Transformer backbone. |
| 1057 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - |
| 1058 | https://arxiv.org/pdf/2103.14030 |
| 1059 | |
| 1060 | Args: |
| 1061 | pretrain_img_size (int): Input image size for training the pretrained model, |
| 1062 | used in absolute postion embedding. Default 224. |
| 1063 | patch_size (int | tuple(int)): Patch size. Default: 4. |
| 1064 | in_channels (int): Number of input image channels. Default: 3. |
| 1065 | embed_dim (int): Number of linear projection output channels. Default: 96. |
| 1066 | depths (tuple[int]): Depths of each Swin Transformer stage. |
| 1067 | num_heads (tuple[int]): Number of attention head of each stage. |
| 1068 | window_size (int): Window size. Default: 7. |
| 1069 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. |
| 1070 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True |
| 1071 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. |
| 1072 | drop_rate (float): Dropout rate. |
| 1073 | attn_drop_rate (float): Attention dropout rate. Default: 0. |
| 1074 | drop_path_rate (float): Stochastic depth rate. Default: 0.2. |
| 1075 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. |
| 1076 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. |
| 1077 | patch_norm (bool): If True, add normalization after patch embedding. Default: True. |
| 1078 | out_indices (Sequence[int]): Output from which stages. |
| 1079 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode). |
| 1080 | -1 means not freezing any parameters. |
| 1081 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. |
| 1082 | """ |
| 1083 | |
| 1084 | def __init__(self, |
| 1085 | pretrain_img_size=224, |
| 1086 | patch_size=4, |
| 1087 | in_channels=3, |
| 1088 | embed_dim=96, |
| 1089 | depths=[2, 2, 6, 2], |
| 1090 | num_heads=[3, 6, 12, 24], |
| 1091 | window_size=7, |
| 1092 | mlp_ratio=4., |
| 1093 | qkv_bias=True, |
| 1094 | qk_scale=None, |
| 1095 | drop_rate=0., |
| 1096 | attn_drop_rate=0., |
| 1097 | drop_path_rate=0.2, |
| 1098 | norm_layer=nn.LayerNorm, |
| 1099 | ape=False, |
| 1100 | patch_norm=True, |
| 1101 | out_indices=(0, 1, 2, 3), |
| 1102 | frozen_stages=-1, |
| 1103 | use_checkpoint=False): |
| 1104 | super().__init__() |
| 1105 | |
| 1106 | self.pretrain_img_size = pretrain_img_size |
| 1107 | self.num_layers = len(depths) |
| 1108 | self.embed_dim = embed_dim |
| 1109 | self.ape = ape |
| 1110 | self.patch_norm = patch_norm |
| 1111 | self.out_indices = out_indices |
| 1112 | self.frozen_stages = frozen_stages |
| 1113 | |
| 1114 | # split image into non-overlapping patches |
| 1115 | self.patch_embed = PatchEmbed( |
| 1116 | patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim, |
| 1117 | norm_layer=norm_layer if self.patch_norm else None) |
| 1118 | |
| 1119 | # absolute position embedding |
| 1120 | if self.ape: |
| 1121 | pretrain_img_size = to_2tuple(pretrain_img_size) |
| 1122 | patch_size = to_2tuple(patch_size) |
| 1123 | patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] |
| 1124 | |
| 1125 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) |
| 1126 | trunc_normal_(self.absolute_pos_embed, std=.02) |
| 1127 | |
| 1128 | self.pos_drop = nn.Dropout(p=drop_rate) |
| 1129 | |
| 1130 | # stochastic depth |
| 1131 | dpr = np.linspace(0, drop_path_rate, sum(depths)).tolist() # stochastic depth decay rule |
| 1132 | |
| 1133 | # build layers |
| 1134 | self.layers = nn.ModuleList() |
| 1135 | for i_layer in range(self.num_layers): |
| 1136 | layer = BasicLayer( |
| 1137 | dim=int(embed_dim * 2 ** i_layer), |
| 1138 | depth=depths[i_layer], |
| 1139 | num_heads=num_heads[i_layer], |
| 1140 | window_size=window_size, |
| 1141 | mlp_ratio=mlp_ratio, |
| 1142 | qkv_bias=qkv_bias, |
| 1143 | qk_scale=qk_scale, |
| 1144 | drop=drop_rate, |
| 1145 | attn_drop=attn_drop_rate, |
| 1146 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], |
| 1147 | norm_layer=norm_layer, |
| 1148 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, |
| 1149 | use_checkpoint=use_checkpoint) |
| 1150 | self.layers.append(layer) |
| 1151 | |
| 1152 | num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] |
| 1153 | self.num_features = num_features |
| 1154 | |
| 1155 | # add a norm layer for each output |
| 1156 | for i_layer in out_indices: |
| 1157 | layer = norm_layer(num_features[i_layer]) |
| 1158 | layer_name = f'norm{i_layer}' |
| 1159 | self.add_module(layer_name, layer) |
| 1160 | |
| 1161 | self._freeze_stages() |
| 1162 | |
| 1163 | def _freeze_stages(self): |
| 1164 | if self.frozen_stages >= 0: |
| 1165 | self.patch_embed.eval() |
| 1166 | for param in self.patch_embed.parameters(): |
| 1167 | param.requires_grad = False |
| 1168 | |
| 1169 | if self.frozen_stages >= 1 and self.ape: |
| 1170 | self.absolute_pos_embed.requires_grad = False |
| 1171 | |
| 1172 | if self.frozen_stages >= 2: |
| 1173 | self.pos_drop.eval() |
| 1174 | for i in range(0, self.frozen_stages - 1): |
| 1175 | m = self.layers[i] |
| 1176 | m.eval() |
| 1177 | for param in m.parameters(): |
| 1178 | param.requires_grad = False |
| 1179 | |
| 1180 | |
| 1181 | def forward(self, x): |
| 1182 | """Forward function.""" |
| 1183 | x = self.patch_embed(x) |
| 1184 | |
| 1185 | Wh, Ww = x.size(2), x.size(3) |
| 1186 | if self.ape: |
| 1187 | # interpolate the position embedding to the corresponding size |
| 1188 | absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') |
| 1189 | x = (x + absolute_pos_embed) # B Wh*Ww C |
| 1190 | |
| 1191 | outs = []#x.contiguous()] |
| 1192 | x = x.flatten(2).transpose(1, 2) |
| 1193 | x = self.pos_drop(x) |
| 1194 | for i in range(self.num_layers): |
| 1195 | layer = self.layers[i] |
| 1196 | x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) |
| 1197 | |
| 1198 | if i in self.out_indices: |
| 1199 | norm_layer = getattr(self, f'norm{i}') |
| 1200 | x_out = norm_layer(x_out) |
| 1201 | |
| 1202 | out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() |
| 1203 | outs.append(out) |
| 1204 | |
| 1205 | return tuple(outs) |
| 1206 | |
| 1207 | def train(self, mode=True): |
| 1208 | """Convert the model into training mode while keep layers freezed.""" |
| 1209 | super(SwinTransformer, self).train(mode) |
| 1210 | self._freeze_stages() |
| 1211 | |
| 1212 | def swin_v1_t(): |
| 1213 | model = SwinTransformer(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7) |
| 1214 | return model |
| 1215 | |
| 1216 | def swin_v1_s(): |
| 1217 | model = SwinTransformer(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], window_size=7) |
| 1218 | return model |
| 1219 | |
| 1220 | def swin_v1_b(): |
| 1221 | model = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12) |
| 1222 | return model |
| 1223 | |
| 1224 | def swin_v1_l(): |
| 1225 | model = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12) |
| 1226 | return model |
| 1227 | |
| 1228 | |
| 1229 | |
| 1230 | ### models/modules/deform_conv.py |
| 1231 | |
| 1232 | import torch |
| 1233 | import torch.nn as nn |
| 1234 | from torchvision.ops import deform_conv2d |
| 1235 | |
| 1236 | |
| 1237 | class DeformableConv2d(nn.Module): |
| 1238 | def __init__(self, |
| 1239 | in_channels, |
| 1240 | out_channels, |
| 1241 | kernel_size=3, |
| 1242 | stride=1, |
| 1243 | padding=1, |
| 1244 | bias=False): |
| 1245 | |
| 1246 | super(DeformableConv2d, self).__init__() |
| 1247 | |
| 1248 | assert type(kernel_size) == tuple or type(kernel_size) == int |
| 1249 | |
| 1250 | kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size) |
| 1251 | self.stride = stride if type(stride) == tuple else (stride, stride) |
| 1252 | self.padding = padding |
| 1253 | |
| 1254 | self.offset_conv = nn.Conv2d(in_channels, |
| 1255 | 2 * kernel_size[0] * kernel_size[1], |
| 1256 | kernel_size=kernel_size, |
| 1257 | stride=stride, |
| 1258 | padding=self.padding, |
| 1259 | bias=True) |
| 1260 | |
| 1261 | nn.init.constant_(self.offset_conv.weight, 0.) |
| 1262 | nn.init.constant_(self.offset_conv.bias, 0.) |
| 1263 | |
| 1264 | self.modulator_conv = nn.Conv2d(in_channels, |
| 1265 | 1 * kernel_size[0] * kernel_size[1], |
| 1266 | kernel_size=kernel_size, |
| 1267 | stride=stride, |
| 1268 | padding=self.padding, |
| 1269 | bias=True) |
| 1270 | |
| 1271 | nn.init.constant_(self.modulator_conv.weight, 0.) |
| 1272 | nn.init.constant_(self.modulator_conv.bias, 0.) |
| 1273 | |
| 1274 | self.regular_conv = nn.Conv2d(in_channels, |
| 1275 | out_channels=out_channels, |
| 1276 | kernel_size=kernel_size, |
| 1277 | stride=stride, |
| 1278 | padding=self.padding, |
| 1279 | bias=bias) |
| 1280 | |
| 1281 | def forward(self, x): |
| 1282 | #h, w = x.shape[2:] |
| 1283 | #max_offset = max(h, w)/4. |
| 1284 | |
| 1285 | offset = self.offset_conv(x)#.clamp(-max_offset, max_offset) |
| 1286 | modulator = 2. * torch.sigmoid(self.modulator_conv(x)) |
| 1287 | |
| 1288 | x = deform_conv2d( |
| 1289 | input=x, |
| 1290 | offset=offset, |
| 1291 | weight=self.regular_conv.weight, |
| 1292 | bias=self.regular_conv.bias, |
| 1293 | padding=self.padding, |
| 1294 | mask=modulator, |
| 1295 | stride=self.stride, |
| 1296 | ) |
| 1297 | return x |
| 1298 | |
| 1299 | |
| 1300 | |
| 1301 | |
| 1302 | ### utils.py |
| 1303 | |
| 1304 | import torch.nn as nn |
| 1305 | |
| 1306 | |
| 1307 | def build_act_layer(act_layer): |
| 1308 | if act_layer == 'ReLU': |
| 1309 | return nn.ReLU(inplace=True) |
| 1310 | elif act_layer == 'SiLU': |
| 1311 | return nn.SiLU(inplace=True) |
| 1312 | elif act_layer == 'GELU': |
| 1313 | return nn.GELU() |
| 1314 | |
| 1315 | raise NotImplementedError(f'build_act_layer does not support {act_layer}') |
| 1316 | |
| 1317 | |
| 1318 | def build_norm_layer(dim, |
| 1319 | norm_layer, |
| 1320 | in_format='channels_last', |
| 1321 | out_format='channels_last', |
| 1322 | eps=1e-6): |
| 1323 | layers = [] |
| 1324 | if norm_layer == 'BN': |
| 1325 | if in_format == 'channels_last': |
| 1326 | layers.append(to_channels_first()) |
| 1327 | layers.append(nn.BatchNorm2d(dim)) |
| 1328 | if out_format == 'channels_last': |
| 1329 | layers.append(to_channels_last()) |
| 1330 | elif norm_layer == 'LN': |
| 1331 | if in_format == 'channels_first': |
| 1332 | layers.append(to_channels_last()) |
| 1333 | layers.append(nn.LayerNorm(dim, eps=eps)) |
| 1334 | if out_format == 'channels_first': |
| 1335 | layers.append(to_channels_first()) |
| 1336 | else: |
| 1337 | raise NotImplementedError( |
| 1338 | f'build_norm_layer does not support {norm_layer}') |
| 1339 | return nn.Sequential(*layers) |
| 1340 | |
| 1341 | |
| 1342 | class to_channels_first(nn.Module): |
| 1343 | |
| 1344 | def __init__(self): |
| 1345 | super().__init__() |
| 1346 | |
| 1347 | def forward(self, x): |
| 1348 | return x.permute(0, 3, 1, 2) |
| 1349 | |
| 1350 | |
| 1351 | class to_channels_last(nn.Module): |
| 1352 | |
| 1353 | def __init__(self): |
| 1354 | super().__init__() |
| 1355 | |
| 1356 | def forward(self, x): |
| 1357 | return x.permute(0, 2, 3, 1) |
| 1358 | |
| 1359 | |
| 1360 | |
| 1361 | ### dataset.py |
| 1362 | |
| 1363 | _class_labels_TR_sorted = ( |
| 1364 | 'Airplane, Ant, Antenna, Archery, Axe, BabyCarriage, Bag, BalanceBeam, Balcony, Balloon, Basket, BasketballHoop, Beatle, Bed, Bee, Bench, Bicycle, ' |
| 1365 | 'BicycleFrame, BicycleStand, Boat, Bonsai, BoomLift, Bridge, BunkBed, Butterfly, Button, Cable, CableLift, Cage, Camcorder, Cannon, Canoe, Car, ' |
| 1366 | 'CarParkDropArm, Carriage, Cart, Caterpillar, CeilingLamp, Centipede, Chair, Clip, Clock, Clothes, CoatHanger, Comb, ConcretePumpTruck, Crack, Crane, ' |
| 1367 | 'Cup, DentalChair, Desk, DeskChair, Diagram, DishRack, DoorHandle, Dragonfish, Dragonfly, Drum, Earphone, Easel, ElectricIron, Excavator, Eyeglasses, ' |
| 1368 | 'Fan, Fence, Fencing, FerrisWheel, FireExtinguisher, Fishing, Flag, FloorLamp, Forklift, GasStation, Gate, Gear, Goal, Golf, GymEquipment, Hammock, ' |
| 1369 | 'Handcart, Handcraft, Handrail, HangGlider, Harp, Harvester, Headset, Helicopter, Helmet, Hook, HorizontalBar, Hydrovalve, IroningTable, Jewelry, Key, ' |
| 1370 | 'KidsPlayground, Kitchenware, Kite, Knife, Ladder, LaundryRack, Lightning, Lobster, Locust, Machine, MachineGun, MagazineRack, Mantis, Medal, MemorialArchway, ' |
| 1371 | 'Microphone, Missile, MobileHolder, Monitor, Mosquito, Motorcycle, MovingTrolley, Mower, MusicPlayer, MusicStand, ObservationTower, Octopus, OilWell, ' |
| 1372 | 'OlympicLogo, OperatingTable, OutdoorFitnessEquipment, Parachute, Pavilion, Piano, Pipe, PlowHarrow, PoleVault, Punchbag, Rack, Racket, Rifle, Ring, Robot, ' |
| 1373 | 'RockClimbing, Rope, Sailboat, Satellite, Scaffold, Scale, Scissor, Scooter, Sculpture, Seadragon, Seahorse, Seal, SewingMachine, Ship, Shoe, ShoppingCart, ' |
| 1374 | 'ShoppingTrolley, Shower, Shrimp, Signboard, Skateboarding, Skeleton, Skiing, Spade, SpeedBoat, Spider, Spoon, Stair, Stand, Stationary, SteeringWheel, ' |
| 1375 | 'Stethoscope, Stool, Stove, StreetLamp, SweetStand, Swing, Sword, TV, Table, TableChair, TableLamp, TableTennis, Tank, Tapeline, Teapot, Telescope, Tent, ' |
| 1376 | 'TobaccoPipe, Toy, Tractor, TrafficLight, TrafficSign, Trampoline, TransmissionTower, Tree, Tricycle, TrimmerCover, Tripod, Trombone, Truck, Trumpet, Tuba, ' |
| 1377 | 'UAV, Umbrella, UnevenBars, UtilityPole, VacuumCleaner, Violin, Wakesurfing, Watch, WaterTower, WateringPot, Well, WellLid, Wheel, Wheelchair, WindTurbine, Windmill, WineGlass, WireWhisk, Yacht' |
| 1378 | ) |
| 1379 | class_labels_TR_sorted = _class_labels_TR_sorted.split(', ') |
| 1380 | |
| 1381 | |
| 1382 | ### models/backbones/build_backbones.py |
| 1383 | |
| 1384 | import torch |
| 1385 | import torch.nn as nn |
| 1386 | from collections import OrderedDict |
| 1387 | from torchvision.models import vgg16, vgg16_bn, VGG16_Weights, VGG16_BN_Weights, resnet50, ResNet50_Weights |
| 1388 | # from models.pvt_v2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b5 |
| 1389 | # from models.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l |
| 1390 | # from config import Config |
| 1391 | |
| 1392 | |
| 1393 | config = Config() |
| 1394 | |
| 1395 | def build_backbone(bb_name, pretrained=True, params_settings=''): |
| 1396 | if bb_name == 'vgg16': |
| 1397 | bb_net = list(vgg16(pretrained=VGG16_Weights.DEFAULT if pretrained else None).children())[0] |
| 1398 | bb = nn.Sequential(OrderedDict({'conv1': bb_net[:4], 'conv2': bb_net[4:9], 'conv3': bb_net[9:16], 'conv4': bb_net[16:23]})) |
| 1399 | elif bb_name == 'vgg16bn': |
| 1400 | bb_net = list(vgg16_bn(pretrained=VGG16_BN_Weights.DEFAULT if pretrained else None).children())[0] |
| 1401 | bb = nn.Sequential(OrderedDict({'conv1': bb_net[:6], 'conv2': bb_net[6:13], 'conv3': bb_net[13:23], 'conv4': bb_net[23:33]})) |
| 1402 | elif bb_name == 'resnet50': |
| 1403 | bb_net = list(resnet50(pretrained=ResNet50_Weights.DEFAULT if pretrained else None).children()) |
| 1404 | bb = nn.Sequential(OrderedDict({'conv1': nn.Sequential(*bb_net[0:3]), 'conv2': bb_net[4], 'conv3': bb_net[5], 'conv4': bb_net[6]})) |
| 1405 | else: |
| 1406 | bb = eval('{}({})'.format(bb_name, params_settings)) |
| 1407 | if pretrained: |
| 1408 | bb = load_weights(bb, bb_name) |
| 1409 | return bb |
| 1410 | |
| 1411 | def load_weights(model, model_name): |
| 1412 | save_model = torch.load(config.weights[model_name], map_location='cpu') |
| 1413 | model_dict = model.state_dict() |
| 1414 | state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model.items() if k in model_dict.keys()} |
| 1415 | # to ignore the weights with mismatched size when I modify the backbone itself. |
| 1416 | if not state_dict: |
| 1417 | save_model_keys = list(save_model.keys()) |
| 1418 | sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None |
| 1419 | state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model[sub_item].items() if k in model_dict.keys()} |
| 1420 | if not state_dict or not sub_item: |
| 1421 | print('Weights are not successully loaded. Check the state dict of weights file.') |
| 1422 | return None |
| 1423 | else: |
| 1424 | print('Found correct weights in the "{}" item of loaded state_dict.'.format(sub_item)) |
| 1425 | model_dict.update(state_dict) |
| 1426 | model.load_state_dict(model_dict) |
| 1427 | return model |
| 1428 | |
| 1429 | |
| 1430 | |
| 1431 | ### models/modules/decoder_blocks.py |
| 1432 | |
| 1433 | import torch |
| 1434 | import torch.nn as nn |
| 1435 | # from models.aspp import ASPP, ASPPDeformable |
| 1436 | # from config import Config |
| 1437 | |
| 1438 | |
| 1439 | # config = Config() |
| 1440 | |
| 1441 | |
| 1442 | class BasicDecBlk(nn.Module): |
| 1443 | def __init__(self, in_channels=64, out_channels=64, inter_channels=64): |
| 1444 | super(BasicDecBlk, self).__init__() |
| 1445 | inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 |
| 1446 | self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1) |
| 1447 | self.relu_in = nn.ReLU(inplace=True) |
| 1448 | if config.dec_att == 'ASPP': |
| 1449 | self.dec_att = ASPP(in_channels=inter_channels) |
| 1450 | elif config.dec_att == 'ASPPDeformable': |
| 1451 | self.dec_att = ASPPDeformable(in_channels=inter_channels) |
| 1452 | self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1) |
| 1453 | self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity() |
| 1454 | self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() |
| 1455 | |
| 1456 | def forward(self, x): |
| 1457 | x = self.conv_in(x) |
| 1458 | x = self.bn_in(x) |
| 1459 | x = self.relu_in(x) |
| 1460 | if hasattr(self, 'dec_att'): |
| 1461 | x = self.dec_att(x) |
| 1462 | x = self.conv_out(x) |
| 1463 | x = self.bn_out(x) |
| 1464 | return x |
| 1465 | |
| 1466 | |
| 1467 | class ResBlk(nn.Module): |
| 1468 | def __init__(self, in_channels=64, out_channels=None, inter_channels=64): |
| 1469 | super(ResBlk, self).__init__() |
| 1470 | if out_channels is None: |
| 1471 | out_channels = in_channels |
| 1472 | inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 |
| 1473 | |
| 1474 | self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1) |
| 1475 | self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity() |
| 1476 | self.relu_in = nn.ReLU(inplace=True) |
| 1477 | |
| 1478 | if config.dec_att == 'ASPP': |
| 1479 | self.dec_att = ASPP(in_channels=inter_channels) |
| 1480 | elif config.dec_att == 'ASPPDeformable': |
| 1481 | self.dec_att = ASPPDeformable(in_channels=inter_channels) |
| 1482 | |
| 1483 | self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1) |
| 1484 | self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() |
| 1485 | |
| 1486 | self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0) |
| 1487 | |
| 1488 | def forward(self, x): |
| 1489 | _x = self.conv_resi(x) |
| 1490 | x = self.conv_in(x) |
| 1491 | x = self.bn_in(x) |
| 1492 | x = self.relu_in(x) |
| 1493 | if hasattr(self, 'dec_att'): |
| 1494 | x = self.dec_att(x) |
| 1495 | x = self.conv_out(x) |
| 1496 | x = self.bn_out(x) |
| 1497 | return x + _x |
| 1498 | |
| 1499 | |
| 1500 | |
| 1501 | ### models/modules/lateral_blocks.py |
| 1502 | |
| 1503 | import numpy as np |
| 1504 | import torch |
| 1505 | import torch.nn as nn |
| 1506 | import torch.nn.functional as F |
| 1507 | from functools import partial |
| 1508 | |
| 1509 | # from config import Config |
| 1510 | |
| 1511 | |
| 1512 | # config = Config() |
| 1513 | |
| 1514 | |
| 1515 | class BasicLatBlk(nn.Module): |
| 1516 | def __init__(self, in_channels=64, out_channels=64, inter_channels=64): |
| 1517 | super(BasicLatBlk, self).__init__() |
| 1518 | inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 |
| 1519 | self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0) |
| 1520 | |
| 1521 | def forward(self, x): |
| 1522 | x = self.conv(x) |
| 1523 | return x |
| 1524 | |
| 1525 | |
| 1526 | |
| 1527 | ### models/modules/aspp.py |
| 1528 | |
| 1529 | import torch |
| 1530 | import torch.nn as nn |
| 1531 | import torch.nn.functional as F |
| 1532 | # from models.deform_conv import DeformableConv2d |
| 1533 | # from config import Config |
| 1534 | |
| 1535 | |
| 1536 | # config = Config() |
| 1537 | |
| 1538 | |
| 1539 | class _ASPPModule(nn.Module): |
| 1540 | def __init__(self, in_channels, planes, kernel_size, padding, dilation): |
| 1541 | super(_ASPPModule, self).__init__() |
| 1542 | self.atrous_conv = nn.Conv2d(in_channels, planes, kernel_size=kernel_size, |
| 1543 | stride=1, padding=padding, dilation=dilation, bias=False) |
| 1544 | self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity() |
| 1545 | self.relu = nn.ReLU(inplace=True) |
| 1546 | |
| 1547 | def forward(self, x): |
| 1548 | x = self.atrous_conv(x) |
| 1549 | x = self.bn(x) |
| 1550 | |
| 1551 | return self.relu(x) |
| 1552 | |
| 1553 | |
| 1554 | class ASPP(nn.Module): |
| 1555 | def __init__(self, in_channels=64, out_channels=None, output_stride=16): |
| 1556 | super(ASPP, self).__init__() |
| 1557 | self.down_scale = 1 |
| 1558 | if out_channels is None: |
| 1559 | out_channels = in_channels |
| 1560 | self.in_channelster = 256 // self.down_scale |
| 1561 | if output_stride == 16: |
| 1562 | dilations = [1, 6, 12, 18] |
| 1563 | elif output_stride == 8: |
| 1564 | dilations = [1, 12, 24, 36] |
| 1565 | else: |
| 1566 | raise NotImplementedError |
| 1567 | |
| 1568 | self.aspp1 = _ASPPModule(in_channels, self.in_channelster, 1, padding=0, dilation=dilations[0]) |
| 1569 | self.aspp2 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[1], dilation=dilations[1]) |
| 1570 | self.aspp3 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[2], dilation=dilations[2]) |
| 1571 | self.aspp4 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[3], dilation=dilations[3]) |
| 1572 | |
| 1573 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), |
| 1574 | nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False), |
| 1575 | nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(), |
| 1576 | nn.ReLU(inplace=True)) |
| 1577 | self.conv1 = nn.Conv2d(self.in_channelster * 5, out_channels, 1, bias=False) |
| 1578 | self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() |
| 1579 | self.relu = nn.ReLU(inplace=True) |
| 1580 | self.dropout = nn.Dropout(0.5) |
| 1581 | |
| 1582 | def forward(self, x): |
| 1583 | x1 = self.aspp1(x) |
| 1584 | x2 = self.aspp2(x) |
| 1585 | x3 = self.aspp3(x) |
| 1586 | x4 = self.aspp4(x) |
| 1587 | x5 = self.global_avg_pool(x) |
| 1588 | x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True) |
| 1589 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) |
| 1590 | |
| 1591 | x = self.conv1(x) |
| 1592 | x = self.bn1(x) |
| 1593 | x = self.relu(x) |
| 1594 | |
| 1595 | return self.dropout(x) |
| 1596 | |
| 1597 | |
| 1598 | ##################### Deformable |
| 1599 | class _ASPPModuleDeformable(nn.Module): |
| 1600 | def __init__(self, in_channels, planes, kernel_size, padding): |
| 1601 | super(_ASPPModuleDeformable, self).__init__() |
| 1602 | self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size, |
| 1603 | stride=1, padding=padding, bias=False) |
| 1604 | self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity() |
| 1605 | self.relu = nn.ReLU(inplace=True) |
| 1606 | |
| 1607 | def forward(self, x): |
| 1608 | x = self.atrous_conv(x) |
| 1609 | x = self.bn(x) |
| 1610 | |
| 1611 | return self.relu(x) |
| 1612 | |
| 1613 | |
| 1614 | class ASPPDeformable(nn.Module): |
| 1615 | def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7]): |
| 1616 | super(ASPPDeformable, self).__init__() |
| 1617 | self.down_scale = 1 |
| 1618 | if out_channels is None: |
| 1619 | out_channels = in_channels |
| 1620 | self.in_channelster = 256 // self.down_scale |
| 1621 | |
| 1622 | self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0) |
| 1623 | self.aspp_deforms = nn.ModuleList([ |
| 1624 | _ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2)) for conv_size in parallel_block_sizes |
| 1625 | ]) |
| 1626 | |
| 1627 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), |
| 1628 | nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False), |
| 1629 | nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(), |
| 1630 | nn.ReLU(inplace=True)) |
| 1631 | self.conv1 = nn.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False) |
| 1632 | self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() |
| 1633 | self.relu = nn.ReLU(inplace=True) |
| 1634 | self.dropout = nn.Dropout(0.5) |
| 1635 | |
| 1636 | def forward(self, x): |
| 1637 | x1 = self.aspp1(x) |
| 1638 | x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms] |
| 1639 | x5 = self.global_avg_pool(x) |
| 1640 | x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True) |
| 1641 | x = torch.cat((x1, *x_aspp_deforms, x5), dim=1) |
| 1642 | |
| 1643 | x = self.conv1(x) |
| 1644 | x = self.bn1(x) |
| 1645 | x = self.relu(x) |
| 1646 | |
| 1647 | return self.dropout(x) |
| 1648 | |
| 1649 | |
| 1650 | |
| 1651 | ### models/refinement/refiner.py |
| 1652 | |
| 1653 | import torch |
| 1654 | import torch.nn as nn |
| 1655 | from collections import OrderedDict |
| 1656 | import torch |
| 1657 | import torch.nn as nn |
| 1658 | import torch.nn.functional as F |
| 1659 | from torchvision.models import vgg16, vgg16_bn |
| 1660 | from torchvision.models import resnet50 |
| 1661 | |
| 1662 | # from config import Config |
| 1663 | # from dataset import class_labels_TR_sorted |
| 1664 | # from models.build_backbone import build_backbone |
| 1665 | # from models.decoder_blocks import BasicDecBlk |
| 1666 | # from models.lateral_blocks import BasicLatBlk |
| 1667 | # from models.ing import * |
| 1668 | # from models.stem_layer import StemLayer |
| 1669 | |
| 1670 | |
| 1671 | class RefinerPVTInChannels4(nn.Module): |
| 1672 | def __init__(self, in_channels=3+1): |
| 1673 | super(RefinerPVTInChannels4, self).__init__() |
| 1674 | self.config = Config() |
| 1675 | self.epoch = 1 |
| 1676 | self.bb = build_backbone(self.config.bb, params_settings='in_channels=4') |
| 1677 | |
| 1678 | lateral_channels_in_collection = { |
| 1679 | 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64], |
| 1680 | 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64], |
| 1681 | 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192], |
| 1682 | } |
| 1683 | channels = lateral_channels_in_collection[self.config.bb] |
| 1684 | self.squeeze_module = BasicDecBlk(channels[0], channels[0]) |
| 1685 | |
| 1686 | self.decoder = Decoder(channels) |
| 1687 | |
| 1688 | if 0: |
| 1689 | for key, value in self.named_parameters(): |
| 1690 | if 'bb.' in key: |
| 1691 | value.requires_grad = False |
| 1692 | |
| 1693 | def forward(self, x): |
| 1694 | if isinstance(x, list): |
| 1695 | x = torch.cat(x, dim=1) |
| 1696 | ########## Encoder ########## |
| 1697 | if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: |
| 1698 | x1 = self.bb.conv1(x) |
| 1699 | x2 = self.bb.conv2(x1) |
| 1700 | x3 = self.bb.conv3(x2) |
| 1701 | x4 = self.bb.conv4(x3) |
| 1702 | else: |
| 1703 | x1, x2, x3, x4 = self.bb(x) |
| 1704 | |
| 1705 | x4 = self.squeeze_module(x4) |
| 1706 | |
| 1707 | ########## Decoder ########## |
| 1708 | |
| 1709 | features = [x, x1, x2, x3, x4] |
| 1710 | scaled_preds = self.decoder(features) |
| 1711 | |
| 1712 | return scaled_preds |
| 1713 | |
| 1714 | |
| 1715 | class Refiner(nn.Module): |
| 1716 | def __init__(self, in_channels=3+1): |
| 1717 | super(Refiner, self).__init__() |
| 1718 | self.config = Config() |
| 1719 | self.epoch = 1 |
| 1720 | self.stem_layer = StemLayer(in_channels=in_channels, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN') |
| 1721 | self.bb = build_backbone(self.config.bb) |
| 1722 | |
| 1723 | lateral_channels_in_collection = { |
| 1724 | 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64], |
| 1725 | 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64], |
| 1726 | 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192], |
| 1727 | } |
| 1728 | channels = lateral_channels_in_collection[self.config.bb] |
| 1729 | self.squeeze_module = BasicDecBlk(channels[0], channels[0]) |
| 1730 | |
| 1731 | self.decoder = Decoder(channels) |
| 1732 | |
| 1733 | if 0: |
| 1734 | for key, value in self.named_parameters(): |
| 1735 | if 'bb.' in key: |
| 1736 | value.requires_grad = False |
| 1737 | |
| 1738 | def forward(self, x): |
| 1739 | if isinstance(x, list): |
| 1740 | x = torch.cat(x, dim=1) |
| 1741 | x = self.stem_layer(x) |
| 1742 | ########## Encoder ########## |
| 1743 | if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: |
| 1744 | x1 = self.bb.conv1(x) |
| 1745 | x2 = self.bb.conv2(x1) |
| 1746 | x3 = self.bb.conv3(x2) |
| 1747 | x4 = self.bb.conv4(x3) |
| 1748 | else: |
| 1749 | x1, x2, x3, x4 = self.bb(x) |
| 1750 | |
| 1751 | x4 = self.squeeze_module(x4) |
| 1752 | |
| 1753 | ########## Decoder ########## |
| 1754 | |
| 1755 | features = [x, x1, x2, x3, x4] |
| 1756 | scaled_preds = self.decoder(features) |
| 1757 | |
| 1758 | return scaled_preds |
| 1759 | |
| 1760 | |
| 1761 | class Decoder(nn.Module): |
| 1762 | def __init__(self, channels): |
| 1763 | super(Decoder, self).__init__() |
| 1764 | self.config = Config() |
| 1765 | DecoderBlock = eval('BasicDecBlk') |
| 1766 | LateralBlock = eval('BasicLatBlk') |
| 1767 | |
| 1768 | self.decoder_block4 = DecoderBlock(channels[0], channels[1]) |
| 1769 | self.decoder_block3 = DecoderBlock(channels[1], channels[2]) |
| 1770 | self.decoder_block2 = DecoderBlock(channels[2], channels[3]) |
| 1771 | self.decoder_block1 = DecoderBlock(channels[3], channels[3]//2) |
| 1772 | |
| 1773 | self.lateral_block4 = LateralBlock(channels[1], channels[1]) |
| 1774 | self.lateral_block3 = LateralBlock(channels[2], channels[2]) |
| 1775 | self.lateral_block2 = LateralBlock(channels[3], channels[3]) |
| 1776 | |
| 1777 | if self.config.ms_supervision: |
| 1778 | self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0) |
| 1779 | self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0) |
| 1780 | self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0) |
| 1781 | self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2, 1, 1, 1, 0)) |
| 1782 | |
| 1783 | def forward(self, features): |
| 1784 | x, x1, x2, x3, x4 = features |
| 1785 | outs = [] |
| 1786 | p4 = self.decoder_block4(x4) |
| 1787 | _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True) |
| 1788 | _p3 = _p4 + self.lateral_block4(x3) |
| 1789 | |
| 1790 | p3 = self.decoder_block3(_p3) |
| 1791 | _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True) |
| 1792 | _p2 = _p3 + self.lateral_block3(x2) |
| 1793 | |
| 1794 | p2 = self.decoder_block2(_p2) |
| 1795 | _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True) |
| 1796 | _p1 = _p2 + self.lateral_block2(x1) |
| 1797 | |
| 1798 | _p1 = self.decoder_block1(_p1) |
| 1799 | _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True) |
| 1800 | p1_out = self.conv_out1(_p1) |
| 1801 | |
| 1802 | if self.config.ms_supervision: |
| 1803 | outs.append(self.conv_ms_spvn_4(p4)) |
| 1804 | outs.append(self.conv_ms_spvn_3(p3)) |
| 1805 | outs.append(self.conv_ms_spvn_2(p2)) |
| 1806 | outs.append(p1_out) |
| 1807 | return outs |
| 1808 | |
| 1809 | |
| 1810 | class RefUNet(nn.Module): |
| 1811 | # Refinement |
| 1812 | def __init__(self, in_channels=3+1): |
| 1813 | super(RefUNet, self).__init__() |
| 1814 | self.encoder_1 = nn.Sequential( |
| 1815 | nn.Conv2d(in_channels, 64, 3, 1, 1), |
| 1816 | nn.Conv2d(64, 64, 3, 1, 1), |
| 1817 | nn.BatchNorm2d(64), |
| 1818 | nn.ReLU(inplace=True) |
| 1819 | ) |
| 1820 | |
| 1821 | self.encoder_2 = nn.Sequential( |
| 1822 | nn.MaxPool2d(2, 2, ceil_mode=True), |
| 1823 | nn.Conv2d(64, 64, 3, 1, 1), |
| 1824 | nn.BatchNorm2d(64), |
| 1825 | nn.ReLU(inplace=True) |
| 1826 | ) |
| 1827 | |
| 1828 | self.encoder_3 = nn.Sequential( |
| 1829 | nn.MaxPool2d(2, 2, ceil_mode=True), |
| 1830 | nn.Conv2d(64, 64, 3, 1, 1), |
| 1831 | nn.BatchNorm2d(64), |
| 1832 | nn.ReLU(inplace=True) |
| 1833 | ) |
| 1834 | |
| 1835 | self.encoder_4 = nn.Sequential( |
| 1836 | nn.MaxPool2d(2, 2, ceil_mode=True), |
| 1837 | nn.Conv2d(64, 64, 3, 1, 1), |
| 1838 | nn.BatchNorm2d(64), |
| 1839 | nn.ReLU(inplace=True) |
| 1840 | ) |
| 1841 | |
| 1842 | self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True) |
| 1843 | ##### |
| 1844 | self.decoder_5 = nn.Sequential( |
| 1845 | nn.Conv2d(64, 64, 3, 1, 1), |
| 1846 | nn.BatchNorm2d(64), |
| 1847 | nn.ReLU(inplace=True) |
| 1848 | ) |
| 1849 | ##### |
| 1850 | self.decoder_4 = nn.Sequential( |
| 1851 | nn.Conv2d(128, 64, 3, 1, 1), |
| 1852 | nn.BatchNorm2d(64), |
| 1853 | nn.ReLU(inplace=True) |
| 1854 | ) |
| 1855 | |
| 1856 | self.decoder_3 = nn.Sequential( |
| 1857 | nn.Conv2d(128, 64, 3, 1, 1), |
| 1858 | nn.BatchNorm2d(64), |
| 1859 | nn.ReLU(inplace=True) |
| 1860 | ) |
| 1861 | |
| 1862 | self.decoder_2 = nn.Sequential( |
| 1863 | nn.Conv2d(128, 64, 3, 1, 1), |
| 1864 | nn.BatchNorm2d(64), |
| 1865 | nn.ReLU(inplace=True) |
| 1866 | ) |
| 1867 | |
| 1868 | self.decoder_1 = nn.Sequential( |
| 1869 | nn.Conv2d(128, 64, 3, 1, 1), |
| 1870 | nn.BatchNorm2d(64), |
| 1871 | nn.ReLU(inplace=True) |
| 1872 | ) |
| 1873 | |
| 1874 | self.conv_d0 = nn.Conv2d(64, 1, 3, 1, 1) |
| 1875 | |
| 1876 | self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
| 1877 | |
| 1878 | def forward(self, x): |
| 1879 | outs = [] |
| 1880 | if isinstance(x, list): |
| 1881 | x = torch.cat(x, dim=1) |
| 1882 | hx = x |
| 1883 | |
| 1884 | hx1 = self.encoder_1(hx) |
| 1885 | hx2 = self.encoder_2(hx1) |
| 1886 | hx3 = self.encoder_3(hx2) |
| 1887 | hx4 = self.encoder_4(hx3) |
| 1888 | |
| 1889 | hx = self.decoder_5(self.pool4(hx4)) |
| 1890 | hx = torch.cat((self.upscore2(hx), hx4), 1) |
| 1891 | |
| 1892 | d4 = self.decoder_4(hx) |
| 1893 | hx = torch.cat((self.upscore2(d4), hx3), 1) |
| 1894 | |
| 1895 | d3 = self.decoder_3(hx) |
| 1896 | hx = torch.cat((self.upscore2(d3), hx2), 1) |
| 1897 | |
| 1898 | d2 = self.decoder_2(hx) |
| 1899 | hx = torch.cat((self.upscore2(d2), hx1), 1) |
| 1900 | |
| 1901 | d1 = self.decoder_1(hx) |
| 1902 | |
| 1903 | x = self.conv_d0(d1) |
| 1904 | outs.append(x) |
| 1905 | return outs |
| 1906 | |
| 1907 | |
| 1908 | |
| 1909 | ### models/stem_layer.py |
| 1910 | |
| 1911 | import torch.nn as nn |
| 1912 | # from utils import build_act_layer, build_norm_layer |
| 1913 | |
| 1914 | |
| 1915 | class StemLayer(nn.Module): |
| 1916 | r""" Stem layer of InternImage |
| 1917 | Args: |
| 1918 | in_channels (int): number of input channels |
| 1919 | out_channels (int): number of output channels |
| 1920 | act_layer (str): activation layer |
| 1921 | norm_layer (str): normalization layer |
| 1922 | """ |
| 1923 | |
| 1924 | def __init__(self, |
| 1925 | in_channels=3+1, |
| 1926 | inter_channels=48, |
| 1927 | out_channels=96, |
| 1928 | act_layer='GELU', |
| 1929 | norm_layer='BN'): |
| 1930 | super().__init__() |
| 1931 | self.conv1 = nn.Conv2d(in_channels, |
| 1932 | inter_channels, |
| 1933 | kernel_size=3, |
| 1934 | stride=1, |
| 1935 | padding=1) |
| 1936 | self.norm1 = build_norm_layer( |
| 1937 | inter_channels, norm_layer, 'channels_first', 'channels_first' |
| 1938 | ) |
| 1939 | self.act = build_act_layer(act_layer) |
| 1940 | self.conv2 = nn.Conv2d(inter_channels, |
| 1941 | out_channels, |
| 1942 | kernel_size=3, |
| 1943 | stride=1, |
| 1944 | padding=1) |
| 1945 | self.norm2 = build_norm_layer( |
| 1946 | out_channels, norm_layer, 'channels_first', 'channels_first' |
| 1947 | ) |
| 1948 | |
| 1949 | def forward(self, x): |
| 1950 | x = self.conv1(x) |
| 1951 | x = self.norm1(x) |
| 1952 | x = self.act(x) |
| 1953 | x = self.conv2(x) |
| 1954 | x = self.norm2(x) |
| 1955 | return x |
| 1956 | |
| 1957 | |
| 1958 | ### models/birefnet.py |
| 1959 | |
| 1960 | import torch |
| 1961 | import torch.nn as nn |
| 1962 | import torch.nn.functional as F |
| 1963 | from kornia.filters import laplacian |
| 1964 | from transformers import PreTrainedModel |
| 1965 | from einops import rearrange |
| 1966 | |
| 1967 | # from config import Config |
| 1968 | # from dataset import class_labels_TR_sorted |
| 1969 | # from models.build_backbone import build_backbone |
| 1970 | # from models.decoder_blocks import BasicDecBlk, ResBlk, HierarAttDecBlk |
| 1971 | # from models.lateral_blocks import BasicLatBlk |
| 1972 | # from models.aspp import ASPP, ASPPDeformable |
| 1973 | # from models.ing import * |
| 1974 | # from models.refiner import Refiner, RefinerPVTInChannels4, RefUNet |
| 1975 | # from models.stem_layer import StemLayer |
| 1976 | from .BiRefNet_config import BiRefNetConfig |
| 1977 | |
| 1978 | |
| 1979 | def image2patches(image, grid_h=2, grid_w=2, patch_ref=None, transformation='b c (hg h) (wg w) -> (b hg wg) c h w'): |
| 1980 | if patch_ref is not None: |
| 1981 | grid_h, grid_w = image.shape[-2] // patch_ref.shape[-2], image.shape[-1] // patch_ref.shape[-1] |
| 1982 | patches = rearrange(image, transformation, hg=grid_h, wg=grid_w) |
| 1983 | return patches |
| 1984 | |
| 1985 | def patches2image(patches, grid_h=2, grid_w=2, patch_ref=None, transformation='(b hg wg) c h w -> b c (hg h) (wg w)'): |
| 1986 | if patch_ref is not None: |
| 1987 | grid_h, grid_w = patch_ref.shape[-2] // patches[0].shape[-2], patch_ref.shape[-1] // patches[0].shape[-1] |
| 1988 | image = rearrange(patches, transformation, hg=grid_h, wg=grid_w) |
| 1989 | return image |
| 1990 | |
| 1991 | class BiRefNet( |
| 1992 | PreTrainedModel |
| 1993 | ): |
| 1994 | config_class = BiRefNetConfig |
| 1995 | def __init__(self, bb_pretrained=True, config=BiRefNetConfig()): |
| 1996 | super(BiRefNet, self).__init__(config) |
| 1997 | bb_pretrained = config.bb_pretrained |
| 1998 | self.config = Config() |
| 1999 | self.epoch = 1 |
| 2000 | self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained) |
| 2001 | |
| 2002 | channels = self.config.lateral_channels_in_collection |
| 2003 | |
| 2004 | if self.config.auxiliary_classification: |
| 2005 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
| 2006 | self.cls_head = nn.Sequential( |
| 2007 | nn.Linear(channels[0], len(class_labels_TR_sorted)) |
| 2008 | ) |
| 2009 | |
| 2010 | if self.config.squeeze_block: |
| 2011 | self.squeeze_module = nn.Sequential(*[ |
| 2012 | eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0]) |
| 2013 | for _ in range(eval(self.config.squeeze_block.split('_x')[1])) |
| 2014 | ]) |
| 2015 | |
| 2016 | self.decoder = Decoder(channels) |
| 2017 | |
| 2018 | if self.config.ender: |
| 2019 | self.dec_end = nn.Sequential( |
| 2020 | nn.Conv2d(1, 16, 3, 1, 1), |
| 2021 | nn.Conv2d(16, 1, 3, 1, 1), |
| 2022 | nn.ReLU(inplace=True), |
| 2023 | ) |
| 2024 | |
| 2025 | # refine patch-level segmentation |
| 2026 | if self.config.refine: |
| 2027 | if self.config.refine == 'itself': |
| 2028 | self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN') |
| 2029 | else: |
| 2030 | self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1')) |
| 2031 | |
| 2032 | if self.config.freeze_bb: |
| 2033 | # Freeze the backbone... |
| 2034 | print(self.named_parameters()) |
| 2035 | for key, value in self.named_parameters(): |
| 2036 | if 'bb.' in key and 'refiner.' not in key: |
| 2037 | value.requires_grad = False |
| 2038 | |
| 2039 | self.post_init() |
| 2040 | |
| 2041 | def forward_enc(self, x): |
| 2042 | if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: |
| 2043 | x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3) |
| 2044 | else: |
| 2045 | x1, x2, x3, x4 = self.bb(x) |
| 2046 | if self.config.mul_scl_ipt == 'cat': |
| 2047 | B, C, H, W = x.shape |
| 2048 | x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True)) |
| 2049 | x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1) |
| 2050 | x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1) |
| 2051 | x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1) |
| 2052 | x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1) |
| 2053 | elif self.config.mul_scl_ipt == 'add': |
| 2054 | B, C, H, W = x.shape |
| 2055 | x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True)) |
| 2056 | x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True) |
| 2057 | x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True) |
| 2058 | x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True) |
| 2059 | x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True) |
| 2060 | class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else None |
| 2061 | if self.config.cxt: |
| 2062 | x4 = torch.cat( |
| 2063 | ( |
| 2064 | *[ |
| 2065 | F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True), |
| 2066 | F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True), |
| 2067 | F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True), |
| 2068 | ][-len(self.config.cxt):], |
| 2069 | x4 |
| 2070 | ), |
| 2071 | dim=1 |
| 2072 | ) |
| 2073 | return (x1, x2, x3, x4), class_preds |
| 2074 | |
| 2075 | def forward_ori(self, x): |
| 2076 | ########## Encoder ########## |
| 2077 | (x1, x2, x3, x4), class_preds = self.forward_enc(x) |
| 2078 | if self.config.squeeze_block: |
| 2079 | x4 = self.squeeze_module(x4) |
| 2080 | ########## Decoder ########## |
| 2081 | features = [x, x1, x2, x3, x4] |
| 2082 | if self.training and self.config.out_ref: |
| 2083 | features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5)) |
| 2084 | scaled_preds = self.decoder(features) |
| 2085 | return scaled_preds, class_preds |
| 2086 | |
| 2087 | def forward(self, x): |
| 2088 | scaled_preds, class_preds = self.forward_ori(x) |
| 2089 | class_preds_lst = [class_preds] |
| 2090 | return [scaled_preds, class_preds_lst] if self.training else scaled_preds |
| 2091 | |
| 2092 | |
| 2093 | class Decoder(nn.Module): |
| 2094 | def __init__(self, channels): |
| 2095 | super(Decoder, self).__init__() |
| 2096 | self.config = Config() |
| 2097 | DecoderBlock = eval(self.config.dec_blk) |
| 2098 | LateralBlock = eval(self.config.lat_blk) |
| 2099 | |
| 2100 | if self.config.dec_ipt: |
| 2101 | self.split = self.config.dec_ipt_split |
| 2102 | N_dec_ipt = 64 |
| 2103 | DBlock = SimpleConvs |
| 2104 | ic = 64 |
| 2105 | ipt_cha_opt = 1 |
| 2106 | self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic) |
| 2107 | self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic) |
| 2108 | self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic) |
| 2109 | self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic) |
| 2110 | self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic) |
| 2111 | else: |
| 2112 | self.split = None |
| 2113 | |
| 2114 | self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[1]) |
| 2115 | self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[2]) |
| 2116 | self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3]) |
| 2117 | self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3]//2) |
| 2118 | self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt] if self.config.dec_ipt else 0), 1, 1, 1, 0)) |
| 2119 | |
| 2120 | self.lateral_block4 = LateralBlock(channels[1], channels[1]) |
| 2121 | self.lateral_block3 = LateralBlock(channels[2], channels[2]) |
| 2122 | self.lateral_block2 = LateralBlock(channels[3], channels[3]) |
| 2123 | |
| 2124 | if self.config.ms_supervision: |
| 2125 | self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0) |
| 2126 | self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0) |
| 2127 | self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0) |
| 2128 | |
| 2129 | if self.config.out_ref: |
| 2130 | _N = 16 |
| 2131 | self.gdt_convs_4 = nn.Sequential(nn.Conv2d(channels[1], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True)) |
| 2132 | self.gdt_convs_3 = nn.Sequential(nn.Conv2d(channels[2], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True)) |
| 2133 | self.gdt_convs_2 = nn.Sequential(nn.Conv2d(channels[3], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True)) |
| 2134 | |
| 2135 | self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) |
| 2136 | self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) |
| 2137 | self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) |
| 2138 | |
| 2139 | self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) |
| 2140 | self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) |
| 2141 | self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) |
| 2142 | |
| 2143 | def forward(self, features): |
| 2144 | if self.training and self.config.out_ref: |
| 2145 | outs_gdt_pred = [] |
| 2146 | outs_gdt_label = [] |
| 2147 | x, x1, x2, x3, x4, gdt_gt = features |
| 2148 | else: |
| 2149 | x, x1, x2, x3, x4 = features |
| 2150 | outs = [] |
| 2151 | |
| 2152 | if self.config.dec_ipt: |
| 2153 | patches_batch = image2patches(x, patch_ref=x4, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x |
| 2154 | x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1) |
| 2155 | p4 = self.decoder_block4(x4) |
| 2156 | m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision and self.training else None |
| 2157 | if self.config.out_ref: |
| 2158 | p4_gdt = self.gdt_convs_4(p4) |
| 2159 | if self.training: |
| 2160 | # >> GT: |
| 2161 | m4_dia = m4 |
| 2162 | gdt_label_main_4 = gdt_gt * F.interpolate(m4_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True) |
| 2163 | outs_gdt_label.append(gdt_label_main_4) |
| 2164 | # >> Pred: |
| 2165 | gdt_pred_4 = self.gdt_convs_pred_4(p4_gdt) |
| 2166 | outs_gdt_pred.append(gdt_pred_4) |
| 2167 | gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid() |
| 2168 | # >> Finally: |
| 2169 | p4 = p4 * gdt_attn_4 |
| 2170 | _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True) |
| 2171 | _p3 = _p4 + self.lateral_block4(x3) |
| 2172 | |
| 2173 | if self.config.dec_ipt: |
| 2174 | patches_batch = image2patches(x, patch_ref=_p3, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x |
| 2175 | _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1) |
| 2176 | p3 = self.decoder_block3(_p3) |
| 2177 | m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision and self.training else None |
| 2178 | if self.config.out_ref: |
| 2179 | p3_gdt = self.gdt_convs_3(p3) |
| 2180 | if self.training: |
| 2181 | # >> GT: |
| 2182 | # m3 --dilation--> m3_dia |
| 2183 | # G_3^gt * m3_dia --> G_3^m, which is the label of gradient |
| 2184 | m3_dia = m3 |
| 2185 | gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True) |
| 2186 | outs_gdt_label.append(gdt_label_main_3) |
| 2187 | # >> Pred: |
| 2188 | # p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx |
| 2189 | # F_3^G --sigmoid--> A_3^G |
| 2190 | gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt) |
| 2191 | outs_gdt_pred.append(gdt_pred_3) |
| 2192 | gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid() |
| 2193 | # >> Finally: |
| 2194 | # p3 = p3 * A_3^G |
| 2195 | p3 = p3 * gdt_attn_3 |
| 2196 | _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True) |
| 2197 | _p2 = _p3 + self.lateral_block3(x2) |
| 2198 | |
| 2199 | if self.config.dec_ipt: |
| 2200 | patches_batch = image2patches(x, patch_ref=_p2, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x |
| 2201 | _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1) |
| 2202 | p2 = self.decoder_block2(_p2) |
| 2203 | m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision and self.training else None |
| 2204 | if self.config.out_ref: |
| 2205 | p2_gdt = self.gdt_convs_2(p2) |
| 2206 | if self.training: |
| 2207 | # >> GT: |
| 2208 | m2_dia = m2 |
| 2209 | gdt_label_main_2 = gdt_gt * F.interpolate(m2_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True) |
| 2210 | outs_gdt_label.append(gdt_label_main_2) |
| 2211 | # >> Pred: |
| 2212 | gdt_pred_2 = self.gdt_convs_pred_2(p2_gdt) |
| 2213 | outs_gdt_pred.append(gdt_pred_2) |
| 2214 | gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid() |
| 2215 | # >> Finally: |
| 2216 | p2 = p2 * gdt_attn_2 |
| 2217 | _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True) |
| 2218 | _p1 = _p2 + self.lateral_block2(x1) |
| 2219 | |
| 2220 | if self.config.dec_ipt: |
| 2221 | patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x |
| 2222 | _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1) |
| 2223 | _p1 = self.decoder_block1(_p1) |
| 2224 | _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True) |
| 2225 | |
| 2226 | if self.config.dec_ipt: |
| 2227 | patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x |
| 2228 | _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1) |
| 2229 | p1_out = self.conv_out1(_p1) |
| 2230 | |
| 2231 | if self.config.ms_supervision and self.training: |
| 2232 | outs.append(m4) |
| 2233 | outs.append(m3) |
| 2234 | outs.append(m2) |
| 2235 | outs.append(p1_out) |
| 2236 | return outs if not (self.config.out_ref and self.training) else ([outs_gdt_pred, outs_gdt_label], outs) |
| 2237 | |
| 2238 | |
| 2239 | class SimpleConvs(nn.Module): |
| 2240 | def __init__( |
| 2241 | self, in_channels: int, out_channels: int, inter_channels=64 |
| 2242 | ) -> None: |
| 2243 | super().__init__() |
| 2244 | self.conv1 = nn.Conv2d(in_channels, inter_channels, 3, 1, 1) |
| 2245 | self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1) |
| 2246 | |
| 2247 | def forward(self, x): |
| 2248 | return self.conv_out(self.conv1(x)) |
| 2249 | |