# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple import numpy as np import torch import torch.nn as nn from einops import rearrange from diffusers.utils import BaseOutput, is_torch_version from diffusers.utils.torch_utils import randn_tensor from diffusers.models.attention_processor import SpatialNorm from .modeling_block import ( UNetMidBlock2D, CausalUNetMidBlock2D, get_down_block, get_up_block, get_input_layer, get_output_layer, ) from .modeling_resnet import ( Downsample2D, Upsample2D, TemporalDownsample2x, TemporalUpsample2x, ) from .modeling_causal_conv import CausalConv3d, CausalGroupNorm @dataclass class DecoderOutput(BaseOutput): r""" Output of decoding method. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The decoded output sample from the last layer of the model. """ sample: torch.FloatTensor class CausalVaeEncoder(nn.Module): r""" The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. double_z (`bool`, *optional*, defaults to `True`): Whether to double the number of output channels for the last block. """ def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",), spatial_down_sample: Tuple[bool, ...] = (True,), temporal_down_sample: Tuple[bool, ...] = (False,), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: Tuple[int, ...] = (2,), norm_num_groups: int = 32, act_fn: str = "silu", double_z: bool = True, block_dropout: Tuple[int, ...] = (0.0,), mid_block_add_attention=True, ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = CausalConv3d( in_channels, block_out_channels[0], kernel_size=3, stride=1, ) self.mid_block = None self.down_blocks = nn.ModuleList([]) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] down_block = get_down_block( down_block_type, num_layers=self.layers_per_block[i], in_channels=input_channel, out_channels=output_channel, add_spatial_downsample=spatial_down_sample[i], add_temporal_downsample=temporal_down_sample[i], resnet_eps=1e-6, downsample_padding=0, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=output_channel, temb_channels=None, dropout=block_dropout[i], ) self.down_blocks.append(down_block) # mid self.mid_block = CausalUNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, add_attention=mid_block_add_attention, dropout=block_dropout[-1], ) # out self.conv_norm_out = CausalGroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3, stride=1) self.gradient_checkpointing = False def forward(self, sample: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" sample = self.conv_in(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # down if is_torch_version(">=", "1.11.0"): for down_block in self.down_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(down_block), sample, is_init_image, temporal_chunk, use_reentrant=False ) # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, is_init_image, temporal_chunk, use_reentrant=False ) else: for down_block in self.down_blocks: sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample, is_init_image, temporal_chunk) # middle sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, is_init_image, temporal_chunk) else: # down for down_block in self.down_blocks: sample = down_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) # middle sample = self.mid_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) return sample class CausalVaeDecoder(nn.Module): r""" The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. norm_type (`str`, *optional*, defaults to `"group"`): The normalization type to use. Can be either `"group"` or `"spatial"`. """ def __init__( self, in_channels: int = 3, out_channels: int = 3, up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",), spatial_up_sample: Tuple[bool, ...] = (True,), temporal_up_sample: Tuple[bool, ...] = (False,), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: Tuple[int, ...] = (2,), norm_num_groups: int = 32, act_fn: str = "silu", mid_block_add_attention=True, interpolate: bool = True, block_dropout: Tuple[int, ...] = (0.0,), ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = CausalConv3d( in_channels, block_out_channels[-1], kernel_size=3, stride=1, ) self.mid_block = None self.up_blocks = nn.ModuleList([]) # mid self.mid_block = CausalUNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, add_attention=mid_block_add_attention, dropout=block_dropout[-1], ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 up_block = get_up_block( up_block_type, num_layers=self.layers_per_block[i], in_channels=prev_output_channel, out_channels=output_channel, prev_output_channel=None, add_spatial_upsample=spatial_up_sample[i], add_temporal_upsample=temporal_up_sample[i], resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=output_channel, temb_channels=None, resnet_time_scale_shift='default', interpolate=interpolate, dropout=block_dropout[i], ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out self.conv_norm_out = CausalGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3, stride=1) self.gradient_checkpointing = False def forward( self, sample: torch.FloatTensor, is_init_image=True, temporal_chunk=False, ) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" sample = self.conv_in(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk, use_reentrant=False, ) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk, use_reentrant=False, ) else: # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk, ) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk,) else: # middle sample = self.mid_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = up_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk,) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk) return sample class DiagonalGaussianDistribution(object): def __init__(self, parameters: torch.Tensor, deterministic: bool = False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like( self.mean, device=self.parameters.device, dtype=self.parameters.dtype ) def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: # make sure sample is on the same device as the parameters and has same dtype sample = randn_tensor( self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype, ) x = self.mean + self.std * sample return x def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: if self.deterministic: return torch.Tensor([0.0]) else: if other is None: return 0.5 * torch.sum( torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[2, 3, 4], ) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[2, 3, 4], ) def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims, ) def mode(self) -> torch.Tensor: return self.mean