"""Conformer definition adjusted given the Lucidrain's repo. https://github.com/lucidrains/soundstorm-pytorch/blob/main/soundstorm_pytorch/soundstorm.py # noqa Copyright PolyAI Limited. """ from collections import namedtuple from functools import wraps from typing import Dict, Union import torch import torch.nn.functional as F from einops import rearrange, reduce from einops.layers.torch import EinMix, Rearrange from torch import einsum, nn # rotary embedding class RotaryEmbedding(nn.Module): def __init__(self, dim, theta = 10000): super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent = False) @property def device(self): return next(self.buffers()).device def forward(self, seq_len): t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq) freqs = torch.einsum('i , j -> i j', t, self.inv_freq) freqs = torch.cat((freqs, freqs), dim = -1) return freqs def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(pos, t): return (t * pos.cos()) + (rotate_half(t) * pos.sin()) # constants EfficientAttentionConfig = namedtuple( 'EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'] ) # helpers def exists(val): return val is not None def default(val, d): return val if exists(val) else d def divisible_by(numer, denom): return (numer % denom) == 0 def calc_same_padding(kernel_size): pad = kernel_size // 2 return (pad, pad - (kernel_size + 1) % 2) def eval_decorator(fn): @wraps(fn) def inner(model, *args, **kwargs): was_training = model.training model.eval() out = fn(model, *args, **kwargs) model.train(was_training) return out return inner def once(fn): called = False @wraps(fn) def inner(x): nonlocal called if called: return called = True return fn(x) return inner print_once = once(print) # t5 relative positional bias class T5RelativePositionBias(nn.Module): def __init__( self, scale = 1., num_buckets = 32, max_distance = 128, heads = 8 ): super().__init__() self.scale = scale self.num_buckets = num_buckets self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(num_buckets, heads) @staticmethod def _relative_position_bucket( relative_position, num_buckets = 32, max_distance = 128 ): ret = 0 n = -relative_position num_buckets //= 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log( max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min( val_if_large, torch.full_like(val_if_large, num_buckets - 1) ) ret += torch.where(is_small, n, val_if_large) return ret @property def device(self): return next(self.parameters()).device def forward(self, n): pos = torch.arange(n, device = self.device).long() rel_pos = rearrange(pos, 'j -> 1 j') - rearrange(pos, 'i -> i 1') rp_bucket = self._relative_position_bucket( rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) values = self.relative_attention_bias(rp_bucket) bias = rearrange(values, 'i j h -> h i j') return bias * self.scale # main class class Attend(nn.Module): def __init__( self, causal = False, dropout = 0., flash = False ): super().__init__() self.dropout = dropout self.attn_dropout = nn.Dropout(dropout) self.causal = causal self.flash = flash # determine efficient attention configs for cuda and cpu self.cpu_config = EfficientAttentionConfig(True, True, True) self.cuda_config = None if not torch.cuda.is_available() or not flash: return device_properties = torch.cuda.get_device_properties(torch.device('cuda')) if device_properties.major == 8 and device_properties.minor == 0: print_once('A100 GPU detected, using flash attention if input tensor is on cuda') # noqa self.cuda_config = EfficientAttentionConfig(True, True, True) else: print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') # noqa self.cuda_config = EfficientAttentionConfig(False, True, True) def get_mask(self, i, j, device): return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) # noqa def flash_attn(self, q, k, v, mask = None, attn_bias = None): _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device # noqa # single headed key / values if k.ndim == 3: k = rearrange(k, 'b n d -> b 1 n d') if v.ndim == 3: v = rearrange(v, 'b n d -> b 1 n d') # Check if mask exists and expand to compatible shape # The mask is B L, so it would have to be expanded to B H N L if exists(mask) and mask.ndim != 4: mask = rearrange(mask, 'b j -> b 1 1 j') mask = mask.expand(-1, heads, q_len, -1) # Check if there is a compatible device for flash attention config = self.cuda_config if is_cuda else self.cpu_config causal = self.causal # handle attention bias if exists(attn_bias): mask_value = -torch.finfo(q.dtype).max // 2 causal_mask = self.get_mask(q_len, k_len, device) attn_bias = attn_bias.masked_fill(causal_mask, mask_value) if exists(mask): attn_bias = attn_bias.masked_fill(~mask, mask_value) mask = attn_bias causal = False # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale with torch.backends.cuda.sdp_kernel(**config._asdict()): out = F.scaled_dot_product_attention( q, k, v, attn_mask = mask, dropout_p = self.dropout if self.training else 0., is_causal = causal ) return out def forward(self, q, k, v, mask = None, attn_bias = None): """ einstein notation b - batch h - heads n, i, j - sequence length (base sequence length, source, target) d - feature dimension """ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device scale = q.shape[-1] ** -0.5 kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' if self.flash: assert not exists(attn_bias) return self.flash_attn(q, k, v, mask = mask) # similarity sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale # attention bias if exists(attn_bias): sim = sim + attn_bias # causal mask if self.causal: causal_mask = self.get_mask(q_len, k_len, device) sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) # key padding mask if exists(mask): if mask.ndim != 4: mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) # attention attn = sim.softmax(dim=-1) attn = self.attn_dropout(attn) # aggregate values out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) return out class Swish(nn.Module): def forward(self, x): return x * x.sigmoid() class GLU(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): out, gate = x.chunk(2, dim=self.dim) return out * gate.sigmoid() class DepthWiseConv1d(nn.Module): def __init__(self, chan_in, chan_out, kernel_size, padding): super().__init__() self.padding = padding self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in) def forward(self, x): x = F.pad(x, self.padding) return self.conv(x) class Scale(nn.Module): def __init__(self, scale, fn): super().__init__() self.fn = fn self.scale = scale def forward(self, x, **kwargs): return self.fn(x, **kwargs) * self.scale class ChanLayerNorm(nn.Module): def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.ones(1, dim, 1)) def forward(self, x): eps = 1e-6 if x.dtype == torch.float32 else 1e-4 var = torch.var(x, dim = 1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) * var.clamp(min = eps).rsqrt() * self.gamma class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) def forward(self, x, **kwargs): x = self.norm(x) return self.fn(x, **kwargs) class Attention(nn.Module): def __init__( self, dim, heads = 8, dim_head = 64, dropout = 0., flash = True ): super().__init__() inner_dim = dim_head * heads self.heads= heads self.scale = dim_head ** -0.5 self.attend = Attend( flash = flash, dropout = dropout ) self.dropout = nn.Dropout(dropout) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.to_out = nn.Linear(inner_dim, dim) def forward( self, x, context = None, mask = None, rotary_emb = None, attn_bias = None ): n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context) context = default(context, x) q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) q, k, v = map( lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) if exists(rotary_emb): q = apply_rotary_pos_emb(rotary_emb, q) k = apply_rotary_pos_emb(rotary_emb, k) out = self.attend(q, k, v, mask = mask, attn_bias = attn_bias) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) class FeedForward(nn.Module): def __init__( self, dim, mult = 4, dropout = 0. ): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * mult), Swish(), nn.Dropout(dropout), nn.Linear(dim * mult, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class ConformerConvModule(nn.Module): def __init__( self, dim, causal = False, expansion_factor = 2, kernel_size = 31, dropout = 0. ): super().__init__() inner_dim = dim * expansion_factor padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) self.net = nn.Sequential( nn.LayerNorm(dim), Rearrange('b n c -> b c n'), nn.Conv1d(dim, inner_dim * 2, 1), GLU(dim=1), DepthWiseConv1d( inner_dim, inner_dim, kernel_size = kernel_size, padding = padding ), Swish(), ChanLayerNorm(inner_dim), nn.Conv1d(inner_dim, dim, 1), Rearrange('b c n -> b n c'), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) # Conformer Block class ConformerBlock(nn.Module): def __init__( self, *, dim, dim_head = 64, heads = 8, ff_mult = 4, conv_expansion_factor = 2, conv_kernel_size = 31, attn_dropout = 0., attn_flash = True, ff_dropout = 0., conv_dropout = 0., conv_causal = False ): super().__init__() self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) self.attn = Attention( dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = attn_flash ) self.conv = ConformerConvModule( dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout ) self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) self.attn = PreNorm(dim, self.attn) self.ff1 = Scale(0.5, PreNorm(dim, self.ff1)) self.ff2 = Scale(0.5, PreNorm(dim, self.ff2)) self.post_norm = nn.LayerNorm(dim) def forward( self, x, mask = None, rotary_emb = None, attn_bias = None ): x = self.ff1(x) + x x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x # noqa x = self.conv(x) + x x = self.ff2(x) + x x = self.post_norm(x) return x # Conformer class Conformer(nn.Module): def __init__( self, dim, *, num_layers, dim_head = 64, heads = 8, ff_mult = 4, conv_expansion_factor = 2, conv_kernel_size = 31, attn_dropout = 0., ff_dropout = 0., conv_dropout = 0., conv_causal = False, attn_flash = True, t5_rel_pos_bias = False ): super().__init__() assert not (t5_rel_pos_bias and attn_flash), 'flash attention is not compatible with learned bias' # noqa self.dim = dim self.layers = nn.ModuleList([]) self.rotary_emb = RotaryEmbedding( dim_head) if not t5_rel_pos_bias else None self.rel_pos_bias = T5RelativePositionBias( dim_head ** 0.5, heads = heads) if t5_rel_pos_bias else None for _ in range(num_layers): self.layers.append(ConformerBlock( dim = dim, dim_head = dim_head, heads = heads, ff_mult = ff_mult, conv_expansion_factor = conv_expansion_factor, conv_kernel_size = conv_kernel_size, attn_dropout = attn_dropout, ff_dropout = ff_dropout, conv_dropout = conv_dropout, conv_causal = conv_causal, attn_flash = attn_flash )) def forward(self, x, mask = None): seq_len = x.shape[-2] rotary_emb = self.rotary_emb(seq_len) if exists(self.rotary_emb) else None # noqa attn_bias = self.rel_pos_bias(seq_len) if exists(self.rel_pos_bias) else None #noqa for block in self.layers: x = block( x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias ) return x # conformer with sum reduction across quantized tokens at the beginning, # along with heads class ConformerWrapper(nn.Module): def __init__( self, *, codebook_size, num_quantizers, conformer: Union[Conformer, Dict[str, any]], grouped_quantizers = 1 ): super().__init__() self.conformer = conformer if isinstance(conformer, dict): self.conformer = Conformer(**self.conformer) dim = self.conformer.dim self.embedding_proj = nn.Sequential( nn.Linear(dim * grouped_quantizers, dim), nn.LayerNorm(dim) ) if grouped_quantizers > 1 else nn.Identity() num_codes_with_mask = codebook_size + 1 num_effective_quantizers = num_quantizers * grouped_quantizers self.code_embeds = nn.Embedding( num_codes_with_mask * num_effective_quantizers, dim) self.register_buffer( 'quantizer_offsets', torch.arange(num_effective_quantizers) * num_codes_with_mask, persistent = False ) self.register_buffer( 'mask_tokens', self.quantizer_offsets + num_codes_with_mask, persistent = False ) self.dim = dim self.codebook_size = codebook_size self.num_codes_with_mask = num_codes_with_mask self.num_quantizers = num_quantizers self.grouped_quantizers = grouped_quantizers self.heads = nn.Sequential( nn.Linear(dim, dim * num_effective_quantizers), Rearrange('b n (h d) -> b (n h) d', h = num_effective_quantizers) ) # each quantizer codebook would require its own logits weight # and bias matrices # the amazing einops makes this easy with 'EinMix' self.to_logits = nn.Sequential( nn.LayerNorm(dim), Rearrange('b (n gq) d -> b n gq d', gq = num_effective_quantizers), EinMix( 'b n gq d -> b n gq l', weight_shape = 'gq d l', bias_shape = 'gq l', gq = num_effective_quantizers, l = codebook_size, d = dim ), Rearrange('b ... d -> b (...) d') ) def forward( self, x, *, mask = None, cond = None, sum_embeds = None, return_embeddings = False, return_logits_and_embeddings = False ): """ einops notation: b - batch n - sequence g - groups q - quantizers d - feature dimension """ n, q, g = x.shape[-1], self.num_quantizers, self.grouped_quantizers assert divisible_by(n, g * q), 'sequence must be divisible by number of quantizers' # noqa x = rearrange(x, 'b (n gq) -> b n gq', gq = g * q) x = x + self.quantizer_offsets x = self.code_embeds(x) x = reduce(x, 'b n (g q) d -> b n (g d)', 'sum', g = g) x = self.embedding_proj(x) if exists(sum_embeds): x = x + sum_embeds if exists(cond): if cond.ndim == 2: cond = rearrange(cond, 'b d -> b 1 d') x = x + cond x = self.conformer(x, mask = mask) embeds = self.heads(x) if return_embeddings or not exists(self.to_logits): return embeds logits = self.to_logits(embeds) if return_logits_and_embeddings: return logits, embeds return logits