Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright (c) XiMing Xing. All rights reserved. | |
# Author: XiMing Xing | |
# Description: | |
import math | |
import torch | |
import torch.nn as nn | |
import torchvision | |
import numpy as np | |
class VGG16Extractor(nn.Module): | |
def __init__(self, space): | |
super().__init__() | |
# load pretrained model | |
self.vgg_layers = torchvision.models.vgg16( | |
weights=torchvision.models.VGG16_Weights.DEFAULT | |
).features | |
for param in self.parameters(): | |
param.requires_grad = False | |
self.capture_layers = [1, 3, 6, 8, 11, 13, 15, 22, 29] | |
self.space = space | |
def forward_base(self, x): | |
feat = [x] | |
for i in range(len(self.vgg_layers)): | |
x = self.vgg_layers[i](x) | |
if i in self.capture_layers: | |
feat.append(x) | |
return feat | |
def forward(self, x): | |
if self.space != 'vgg': | |
x = (x + 1.) / 2. | |
x = x - (torch.Tensor([0.485, 0.456, 0.406]).to(x.device).view(1, -1, 1, 1)) | |
x = x / (torch.Tensor([0.229, 0.224, 0.225]).to(x.device).view(1, -1, 1, 1)) | |
feat = self.forward_base(x) | |
return feat | |
def forward_samples_hypercolumn(self, X, samps=100): | |
feat = self.forward(X) | |
xx, xy = np.meshgrid(np.arange(X.shape[2]), np.arange(X.shape[3])) | |
xx = np.expand_dims(xx.flatten(), 1) | |
xy = np.expand_dims(xy.flatten(), 1) | |
xc = np.concatenate([xx, xy], 1) | |
samples = min(samps, xc.shape[0]) | |
np.random.shuffle(xc) | |
xx = xc[:samples, 0] | |
yy = xc[:samples, 1] | |
feat_samples = [] | |
for i in range(len(feat)): | |
layer_feat = feat[i] | |
# hack to detect lower resolution | |
if i > 0 and feat[i].size(2) < feat[i - 1].size(2): | |
xx = xx / 2.0 | |
yy = yy / 2.0 | |
xx = np.clip(xx, 0, layer_feat.shape[2] - 1).astype(np.int32) | |
yy = np.clip(yy, 0, layer_feat.shape[3] - 1).astype(np.int32) | |
features = layer_feat[:, :, xx[range(samples)], yy[range(samples)]] | |
feat_samples.append(features.clone().detach()) | |
feat = torch.cat(feat_samples, 1) | |
return feat | |
class StyleLoss: | |
def spatial_feature_extract(self, feat_result, feat_content, xx, xy): | |
l2, l3 = [], [] | |
device = feat_result[0].device | |
# for each extracted layer | |
for i in range(len(feat_result)): | |
fr = feat_result[i] | |
fc = feat_content[i] | |
# hack to detect reduced scale | |
if i > 0 and feat_result[i - 1].size(2) > feat_result[i].size(2): | |
xx = xx / 2.0 | |
xy = xy / 2.0 | |
# go back to ints and get residual | |
xxm = np.floor(xx).astype(np.float32) | |
xxr = xx - xxm | |
xym = np.floor(xy).astype(np.float32) | |
xyr = xy - xym | |
# do bilinear resample | |
w00 = torch.from_numpy((1. - xxr) * (1. - xyr)).float().view(1, 1, -1, 1).to(device) | |
w01 = torch.from_numpy((1. - xxr) * xyr).float().view(1, 1, -1, 1).to(device) | |
w10 = torch.from_numpy(xxr * (1. - xyr)).float().view(1, 1, -1, 1).to(device) | |
w11 = torch.from_numpy(xxr * xyr).float().view(1, 1, -1, 1).to(device) | |
xxm = np.clip(xxm.astype(np.int32), 0, fr.size(2) - 1) | |
xym = np.clip(xym.astype(np.int32), 0, fr.size(3) - 1) | |
s00 = xxm * fr.size(3) + xym | |
s01 = xxm * fr.size(3) + np.clip(xym + 1, 0, fr.size(3) - 1) | |
s10 = np.clip(xxm + 1, 0, fr.size(2) - 1) * fr.size(3) + (xym) | |
s11 = np.clip(xxm + 1, 0, fr.size(2) - 1) * fr.size(3) + np.clip(xym + 1, 0, fr.size(3) - 1) | |
fr = fr.view(1, fr.size(1), fr.size(2) * fr.size(3), 1) | |
fr = fr[:, :, s00, :].mul_(w00).add_(fr[:, :, s01, :].mul_(w01)).add_(fr[:, :, s10, :].mul_(w10)).add_( | |
fr[:, :, s11, :].mul_(w11)) | |
fc = fc.view(1, fc.size(1), fc.size(2) * fc.size(3), 1) | |
fc = fc[:, :, s00, :].mul_(w00).add_(fc[:, :, s01, :].mul_(w01)).add_(fc[:, :, s10, :].mul_(w10)).add_( | |
fc[:, :, s11, :].mul_(w11)) | |
l2.append(fr) | |
l3.append(fc) | |
x_st = torch.cat([li.contiguous() for li in l2], 1) | |
c_st = torch.cat([li.contiguous() for li in l3], 1) | |
xx = torch.from_numpy(xx).view(1, 1, x_st.size(2), 1).float().to(device) | |
yy = torch.from_numpy(xy).view(1, 1, x_st.size(2), 1).float().to(device) | |
x_st = torch.cat([x_st, xx, yy], 1) | |
c_st = torch.cat([c_st, xx, yy], 1) | |
return x_st, c_st | |
def rgb_to_yuv(self, rgb): | |
C = torch.Tensor( | |
[[0.577350, 0.577350, 0.577350], [-0.577350, 0.788675, -0.211325], [-0.577350, -0.211325, 0.788675]] | |
).to(rgb.device) | |
yuv = torch.mm(C, rgb) | |
return yuv | |
def pairwise_distances_cos(self, x, y): | |
x_norm = torch.sqrt((x ** 2).sum(1).view(-1, 1)) | |
y_t = torch.transpose(y, 0, 1) | |
y_norm = torch.sqrt((y ** 2).sum(1).view(1, -1)) | |
dist = 1. - torch.mm(x, y_t) / x_norm / y_norm | |
return dist | |
def pairwise_distances_sq_l2(self, x, y): | |
x_norm = (x ** 2).sum(1).view(-1, 1) | |
y_t = torch.transpose(y, 0, 1) | |
y_norm = (y ** 2).sum(1).view(1, -1) | |
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) | |
return torch.clamp(dist, 1e-5, 1e5) / x.size(1) | |
def distmat(self, x, y, cos_d=True): | |
if cos_d: | |
M = self.pairwise_distances_cos(x, y) | |
else: | |
M = torch.sqrt(self.pairwise_distances_sq_l2(x, y)) | |
return M | |
def style_loss(self, X, Y): | |
d = X.shape[1] | |
if d == 3: | |
X = self.rgb_to_yuv(X.transpose(0, 1).contiguous().view(d, -1)).transpose(0, 1) | |
Y = self.rgb_to_yuv(Y.transpose(0, 1).contiguous().view(d, -1)).transpose(0, 1) | |
else: | |
X = X.transpose(0, 1).contiguous().view(d, -1).transpose(0, 1) | |
Y = Y.transpose(0, 1).contiguous().view(d, -1).transpose(0, 1) | |
# Relaxed EMD | |
CX_M = self.distmat(X, Y, cos_d=True) | |
if d == 3: | |
CX_M = CX_M + self.distmat(X, Y, cos_d=False) | |
m1, m1_inds = CX_M.min(1) | |
m2, m2_inds = CX_M.min(0) | |
remd = torch.max(m1.mean(), m2.mean()) | |
return remd | |
def moment_loss(self, X, Y, moments=[1, 2]): | |
loss = 0. | |
X = X.squeeze().t() | |
Y = Y.squeeze().t() | |
mu_x = torch.mean(X, 0, keepdim=True) | |
mu_y = torch.mean(Y, 0, keepdim=True) | |
mu_d = torch.abs(mu_x - mu_y).mean() | |
if 1 in moments: | |
loss = loss + mu_d | |
if 2 in moments: | |
X_c = X - mu_x | |
Y_c = Y - mu_y | |
X_cov = torch.mm(X_c.t(), X_c) / (X.shape[0] - 1) | |
Y_cov = torch.mm(Y_c.t(), Y_c) / (Y.shape[0] - 1) | |
D_cov = torch.abs(X_cov - Y_cov).mean() | |
loss = loss + D_cov | |
return loss | |
def forward(self, feat_result, feat_content, feat_style, indices, content_weight, moment_weight=1.0): | |
# spatial feature extract | |
num_locations = 1024 | |
spatial_result, spatial_content = self.spatial_feature_extract( | |
feat_result, feat_content, indices[0][:num_locations], indices[1][:num_locations] | |
) | |
# loss_content = content_loss(spatial_result, spatial_content) | |
d = feat_style.shape[1] | |
spatial_style = feat_style.view(1, d, -1, 1) | |
feat_max = 3 + 2 * 64 + 128 * 2 + 256 * 3 + 512 * 2 # (sum of all extracted channels) | |
loss_remd = self.style_loss(spatial_result[:, :feat_max, :, :], spatial_style[:, :feat_max, :, :]) | |
loss_moment = self.moment_loss(spatial_result[:, :-2, :, :], | |
spatial_style, | |
moments=[1, 2]) # -2 is so that it can fit? | |
# palette matching | |
content_weight_frac = 1. / max(content_weight, 1.) | |
loss_moment += content_weight_frac * self.style_loss(spatial_result[:, :3, :, :], spatial_style[:, :3, :, :]) | |
loss_style = loss_remd + moment_weight * loss_moment | |
# print(f'Style: {loss_style.item():.3f}, Content: {loss_content.item():.3f}') | |
style_weight = 1.0 + moment_weight | |
loss_total = (loss_style) / (content_weight + style_weight) | |
return loss_total | |
def sample_indices(feat_content, feat_style): | |
const = 128 ** 2 # 32k or so | |
big_size = feat_content.shape[2] * feat_content.shape[3] # num feaxels | |
stride_x = int(max(math.floor(math.sqrt(big_size // const)), 1)) | |
offset_x = np.random.randint(stride_x) | |
stride_y = int(max(math.ceil(math.sqrt(big_size // const)), 1)) | |
offset_y = np.random.randint(stride_y) | |
xx, xy = np.meshgrid( | |
np.arange(feat_content.shape[2])[offset_x::stride_x], | |
np.arange(feat_content.shape[3])[offset_y::stride_y] | |
) | |
xx = xx.flatten() | |
xy = xy.flatten() | |
return xx, xy | |