import torch import torch.nn as nn from torch.autograd import Variable import numpy as np import pdb from torch.nn import functional as F from torch.nn import init ''' ''' class Concat_embed4(nn.Module): def __init__(self, embed_dim, projected_embed_dim): super(Concat_embed4, self).__init__() self.projection = nn.Sequential( nn.Linear(in_features=embed_dim, out_features=embed_dim), nn.BatchNorm1d(num_features=embed_dim), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Linear(in_features=embed_dim, out_features=embed_dim), nn.BatchNorm1d(num_features=embed_dim), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Linear(in_features=embed_dim, out_features=projected_embed_dim), nn.LeakyReLU(negative_slope=0.2, inplace=True), ) def forward(self, inp, embed): projected_embed = self.projection(embed) replicated_embed = projected_embed.repeat(4, 4, 1, 1).permute(2, 3, 0, 1) hidden_concat = torch.cat([inp, replicated_embed], 1) return hidden_concat class generator(nn.Module): def __init__(self): super(generator, self).__init__() self.image_size = 64 self.num_channels = 3 self.noise_dim = 100 self.embed_dim = 768 self.projected_embed_dim = 128 self.latent_dim = self.noise_dim + self.projected_embed_dim self.ngf = 64 self.projection = nn.Sequential( nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim), nn.BatchNorm1d(num_features=self.embed_dim), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Linear(in_features=self.embed_dim, out_features=self.embed_dim), nn.BatchNorm1d(num_features=self.embed_dim), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Linear(in_features=self.embed_dim, out_features=self.projected_embed_dim), nn.BatchNorm1d(num_features=self.projected_embed_dim), nn.LeakyReLU(negative_slope=0.2, inplace=True) ) self.netG = nn.ModuleList([ nn.ConvTranspose2d(self.latent_dim, self.ngf * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(self.ngf * 8), nn.ReLU(True), # state size. (ngf*8) x 4 x 4 nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(self.ngf * 4), nn.ReLU(True), # state size. (ngf*4) x 8 x 8 nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(self.ngf * 2), nn.ReLU(True), # state size. (ngf*2) x 16 x 16 nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False), nn.BatchNorm2d(self.ngf), nn.ReLU(True), # state size. (ngf) x 32 x 32 nn.ConvTranspose2d(self.ngf, self.num_channels, 4, 2, 1, bias=False), nn.Tanh() # state size. (num_channels) x 64 x 64 ]) def forward(self, embed_vector, z): projected_embed = self.projection(embed_vector) out = torch.cat([projected_embed.unsqueeze(2).unsqueeze(3), z], 1) for m in self.netG: out = m(out) return out class discriminator(nn.Module): def __init__(self): super(discriminator, self).__init__() self.image_size = 64 self.num_channels = 3 self.embed_dim = 768 self.projected_embed_dim = 128 self.ndf = 64 self.B_dim = 128 self.C_dim = 16 self.netD_1 = nn.Sequential( # input is (nc) x 64 x 64 nn.Conv2d(self.num_channels, self.ndf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf) x 32 x 32 # SelfAttention(self.ndf), nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(self.ndf * 2), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf*2) x 16 x 16 nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(self.ndf * 4), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf*4) x 8 x 8 nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(self.ndf * 8), nn.LeakyReLU(0.2, inplace=True), ) self.projector = Concat_embed4(self.embed_dim, self.projected_embed_dim) self.netD_2 = nn.Sequential( # state size. (ndf*8) x 4 x 4 nn.Conv2d(self.ndf * 8 + self.projected_embed_dim, self.ndf * 8, 1, 1, 0, bias=False), nn.BatchNorm2d(self.ndf * 8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, inp, embed): x_intermediate = self.netD_1(inp) x = self.projector(x_intermediate, embed) x = self.netD_2(x) return x.view(-1, 1).squeeze(1), x_intermediate