# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # # This implementation is inspired from # https://github.com/lucidrains/vector-quantize-pytorch # which is released under MIT License. Hereafter, the original license: # MIT License # # Copyright (c) 2020 Phil Wang # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. """Core vector quantization implementation.""" import torch.nn.functional as F from einops import rearrange from einops import repeat from torch import nn # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Torch distributed utilities.""" import typing as tp import torch def rank(): if torch.distributed.is_initialized(): return torch.distributed.get_rank() else: return 0 def world_size(): if torch.distributed.is_initialized(): return torch.distributed.get_world_size() else: return 1 def is_distributed(): return world_size() > 1 def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): if is_distributed(): return torch.distributed.all_reduce(tensor, op) def _is_complex_or_float(tensor): return torch.is_floating_point(tensor) or torch.is_complex(tensor) def _check_number_of_params(params: tp.List[torch.Tensor]): # utility function to check that the number of params in all workers is the same, # and thus avoid a deadlock with distributed all reduce. if not is_distributed() or not params: return # print('params[0].device ', params[0].device) tensor = torch.tensor( [len(params)], device=params[0].device, dtype=torch.long) all_reduce(tensor) if tensor.item() != len(params) * world_size(): # If not all the workers have the same number, for at least one of them, # this inequality will be verified. raise RuntimeError( f"Mismatch in number of params: ours is {len(params)}, " "at least one worker has a different one.") def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): """Broadcast the tensors from the given parameters to all workers. This can be used to ensure that all workers have the same model to start with. """ if not is_distributed(): return tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] _check_number_of_params(tensors) handles = [] for tensor in tensors: # src = int(rank()) # added code handle = torch.distributed.broadcast( tensor.data, src=src, async_op=True) handles.append(handle) for handle in handles: handle.wait() def sync_buffer(buffers, average=True): """ Sync grad for buffers. If average is False, broadcast instead of averaging. """ if not is_distributed(): return handles = [] for buffer in buffers: if torch.is_floating_point(buffer.data): if average: handle = torch.distributed.all_reduce( buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) else: handle = torch.distributed.broadcast( buffer.data, src=0, async_op=True) handles.append((buffer, handle)) for buffer, handle in handles: handle.wait() if average: buffer.data /= world_size def sync_grad(params): """ Simpler alternative to DistributedDataParallel, that doesn't rely on any black magic. For simple models it can also be as fast. Just call this on your model parameters after the call to backward! """ if not is_distributed(): return handles = [] for p in params: if p.grad is not None: handle = torch.distributed.all_reduce( p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) handles.append((p, handle)) for p, handle in handles: handle.wait() p.grad.data /= world_size() def average_metrics(metrics: tp.Dict[str, float], count=1.): """Average a dictionary of metrics across all workers, using the optional `count` as unormalized weight. """ if not is_distributed(): return metrics keys, values = zip(*metrics.items()) device = 'cuda' if torch.cuda.is_available() else 'cpu' tensor = torch.tensor( list(values) + [1], device=device, dtype=torch.float32) tensor *= count all_reduce(tensor) averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() return dict(zip(keys, averaged)) def default(val: tp.Any, d: tp.Any) -> tp.Any: return val if val is not None else d def ema_inplace(moving_avg, new, decay: float): moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): return (x + epsilon) / (x.sum() + n_categories * epsilon) def uniform_init(*shape: int): t = torch.empty(shape) nn.init.kaiming_uniform_(t) return t def sample_vectors(samples, num: int): num_samples, device = samples.shape[0], samples.device if num_samples >= num: indices = torch.randperm(num_samples, device=device)[:num] else: indices = torch.randint(0, num_samples, (num,), device=device) return samples[indices] def kmeans(samples, num_clusters: int, num_iters: int = 10): dim, dtype = samples.shape[-1], samples.dtype means = sample_vectors(samples, num_clusters) for _ in range(num_iters): diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") dists = -(diffs ** 2).sum(dim=-1) buckets = dists.max(dim=-1).indices bins = torch.bincount(buckets, minlength=num_clusters) zero_mask = bins == 0 bins_min_clamped = bins.masked_fill(zero_mask, 1) new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) new_means = new_means / bins_min_clamped[..., None] means = torch.where(zero_mask[..., None], means, new_means) return means, bins class EuclideanCodebook(nn.Module): """Codebook with Euclidean distance. Args: dim (int): Dimension. codebook_size (int): Codebook size. kmeans_init (bool): Whether to use k-means to initialize the codebooks. If set to true, run the k-means algorithm on the first training batch and use the learned centroids as initialization. kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. decay (float): Decay for exponential moving average over the codebooks. epsilon (float): Epsilon value for numerical stability. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ def __init__( self, dim: int, codebook_size: int, kmeans_init: int = False, kmeans_iters: int = 10, decay: float = 0.99, epsilon: float = 1e-5, threshold_ema_dead_code: int = 2, ): super().__init__() self.decay = decay init_fn: tp.Union[ tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros embed = init_fn(codebook_size, dim) self.codebook_size = codebook_size self.kmeans_iters = kmeans_iters self.epsilon = epsilon self.threshold_ema_dead_code = threshold_ema_dead_code self.register_buffer("inited", torch.Tensor([not kmeans_init])) self.register_buffer("cluster_size", torch.zeros(codebook_size)) self.register_buffer("embed", embed) self.register_buffer("embed_avg", embed.clone()) @torch.jit.ignore def init_embed_(self, data): if self.inited: return embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) self.embed.data.copy_(embed) self.embed_avg.data.copy_(embed.clone()) self.cluster_size.data.copy_(cluster_size) self.inited.data.copy_(torch.Tensor([True])) # Make sure all buffers across workers are in sync after initialization broadcast_tensors(self.buffers()) def replace_(self, samples, mask): modified_codebook = torch.where( mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) self.embed.data.copy_(modified_codebook) def expire_codes_(self, batch_samples): if self.threshold_ema_dead_code == 0: return expired_codes = self.cluster_size < self.threshold_ema_dead_code if not torch.any(expired_codes): return batch_samples = rearrange(batch_samples, "... d -> (...) d") self.replace_(batch_samples, mask=expired_codes) broadcast_tensors(self.buffers()) def preprocess(self, x): x = rearrange(x, "... d -> (...) d") return x def quantize(self, x): embed = self.embed.t() dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True)) embed_ind = dist.max(dim=-1).indices return embed_ind def postprocess_emb(self, embed_ind, shape): return embed_ind.view(*shape[:-1]) def dequantize(self, embed_ind): quantize = F.embedding(embed_ind, self.embed) return quantize def encode(self, x): shape = x.shape # pre-process x = self.preprocess(x) # quantize embed_ind = self.quantize(x) # post-process embed_ind = self.postprocess_emb(embed_ind, shape) return embed_ind def decode(self, embed_ind): quantize = self.dequantize(embed_ind) return quantize def forward(self, x): shape, dtype = x.shape, x.dtype x = self.preprocess(x) self.init_embed_(x) embed_ind = self.quantize(x) embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) embed_ind = self.postprocess_emb(embed_ind, shape) quantize = self.dequantize(embed_ind) if self.training: # We do the expiry of code at that point as buffers are in sync # and all the workers will take the same decision. self.expire_codes_(x) ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) embed_sum = x.t() @ embed_onehot ema_inplace(self.embed_avg, embed_sum.t(), self.decay) cluster_size = ( laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()) embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) self.embed.data.copy_(embed_normalized) return quantize, embed_ind class VectorQuantization(nn.Module): """Vector quantization implementation. Currently supports only euclidean distance. Args: dim (int): Dimension codebook_size (int): Codebook size codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. decay (float): Decay for exponential moving average over the codebooks. epsilon (float): Epsilon value for numerical stability. kmeans_init (bool): Whether to use kmeans to initialize the codebooks. kmeans_iters (int): Number of iterations used for kmeans initialization. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. commitment_weight (float): Weight for commitment loss. """ def __init__( self, dim: int, codebook_size: int, codebook_dim: tp.Optional[int] = None, decay: float = 0.99, epsilon: float = 1e-5, kmeans_init: bool = True, kmeans_iters: int = 50, threshold_ema_dead_code: int = 2, commitment_weight: float = 1., ): super().__init__() _codebook_dim: int = default(codebook_dim, dim) requires_projection = _codebook_dim != dim self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) self.epsilon = epsilon self.commitment_weight = commitment_weight self._codebook = EuclideanCodebook( dim=_codebook_dim, codebook_size=codebook_size, kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, decay=decay, epsilon=epsilon, threshold_ema_dead_code=threshold_ema_dead_code) self.codebook_size = codebook_size @property def codebook(self): return self._codebook.embed def encode(self, x): x = rearrange(x, "b d n -> b n d") x = self.project_in(x) embed_in = self._codebook.encode(x) return embed_in def decode(self, embed_ind): quantize = self._codebook.decode(embed_ind) quantize = self.project_out(quantize) if len(quantize.size()) < 3: quantize = quantize.unsqueeze(0) quantize = rearrange(quantize, "b n d -> b d n") return quantize def forward(self, x): device = x.device x = rearrange(x, "b d n -> b n d") x = self.project_in(x) quantize, embed_ind = self._codebook(x) if self.training: quantize = x + (quantize - x).detach() loss = torch.tensor([0.0], device=device, requires_grad=self.training) if self.training: if self.commitment_weight > 0: commit_loss = F.mse_loss(quantize.detach(), x) loss = loss + commit_loss * self.commitment_weight quantize = self.project_out(quantize) quantize = rearrange(quantize, "b n d -> b d n") return quantize, embed_ind, loss class ResidualVectorQuantization(nn.Module): """Residual vector quantization implementation. Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ def __init__(self, *, num_quantizers, **kwargs): super().__init__() self.layers = nn.ModuleList( [VectorQuantization(**kwargs) for _ in range(num_quantizers)]) def forward(self, x, n_q: tp.Optional[int] = None): quantized_out = 0.0 residual = x all_losses = [] all_indices = [] n_q = n_q or len(self.layers) for layer in self.layers[:n_q]: quantized, indices, loss = layer(residual) residual = residual - quantized quantized_out = quantized_out + quantized all_indices.append(indices) all_losses.append(loss) out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) return quantized_out, out_indices, out_losses def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor: residual = x all_indices = [] n_q = n_q or len(self.layers) st = st or 0 for layer in self.layers[st:n_q]: # 设置解码的起止layer indices = layer.encode(residual) quantized = layer.decode(indices) residual = residual - quantized all_indices.append(indices) out_indices = torch.stack(all_indices) return out_indices def decode(self, q_indices: torch.Tensor) -> torch.Tensor: quantized_out = torch.tensor(0.0, device=q_indices.device) for i, indices in enumerate(q_indices): layer = self.layers[i] quantized = layer.decode(indices) quantized_out = quantized_out + quantized return quantized_out # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Residual vector quantizer implementation.""" from dataclasses import dataclass, field import math import typing as tp import torch from torch import nn @dataclass class QuantizedResult: quantized: torch.Tensor codes: torch.Tensor bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. penalty: tp.Optional[torch.Tensor] = None metrics: dict = field(default_factory=dict) class ResidualVectorQuantizer(nn.Module): """Residual Vector Quantizer. Args: dimension (int): Dimension of the codebooks. n_q (int): Number of residual vector quantizers used. bins (int): Codebook size. decay (float): Decay for exponential moving average over the codebooks. kmeans_init (bool): Whether to use kmeans to initialize the codebooks. kmeans_iters (int): Number of iterations used for kmeans initialization. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ def __init__( self, dimension: int = 256, n_q: int = 8, bins: int = 1024, decay: float = 0.99, kmeans_init: bool = True, kmeans_iters: int = 50, threshold_ema_dead_code: int = 2, ): super().__init__() self.n_q = n_q self.dimension = dimension self.bins = bins self.decay = decay self.kmeans_init = kmeans_init self.kmeans_iters = kmeans_iters self.threshold_ema_dead_code = threshold_ema_dead_code self.vq = ResidualVectorQuantization( dim=self.dimension, codebook_size=self.bins, num_quantizers=self.n_q, decay=self.decay, kmeans_init=self.kmeans_init, kmeans_iters=self.kmeans_iters, threshold_ema_dead_code=self.threshold_ema_dead_code, ) def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult: """Residual vector quantization on the given input tensor. Args: x (torch.Tensor): Input tensor. sample_rate (int): Sample rate of the input tensor. bandwidth (float): Target bandwidth. Returns: QuantizedResult: The quantized (or approximately quantized) representation with the associated bandwidth and any penalty term for the loss. """ bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) quantized, codes, commit_loss = self.vq(x, n_q=n_q) bw = torch.tensor(n_q * bw_per_q).to(x) return quantized, codes, bw, torch.mean(commit_loss) # return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int: """Return n_q based on specified target bandwidth. """ bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) n_q = self.n_q if bandwidth and bandwidth > 0.: n_q = int(max(1, math.floor(bandwidth / bw_per_q))) return n_q def get_bandwidth_per_quantizer(self, sample_rate: int): """Return bandwidth per quantizer for a given input sample rate. """ return math.log2(self.bins) * sample_rate / 1000 def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None, st: tp.Optional[int] = None) -> torch.Tensor: """Encode a given input tensor with the specified sample rate at the given bandwidth. The RVQ encode method sets the appropriate number of quantizer to use and returns indices for each quantizer. """ n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) st = st or 0 codes = self.vq.encode(x, n_q=n_q, st=st) return codes def decode(self, codes: torch.Tensor) -> torch.Tensor: """Decode the given codes to the quantized representation. """ quantized = self.vq.decode(codes) return quantized