T2I / gan_cls_768.py
DataRaptor's picture
Upload 6 files
f8a1225
raw
history blame contribute delete
No virus
5.18 kB
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import pdb
from torch.nn import functional as F
from torch.nn import init
'''
'''
class Concat_embed4(nn.Module):
def __init__(self, embed_dim, projected_embed_dim):
super(Concat_embed4, self).__init__()
self.projection = nn.Sequential(
nn.Linear(in_features=embed_dim, out_features=embed_dim),
nn.BatchNorm1d(num_features=embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(in_features=embed_dim, out_features=embed_dim),
nn.BatchNorm1d(num_features=embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(in_features=embed_dim, out_features=projected_embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
)
def forward(self, inp, embed):
projected_embed = self.projection(embed)
replicated_embed = projected_embed.repeat(4, 4, 1, 1).permute(2, 3, 0, 1)
hidden_concat = torch.cat([inp, replicated_embed], 1)
return hidden_concat
class generator(nn.Module):
def __init__(self):
super(generator, self).__init__()
self.image_size = 64
self.num_channels = 3
self.noise_dim = 100
self.embed_dim = 768
self.projected_embed_dim = 128
self.latent_dim = self.noise_dim + self.projected_embed_dim
self.ngf = 64
self.projection = nn.Sequential(
nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim),
nn.BatchNorm1d(num_features=self.embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim),
nn.BatchNorm1d(num_features=self.embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(in_features=self.embed_dim, out_features=self.projected_embed_dim),
nn.BatchNorm1d(num_features=self.projected_embed_dim),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
self.netG = nn.ModuleList([
nn.ConvTranspose2d(self.latent_dim, self.ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(self.ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(self.ngf, self.num_channels, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (num_channels) x 64 x 64
])
def forward(self, embed_vector, z):
projected_embed = self.projection(embed_vector)
out = torch.cat([projected_embed.unsqueeze(2).unsqueeze(3), z], 1)
for m in self.netG:
out = m(out)
return out
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.image_size = 64
self.num_channels = 3
self.embed_dim = 768
self.projected_embed_dim = 128
self.ndf = 64
self.B_dim = 128
self.C_dim = 16
self.netD_1 = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(self.num_channels, self.ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
# SelfAttention(self.ndf),
nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(self.ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
)
self.projector = Concat_embed4(self.embed_dim, self.projected_embed_dim)
self.netD_2 = nn.Sequential(
# state size. (ndf*8) x 4 x 4
nn.Conv2d(self.ndf * 8 + self.projected_embed_dim,
self.ndf * 8, 1, 1, 0, bias=False),
nn.BatchNorm2d(self.ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, inp, embed):
x_intermediate = self.netD_1(inp)
x = self.projector(x_intermediate, embed)
x = self.netD_2(x)
return x.view(-1, 1).squeeze(1), x_intermediate