EasyAnimate / easyanimate /vae /ldm /models /omnigen_enc_dec.py
bubbliiiing
add requirements
43ed08d
raw
history blame
No virus
17 kB
import torch
import torch.nn as nn
import numpy as np
from ..modules.vaemodules.activations import get_activation
from ..modules.vaemodules.common import CausalConv3d
from ..modules.vaemodules.down_blocks import get_down_block
from ..modules.vaemodules.mid_blocks import get_mid_block
from ..modules.vaemodules.up_blocks import get_up_block
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 8):
The number of output channels.
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("SpatialDownBlock3D",)`):
The types of down blocks to use.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
use_gc_blocks (`Tuple[bool, ...]`, *optional*, defaults to `None`):
Whether to use global context blocks for each down block.
mid_block_type (`str`, *optional*, defaults to `"MidBlock3D"`):
The type of mid block to use.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
num_attention_heads (`int`, *optional*, defaults to 1):
The number of attention heads to use.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 8,
down_block_types = ("SpatialDownBlock3D",),
ch = 128,
ch_mult = [1,2,4,4,],
use_gc_blocks = None,
mid_block_type: str = "MidBlock3D",
mid_block_use_attention: bool = True,
mid_block_attention_type: str = "3d",
mid_block_num_attention_heads: int = 1,
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
num_attention_heads: int = 1,
double_z: bool = True,
slice_compression_vae: bool = False,
mini_batch_encoder: int = 9,
verbose = False,
):
super().__init__()
block_out_channels = [ch * i for i in ch_mult]
assert len(down_block_types) == len(block_out_channels), (
"Number of down block types must match number of block output channels."
)
if use_gc_blocks is not None:
assert len(use_gc_blocks) == len(down_block_types), (
"Number of GC blocks must match number of down block types."
)
else:
use_gc_blocks = [False] * len(down_block_types)
self.conv_in = CausalConv3d(
in_channels,
block_out_channels[0],
kernel_size=3,
)
self.down_blocks = nn.ModuleList([])
output_channels = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channels = output_channels
output_channels = block_out_channels[i]
is_final_block = (i == len(block_out_channels) - 1)
down_block = get_down_block(
down_block_type,
in_channels=input_channels,
out_channels=output_channels,
num_layers=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=1e-6,
num_attention_heads=num_attention_heads,
add_gc_block=use_gc_blocks[i],
add_downsample=not is_final_block,
)
self.down_blocks.append(down_block)
self.mid_block = get_mid_block(
mid_block_type,
in_channels=block_out_channels[-1],
num_layers=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=1e-6,
add_attention=mid_block_use_attention,
attention_type=mid_block_attention_type,
num_attention_heads=mid_block_num_attention_heads,
)
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[-1],
num_groups=norm_num_groups,
eps=1e-6,
)
self.conv_act = get_activation(act_fn)
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
self.slice_compression_vae = slice_compression_vae
self.mini_batch_encoder = mini_batch_encoder
self.features_share = False
self.verbose = verbose
def set_padding_one_frame(self):
def _set_padding_one_frame(name, module):
if hasattr(module, 'padding_flag'):
if self.verbose:
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
module.padding_flag = 1
for sub_name, sub_mod in module.named_children():
_set_padding_one_frame(sub_name, sub_mod)
for name, module in self.named_children():
_set_padding_one_frame(name, module)
def set_padding_more_frame(self):
def _set_padding_more_frame(name, module):
if hasattr(module, 'padding_flag'):
if self.verbose:
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
module.padding_flag = 2
for sub_name, sub_mod in module.named_children():
_set_padding_more_frame(sub_name, sub_mod)
for name, module in self.named_children():
_set_padding_more_frame(name, module)
def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor:
# x: (B, C, T, H, W)
if self.features_share and previous_features is not None and after_features is None:
x = torch.concat([previous_features, x], 2)
elif self.features_share and previous_features is None and after_features is not None:
x = torch.concat([x, after_features], 2)
elif self.features_share and previous_features is not None and after_features is not None:
x = torch.concat([previous_features, x, after_features], 2)
x = self.conv_in(x)
for down_block in self.down_blocks:
x = down_block(x)
x = self.mid_block(x)
x = self.conv_norm_out(x)
x = self.conv_act(x)
x = self.conv_out(x)
if self.features_share and previous_features is not None and after_features is None:
x = x[:, :, 1:]
elif self.features_share and previous_features is None and after_features is not None:
x = x[:, :, :2]
elif self.features_share and previous_features is not None and after_features is not None:
x = x[:, :, 1:3]
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.slice_compression_vae:
_, _, f, _, _ = x.size()
if f % 2 != 0:
self.set_padding_one_frame()
first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None)
self.set_padding_more_frame()
new_pixel_values = [first_frames]
start_index = 1
else:
self.set_padding_more_frame()
new_pixel_values = []
start_index = 0
previous_features = None
for i in range(start_index, x.shape[2], self.mini_batch_encoder):
after_features = x[:, :, i + self.mini_batch_encoder: i + self.mini_batch_encoder + 4, :, :] if i + self.mini_batch_encoder < x.shape[2] else None
next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], previous_features, after_features)
previous_features = x[:, :, i + self.mini_batch_encoder - 4: i + self.mini_batch_encoder, :, :]
new_pixel_values.append(next_frames)
new_pixel_values = torch.cat(new_pixel_values, dim=2)
else:
new_pixel_values = self.single_forward(x, None, None)
return new_pixel_values
class Decoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
in_channels (`int`, *optional*, defaults to 8):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("SpatialUpBlock3D",)`):
The types of up blocks to use.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
use_gc_blocks (`Tuple[bool, ...]`, *optional*, defaults to `None`):
Whether to use global context blocks for each down block.
mid_block_type (`str`, *optional*, defaults to `"MidBlock3D"`):
The type of mid block to use.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
num_attention_heads (`int`, *optional*, defaults to 1):
The number of attention heads to use.
"""
def __init__(
self,
in_channels: int = 8,
out_channels: int = 3,
up_block_types = ("SpatialUpBlock3D",),
ch = 128,
ch_mult = [1,2,4,4,],
use_gc_blocks = None,
mid_block_type: str = "MidBlock3D",
mid_block_use_attention: bool = True,
mid_block_attention_type: str = "3d",
mid_block_num_attention_heads: int = 1,
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
num_attention_heads: int = 1,
slice_compression_vae: bool = False,
mini_batch_decoder: int = 3,
verbose = False,
):
super().__init__()
block_out_channels = [ch * i for i in ch_mult]
assert len(up_block_types) == len(block_out_channels), (
"Number of up block types must match number of block output channels."
)
if use_gc_blocks is not None:
assert len(use_gc_blocks) == len(up_block_types), (
"Number of GC blocks must match number of up block types."
)
else:
use_gc_blocks = [False] * len(up_block_types)
self.conv_in = CausalConv3d(
in_channels,
block_out_channels[-1],
kernel_size=3,
)
self.mid_block = get_mid_block(
mid_block_type,
in_channels=block_out_channels[-1],
num_layers=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=1e-6,
add_attention=mid_block_use_attention,
attention_type=mid_block_attention_type,
num_attention_heads=mid_block_num_attention_heads,
)
self.up_blocks = nn.ModuleList([])
reversed_block_out_channels = list(reversed(block_out_channels))
output_channels = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
input_channels = output_channels
output_channels = reversed_block_out_channels[i]
# is_first_block = i == 0
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
in_channels=input_channels,
out_channels=output_channels,
num_layers=layers_per_block + 1,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=1e-6,
num_attention_heads=num_attention_heads,
add_gc_block=use_gc_blocks[i],
add_upsample=not is_final_block,
)
self.up_blocks.append(up_block)
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0],
num_groups=norm_num_groups,
eps=1e-6,
)
self.conv_act = get_activation(act_fn)
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
self.slice_compression_vae = slice_compression_vae
self.mini_batch_decoder = mini_batch_decoder
self.features_share = True
self.verbose = verbose
def set_padding_one_frame(self):
def _set_padding_one_frame(name, module):
if hasattr(module, 'padding_flag'):
if self.verbose:
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
module.padding_flag = 1
for sub_name, sub_mod in module.named_children():
_set_padding_one_frame(sub_name, sub_mod)
for name, module in self.named_children():
_set_padding_one_frame(name, module)
def set_padding_more_frame(self):
def _set_padding_more_frame(name, module):
if hasattr(module, 'padding_flag'):
if self.verbose:
print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
module.padding_flag = 2
for sub_name, sub_mod in module.named_children():
_set_padding_more_frame(sub_name, sub_mod)
for name, module in self.named_children():
_set_padding_more_frame(name, module)
def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor:
# x: (B, C, T, H, W)
if self.features_share and previous_features is not None and after_features is None:
b, c, t, h, w = x.size()
x = torch.concat([previous_features, x], 2)
x = self.conv_in(x)
x = self.mid_block(x)
x = x[:, :, -t:]
elif self.features_share and previous_features is None and after_features is not None:
b, c, t, h, w = x.size()
x = torch.concat([x, after_features], 2)
x = self.conv_in(x)
x = self.mid_block(x)
x = x[:, :, :t]
elif self.features_share and previous_features is not None and after_features is not None:
_, _, t_1, _, _ = previous_features.size()
_, _, t_2, _, _ = x.size()
x = torch.concat([previous_features, x, after_features], 2)
x = self.conv_in(x)
x = self.mid_block(x)
x = x[:, :, t_1:(t_1 + t_2)]
else:
x = self.conv_in(x)
x = self.mid_block(x)
for up_block in self.up_blocks:
x = up_block(x)
x = self.conv_norm_out(x)
x = self.conv_act(x)
x = self.conv_out(x)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.slice_compression_vae:
_, _, f, _, _ = x.size()
if f % 2 != 0:
self.set_padding_one_frame()
first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None)
self.set_padding_more_frame()
new_pixel_values = [first_frames]
start_index = 1
else:
self.set_padding_more_frame()
new_pixel_values = []
start_index = 0
previous_features = None
for i in range(start_index, x.shape[2], self.mini_batch_decoder):
after_features = x[:, :, i + self.mini_batch_decoder: i + 2 * self.mini_batch_decoder, :, :] if i + self.mini_batch_decoder < x.shape[2] else None
next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], previous_features, after_features)
previous_features = x[:, :, i: i + self.mini_batch_decoder, :, :]
new_pixel_values.append(next_frames)
new_pixel_values = torch.cat(new_pixel_values, dim=2)
else:
new_pixel_values = self.single_forward(x, None, None)
return new_pixel_values