Flux9665 commited on
Commit
a1551ba
β€’
1 Parent(s): db5766e

Update Modules/ControllabilityGAN/GAN.py

Browse files
Files changed (1) hide show
  1. Modules/ControllabilityGAN/GAN.py +11 -8
Modules/ControllabilityGAN/GAN.py CHANGED
@@ -5,7 +5,7 @@ from Modules.ControllabilityGAN.wgan.init_wgan import create_wgan
5
 
6
  class GanWrapper:
7
 
8
- def __init__(self, path_wgan, device):
9
  self.device = device
10
  self.path_wgan = path_wgan
11
 
@@ -20,15 +20,18 @@ class GanWrapper:
20
 
21
  self.z_list = list()
22
 
23
- for _ in range(1100):
24
- self.z_list.append(self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8))
 
 
 
 
 
 
25
  self.z = self.z_list[0]
26
 
27
  def set_latent(self, seed):
28
- self.z = self.z = self.z_list[seed]
29
-
30
- def reset_default_latent(self):
31
- self.z = self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8)
32
 
33
  def load_model(self, path):
34
  gan_checkpoint = torch.load(path, map_location="cpu")
@@ -53,7 +56,7 @@ class GanWrapper:
53
  self.mean = gan_checkpoint["dataset_mean"]
54
  self.std = gan_checkpoint["dataset_std"]
55
 
56
- def compute_controllability(self, n_samples=100000):
57
  _, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True)
58
  intermediate = intermediate.cpu()
59
  z = z.cpu()
 
5
 
6
  class GanWrapper:
7
 
8
+ def __init__(self, path_wgan, device, num_cached_voices=10):
9
  self.device = device
10
  self.path_wgan = path_wgan
11
 
 
20
 
21
  self.z_list = list()
22
 
23
+ while len(self.z_list) < num_cached_voices + 2:
24
+ z = self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8)
25
+ sims = [-1.0]
26
+ for other_z in self.z_list:
27
+ sims.append(torch.nn.functional.cosine_similarity(z, other_z))
28
+ print(max(sims), len(self.z_list))
29
+ if max(sims) < 0.25:
30
+ self.z_list.append(z)
31
  self.z = self.z_list[0]
32
 
33
  def set_latent(self, seed):
34
+ self.z = self.z_list[seed]
 
 
 
35
 
36
  def load_model(self, path):
37
  gan_checkpoint = torch.load(path, map_location="cpu")
 
56
  self.mean = gan_checkpoint["dataset_mean"]
57
  self.std = gan_checkpoint["dataset_std"]
58
 
59
+ def compute_controllability(self, n_samples=200000):
60
  _, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True)
61
  intermediate = intermediate.cpu()
62
  z = z.cpu()