Spaces:
Runtime error
Runtime error
Мясников Филипп Сергеевич
commited on
Commit
•
8c42239
1
Parent(s):
1816654
like in colab
Browse files
app.py
CHANGED
@@ -16,7 +16,6 @@ from tqdm import tqdm
|
|
16 |
import lpips
|
17 |
from model import *
|
18 |
|
19 |
-
|
20 |
#from e4e_projection import projection as e4e_projection
|
21 |
|
22 |
from copy import deepcopy
|
@@ -30,44 +29,103 @@ import torch
|
|
30 |
import torchvision.transforms as transforms
|
31 |
from argparse import Namespace
|
32 |
from e4e.models.psp import pSp
|
|
|
33 |
from util import *
|
34 |
from huggingface_hub import hf_hub_download
|
35 |
|
36 |
device= 'cpu'
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
|
64 |
def inference(img):
|
65 |
img.save('out.jpg')
|
66 |
aligned_face = align_face('out.jpg')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
imageio.imwrite('filename.jpeg', npimage)
|
72 |
return 'filename.jpeg'
|
73 |
|
|
|
16 |
import lpips
|
17 |
from model import *
|
18 |
|
|
|
19 |
#from e4e_projection import projection as e4e_projection
|
20 |
|
21 |
from copy import deepcopy
|
|
|
29 |
import torchvision.transforms as transforms
|
30 |
from argparse import Namespace
|
31 |
from e4e.models.psp import pSp
|
32 |
+
from models.encoders import psp_encoders
|
33 |
from util import *
|
34 |
from huggingface_hub import hf_hub_download
|
35 |
|
36 |
device= 'cpu'
|
37 |
+
ffhq_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512.pt")
|
38 |
+
|
39 |
+
ffhq_ckpt = torch.load(ffhq_model_path, map_location='cpu')
|
40 |
+
ffhq_latent_avg = ffhq_ckpt['latent_avg'].to('cuda:0')
|
41 |
+
ffhq_opts = ffhq_ckpt['opts']
|
42 |
+
ffhq_opts['checkpoint_path'] = ffhq_model_path
|
43 |
+
ffhq_opts= Namespace(**ffhq_opts)
|
44 |
+
|
45 |
+
ffhq_encoder = psp_encoders.Encoder4Editing(50, 'ir_se', ffhq_opts)
|
46 |
+
ffhq_e_filt = {k[len('encoder') + 1:]: v for k, v in ffhq_ckpt['state_dict'].items() if k[:len('encoder')] == 'encoder'}
|
47 |
+
ffhq_encoder.load_state_dict(ffhq_e_filt, strict=True)
|
48 |
+
ffhq_encoder.eval()
|
49 |
+
ffhq_encoder.to(device)
|
50 |
+
|
51 |
+
ffhq_decoder = Generator(512, 512, 8, channel_multiplier=2)
|
52 |
+
ffhq_d_filt = {k[len('decoder') + 1:]: v for k, v in ffhq_ckpt['state_dict'].items() if k[:len('decoder')] == 'decoder'}
|
53 |
+
ffhq_decoder.load_state_dict(ffhq_d_filt, strict=True)
|
54 |
+
ffhq_decoder.eval()
|
55 |
+
ffhq_decoder.to(device)
|
56 |
+
clear_output()
|
57 |
+
|
58 |
+
dog_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512_dog.pt")
|
59 |
+
|
60 |
+
dog_ckpt = torch.load(dog_model_path, map_location='cpu')
|
61 |
+
dog_latent_avg = dog_ckpt['latent_avg'].to('cuda:0')
|
62 |
+
dog_opts = dog_ckpt['opts']
|
63 |
+
dog_opts['checkpoint_path'] = dog_model_path
|
64 |
+
dog_opts= Namespace(**dog_opts)
|
65 |
+
|
66 |
+
dog_encoder = psp_encoders.Encoder4Editing(50, 'ir_se', dog_opts)
|
67 |
+
dog_e_filt = {k[len('encoder') + 1:]: v for k, v in dog_ckpt['state_dict'].items() if k[:len('encoder')] == 'encoder'}
|
68 |
+
dog_encoder.load_state_dict(dog_e_filt, strict=True)
|
69 |
+
dog_encoder.eval()
|
70 |
+
dog_encoder.to(device)
|
71 |
+
|
72 |
+
dog_decoder = Generator(512, 512, 8, channel_multiplier=2)
|
73 |
+
dog_d_filt = {k[len('decoder') + 1:]: v for k, v in dog_ckpt['state_dict'].items() if k[:len('decoder')] == 'decoder'}
|
74 |
+
dog_decoder.load_state_dict(dog_d_filt, strict=True)
|
75 |
+
dog_decoder.eval()
|
76 |
+
dog_decoder.to(device)
|
77 |
+
clear_output()
|
78 |
+
|
79 |
+
cat_model_path = hf_hub_download(repo_id="bankholdup/stylegan_petbreeder", filename="e4e_ffhq512_cat.pt")
|
80 |
+
|
81 |
+
cat_ckpt = torch.load(cat_model_path, map_location='cpu')
|
82 |
+
cat_latent_avg = cat_ckpt['latent_avg'].to('cuda:0')
|
83 |
+
cat_opts = cat_ckpt['opts']
|
84 |
+
cat_opts['checkpoint_path'] = cat_model_path
|
85 |
+
cat_opts= Namespace(**cat_opts)
|
86 |
+
|
87 |
+
cat_encoder = psp_encoders.Encoder4Editing(50, 'ir_se', cat_opts)
|
88 |
+
cat_e_filt = {k[len('encoder') + 1:]: v for k, v in cat_ckpt['state_dict'].items() if k[:len('encoder')] == 'encoder'}
|
89 |
+
cat_encoder.load_state_dict(cat_e_filt, strict=True)
|
90 |
+
cat_encoder.eval()
|
91 |
+
cat_encoder.to(device)
|
92 |
+
|
93 |
+
cat_decoder = Generator(512, 512, 8, channel_multiplier=2)
|
94 |
+
cat_d_filt = {k[len('decoder') + 1:]: v for k, v in cat_ckpt['state_dict'].items() if k[:len('decoder')] == 'decoder'}
|
95 |
+
cat_decoder.load_state_dict(cat_d_filt, strict=True)
|
96 |
+
cat_decoder.eval()
|
97 |
+
cat_decoder.to(device)
|
98 |
+
clear_output()
|
99 |
+
|
100 |
+
|
101 |
+
def gen_im(model_type='ffhq'):
|
102 |
+
if model_type=='ffhq':
|
103 |
+
imgs, _ = ffhq_decoder([ffhq_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
|
104 |
+
elif model_type=='dog':
|
105 |
+
imgs, _ = dog_decoder([dog_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
|
106 |
+
elif model_type=='cat':
|
107 |
+
imgs, _ = cat_decoder([cat_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
|
108 |
+
else:
|
109 |
+
imgs, _ = custom_decoder([custom_codes], input_is_latent=True, randomize_noise=False, return_latents=True)
|
110 |
+
return tensor2im(imgs[0])
|
111 |
|
112 |
|
113 |
def inference(img):
|
114 |
img.save('out.jpg')
|
115 |
aligned_face = align_face('out.jpg')
|
116 |
+
|
117 |
+
ffhq_codes = ffhq_encoder(aligned_face.unsqueeze(0).to("cuda").float())
|
118 |
+
ffhq_codes = ffhq_codes + ffhq_latent_avg.repeat(ffhq_codes.shape[0], 1, 1)
|
119 |
+
|
120 |
+
cat_codes = cat_encoder(aligned_face.unsqueeze(0).to("cuda").float())
|
121 |
+
cat_codes = cat_codes + ffhq_latent_avg.repeat(cat_codes.shape[0], 1, 1)
|
122 |
+
|
123 |
+
dog_codes = dog_encoder(aligned_face.unsqueeze(0).to("cuda").float())
|
124 |
+
dog_codes = dog_codes + ffhq_latent_avg.repeat(dog_codes.shape[0], 1, 1)
|
125 |
|
126 |
+
animal = "cat"
|
127 |
+
npimage = gen_im(animal)
|
128 |
+
|
129 |
imageio.imwrite('filename.jpeg', npimage)
|
130 |
return 'filename.jpeg'
|
131 |
|