Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
import numpy as np | |
import pdb | |
from torch.nn import functional as F | |
from torch.nn import init | |
''' | |
''' | |
class Concat_embed4(nn.Module): | |
def __init__(self, embed_dim, projected_embed_dim): | |
super(Concat_embed4, self).__init__() | |
self.projection = nn.Sequential( | |
nn.Linear(in_features=embed_dim, out_features=embed_dim), | |
nn.BatchNorm1d(num_features=embed_dim), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
nn.Linear(in_features=embed_dim, out_features=embed_dim), | |
nn.BatchNorm1d(num_features=embed_dim), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
nn.Linear(in_features=embed_dim, out_features=projected_embed_dim), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
) | |
def forward(self, inp, embed): | |
projected_embed = self.projection(embed) | |
replicated_embed = projected_embed.repeat(4, 4, 1, 1).permute(2, 3, 0, 1) | |
hidden_concat = torch.cat([inp, replicated_embed], 1) | |
return hidden_concat | |
class generator(nn.Module): | |
def __init__(self): | |
super(generator, self).__init__() | |
self.image_size = 64 | |
self.num_channels = 3 | |
self.noise_dim = 100 | |
self.embed_dim = 768 | |
self.projected_embed_dim = 128 | |
self.latent_dim = self.noise_dim + self.projected_embed_dim | |
self.ngf = 64 | |
self.projection = nn.Sequential( | |
nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim), | |
nn.BatchNorm1d(num_features=self.embed_dim), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim), | |
nn.BatchNorm1d(num_features=self.embed_dim), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
nn.Linear(in_features=self.embed_dim, out_features=self.projected_embed_dim), | |
nn.BatchNorm1d(num_features=self.projected_embed_dim), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
) | |
self.netG = nn.ModuleList([ | |
nn.ConvTranspose2d(self.latent_dim, self.ngf * 8, 4, 1, 0, bias=False), | |
nn.BatchNorm2d(self.ngf * 8), | |
nn.ReLU(True), | |
# state size. (ngf*8) x 4 x 4 | |
nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(self.ngf * 4), | |
nn.ReLU(True), | |
# state size. (ngf*4) x 8 x 8 | |
nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(self.ngf * 2), | |
nn.ReLU(True), | |
# state size. (ngf*2) x 16 x 16 | |
nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(self.ngf), | |
nn.ReLU(True), | |
# state size. (ngf) x 32 x 32 | |
nn.ConvTranspose2d(self.ngf, self.num_channels, 4, 2, 1, bias=False), | |
nn.Tanh() | |
# state size. (num_channels) x 64 x 64 | |
]) | |
def forward(self, embed_vector, z): | |
projected_embed = self.projection(embed_vector) | |
out = torch.cat([projected_embed.unsqueeze(2).unsqueeze(3), z], 1) | |
for m in self.netG: | |
out = m(out) | |
return out | |
class discriminator(nn.Module): | |
def __init__(self): | |
super(discriminator, self).__init__() | |
self.image_size = 64 | |
self.num_channels = 3 | |
self.embed_dim = 768 | |
self.projected_embed_dim = 128 | |
self.ndf = 64 | |
self.B_dim = 128 | |
self.C_dim = 16 | |
self.netD_1 = nn.Sequential( | |
# input is (nc) x 64 x 64 | |
nn.Conv2d(self.num_channels, self.ndf, 4, 2, 1, bias=False), | |
nn.LeakyReLU(0.2, inplace=True), | |
# state size. (ndf) x 32 x 32 | |
# SelfAttention(self.ndf), | |
nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(self.ndf * 2), | |
nn.LeakyReLU(0.2, inplace=True), | |
# state size. (ndf*2) x 16 x 16 | |
nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(self.ndf * 4), | |
nn.LeakyReLU(0.2, inplace=True), | |
# state size. (ndf*4) x 8 x 8 | |
nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(self.ndf * 8), | |
nn.LeakyReLU(0.2, inplace=True), | |
) | |
self.projector = Concat_embed4(self.embed_dim, self.projected_embed_dim) | |
self.netD_2 = nn.Sequential( | |
# state size. (ndf*8) x 4 x 4 | |
nn.Conv2d(self.ndf * 8 + self.projected_embed_dim, | |
self.ndf * 8, 1, 1, 0, bias=False), | |
nn.BatchNorm2d(self.ndf * 8), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False), | |
nn.Sigmoid() | |
) | |
def forward(self, inp, embed): | |
x_intermediate = self.netD_1(inp) | |
x = self.projector(x_intermediate, embed) | |
x = self.netD_2(x) | |
return x.view(-1, 1).squeeze(1), x_intermediate | |