Flux9665's picture
use explicit code instead of relying on release download
9e275b8
raw
history blame
No virus
2.02 kB
import math
import random
import numpy as np
import torch.nn as nn
from Preprocessing.Codec.seanet import SEANetDecoder
from Preprocessing.Codec.seanet import SEANetEncoder
from Preprocessing.Codec.vq import ResidualVectorQuantizer
# Generator
class EnCodec(nn.Module):
def __init__(self,
n_filters,
D,
target_bandwidths=[1, 1.5, 2, 4, 6, 12],
ratios=[8, 5, 4, 2],
sample_rate=16000,
bins=1024,
normalize=False):
super().__init__()
self.hop_length = np.prod(ratios) # 计算乘积
self.encoder = SEANetEncoder(n_filters=n_filters, dimension=D, ratios=ratios)
n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / self.hop_length) * 10))
self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50
self.bits_per_codebook = int(math.log2(bins))
self.target_bandwidths = target_bandwidths
self.quantizer = ResidualVectorQuantizer(dimension=D, n_q=n_q, bins=bins)
self.decoder = SEANetDecoder(n_filters=n_filters, dimension=D, ratios=ratios)
def get_last_layer(self):
return self.decoder.layers[-1].weight
def forward(self, x):
e = self.encoder(x)
max_idx = len(self.target_bandwidths) - 1
bw = self.target_bandwidths[random.randint(0, max_idx)]
quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
o = self.decoder(quantized)
return o, commit_loss, None
def encode(self, x, target_bw=None, st=None):
e = self.encoder(x)
if target_bw is None:
bw = self.target_bandwidths[-1]
else:
bw = target_bw
if st is None:
st = 0
codes = self.quantizer.encode(e, self.frame_rate, bw, st)
return codes
def decode(self, codes):
quantized = self.quantizer.decode(codes)
o = self.decoder(quantized)
return o