hjc-owo
init repo
966ae59
raw
history blame contribute delete
No virus
9 kB
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
import math
import torch
import torch.nn as nn
import torchvision
import numpy as np
class VGG16Extractor(nn.Module):
def __init__(self, space):
super().__init__()
# load pretrained model
self.vgg_layers = torchvision.models.vgg16(
weights=torchvision.models.VGG16_Weights.DEFAULT
).features
for param in self.parameters():
param.requires_grad = False
self.capture_layers = [1, 3, 6, 8, 11, 13, 15, 22, 29]
self.space = space
def forward_base(self, x):
feat = [x]
for i in range(len(self.vgg_layers)):
x = self.vgg_layers[i](x)
if i in self.capture_layers:
feat.append(x)
return feat
def forward(self, x):
if self.space != 'vgg':
x = (x + 1.) / 2.
x = x - (torch.Tensor([0.485, 0.456, 0.406]).to(x.device).view(1, -1, 1, 1))
x = x / (torch.Tensor([0.229, 0.224, 0.225]).to(x.device).view(1, -1, 1, 1))
feat = self.forward_base(x)
return feat
def forward_samples_hypercolumn(self, X, samps=100):
feat = self.forward(X)
xx, xy = np.meshgrid(np.arange(X.shape[2]), np.arange(X.shape[3]))
xx = np.expand_dims(xx.flatten(), 1)
xy = np.expand_dims(xy.flatten(), 1)
xc = np.concatenate([xx, xy], 1)
samples = min(samps, xc.shape[0])
np.random.shuffle(xc)
xx = xc[:samples, 0]
yy = xc[:samples, 1]
feat_samples = []
for i in range(len(feat)):
layer_feat = feat[i]
# hack to detect lower resolution
if i > 0 and feat[i].size(2) < feat[i - 1].size(2):
xx = xx / 2.0
yy = yy / 2.0
xx = np.clip(xx, 0, layer_feat.shape[2] - 1).astype(np.int32)
yy = np.clip(yy, 0, layer_feat.shape[3] - 1).astype(np.int32)
features = layer_feat[:, :, xx[range(samples)], yy[range(samples)]]
feat_samples.append(features.clone().detach())
feat = torch.cat(feat_samples, 1)
return feat
class StyleLoss:
def spatial_feature_extract(self, feat_result, feat_content, xx, xy):
l2, l3 = [], []
device = feat_result[0].device
# for each extracted layer
for i in range(len(feat_result)):
fr = feat_result[i]
fc = feat_content[i]
# hack to detect reduced scale
if i > 0 and feat_result[i - 1].size(2) > feat_result[i].size(2):
xx = xx / 2.0
xy = xy / 2.0
# go back to ints and get residual
xxm = np.floor(xx).astype(np.float32)
xxr = xx - xxm
xym = np.floor(xy).astype(np.float32)
xyr = xy - xym
# do bilinear resample
w00 = torch.from_numpy((1. - xxr) * (1. - xyr)).float().view(1, 1, -1, 1).to(device)
w01 = torch.from_numpy((1. - xxr) * xyr).float().view(1, 1, -1, 1).to(device)
w10 = torch.from_numpy(xxr * (1. - xyr)).float().view(1, 1, -1, 1).to(device)
w11 = torch.from_numpy(xxr * xyr).float().view(1, 1, -1, 1).to(device)
xxm = np.clip(xxm.astype(np.int32), 0, fr.size(2) - 1)
xym = np.clip(xym.astype(np.int32), 0, fr.size(3) - 1)
s00 = xxm * fr.size(3) + xym
s01 = xxm * fr.size(3) + np.clip(xym + 1, 0, fr.size(3) - 1)
s10 = np.clip(xxm + 1, 0, fr.size(2) - 1) * fr.size(3) + (xym)
s11 = np.clip(xxm + 1, 0, fr.size(2) - 1) * fr.size(3) + np.clip(xym + 1, 0, fr.size(3) - 1)
fr = fr.view(1, fr.size(1), fr.size(2) * fr.size(3), 1)
fr = fr[:, :, s00, :].mul_(w00).add_(fr[:, :, s01, :].mul_(w01)).add_(fr[:, :, s10, :].mul_(w10)).add_(
fr[:, :, s11, :].mul_(w11))
fc = fc.view(1, fc.size(1), fc.size(2) * fc.size(3), 1)
fc = fc[:, :, s00, :].mul_(w00).add_(fc[:, :, s01, :].mul_(w01)).add_(fc[:, :, s10, :].mul_(w10)).add_(
fc[:, :, s11, :].mul_(w11))
l2.append(fr)
l3.append(fc)
x_st = torch.cat([li.contiguous() for li in l2], 1)
c_st = torch.cat([li.contiguous() for li in l3], 1)
xx = torch.from_numpy(xx).view(1, 1, x_st.size(2), 1).float().to(device)
yy = torch.from_numpy(xy).view(1, 1, x_st.size(2), 1).float().to(device)
x_st = torch.cat([x_st, xx, yy], 1)
c_st = torch.cat([c_st, xx, yy], 1)
return x_st, c_st
def rgb_to_yuv(self, rgb):
C = torch.Tensor(
[[0.577350, 0.577350, 0.577350], [-0.577350, 0.788675, -0.211325], [-0.577350, -0.211325, 0.788675]]
).to(rgb.device)
yuv = torch.mm(C, rgb)
return yuv
def pairwise_distances_cos(self, x, y):
x_norm = torch.sqrt((x ** 2).sum(1).view(-1, 1))
y_t = torch.transpose(y, 0, 1)
y_norm = torch.sqrt((y ** 2).sum(1).view(1, -1))
dist = 1. - torch.mm(x, y_t) / x_norm / y_norm
return dist
def pairwise_distances_sq_l2(self, x, y):
x_norm = (x ** 2).sum(1).view(-1, 1)
y_t = torch.transpose(y, 0, 1)
y_norm = (y ** 2).sum(1).view(1, -1)
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
return torch.clamp(dist, 1e-5, 1e5) / x.size(1)
def distmat(self, x, y, cos_d=True):
if cos_d:
M = self.pairwise_distances_cos(x, y)
else:
M = torch.sqrt(self.pairwise_distances_sq_l2(x, y))
return M
def style_loss(self, X, Y):
d = X.shape[1]
if d == 3:
X = self.rgb_to_yuv(X.transpose(0, 1).contiguous().view(d, -1)).transpose(0, 1)
Y = self.rgb_to_yuv(Y.transpose(0, 1).contiguous().view(d, -1)).transpose(0, 1)
else:
X = X.transpose(0, 1).contiguous().view(d, -1).transpose(0, 1)
Y = Y.transpose(0, 1).contiguous().view(d, -1).transpose(0, 1)
# Relaxed EMD
CX_M = self.distmat(X, Y, cos_d=True)
if d == 3:
CX_M = CX_M + self.distmat(X, Y, cos_d=False)
m1, m1_inds = CX_M.min(1)
m2, m2_inds = CX_M.min(0)
remd = torch.max(m1.mean(), m2.mean())
return remd
def moment_loss(self, X, Y, moments=[1, 2]):
loss = 0.
X = X.squeeze().t()
Y = Y.squeeze().t()
mu_x = torch.mean(X, 0, keepdim=True)
mu_y = torch.mean(Y, 0, keepdim=True)
mu_d = torch.abs(mu_x - mu_y).mean()
if 1 in moments:
loss = loss + mu_d
if 2 in moments:
X_c = X - mu_x
Y_c = Y - mu_y
X_cov = torch.mm(X_c.t(), X_c) / (X.shape[0] - 1)
Y_cov = torch.mm(Y_c.t(), Y_c) / (Y.shape[0] - 1)
D_cov = torch.abs(X_cov - Y_cov).mean()
loss = loss + D_cov
return loss
def forward(self, feat_result, feat_content, feat_style, indices, content_weight, moment_weight=1.0):
# spatial feature extract
num_locations = 1024
spatial_result, spatial_content = self.spatial_feature_extract(
feat_result, feat_content, indices[0][:num_locations], indices[1][:num_locations]
)
# loss_content = content_loss(spatial_result, spatial_content)
d = feat_style.shape[1]
spatial_style = feat_style.view(1, d, -1, 1)
feat_max = 3 + 2 * 64 + 128 * 2 + 256 * 3 + 512 * 2 # (sum of all extracted channels)
loss_remd = self.style_loss(spatial_result[:, :feat_max, :, :], spatial_style[:, :feat_max, :, :])
loss_moment = self.moment_loss(spatial_result[:, :-2, :, :],
spatial_style,
moments=[1, 2]) # -2 is so that it can fit?
# palette matching
content_weight_frac = 1. / max(content_weight, 1.)
loss_moment += content_weight_frac * self.style_loss(spatial_result[:, :3, :, :], spatial_style[:, :3, :, :])
loss_style = loss_remd + moment_weight * loss_moment
# print(f'Style: {loss_style.item():.3f}, Content: {loss_content.item():.3f}')
style_weight = 1.0 + moment_weight
loss_total = (loss_style) / (content_weight + style_weight)
return loss_total
def sample_indices(feat_content, feat_style):
const = 128 ** 2 # 32k or so
big_size = feat_content.shape[2] * feat_content.shape[3] # num feaxels
stride_x = int(max(math.floor(math.sqrt(big_size // const)), 1))
offset_x = np.random.randint(stride_x)
stride_y = int(max(math.ceil(math.sqrt(big_size // const)), 1))
offset_y = np.random.randint(stride_y)
xx, xy = np.meshgrid(
np.arange(feat_content.shape[2])[offset_x::stride_x],
np.arange(feat_content.shape[3])[offset_y::stride_y]
)
xx = xx.flatten()
xy = xy.flatten()
return xx, xy