MasonCrinr's picture
Upload 331 files
8026e91
raw
history blame contribute delete
No virus
10.9 kB
import functools
import numpy as np
import torch as t
import torch.nn as nn
import jukebox.utils.dist_adapter as dist
from jukebox.transformer.ops import Conv1D, ACT_FNS, LayerNorm
from jukebox.transformer.factored_attention import FactoredAttention
from jukebox.utils.checkpoint import checkpoint
def _convert_mlp_traced(l):
if isinstance(l, ResAttnBlock):
l.mlp = t.jit.trace(l.mlp, t.randn(1, 1, l.n_in).cuda())
def _convert_mlp_traced_fp16(l):
if isinstance(l, ResAttnBlock):
l.mlp = t.jit.trace(l.mlp, t.randn(1, 1, l.n_in).cuda().half())
class MLP(nn.Module):
def __init__(self, n_in, n_state, resid_dropout=0.0, afn='quick_gelu', zero_out=False, init_scale=1.0):
super().__init__()
self.c_fc = Conv1D(n_in, n_state, init_scale=init_scale)
self.c_proj = Conv1D(n_state, n_in, zero_out, init_scale=init_scale)
self.act = ACT_FNS[afn]
self.resid_dropout = nn.Dropout(resid_dropout) if resid_dropout > 0.0 else lambda x: x
def forward(self, x):
m = self.act(self.c_fc(x))
m = self.c_proj(m)
return self.resid_dropout(m)
class ResAttnBlock(nn.Module):
def __init__(self, n_in, n_ctx, n_head,
attn_dropout=0.0, resid_dropout=0.0,
afn='quick_gelu', scale=True, mask=False,
zero_out=False, init_scale=1.0, res_scale=1.0,
m_attn = 0.25, m_mlp = 1.,
checkpoint_attn = 0, checkpoint_mlp = 0,
attn_func=0, blocks=None, spread=None,
encoder_dims=None, prime_len=None):
super().__init__()
self.attn = FactoredAttention(n_in=n_in, n_ctx=n_ctx, n_state=int(m_attn * n_in), n_head=n_head,
attn_dropout=attn_dropout, resid_dropout=resid_dropout,
scale=scale, mask=mask,
zero_out=zero_out, init_scale=init_scale,
checkpoint_attn=checkpoint_attn,
attn_func=attn_func, blocks=blocks, spread=spread,
encoder_dims=encoder_dims, prime_len=prime_len)
self.ln_0 = LayerNorm(n_in)
self.mlp = MLP(n_in=n_in, n_state=int(m_mlp * n_in),
resid_dropout=resid_dropout,
afn=afn,
zero_out=zero_out, init_scale=init_scale)
self.ln_1 = LayerNorm(n_in)
self.res_scale = res_scale
self.checkpoint_attn = checkpoint_attn
self.checkpoint_mlp = checkpoint_mlp
self.n_in = n_in
self.attn_func = attn_func
def forward(self, x, encoder_kv, sample=False):
if sample:
a = self.attn(self.ln_0(x), encoder_kv, sample)
m = self.mlp(self.ln_1(x + a))
else:
if self.attn_func == 6:
assert encoder_kv is not None
a = checkpoint(lambda _x,_enc_kv,_s=sample: self.attn(self.ln_0(_x),_enc_kv,_s),
(x,encoder_kv),
(*self.attn.parameters(), *self.ln_0.parameters()),
self.checkpoint_attn == 3) # 2 recomputes after the projections, and 1 recomputes after head splitting.
else:
assert encoder_kv is None
a = checkpoint(lambda _x,_enc_kv=None,_s=sample: self.attn(self.ln_0(_x),_enc_kv,_s),
(x,),
(*self.attn.parameters(), *self.ln_0.parameters()),
self.checkpoint_attn == 3) # 2 recomputes after the projections, and 1 recomputes after head splitting.
m = checkpoint(lambda _x: self.mlp(self.ln_1(_x)), (x + a,),
(*self.mlp.parameters(), *self.ln_1.parameters()),
self.checkpoint_mlp == 1)
if self.res_scale == 1.0:
h = x + a + m
else:
h = x + self.res_scale * (a + m)
return h
class Transformer(nn.Module):
def __init__(self, n_in, n_ctx, n_head, n_depth,
attn_dropout=0.0, resid_dropout=0.0,
afn='quick_gelu', scale=True, mask=False,
zero_out=False, init_scale=1.0, res_scale=False,
m_attn=0.25, m_mlp=1.,
checkpoint_attn=0, checkpoint_mlp=0, checkpoint_res=0,
attn_order=0, blocks=None, spread=None,
encoder_dims=None, prime_len=None):
super().__init__()
self.n_in = n_in
self.n_ctx = n_ctx
self.encoder_dims = encoder_dims
self.blocks = blocks
if blocks is not None:
assert n_ctx % blocks == 0
self.block_ctx = n_ctx // blocks
self.prime_len = prime_len
self.n_head = n_head
res_scale = 1.0 / n_depth if res_scale else 1.0
# Orders of attn_func
attn_func = {0: lambda d: 0, # Complete dense attn
1: lambda d: [1,2][d%2], # Alternate row and column attn
2: lambda d: [1,2,3][d % 3], # Alternate row, column and previous row attn
3: lambda d: [1,4][d % 2], # Alternate row and last column
4: lambda d: [1,5][d % 2], # Alternate row and last k columns
5: lambda d: [1,4,1,1][d % 4], # Alternate row, last column, row, row
6: lambda d: [1,2,3,6][d % 4],
7: lambda d: [*[1,2,3]*5,6][d%16],
8: lambda d: [1,2,3,1,2,3,1,2,3,6][d%10], # Used by separated_enc_dec model with lyrics
9: lambda d: [1,2,3,0][d % 4],
10: lambda d: [*[1,2,3,1,2,3,1,2,3],*[1,2,3,1,2,3,1,2,3,6]*7][d%79], # Used by large separated_enc_dec model with lyrics
11: lambda d: [6,6,0][d%3] if d%16 == 15 else [1,2,3][d%3],
12: lambda d: [7,7,0][d%3] if d%16 == 15 else [1,2,3][d%3], # Used by single_enc_dec model with lyrics
}[attn_order]
attn_cycle = {0:1, 1:2, 2:3, 3:2, 4:2, 5:4, 6:4, 7:16, 8:10, 9:4, 10:79, 11:16, 12:16}[attn_order]
#assert n_depth % attn_cycle == 0, f'Depth {n_depth} not a multiple of cycle {attn_cycle} for attn_order {attn_order}'
attn_block = lambda d: ResAttnBlock(n_in=n_in, n_ctx=n_ctx, n_head=n_head,
attn_dropout=attn_dropout, resid_dropout=resid_dropout,
afn=afn, scale=scale, mask=mask,
zero_out=zero_out if attn_func(d) !=6 else True,
init_scale=init_scale, res_scale=res_scale,
m_attn=m_attn, m_mlp=m_mlp,
checkpoint_attn=checkpoint_attn, checkpoint_mlp=checkpoint_mlp,
attn_func=attn_func(d), blocks=blocks, spread=spread,
encoder_dims=encoder_dims, prime_len=prime_len)
self.checkpoint_res = checkpoint_res
self._attn_mods = nn.ModuleList()
for d in range(n_depth):
self._attn_mods.append(attn_block(d))
self.ws = []
def set_record_attn(self, record_attn):
"""
Arguments:
record_attn (bool or set): Makes forward prop dump self-attention
softmaxes to self.ws. Either a set of layer indices indicating
which layers to store, or a boolean value indicating whether to
dump all.
"""
def _should_record_attn(layer_idx):
if isinstance(record_attn, bool):
return record_attn
return layer_idx in record_attn
for i, l in enumerate(self._attn_mods):
l.attn.record_attn = _should_record_attn(i)
if record_attn:
assert self.ws == []
for l in self._attn_mods:
assert l.attn.w == None
else:
self.ws = []
for l in self._attn_mods:
l.attn.w = None
def forward(self, x, encoder_kv=None, sample=False, fp16=False, fp16_out=False):
if fp16:
x = x.half()
# Blocks
for i,l in enumerate(self._attn_mods):
if self.checkpoint_res == 1 and not sample:
if l.attn_func == 6:
assert encoder_kv is not None
f = functools.partial(l, sample=sample)
x = checkpoint(f, (x, encoder_kv), l.parameters(), True)
else:
f = functools.partial(l, encoder_kv=None, sample=sample)
x = checkpoint(f, (x,), l.parameters(), True)
else:
if l.attn_func == 6:
x = l(x, encoder_kv=encoder_kv, sample=sample)
else:
x = l(x, encoder_kv=None, sample=sample)
if l.attn.record_attn:
self.ws.append(l.attn.w)
if not fp16_out:
x = x.float()
return x
def check_cache(self, n_samples, sample_t, fp16):
for l in self._attn_mods:
l.attn.check_cache(n_samples, sample_t, fp16)
def del_cache(self):
for l in self._attn_mods:
l.attn.del_cache()
def check_sample(self):
bs, l, s, d = (4, self.n_ctx, self.encoder_dims, self.n_in)
prime = 5
with t.no_grad():
encoder_kv = t.randn(bs, s, d).cuda()
x = t.randn(bs, l, d).cuda()
y_forw = self.forward(x, encoder_kv=encoder_kv, sample=True)
self.del_cache()
x_chunks = t.chunk(x, 4, dim=1)
y_chunks = []
n = 0
for x_chunk in x_chunks:
self.check_cache(bs, n, False)
y_chunk = self.forward(x_chunk, encoder_kv=encoder_kv, sample=True)
y_chunks.append(y_chunk)
n += x_chunk.shape[1]
self.check_cache(bs, n, False)
y_forw_in_chunks = t.cat(y_chunks, dim=1)
max_err = t.max(t.abs(y_forw - y_forw_in_chunks))
assert max_err <= 1e-6, f"Max err is {max_err} {[i for i in range(l) if t.max(t.abs(y_forw - y_forw_in_chunks)[:, i, :]) > 1e-6]}"
if __name__ == '__main__':
from jukebox.utils.dist_utils import setup_dist_from_mpi
setup_dist_from_mpi(port=29600)
n_in = 16
n_ctx = 192
n_head = 4
n_depth = 12
blocks = 16
for attn_order in [0,2,6]:
encoder_dims = {0: 0, 2: 0, 6: 64}[attn_order]
prior = Transformer(n_in, n_ctx, n_head, n_depth, mask=True, attn_order=attn_order, encoder_dims=encoder_dims, blocks=blocks).cuda()
prior.training = False
prior.check_sample()
print(f"Checked attn_order: {attn_order}")