Flux9665's picture
use explicit code instead of relying on release download
9e275b8
raw
history blame
No virus
1.47 kB
import json
import torch
import torch.nn as nn
from Preprocessing.Codec.env import AttrDict
from Preprocessing.Codec.models import Encoder
from Preprocessing.Codec.models import Generator
from Preprocessing.Codec.models import Quantizer
class VQVAE(nn.Module):
def __init__(self,
config_path,
ckpt_path,
with_encoder=False):
super(VQVAE, self).__init__()
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
with open(config_path) as f:
data = f.read()
json_config = json.loads(data)
self.h = AttrDict(json_config)
self.quantizer = Quantizer(self.h)
self.generator = Generator(self.h)
self.generator.load_state_dict(ckpt['generator'])
self.quantizer.load_state_dict(ckpt['quantizer'])
if with_encoder:
self.encoder = Encoder(self.h)
self.encoder.load_state_dict(ckpt['encoder'])
def forward(self, x):
# x is the codebook
# x.shape (B, T, Nq)
quant_emb = self.quantizer.embed(x)
return self.generator(quant_emb)
def encode(self, x):
batch_size = x.size(0)
if len(x.shape) == 3 and x.shape[-1] == 1:
x = x.squeeze(-1)
c = self.encoder(x.unsqueeze(1))
q, loss_q, c = self.quantizer(c)
c = [code.reshape(batch_size, -1) for code in c]
# shape: [N, T, 4]
return torch.stack(c, -1)