CRM / model /archs /unet.py
Zhengyi's picture
RM exformers in unet
a71660a verified
raw
history blame contribute delete
No virus
1.78 kB
'''
Codes are from:
https://github.com/jaxony/unet-pytorch/blob/master/model.py
'''
import torch
import torch.nn as nn
from diffusers import UNet2DModel
import einops
class UNetPP(nn.Module):
'''
Wrapper for UNet in diffusers
'''
def __init__(self, in_channels):
super(UNetPP, self).__init__()
self.in_channels = in_channels
self.unet = UNet2DModel(
sample_size=[256, 256*3],
in_channels=in_channels,
out_channels=32,
layers_per_block=2,
block_out_channels=(64, 128, 128, 128*2, 128*2, 128*4, 128*4),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"AttnUpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
# self.unet.enable_xformers_memory_efficient_attention()
if in_channels > 12:
self.learned_plane = torch.nn.parameter.Parameter(torch.zeros([1,in_channels-12,256,256*3]))
def forward(self, x, t=256):
learned_plane = self.learned_plane
if x.shape[1] < self.in_channels:
learned_plane = einops.repeat(learned_plane, '1 C H W -> B C H W', B=x.shape[0]).to(x.device)
x = torch.cat([x, learned_plane], dim = 1)
return self.unet(x, t).sample