Gertie01 commited on
Commit
92ffb1c
1 Parent(s): d751bdd

Create musiclm_pytorch.py

Browse files
Files changed (1) hide show
  1. musiclm_pytorch.py +559 -0
musiclm_pytorch.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn, einsum
4
+
5
+ from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking
6
+
7
+ from audiolm_pytorch import AudioLM
8
+ from audiolm_pytorch.utils import AudioConditionerBase
9
+
10
+ from x_clip.tokenizer import tokenizer
11
+ from vector_quantize_pytorch import ResidualVQ
12
+
13
+ from einops import rearrange, repeat, reduce, pack, unpack
14
+
15
+ from beartype.typing import List, Optional, Tuple
16
+ from beartype import beartype
17
+
18
+ # functions
19
+
20
+ def exists(val):
21
+ return val is not None
22
+
23
+ def default(val, d):
24
+ return val if exists(val) else d
25
+
26
+ def round_down_nearest_multiple(n, divisor):
27
+ return n // divisor * divisor
28
+
29
+ # tensor functions
30
+
31
+ def log(t, eps = 1e-20):
32
+ return torch.log(t.clamp(min = eps))
33
+
34
+ def l2norm(t):
35
+ return F.normalize(t, p = 2, dim = -1)
36
+
37
+ # 2d sinusoidal positional embedding
38
+ # simple vit paper shows it is good enough compared to learned
39
+
40
+ def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
41
+ _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
42
+
43
+ y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
44
+ assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
45
+
46
+ omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
47
+ omega = 1. / (temperature ** omega)
48
+
49
+ y = y.flatten()[:, None] * omega[None, :]
50
+ x = x.flatten()[:, None] * omega[None, :]
51
+
52
+ pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
53
+ pe = pe.type(dtype)
54
+
55
+ return rearrange(pe, '(h w) d -> h w d', h = h, w = w)
56
+
57
+ # biasless layernorm
58
+
59
+ class LayerNorm(nn.Module):
60
+ def __init__(self, dim):
61
+ super().__init__()
62
+ self.gamma = nn.Parameter(torch.ones(dim))
63
+ self.register_buffer('beta', torch.zeros(dim))
64
+
65
+ def forward(self, x):
66
+ return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
67
+
68
+ # feedforward
69
+
70
+ class GEGLU(nn.Module):
71
+ def forward(self, x):
72
+ x, gate = x.chunk(2, dim = -1)
73
+ return F.gelu(gate) * x
74
+
75
+ def FeedForward(dim, mult = 4, dropout = 0.):
76
+ dim_hidden = int(dim * mult * 2 / 3)
77
+
78
+ return nn.Sequential(
79
+ LayerNorm(dim),
80
+ nn.Linear(dim, dim_hidden * 2, bias = False),
81
+ GEGLU(),
82
+ nn.Dropout(dropout),
83
+ nn.Linear(dim_hidden, dim, bias = False)
84
+ )
85
+
86
+ # attention
87
+
88
+ class Attention(nn.Module):
89
+ def __init__(
90
+ self,
91
+ dim,
92
+ causal = False,
93
+ dim_head = 64,
94
+ heads = 8,
95
+ dropout = 0.
96
+ ):
97
+ super().__init__()
98
+ self.heads = heads
99
+ self.scale = dim_head ** -0.5
100
+ self.causal = causal
101
+ inner_dim = dim_head * heads
102
+
103
+ self.norm = LayerNorm(dim)
104
+
105
+ self.attn_dropout = nn.Dropout(dropout)
106
+
107
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
108
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
109
+
110
+ self.to_out = nn.Sequential(
111
+ nn.Linear(inner_dim, dim, bias = False),
112
+ nn.Dropout(dropout)
113
+ )
114
+
115
+ def forward(
116
+ self,
117
+ x,
118
+ mask = None
119
+ ):
120
+ b, n, _, device = *x.shape, x.device
121
+
122
+ # prenorm
123
+
124
+ x = self.norm(x)
125
+
126
+ # project for queries, keys, values
127
+
128
+ q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)
129
+
130
+ # split for multi-headed attention
131
+
132
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
133
+
134
+ q = q * self.scale
135
+
136
+ # similarities
137
+
138
+ sim = einsum('b h i d, b h j d -> b h i j', q, k)
139
+
140
+ if exists(mask):
141
+ mask = rearrange(mask, 'b j -> b 1 1 j')
142
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
143
+
144
+ if self.causal:
145
+ i, j = sim.shape[-2:]
146
+ causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
147
+ sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
148
+
149
+ # attention
150
+
151
+ attn = sim.softmax(dim = -1)
152
+ attn = self.attn_dropout(attn)
153
+
154
+ # aggregate
155
+
156
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
157
+
158
+ # merge heads
159
+
160
+ out = rearrange(out, 'b h n d -> b n (h d)')
161
+ return self.to_out(out)
162
+
163
+ # transformer
164
+
165
+ class Transformer(nn.Module):
166
+ def __init__(
167
+ self,
168
+ dim,
169
+ depth,
170
+ dim_head = 64,
171
+ heads = 8,
172
+ attn_dropout = 0.,
173
+ ff_mult = 4,
174
+ ff_dropout = 0.
175
+ ):
176
+ super().__init__()
177
+ self.layers = nn.ModuleList([])
178
+ for _ in range(depth):
179
+ self.layers.append(nn.ModuleList([
180
+ Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
181
+ FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout),
182
+ ]))
183
+
184
+ def forward(self, x, mask = None):
185
+
186
+ for attn, ff in self.layers:
187
+ x = attn(x, mask = mask) + x
188
+ x = ff(x) + x
189
+
190
+ return x
191
+
192
+ # Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778
193
+
194
+ def pair(t):
195
+ return (t, t) if not isinstance(t, tuple) else t
196
+
197
+ class AudioSpectrogramTransformer(nn.Module):
198
+ def __init__(
199
+ self,
200
+ dim,
201
+ depth,
202
+ patch_size = 16,
203
+ dim_head = 64,
204
+ heads = 8,
205
+ attn_dropout = 0.,
206
+ ff_mult = 4,
207
+ ff_dropout = 0.,
208
+ spec_n_fft = 128,
209
+ spec_power = 2,
210
+ spec_win_length = 24,
211
+ spec_hop_length = None,
212
+ spec_pad = 0,
213
+ spec_center = True,
214
+ spec_pad_mode = 'reflect',
215
+ spec_aug_stretch_factor = 0.8,
216
+ spec_aug_freq_mask = 80,
217
+ spec_aug_time_mask = 80
218
+ ):
219
+ super().__init__()
220
+ self.dim = dim
221
+
222
+ self.patch_size = pair(patch_size)
223
+ self.to_patch_tokens = nn.Conv2d(self.patch_size[0] * self.patch_size[1], dim, 1)
224
+
225
+ self.spec = Spectrogram(
226
+ n_fft = spec_n_fft,
227
+ power = spec_power,
228
+ win_length = spec_win_length,
229
+ hop_length = spec_hop_length,
230
+ pad = spec_pad,
231
+ center = spec_center,
232
+ pad_mode = spec_pad_mode
233
+ )
234
+
235
+ # SpecAugment - seems to be widely used in audio field https://arxiv.org/abs/1904.08779
236
+
237
+ self.aug = torch.nn.Sequential(
238
+ TimeStretch(spec_aug_stretch_factor, fixed_rate=True),
239
+ FrequencyMasking(freq_mask_param = spec_aug_freq_mask),
240
+ TimeMasking(time_mask_param = spec_aug_time_mask),
241
+ )
242
+
243
+ self.transformer = Transformer(
244
+ dim = dim,
245
+ depth = depth,
246
+ dim_head = dim_head,
247
+ heads = heads,
248
+ attn_dropout = attn_dropout,
249
+ ff_mult = ff_mult,
250
+ ff_dropout = ff_dropout
251
+ )
252
+
253
+ self.norm = LayerNorm(dim)
254
+
255
+ def forward(self, x):
256
+ x = self.spec(x)
257
+
258
+ if self.training:
259
+ x = self.aug(x)
260
+
261
+ # automatically crop if audio does not yield a 2d spectrogram that is divisible by patch sizes
262
+
263
+ height, width = x.shape[-2:]
264
+ patch_height, patch_width = self.patch_size
265
+
266
+ rounded_height, rounded_width = map(lambda args: round_down_nearest_multiple(*args), ((height, patch_height), (width, patch_width)))
267
+
268
+ if (height, width) != (rounded_height, rounded_width): # just keep printing to be annoying until it is fixed
269
+ print(f'spectrogram yielded shape of {(height, width)}, but had to be cropped to {(rounded_height, rounded_width)} to be patchified for transformer')
270
+
271
+ x = x[..., :rounded_height, :rounded_width]
272
+
273
+ # to patches
274
+
275
+ x = rearrange(x, 'b (h p1) (w p2) -> b (p1 p2) h w', p1 = patch_height, p2 = patch_width)
276
+ x = self.to_patch_tokens(x)
277
+
278
+ # 2d sinusoidal positional embedding
279
+
280
+ x = rearrange(x, 'b c h w -> b h w c')
281
+ x = x + posemb_sincos_2d(x)
282
+
283
+ # attention, what else
284
+
285
+ x = rearrange(x, 'b ... c -> b (...) c')
286
+
287
+ x = self.transformer(x)
288
+
289
+ # final global average and norm (most recent papers show this is superior to CLS token)
290
+
291
+ x = reduce(x, 'b n d -> b d', 'mean')
292
+
293
+ return self.norm(x)
294
+
295
+ # text transformer
296
+
297
+ @beartype
298
+ class TextTransformer(nn.Module):
299
+ def __init__(
300
+ self,
301
+ dim,
302
+ depth,
303
+ num_tokens = tokenizer.vocab_size,
304
+ max_seq_len = 256,
305
+ dim_head = 64,
306
+ heads = 8,
307
+ attn_dropout = 0.,
308
+ ff_dropout = 0.,
309
+ ff_mult = 4,
310
+ pad_id = 0
311
+ ):
312
+ super().__init__()
313
+ self.dim = dim
314
+
315
+ self.token_emb = nn.Embedding(num_tokens, dim)
316
+ self.pos_emb = nn.Embedding(max_seq_len, dim)
317
+
318
+ self.cls_token = nn.Parameter(torch.randn(dim))
319
+
320
+ self.transformer = Transformer(
321
+ dim = dim,
322
+ depth = depth,
323
+ dim_head = dim_head,
324
+ heads = heads,
325
+ attn_dropout = attn_dropout,
326
+ ff_dropout = ff_dropout,
327
+ ff_mult = ff_mult
328
+ )
329
+
330
+ self.pad_id = pad_id
331
+ self.norm = LayerNorm(dim)
332
+
333
+ def forward(
334
+ self,
335
+ x = None,
336
+ raw_texts: Optional[List[str]] = None,
337
+ mask = None
338
+ ):
339
+ assert exists(x) ^ exists(raw_texts)
340
+
341
+ if exists(raw_texts):
342
+ x = tokenizer.tokenize(raw_texts)
343
+
344
+ if not exists(mask):
345
+ mask = x != self.pad_id
346
+
347
+ b, n, device = *x.shape, x.device
348
+
349
+ # token embedding + positional embedding
350
+
351
+ x = self.token_emb(x)
352
+ x = x + self.pos_emb(torch.arange(n, device = device))
353
+
354
+ # cls tokens, as in bert
355
+
356
+ cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
357
+ x, ps = pack([cls_tokens, x], 'b * d')
358
+
359
+ # account for attending to cls token with self attention mask
360
+
361
+ mask = F.pad(mask, (1, 0), value = True)
362
+
363
+ # attention
364
+
365
+ x = self.transformer(x, mask = mask)
366
+
367
+ # unpack the cls tokens
368
+
369
+ cls_tokens, _ = unpack(x, ps, 'b * d')
370
+
371
+ return self.norm(cls_tokens)
372
+
373
+ # main classes
374
+
375
+ @beartype
376
+ class MuLaN(nn.Module):
377
+ def __init__(
378
+ self,
379
+ audio_transformer: AudioSpectrogramTransformer,
380
+ text_transformer: TextTransformer,
381
+ dim_latent = 128, # they use 128
382
+ decoupled_contrastive_learning = True, # think this was used, make it optional
383
+ ):
384
+ super().__init__()
385
+ self.dim_latent = dim_latent
386
+
387
+ self.audio = audio_transformer
388
+ self.text = text_transformer
389
+
390
+ self.temperature = nn.Parameter(torch.tensor(1.))
391
+
392
+ self.text_to_latents = nn.Linear(self.text.dim, dim_latent)
393
+ self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent)
394
+
395
+ self.decoupled_contrastive_learning = decoupled_contrastive_learning
396
+
397
+ def get_audio_latents(
398
+ self,
399
+ wavs
400
+ ):
401
+ audio_embeds = self.audio(wavs)
402
+ audio_latents = self.audio_to_latents(audio_embeds)
403
+ return l2norm(audio_latents)
404
+
405
+ def get_text_latents(
406
+ self,
407
+ texts = None,
408
+ raw_texts: Optional[List[str]] = None
409
+ ):
410
+ text_embeds = self.text(texts)
411
+ text_latents = self.text_to_latents(text_embeds)
412
+ return l2norm(text_latents)
413
+
414
+ def forward(
415
+ self,
416
+ wavs,
417
+ texts = None,
418
+ raw_texts: Optional[List[str]] = None,
419
+ return_similarities = False
420
+ ):
421
+ batch, device = wavs.shape[0], wavs.device
422
+
423
+ audio_latents = self.get_audio_latents(wavs)
424
+ text_latents = self.get_text_latents(texts, raw_texts = raw_texts)
425
+
426
+ cosine_sim = einsum('i d, j d -> i j', audio_latents, text_latents)
427
+
428
+ assert cosine_sim.shape[0] == cosine_sim.shape[1], 'batch sizes for audio and text are not equal'
429
+
430
+ if return_similarities:
431
+ return cosine_sim
432
+
433
+ cosine_sim = cosine_sim * self.temperature.exp()
434
+
435
+ cosine_sim_exp = cosine_sim.exp()
436
+
437
+ numerator = cosine_sim_exp.diag()
438
+
439
+ if self.decoupled_contrastive_learning:
440
+ eye = torch.eye(batch, device = device)
441
+ cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.)
442
+
443
+ denominator = reduce(cosine_sim_exp, 'i j -> i', 'sum')
444
+
445
+ contrastive_loss = -log(numerator / denominator)
446
+ return contrastive_loss.mean()
447
+
448
+ # music lm
449
+
450
+ @beartype
451
+ class MuLaNEmbedQuantizer(AudioConditionerBase):
452
+ def __init__(
453
+ self,
454
+ mulan: MuLaN,
455
+ conditioning_dims: Tuple[int, ...],
456
+ rq_num_quantizers = 8,
457
+ rq_ema_decay = 0.9,
458
+ codebook_size = 1024,
459
+ namespaces: Tuple[str, ...] = ('semantic', 'coarse', 'fine'),
460
+ ):
461
+ super().__init__()
462
+ self.mulan = mulan
463
+
464
+ assert len(namespaces) > 0
465
+ self.namespaces = namespaces
466
+ self.conditioning_dims = conditioning_dims
467
+
468
+ assert len(conditioning_dims) == len(namespaces), 'number of conditioning dimensions must be equal to number of namespaces'
469
+
470
+ dim = mulan.dim_latent
471
+
472
+ self.rq = ResidualVQ(
473
+ dim = dim,
474
+ num_quantizers = rq_num_quantizers,
475
+ codebook_size = codebook_size,
476
+ decay = rq_ema_decay,
477
+ commitment_weight = 0, # only use EMA to update codebooks
478
+ kmeans_init = True,
479
+ threshold_ema_dead_code = 2,
480
+ quantize_dropout = False # no quantize dropout
481
+ )
482
+
483
+ self.dim = dim
484
+ self.num_codebooks = rq_num_quantizers
485
+
486
+ self.cond_embeddings = nn.ParameterDict({})
487
+
488
+ for namespace, conditioning_dim in zip(namespaces, conditioning_dims):
489
+ cond_embeddings = nn.Parameter(torch.randn(rq_num_quantizers, codebook_size, conditioning_dim))
490
+ nn.init.normal_(cond_embeddings, std = 0.02)
491
+
492
+ self.cond_embeddings[namespace] = cond_embeddings
493
+
494
+ self.set_default_namespace(namespaces[0])
495
+
496
+ def parameters(self):
497
+ return self.cond_embeddings.parameters()
498
+
499
+ def set_default_namespace(self, namespace):
500
+ self._default_namespace = namespace
501
+
502
+ def forward(
503
+ self,
504
+ wavs = None,
505
+ texts = None,
506
+ namespace = None
507
+ ):
508
+ assert exists(wavs) ^ exists(texts)
509
+
510
+ namespace = default(namespace, self._default_namespace)
511
+ assert namespace in self.namespaces, f'namespace {namespace} not found'
512
+ cond_embeddings = self.cond_embeddings[namespace]
513
+
514
+ with torch.no_grad():
515
+ self.mulan.eval()
516
+
517
+ # sound and language live in joint embedding space because of contrastive learning
518
+
519
+ if exists(wavs):
520
+ latents = self.mulan.get_audio_latents(wavs)
521
+ elif exists(texts):
522
+ latents = self.mulan.get_text_latents(texts)
523
+
524
+ _, indices, _ = self.rq(latents)
525
+
526
+ batch, num_codebooks, dim = indices.shape[0], self.num_codebooks, cond_embeddings.shape[-1]
527
+
528
+ cond_embeddings = repeat(cond_embeddings, 'q c d -> b q c d', b = batch)
529
+ indices = repeat(indices, 'b q -> b q 1 d', q = num_codebooks, d = dim)
530
+
531
+ cond_embeddings = cond_embeddings.gather(2, indices)
532
+ return rearrange(cond_embeddings, 'b q 1 d -> b q d')
533
+
534
+ @beartype
535
+ class MusicLM(nn.Module):
536
+ def __init__(
537
+ self,
538
+ audio_lm: AudioLM,
539
+ mulan_embed_quantizer: MuLaNEmbedQuantizer
540
+ ):
541
+ super().__init__()
542
+ assert not exists(audio_lm.audio_conditioner), 'mulan must not have been passed into AudioLM. it will be managed externally now, embedding the text into the joint embedding space for text-to-audio synthesis'
543
+
544
+ self.mulan_embed_quantizer = mulan_embed_quantizer
545
+ self.audio_lm = audio_lm
546
+
547
+ @torch.no_grad()
548
+ def forward(
549
+ self,
550
+ raw_texts: List[str],
551
+ **audio_lm_kwargs
552
+ ):
553
+ self.eval()
554
+
555
+ texts = tokenizer.tokenize(raw_texts)
556
+
557
+ text_embeds = self.mulan_embed_quantizer(texts = texts)
558
+
559
+ return self.audio_lm(text_embeds = text_embeds, **audio_lm_kwargs)