briarmbg.py
12.8 KB · 459 lines · python Raw
1 import torch
2 import torch.nn as nn
3 import torch.nn.functional as F
4 from transformers import PreTrainedModel
5 from .MyConfig import RMBGConfig
6
7 class REBNCONV(nn.Module):
8 def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
9 super(REBNCONV,self).__init__()
10
11 self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
12 self.bn_s1 = nn.BatchNorm2d(out_ch)
13 self.relu_s1 = nn.ReLU(inplace=True)
14
15 def forward(self,x):
16
17 hx = x
18 xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
19
20 return xout
21
22 ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
23 def _upsample_like(src,tar):
24
25 src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
26
27 return src
28
29
30 ### RSU-7 ###
31 class RSU7(nn.Module):
32
33 def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
34 super(RSU7,self).__init__()
35
36 self.in_ch = in_ch
37 self.mid_ch = mid_ch
38 self.out_ch = out_ch
39
40 self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
41
42 self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
43 self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
44
45 self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
46 self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
47
48 self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
49 self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
50
51 self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
52 self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
53
54 self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
55 self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
56
57 self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
58
59 self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
60
61 self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
62 self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
63 self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
64 self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
65 self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
66 self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
67
68 def forward(self,x):
69 b, c, h, w = x.shape
70
71 hx = x
72 hxin = self.rebnconvin(hx)
73
74 hx1 = self.rebnconv1(hxin)
75 hx = self.pool1(hx1)
76
77 hx2 = self.rebnconv2(hx)
78 hx = self.pool2(hx2)
79
80 hx3 = self.rebnconv3(hx)
81 hx = self.pool3(hx3)
82
83 hx4 = self.rebnconv4(hx)
84 hx = self.pool4(hx4)
85
86 hx5 = self.rebnconv5(hx)
87 hx = self.pool5(hx5)
88
89 hx6 = self.rebnconv6(hx)
90
91 hx7 = self.rebnconv7(hx6)
92
93 hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
94 hx6dup = _upsample_like(hx6d,hx5)
95
96 hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
97 hx5dup = _upsample_like(hx5d,hx4)
98
99 hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
100 hx4dup = _upsample_like(hx4d,hx3)
101
102 hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
103 hx3dup = _upsample_like(hx3d,hx2)
104
105 hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
106 hx2dup = _upsample_like(hx2d,hx1)
107
108 hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
109
110 return hx1d + hxin
111
112
113 ### RSU-6 ###
114 class RSU6(nn.Module):
115
116 def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
117 super(RSU6,self).__init__()
118
119 self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
120
121 self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
122 self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
123
124 self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
125 self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
126
127 self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
128 self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
129
130 self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
131 self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
132
133 self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
134
135 self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
136
137 self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
138 self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
139 self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
140 self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
141 self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
142
143 def forward(self,x):
144
145 hx = x
146
147 hxin = self.rebnconvin(hx)
148
149 hx1 = self.rebnconv1(hxin)
150 hx = self.pool1(hx1)
151
152 hx2 = self.rebnconv2(hx)
153 hx = self.pool2(hx2)
154
155 hx3 = self.rebnconv3(hx)
156 hx = self.pool3(hx3)
157
158 hx4 = self.rebnconv4(hx)
159 hx = self.pool4(hx4)
160
161 hx5 = self.rebnconv5(hx)
162
163 hx6 = self.rebnconv6(hx5)
164
165
166 hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
167 hx5dup = _upsample_like(hx5d,hx4)
168
169 hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
170 hx4dup = _upsample_like(hx4d,hx3)
171
172 hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
173 hx3dup = _upsample_like(hx3d,hx2)
174
175 hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
176 hx2dup = _upsample_like(hx2d,hx1)
177
178 hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
179
180 return hx1d + hxin
181
182 ### RSU-5 ###
183 class RSU5(nn.Module):
184
185 def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
186 super(RSU5,self).__init__()
187
188 self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
189
190 self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
191 self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
192
193 self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
194 self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
195
196 self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
197 self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
198
199 self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
200
201 self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
202
203 self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
204 self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
205 self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
206 self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
207
208 def forward(self,x):
209
210 hx = x
211
212 hxin = self.rebnconvin(hx)
213
214 hx1 = self.rebnconv1(hxin)
215 hx = self.pool1(hx1)
216
217 hx2 = self.rebnconv2(hx)
218 hx = self.pool2(hx2)
219
220 hx3 = self.rebnconv3(hx)
221 hx = self.pool3(hx3)
222
223 hx4 = self.rebnconv4(hx)
224
225 hx5 = self.rebnconv5(hx4)
226
227 hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
228 hx4dup = _upsample_like(hx4d,hx3)
229
230 hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
231 hx3dup = _upsample_like(hx3d,hx2)
232
233 hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
234 hx2dup = _upsample_like(hx2d,hx1)
235
236 hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
237
238 return hx1d + hxin
239
240 ### RSU-4 ###
241 class RSU4(nn.Module):
242
243 def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
244 super(RSU4,self).__init__()
245
246 self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
247
248 self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
249 self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
250
251 self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
252 self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
253
254 self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
255
256 self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
257
258 self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
259 self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
260 self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
261
262 def forward(self,x):
263
264 hx = x
265
266 hxin = self.rebnconvin(hx)
267
268 hx1 = self.rebnconv1(hxin)
269 hx = self.pool1(hx1)
270
271 hx2 = self.rebnconv2(hx)
272 hx = self.pool2(hx2)
273
274 hx3 = self.rebnconv3(hx)
275
276 hx4 = self.rebnconv4(hx3)
277
278 hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
279 hx3dup = _upsample_like(hx3d,hx2)
280
281 hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
282 hx2dup = _upsample_like(hx2d,hx1)
283
284 hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
285
286 return hx1d + hxin
287
288 ### RSU-4F ###
289 class RSU4F(nn.Module):
290
291 def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
292 super(RSU4F,self).__init__()
293
294 self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
295
296 self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
297 self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
298 self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
299
300 self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
301
302 self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
303 self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
304 self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
305
306 def forward(self,x):
307
308 hx = x
309
310 hxin = self.rebnconvin(hx)
311
312 hx1 = self.rebnconv1(hxin)
313 hx2 = self.rebnconv2(hx1)
314 hx3 = self.rebnconv3(hx2)
315
316 hx4 = self.rebnconv4(hx3)
317
318 hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
319 hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
320 hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
321
322 return hx1d + hxin
323
324
325 class myrebnconv(nn.Module):
326 def __init__(self, in_ch=3,
327 out_ch=1,
328 kernel_size=3,
329 stride=1,
330 padding=1,
331 dilation=1,
332 groups=1):
333 super(myrebnconv,self).__init__()
334
335 self.conv = nn.Conv2d(in_ch,
336 out_ch,
337 kernel_size=kernel_size,
338 stride=stride,
339 padding=padding,
340 dilation=dilation,
341 groups=groups)
342 self.bn = nn.BatchNorm2d(out_ch)
343 self.rl = nn.ReLU(inplace=True)
344
345 def forward(self,x):
346 return self.rl(self.bn(self.conv(x)))
347
348
349 class BriaRMBG(PreTrainedModel):
350 config_class = RMBGConfig
351 def __init__(self,config:RMBGConfig = RMBGConfig()):
352 super().__init__(config)
353 in_ch = config.in_ch # 3
354 out_ch = config.out_ch # 1
355 self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
356 self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
357
358 self.stage1 = RSU7(64,32,64)
359 self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
360
361 self.stage2 = RSU6(64,32,128)
362 self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
363
364 self.stage3 = RSU5(128,64,256)
365 self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
366
367 self.stage4 = RSU4(256,128,512)
368 self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
369
370 self.stage5 = RSU4F(512,256,512)
371 self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
372
373 self.stage6 = RSU4F(512,256,512)
374
375 # decoder
376 self.stage5d = RSU4F(1024,256,512)
377 self.stage4d = RSU4(1024,128,256)
378 self.stage3d = RSU5(512,64,128)
379 self.stage2d = RSU6(256,32,64)
380 self.stage1d = RSU7(128,16,64)
381
382 self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
383 self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
384 self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
385 self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
386 self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
387 self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
388
389 # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
390
391 def forward(self,x):
392
393 hx = x
394
395 hxin = self.conv_in(hx)
396 #hx = self.pool_in(hxin)
397
398 #stage 1
399 hx1 = self.stage1(hxin)
400 hx = self.pool12(hx1)
401
402 #stage 2
403 hx2 = self.stage2(hx)
404 hx = self.pool23(hx2)
405
406 #stage 3
407 hx3 = self.stage3(hx)
408 hx = self.pool34(hx3)
409
410 #stage 4
411 hx4 = self.stage4(hx)
412 hx = self.pool45(hx4)
413
414 #stage 5
415 hx5 = self.stage5(hx)
416 hx = self.pool56(hx5)
417
418 #stage 6
419 hx6 = self.stage6(hx)
420 hx6up = _upsample_like(hx6,hx5)
421
422 #-------------------- decoder --------------------
423 hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
424 hx5dup = _upsample_like(hx5d,hx4)
425
426 hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
427 hx4dup = _upsample_like(hx4d,hx3)
428
429 hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
430 hx3dup = _upsample_like(hx3d,hx2)
431
432 hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
433 hx2dup = _upsample_like(hx2d,hx1)
434
435 hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
436
437
438 #side output
439 d1 = self.side1(hx1d)
440 d1 = _upsample_like(d1,x)
441
442 d2 = self.side2(hx2d)
443 d2 = _upsample_like(d2,x)
444
445 d3 = self.side3(hx3d)
446 d3 = _upsample_like(d3,x)
447
448 d4 = self.side4(hx4d)
449 d4 = _upsample_like(d4,x)
450
451 d5 = self.side5(hx5d)
452 d5 = _upsample_like(d5,x)
453
454 d6 = self.side6(hx6)
455 d6 = _upsample_like(d6,x)
456
457 return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
458
459