autoencoder_kl_3d.py
31.5 KB · 794 lines · python Raw
1 # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
2 # you may not use this file except in compliance with the License.
3 # You may obtain a copy of the License at
4 #
5 # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
6 #
7 # Unless required by applicable law or agreed to in writing, software
8 # distributed under the License is distributed on an "AS IS" BASIS,
9 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 # See the License for the specific language governing permissions and
11 # limitations under the License.
12 # ==============================================================================
13
14 from dataclasses import dataclass
15 from typing import Tuple, Optional
16 import math
17 import random
18 import numpy as np
19 from einops import rearrange
20 import torch
21 from torch import Tensor, nn
22 import torch.nn.functional as F
23
24 from diffusers.configuration_utils import ConfigMixin, register_to_config
25 from diffusers.models.modeling_outputs import AutoencoderKLOutput
26 from diffusers.models.modeling_utils import ModelMixin
27 from diffusers.utils.torch_utils import randn_tensor
28 from diffusers.utils import BaseOutput
29
30
31 class DiagonalGaussianDistribution(object):
32 def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
33 if parameters.ndim == 3:
34 dim = 2 # (B, L, C)
35 elif parameters.ndim == 5 or parameters.ndim == 4:
36 dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
37 else:
38 raise NotImplementedError
39 self.parameters = parameters
40 self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
41 self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
42 self.deterministic = deterministic
43 self.std = torch.exp(0.5 * self.logvar)
44 self.var = torch.exp(self.logvar)
45 if self.deterministic:
46 self.var = self.std = torch.zeros_like(
47 self.mean, device=self.parameters.device, dtype=self.parameters.dtype
48 )
49
50 def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
51 # make sure sample is on the same device as the parameters and has same dtype
52 sample = randn_tensor(
53 self.mean.shape,
54 generator=generator,
55 device=self.parameters.device,
56 dtype=self.parameters.dtype,
57 )
58 x = self.mean + self.std * sample
59 return x
60
61 def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
62 if self.deterministic:
63 return torch.Tensor([0.0])
64 else:
65 reduce_dim = list(range(1, self.mean.ndim))
66 if other is None:
67 return 0.5 * torch.sum(
68 torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
69 dim=reduce_dim,
70 )
71 else:
72 return 0.5 * torch.sum(
73 torch.pow(self.mean - other.mean, 2) / other.var +
74 self.var / other.var -
75 1.0 -
76 self.logvar +
77 other.logvar,
78 dim=reduce_dim,
79 )
80
81 def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
82 if self.deterministic:
83 return torch.Tensor([0.0])
84 logtwopi = np.log(2.0 * np.pi)
85 return 0.5 * torch.sum(
86 logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
87 dim=dims,
88 )
89
90 def mode(self) -> torch.Tensor:
91 return self.mean
92
93
94 @dataclass
95 class DecoderOutput(BaseOutput):
96 sample: torch.FloatTensor
97 posterior: Optional[DiagonalGaussianDistribution] = None
98
99
100 def swish(x: Tensor) -> Tensor:
101 return x * torch.sigmoid(x)
102
103
104 def forward_with_checkpointing(module, *inputs, use_checkpointing=False):
105 def create_custom_forward(module):
106 def custom_forward(*inputs):
107 return module(*inputs)
108 return custom_forward
109
110 if use_checkpointing:
111 return torch.utils.checkpoint.checkpoint(create_custom_forward(module), *inputs, use_reentrant=False)
112 else:
113 return module(*inputs)
114
115
116 class Conv3d(nn.Conv3d):
117 """
118 Perform Conv3d on patches with numerical differences from nn.Conv3d within 1e-5.
119 Only symmetric padding is supported.
120 """
121
122 def forward(self, input):
123 B, C, T, H, W = input.shape
124 memory_count = (C * T * H * W) * 2 / 1024**3
125 if memory_count > 2:
126 n_split = math.ceil(memory_count / 2)
127 assert n_split >= 2
128 chunks = torch.chunk(input, chunks=n_split, dim=-3)
129 padded_chunks = []
130 for i in range(len(chunks)):
131 if self.padding[0] > 0:
132 padded_chunk = F.pad(
133 chunks[i],
134 (0, 0, 0, 0, self.padding[0], self.padding[0]),
135 mode="constant" if self.padding_mode == "zeros" else self.padding_mode,
136 value=0,
137 )
138 if i > 0:
139 padded_chunk[:, :, :self.padding[0]] = chunks[i - 1][:, :, -self.padding[0]:]
140 if i < len(chunks) - 1:
141 padded_chunk[:, :, -self.padding[0]:] = chunks[i + 1][:, :, :self.padding[0]]
142 else:
143 padded_chunk = chunks[i]
144 padded_chunks.append(padded_chunk)
145 padding_bak = self.padding
146 self.padding = (0, self.padding[1], self.padding[2])
147 outputs = []
148 for i in range(len(padded_chunks)):
149 outputs.append(super().forward(padded_chunks[i]))
150 self.padding = padding_bak
151 return torch.cat(outputs, dim=-3)
152 else:
153 return super().forward(input)
154
155
156 class AttnBlock(nn.Module):
157 """ Attention with torch sdpa implementation. """
158 def __init__(self, in_channels: int):
159 super().__init__()
160 self.in_channels = in_channels
161
162 self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
163
164 self.q = Conv3d(in_channels, in_channels, kernel_size=1)
165 self.k = Conv3d(in_channels, in_channels, kernel_size=1)
166 self.v = Conv3d(in_channels, in_channels, kernel_size=1)
167 self.proj_out = Conv3d(in_channels, in_channels, kernel_size=1)
168
169 def attention(self, h_: Tensor) -> Tensor:
170 h_ = self.norm(h_)
171 q = self.q(h_)
172 k = self.k(h_)
173 v = self.v(h_)
174
175 b, c, f, h, w = q.shape
176 q = rearrange(q, "b c f h w -> b 1 (f h w) c").contiguous()
177 k = rearrange(k, "b c f h w -> b 1 (f h w) c").contiguous()
178 v = rearrange(v, "b c f h w -> b 1 (f h w) c").contiguous()
179 h_ = nn.functional.scaled_dot_product_attention(q, k, v)
180
181 return rearrange(h_, "b 1 (f h w) c -> b c f h w", f=f, h=h, w=w, c=c, b=b)
182
183 def forward(self, x: Tensor) -> Tensor:
184 return x + self.proj_out(self.attention(x))
185
186
187 class ResnetBlock(nn.Module):
188 def __init__(self, in_channels: int, out_channels: int):
189 super().__init__()
190 self.in_channels = in_channels
191 out_channels = in_channels if out_channels is None else out_channels
192 self.out_channels = out_channels
193
194 self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
195 self.conv1 = Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
196 self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
197 self.conv2 = Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
198 if self.in_channels != self.out_channels:
199 self.nin_shortcut = Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
200
201 def forward(self, x):
202 h = x
203 h = self.norm1(h)
204 h = swish(h)
205 h = self.conv1(h)
206
207 h = self.norm2(h)
208 h = swish(h)
209 h = self.conv2(h)
210
211 if self.in_channels != self.out_channels:
212 x = self.nin_shortcut(x)
213 return x + h
214
215
216 class Downsample(nn.Module):
217 def __init__(self, in_channels: int, add_temporal_downsample: bool = True):
218 super().__init__()
219 self.add_temporal_downsample = add_temporal_downsample
220 stride = (2, 2, 2) if add_temporal_downsample else (1, 2, 2) # THW
221 # no asymmetric padding in torch conv, must do it ourselves
222 self.conv = Conv3d(in_channels, in_channels, kernel_size=3, stride=stride, padding=0)
223
224 def forward(self, x: Tensor):
225 spatial_pad = (0, 1, 0, 1, 0, 0) # WHT
226 x = nn.functional.pad(x, spatial_pad, mode="constant", value=0)
227
228 temporal_pad = (0, 0, 0, 0, 0, 1) if self.add_temporal_downsample else (0, 0, 0, 0, 1, 1)
229 x = nn.functional.pad(x, temporal_pad, mode="replicate")
230
231 x = self.conv(x)
232 return x
233
234
235 class DownsampleDCAE(nn.Module):
236 def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
237 super().__init__()
238 factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
239 assert out_channels % factor == 0
240 self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
241
242 self.add_temporal_downsample = add_temporal_downsample
243 self.group_size = factor * in_channels // out_channels
244
245 def forward(self, x: Tensor):
246 r1 = 2 if self.add_temporal_downsample else 1
247 h = self.conv(x)
248 h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
249 shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
250
251 B, C, T, H, W = shortcut.shape
252 shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
253 return h + shortcut
254
255
256 class Upsample(nn.Module):
257 def __init__(self, in_channels: int, add_temporal_upsample: bool = True):
258 super().__init__()
259 self.add_temporal_upsample = add_temporal_upsample
260 self.scale_factor = (2, 2, 2) if add_temporal_upsample else (1, 2, 2) # THW
261 self.conv = Conv3d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
262
263 def forward(self, x: Tensor):
264 x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
265 x = self.conv(x)
266 return x
267
268
269 class UpsampleDCAE(nn.Module):
270 def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
271 super().__init__()
272 factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
273 self.conv = Conv3d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
274
275 self.add_temporal_upsample = add_temporal_upsample
276 self.repeats = factor * out_channels // in_channels
277
278 def forward(self, x: Tensor):
279 r1 = 2 if self.add_temporal_upsample else 1
280 h = self.conv(x)
281 h = rearrange(h, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
282 shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
283 shortcut = rearrange(shortcut, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
284 return h + shortcut
285
286
287 class Encoder(nn.Module):
288 """
289 The encoder network of AutoencoderKLConv3D.
290 """
291 def __init__(
292 self,
293 in_channels: int,
294 z_channels: int,
295 block_out_channels: Tuple[int, ...],
296 num_res_blocks: int,
297 ffactor_spatial: int,
298 ffactor_temporal: int,
299 downsample_match_channel: bool = True,
300 ):
301 super().__init__()
302 assert block_out_channels[-1] % (2 * z_channels) == 0
303
304 self.z_channels = z_channels
305 self.block_out_channels = block_out_channels
306 self.num_res_blocks = num_res_blocks
307
308 # downsampling
309 self.conv_in = Conv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
310
311 self.down = nn.ModuleList()
312 block_in = block_out_channels[0]
313 for i_level, ch in enumerate(block_out_channels):
314 block = nn.ModuleList()
315 block_out = ch
316 for _ in range(self.num_res_blocks):
317 block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
318 block_in = block_out
319 down = nn.Module()
320 down.block = block
321
322 add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial))
323 add_temporal_downsample = (add_spatial_downsample and
324 bool(i_level >= np.log2(ffactor_spatial // ffactor_temporal)))
325 if add_spatial_downsample or add_temporal_downsample:
326 assert i_level < len(block_out_channels) - 1
327 block_out = block_out_channels[i_level + 1] if downsample_match_channel else block_in
328 down.downsample = DownsampleDCAE(block_in, block_out, add_temporal_downsample)
329 block_in = block_out
330 self.down.append(down)
331
332 # middle
333 self.mid = nn.Module()
334 self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
335 self.mid.attn_1 = AttnBlock(block_in)
336 self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
337
338 # end
339 self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
340 self.conv_out = Conv3d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
341
342 self.gradient_checkpointing = False
343
344 def forward(self, x: Tensor) -> Tensor:
345 use_checkpointing = bool(self.training and self.gradient_checkpointing)
346
347 # downsampling
348 h = self.conv_in(x)
349 for i_level in range(len(self.block_out_channels)):
350 for i_block in range(self.num_res_blocks):
351 h = forward_with_checkpointing(
352 self.down[i_level].block[i_block], h, use_checkpointing=use_checkpointing)
353 if hasattr(self.down[i_level], "downsample"):
354 h = forward_with_checkpointing(self.down[i_level].downsample, h, use_checkpointing=use_checkpointing)
355
356 # middle
357 h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing)
358 h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing)
359 h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing)
360
361 # end
362 group_size = self.block_out_channels[-1] // (2 * self.z_channels)
363 shortcut = rearrange(h, "b (c r) f h w -> b c r f h w", r=group_size).mean(dim=2)
364 h = self.norm_out(h)
365 h = swish(h)
366 h = self.conv_out(h)
367 h += shortcut
368 return h
369
370
371 class Decoder(nn.Module):
372 """
373 The decoder network of AutoencoderKLConv3D.
374 """
375 def __init__(
376 self,
377 z_channels: int,
378 out_channels: int,
379 block_out_channels: Tuple[int, ...],
380 num_res_blocks: int,
381 ffactor_spatial: int,
382 ffactor_temporal: int,
383 upsample_match_channel: bool = True,
384 ):
385 super().__init__()
386 assert block_out_channels[0] % z_channels == 0
387
388 self.z_channels = z_channels
389 self.block_out_channels = block_out_channels
390 self.num_res_blocks = num_res_blocks
391
392 # z to block_in
393 block_in = block_out_channels[0]
394 self.conv_in = Conv3d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
395
396 # middle
397 self.mid = nn.Module()
398 self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
399 self.mid.attn_1 = AttnBlock(block_in)
400 self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
401
402 # upsampling
403 self.up = nn.ModuleList()
404 for i_level, ch in enumerate(block_out_channels):
405 block = nn.ModuleList()
406 block_out = ch
407 for _ in range(self.num_res_blocks + 1):
408 block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
409 block_in = block_out
410 up = nn.Module()
411 up.block = block
412
413 add_spatial_upsample = bool(i_level < np.log2(ffactor_spatial))
414 add_temporal_upsample = bool(i_level < np.log2(ffactor_temporal))
415 if add_spatial_upsample or add_temporal_upsample:
416 assert i_level < len(block_out_channels) - 1
417 block_out = block_out_channels[i_level + 1] if upsample_match_channel else block_in
418 up.upsample = UpsampleDCAE(block_in, block_out, add_temporal_upsample)
419 block_in = block_out
420 self.up.append(up)
421
422 # end
423 self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
424 self.conv_out = Conv3d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
425
426 self.gradient_checkpointing = False
427
428 def forward(self, z: Tensor) -> Tensor:
429 use_checkpointing = bool(self.training and self.gradient_checkpointing)
430
431 # z to block_in
432 repeats = self.block_out_channels[0] // (self.z_channels)
433 h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
434
435 # middle
436 h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing)
437 h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing)
438 h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing)
439
440 # upsampling
441 for i_level in range(len(self.block_out_channels)):
442 for i_block in range(self.num_res_blocks + 1):
443 h = forward_with_checkpointing(self.up[i_level].block[i_block], h, use_checkpointing=use_checkpointing)
444 if hasattr(self.up[i_level], "upsample"):
445 h = forward_with_checkpointing(self.up[i_level].upsample, h, use_checkpointing=use_checkpointing)
446
447 # end
448 h = self.norm_out(h)
449 h = swish(h)
450 h = self.conv_out(h)
451 return h
452
453
454 class AutoencoderKLConv3D(ModelMixin, ConfigMixin):
455 """
456 Autoencoder model with KL-regularized latent space based on 3D convolutions.
457 """
458 _supports_gradient_checkpointing = True
459
460 @register_to_config
461 def __init__(
462 self,
463 in_channels: int,
464 out_channels: int,
465 latent_channels: int,
466 block_out_channels: Tuple[int, ...],
467 layers_per_block: int,
468 ffactor_spatial: int,
469 ffactor_temporal: int,
470 sample_size: int,
471 sample_tsize: int,
472 scaling_factor: float = None,
473 shift_factor: Optional[float] = None,
474 downsample_match_channel: bool = True,
475 upsample_match_channel: bool = True,
476 only_encoder: bool = False, # only build encoder for saving memory
477 only_decoder: bool = False, # only build decoder for saving memory
478 ):
479 super().__init__()
480 self.ffactor_spatial = ffactor_spatial
481 self.ffactor_temporal = ffactor_temporal
482 self.scaling_factor = scaling_factor
483 self.shift_factor = shift_factor
484
485 # build model
486 if not only_decoder:
487 self.encoder = Encoder(
488 in_channels=in_channels,
489 z_channels=latent_channels,
490 block_out_channels=block_out_channels,
491 num_res_blocks=layers_per_block,
492 ffactor_spatial=ffactor_spatial,
493 ffactor_temporal=ffactor_temporal,
494 downsample_match_channel=downsample_match_channel,
495 )
496 if not only_encoder:
497 self.decoder = Decoder(
498 z_channels=latent_channels,
499 out_channels=out_channels,
500 block_out_channels=list(reversed(block_out_channels)),
501 num_res_blocks=layers_per_block,
502 ffactor_spatial=ffactor_spatial,
503 ffactor_temporal=ffactor_temporal,
504 upsample_match_channel=upsample_match_channel,
505 )
506
507 # slicing and tiling related
508 self.use_slicing = False
509 self.slicing_bsz = 1
510 self.use_spatial_tiling = False
511 self.use_temporal_tiling = False
512 self.use_tiling_during_training = False
513
514 # only relevant if vae tiling is enabled
515 self.tile_sample_min_size = sample_size
516 self.tile_latent_min_size = sample_size // ffactor_spatial
517 self.tile_sample_min_tsize = sample_tsize
518 self.tile_latent_min_tsize = sample_tsize // ffactor_temporal
519 self.tile_overlap_factor = 0.25
520
521 # use torch.compile for faster encode speed
522 self.use_compile = False
523
524 def _set_gradient_checkpointing(self, module, value=False):
525 if isinstance(module, (Encoder, Decoder)):
526 module.gradient_checkpointing = value
527
528 def enable_tiling_during_training(self, use_tiling: bool = True):
529 self.use_tiling_during_training = use_tiling
530
531 def disable_tiling_during_training(self):
532 self.enable_tiling_during_training(False)
533
534 def enable_temporal_tiling(self, use_tiling: bool = True):
535 self.use_temporal_tiling = use_tiling
536
537 def disable_temporal_tiling(self):
538 self.enable_temporal_tiling(False)
539
540 def enable_spatial_tiling(self, use_tiling: bool = True):
541 self.use_spatial_tiling = use_tiling
542
543 def disable_spatial_tiling(self):
544 self.enable_spatial_tiling(False)
545
546 def enable_tiling(self, use_tiling: bool = True):
547 self.enable_spatial_tiling(use_tiling)
548
549 def disable_tiling(self):
550 self.disable_spatial_tiling()
551
552 def enable_slicing(self):
553 self.use_slicing = True
554
555 def disable_slicing(self):
556 self.use_slicing = False
557
558 def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
559 blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
560 for x in range(blend_extent):
561 b[:, :, :, :, x] = \
562 a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
563 return b
564
565 def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
566 blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
567 for y in range(blend_extent):
568 b[:, :, :, y, :] = \
569 a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
570 return b
571
572 def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
573 blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
574 for x in range(blend_extent):
575 b[:, :, x, :, :] = \
576 a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
577 return b
578
579 def spatial_tiled_encode(self, x: torch.Tensor):
580 """ spatial tailing for frames """
581 B, C, T, H, W = x.shape
582 overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
583 blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) # 8 * 0.25 = 2
584 row_limit = self.tile_latent_min_size - blend_extent # 8 - 2 = 6
585
586 rows = []
587 for i in range(0, H, overlap_size):
588 row = []
589 for j in range(0, W, overlap_size):
590 tile = x[:, :, :, i: i + self.tile_sample_min_size, j: j + self.tile_sample_min_size]
591 tile = self.encoder(tile)
592 row.append(tile)
593 rows.append(row)
594 result_rows = []
595 for i, row in enumerate(rows):
596 result_row = []
597 for j, tile in enumerate(row):
598 if i > 0:
599 tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
600 if j > 0:
601 tile = self.blend_h(row[j - 1], tile, blend_extent)
602 result_row.append(tile[:, :, :, :row_limit, :row_limit])
603 result_rows.append(torch.cat(result_row, dim=-1))
604 moments = torch.cat(result_rows, dim=-2)
605 return moments
606
607 def temporal_tiled_encode(self, x: torch.Tensor):
608 """ temporal tailing for frames """
609 B, C, T, H, W = x.shape
610 overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) # 64 * (1 - 0.25) = 48
611 blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) # 8 * 0.25 = 2
612 t_limit = self.tile_latent_min_tsize - blend_extent # 8 - 2 = 6
613
614 row = []
615 for i in range(0, T, overlap_size):
616 tile = x[:, :, i: i + self.tile_sample_min_tsize, :, :]
617 if self.use_spatial_tiling and (
618 tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
619 tile = self.spatial_tiled_encode(tile)
620 else:
621 tile = self.encoder(tile)
622 row.append(tile)
623 result_row = []
624 for i, tile in enumerate(row):
625 if i > 0:
626 tile = self.blend_t(row[i - 1], tile, blend_extent)
627 result_row.append(tile[:, :, :t_limit, :, :])
628 moments = torch.cat(result_row, dim=-3)
629 return moments
630
631 def spatial_tiled_decode(self, z: torch.Tensor):
632 """ spatial tailing for frames """
633 B, C, T, H, W = z.shape
634 overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
635 blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) # 256 * 0.25 = 64
636 row_limit = self.tile_sample_min_size - blend_extent # 256 - 64 = 192
637
638 rows = []
639 for i in range(0, H, overlap_size):
640 row = []
641 for j in range(0, W, overlap_size):
642 tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size]
643 decoded = self.decoder(tile)
644 row.append(decoded)
645 rows.append(row)
646
647 result_rows = []
648 for i, row in enumerate(rows):
649 result_row = []
650 for j, tile in enumerate(row):
651 if i > 0:
652 tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
653 if j > 0:
654 tile = self.blend_h(row[j - 1], tile, blend_extent)
655 result_row.append(tile[:, :, :, :row_limit, :row_limit])
656 result_rows.append(torch.cat(result_row, dim=-1))
657 dec = torch.cat(result_rows, dim=-2)
658 return dec
659
660 def temporal_tiled_decode(self, z: torch.Tensor):
661 """ temporal tailing for frames """
662 B, C, T, H, W = z.shape
663 overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
664 blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) # 64 * 0.25 = 16
665 t_limit = self.tile_sample_min_tsize - blend_extent # 64 - 16 = 48
666 assert 0 < overlap_size < self.tile_latent_min_tsize
667
668 row = []
669 for i in range(0, T, overlap_size):
670 tile = z[:, :, i: i + self.tile_latent_min_tsize, :, :]
671 if self.use_spatial_tiling and (
672 tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
673 decoded = self.spatial_tiled_decode(tile)
674 else:
675 decoded = self.decoder(tile)
676 row.append(decoded)
677
678 result_row = []
679 for i, tile in enumerate(row):
680 if i > 0:
681 tile = self.blend_t(row[i - 1], tile, blend_extent)
682 result_row.append(tile[:, :, :t_limit, :, :])
683 dec = torch.cat(result_row, dim=-3)
684 return dec
685
686 def encode(self, x: Tensor, return_dict: bool = True):
687 """
688 Encodes the input by passing through the encoder network.
689 Support slicing and tiling for memory efficiency.
690 """
691 def _encode(x):
692 if self.use_temporal_tiling and x.shape[-3] > self.tile_sample_min_tsize:
693 return self.temporal_tiled_encode(x)
694 if self.use_spatial_tiling and (
695 x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
696 return self.spatial_tiled_encode(x)
697
698 if self.use_compile:
699 @torch.compile
700 def encoder(x):
701 return self.encoder(x)
702 return encoder(x)
703 return self.encoder(x)
704
705 if len(x.shape) != 5: # (B, C, T, H, W)
706 x = x[:, :, None]
707 assert len(x.shape) == 5 # (B, C, T, H, W)
708 if x.shape[2] == 1:
709 x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
710 else:
711 assert x.shape[2] != self.ffactor_temporal and x.shape[2] % self.ffactor_temporal == 0
712
713 if self.use_slicing and x.shape[0] > 1:
714 if self.slicing_bsz == 1:
715 encoded_slices = [_encode(x_slice) for x_slice in x.split(1)]
716 else:
717 sections = [self.slicing_bsz] * (x.shape[0] // self.slicing_bsz)
718 if x.shape[0] % self.slicing_bsz != 0:
719 sections.append(x.shape[0] % self.slicing_bsz)
720 encoded_slices = [_encode(x_slice) for x_slice in x.split(sections)]
721 h = torch.cat(encoded_slices)
722 else:
723 h = _encode(x)
724 posterior = DiagonalGaussianDistribution(h)
725
726 if not return_dict:
727 return (posterior,)
728
729 return AutoencoderKLOutput(latent_dist=posterior)
730
731 def decode(self, z: Tensor, return_dict: bool = True, generator=None):
732 """
733 Decodes the input by passing through the decoder network.
734 Support slicing and tiling for memory efficiency.
735 """
736 def _decode(z):
737 if self.use_temporal_tiling and z.shape[-3] > self.tile_latent_min_tsize:
738 return self.temporal_tiled_decode(z)
739 if self.use_spatial_tiling and (
740 z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
741 return self.spatial_tiled_decode(z)
742 return self.decoder(z)
743
744 if self.use_slicing and z.shape[0] > 1:
745 decoded_slices = [_decode(z_slice) for z_slice in z.split(1)]
746 decoded = torch.cat(decoded_slices)
747 else:
748 decoded = _decode(z)
749
750 if z.shape[-3] == 1:
751 decoded = decoded[:, :, -1:]
752
753 if not return_dict:
754 return (decoded,)
755
756 return DecoderOutput(sample=decoded)
757
758 def forward(
759 self,
760 sample: torch.Tensor,
761 sample_posterior: bool = False,
762 return_posterior: bool = True,
763 return_dict: bool = True
764 ):
765 posterior = self.encode(sample).latent_dist
766 z = posterior.sample() if sample_posterior else posterior.mode()
767 dec = self.decode(z).sample
768 return DecoderOutput(sample=dec, posterior=posterior) if return_dict else (dec, posterior)
769
770 def random_reset_tiling(self, x: torch.Tensor):
771 if x.shape[-3] == 1:
772 self.disable_spatial_tiling()
773 self.disable_temporal_tiling()
774 return
775
776 # Use fixed shape here
777 min_sample_size = int(1 / self.tile_overlap_factor) * self.ffactor_spatial
778 min_sample_tsize = int(1 / self.tile_overlap_factor) * self.ffactor_temporal
779 sample_size = random.choice([None, 1 * min_sample_size, 2 * min_sample_size, 3 * min_sample_size])
780 if sample_size is None:
781 self.disable_spatial_tiling()
782 else:
783 self.tile_sample_min_size = sample_size
784 self.tile_latent_min_size = sample_size // self.ffactor_spatial
785 self.enable_spatial_tiling()
786
787 sample_tsize = random.choice([None, 1 * min_sample_tsize, 2 * min_sample_tsize, 3 * min_sample_tsize])
788 if sample_tsize is None:
789 self.disable_temporal_tiling()
790 else:
791 self.tile_sample_min_tsize = sample_tsize
792 self.tile_latent_min_tsize = sample_tsize // self.ffactor_temporal
793 self.enable_temporal_tiling()
794