HPARAMS_REGISTRY = {} DEFAULTS = {} class Hyperparams(dict): def __getattr__(self, attr): return self[attr] def __setattr__(self, attr, value): self[attr] = value def setup_hparams(hparam_set_names, kwargs): H = Hyperparams() if not isinstance(hparam_set_names, tuple): hparam_set_names = hparam_set_names.split(",") hparam_sets = [HPARAMS_REGISTRY[x.strip()] for x in hparam_set_names if x] + [kwargs] for k, v in DEFAULTS.items(): H.update(v) for hps in hparam_sets: for k in hps: if k not in H: raise ValueError(f"{k} not in default args") H.update(**hps) H.update(**kwargs) return H # Teeny for testing teeny = Hyperparams( ) HPARAMS_REGISTRY["teeny"] = teeny easy = Hyperparams( sr=22050, ) HPARAMS_REGISTRY["easy"] = easy REMOTE_PREFIX = 'https://openaipublic.azureedge.net/' # Model hps vqvae = Hyperparams( levels = 3, downs_t = (3, 2, 2), strides_t = (2, 2, 2), emb_width = 64, l_bins = 2048, l_mu = 0.99, commit = 0.02, spectral = 0.0, multispectral = 1.0, hvqvae_multipliers = (2, 1, 1), loss_fn = 'lmix', lmix_l2 = 1.0, lmix_linf=0.02, width = 32, depth = 4, m_conv = 1.0, dilation_growth_rate = 3, restore_vqvae=REMOTE_PREFIX + 'jukebox/models/5b/vqvae.pth.tar', ) HPARAMS_REGISTRY["vqvae"] = vqvae labels = Hyperparams( y_bins=(120, 4111), t_bins=128, max_bow_genre_size=5, n_vocab=80, ) upsamplers = Hyperparams( n_ctx=8192, prior_width=1920, prior_depth=72, heads=1, attn_order=2, blocks=128, init_scale=0.4, c_res=1, cond_width=1024, cond_depth=16, cond_dilation_growth_rate=3, cond_dilation_cycle=8, cond_c_res=1, use_tokens=False, prime_loss_fraction=0.0, fp16_params=False, ) upsamplers.update(labels) upsampler_level_0 = Hyperparams( level=0, restore_prior=REMOTE_PREFIX + 'jukebox/models/5b/prior_level_0.pth.tar' ) upsampler_level_0.update(upsamplers) HPARAMS_REGISTRY["upsampler_level_0"] = upsampler_level_0 upsampler_level_1 = Hyperparams( level=1, cond_res_scale=True, restore_prior=REMOTE_PREFIX + 'jukebox/models/5b/prior_level_1.pth.tar' ) upsampler_level_1.update(upsamplers) HPARAMS_REGISTRY["upsampler_level_1"] = upsampler_level_1 prior_5b = Hyperparams( level=2, n_ctx=8192, prior_width=4800, prior_depth=72, heads=8, attn_order=2, blocks=128, init_scale=0.1, c_res=1, beta2=0.925, min_duration=60.0, max_duration=600.0, use_tokens=False, n_tokens=0, prime_loss_fraction=0.0, merged_decoder=True, restore_prior=REMOTE_PREFIX + 'jukebox/models/5b/prior_level_2.pth.tar', fp16_params=True, ) prior_5b.update(labels) HPARAMS_REGISTRY["prior_5b"] = prior_5b prior_5b_lyrics = Hyperparams( level=2, n_ctx=8192, prior_width=4800, prior_depth=79, heads=8, attn_order=10, blocks=128, init_scale=0.1, c_res=1, prime_width=1280, prime_depth=18, prime_heads=4, prime_attn_order=2, prime_blocks=32, prime_init_scale=0.7, prime_c_res=1, min_duration=23.8, max_duration=600.0, use_tokens=True, n_tokens=512, prime_loss_fraction=0.4, merged_decoder=True, restore_prior=REMOTE_PREFIX + 'jukebox/models/5b_lyrics/prior_level_2.pth.tar', fp16_params=True, alignment_layer=68, alignment_head=2, ) prior_5b_lyrics.update(labels) HPARAMS_REGISTRY["prior_5b_lyrics"] = prior_5b_lyrics labels_v3 = Hyperparams( y_bins=(604, 7898), t_bins=64, max_bow_genre_size=1, n_vocab=79, ) prior_1b_lyrics = Hyperparams( level=2, n_ctx=6144, prior_width=2048, prior_depth=72, heads=2, attn_order=12, blocks=64, init_scale=0.2, c_res=1, labels_v3=True, min_duration=17.84, max_duration=600.0, use_tokens=True, n_tokens=384, prime_loss_fraction=0.4, single_enc_dec=True, restore_prior=REMOTE_PREFIX + 'jukebox/models/1b_lyrics/prior_level_2.pth.tar', fp16_params=False, alignment_layer=63, alignment_head=0, ) prior_1b_lyrics.update(labels_v3) HPARAMS_REGISTRY["prior_1b_lyrics"] = prior_1b_lyrics # Small models small_vqvae = Hyperparams( sr = 22050, levels = 2, downs_t = (5, 3), strides_t = (2, 2), emb_width = 64, l_bins = 1024, l_mu = 0.99, commit = 0.02, spectral = 0.0, multispectral = 1.0, loss_fn = 'l2', width = 32, depth = 4, m_conv = 1.0, dilation_growth_rate = 3, ) HPARAMS_REGISTRY["small_vqvae"] = small_vqvae small_prior = Hyperparams( n_ctx=8192, prior_width=1024, prior_depth=48, heads=1, c_res=1, attn_order=2, blocks=64, init_scale=0.7, ) HPARAMS_REGISTRY["small_prior"] = small_prior small_labelled_prior = Hyperparams( labels=True, labels_v3=True, y_bins=(10,100), # Set this to (genres, artists) for your dataset max_bow_genre_size=1, min_duration=60.0, max_duration=600.0, t_bins=64, ) small_labelled_prior.update(small_prior) HPARAMS_REGISTRY["small_labelled_prior"] = small_labelled_prior small_single_enc_dec_prior = Hyperparams( n_ctx=6144, prior_width=1024, prior_depth=48, heads=2, attn_order=12, blocks=64, init_scale=0.7, c_res=1, prime_loss_fraction=0.4, single_enc_dec=True, labels=True, labels_v3=True, y_bins=(10,100), # Set this to (genres, artists) for your dataset max_bow_genre_size=1, min_duration=60.0, max_duration=600.0, t_bins=64, use_tokens=True, n_tokens=384, n_vocab=79, ) HPARAMS_REGISTRY["small_single_enc_dec_prior"] = small_single_enc_dec_prior small_sep_enc_dec_prior = Hyperparams( n_ctx=6144, prior_width=1024, prior_depth=50, heads=2, attn_order=8, blocks=64, init_scale=0.7, c_res=1, prime_width=256, prime_depth=9, prime_heads=2, prime_attn_order=2, prime_blocks=32, prime_init_scale=0.7, prime_c_res=1, prime_loss_fraction=0.4, labels=True, labels_v3=True, y_bins=(10,100), # Set this to (genres, artists) for your dataset max_bow_genre_size=1, min_duration=60.0, max_duration=600.0, t_bins=64, use_tokens=True, n_tokens=384, n_vocab=79, ) HPARAMS_REGISTRY["small_sep_enc_dec_prior"] = small_sep_enc_dec_prior small_upsampler = Hyperparams( n_ctx=8192, prior_width=1024, prior_depth=48, heads=1, c_res=1, attn_order=2, blocks=64, init_scale=0.7, cond_width=512, cond_depth=16, cond_dilation_growth_rate=3, cond_dilation_cycle=8, cond_c_res=1, ) HPARAMS_REGISTRY["small_upsampler"] = small_upsampler all_fp16 = Hyperparams( fp16=True, fp16_params=True, fp16_opt=True, fp16_scale_window=250, ) HPARAMS_REGISTRY["all_fp16"] = all_fp16 cpu_ema = Hyperparams( ema=True, cpu_ema=True, cpu_ema_freq=100, ema_fused=False, ) HPARAMS_REGISTRY["cpu_ema"] = cpu_ema DEFAULTS["rcall"] = Hyperparams( rcall_command="", git_commit="", ) DEFAULTS["script"] = Hyperparams( name='', debug_mem=False, debug_eval_files=False, debug_speed=False, debug_iters=100, debug_batch=False, debug_grad_accum=False, debug_inputs=False, local_path='', local_logdir='logs', max_len=24, max_log=32, save=True, save_iters=20000, seed=0, prior=False, log_steps=100, func='', ) DEFAULTS["data"] = Hyperparams( audio_files_dir='', finetune='', english_only=False, bs=1, bs_sample=1, nworkers=1, aug_shift=False, aug_blend=False, train_test_split=0.9, train_shrink_factor=1.0, test_shrink_factor=1.0, p_unk=0.1, min_duration=None, max_duration=None, n_tokens=0, n_vocab=0, use_tokens=False, curr_epoch=-1, ) DEFAULTS["vqvae"] = Hyperparams( restore_vqvae='', levels=2, downs_t=(1,1), strides_t=(2,2), hvqvae_multipliers=None, revival_threshold=1.0, emb_width=64, l_bins=512, l_mu=0.99, commit=1.0, spectral=0.0, multispectral=1.0, loss_fn='l2', linf_k=2048, lmix_l1=0.0, lmix_l2=0.0, lmix_linf=0.0, use_bottleneck=True, ) DEFAULTS["vqvae_conv_block"] = Hyperparams( depth=3, width=128, m_conv=1.0, dilation_growth_rate=1, dilation_cycle=None, vqvae_reverse_decoder_dilation=True, ) DEFAULTS["prior"] = Hyperparams( restore_prior='', restore_prior_ddp=False, max_bow_genre_size=None, y_bins=0, level=0, cond_levels=None, t_bins=64, y_cond_as_bias=False, copy_input=False, merged_decoder=False, single_enc_dec=False, alignment_layer=None, alignment_head=None, ) DEFAULTS["prior_attn_block"] = Hyperparams( n_ctx=1024, prior_depth=3, prior_width=128, heads=1, attn_order=0, blocks=None, spread=None, attn_dropout=0.0, resid_dropout=0.0, emb_dropout=0.0, zero_out=False, res_scale=False, pos_init=False, init_scale=1.0, m_attn=0.25, m_mlp=1.0, c_res=0, c_attn=0, c_mlp=0, ) DEFAULTS["cond_conv_block"] = Hyperparams( cond_depth=3, cond_width=128, cond_m_conv=1.0, cond_zero_out=False, cond_res_scale=False, cond_dilation_growth_rate=1, cond_dilation_cycle=None, cond_c_res=0, ) DEFAULTS["sample"] = Hyperparams( primed_chunk_size=None, selected_artists='', temp_top=1.0, temp_rest=0.99, sample_length_in_seconds=24, total_sample_length_in_seconds=240, ) DEFAULTS["prime"] = Hyperparams( #encoder_kv_width=128, prime_loss_fraction=0.1, restore_decoder='', ) DEFAULTS["prime_attn_block"] = Hyperparams( prime_depth=3, prime_width=128, prime_heads=1, prime_attn_order=0, prime_blocks=None, prime_spread=None, prime_attn_dropout=0.0, prime_resid_dropout=0.0, prime_emb_dropout=0.0, prime_zero_out=False, prime_res_scale=False, prime_pos_init=False, prime_init_scale=1.0, prime_m_attn=0.25, prime_m_mlp=1.0, prime_c_res=0, prime_c_attn=0, prime_c_mlp=0, prime_rel_attn=False, prime_posemb_timescale=10000, ) DEFAULTS["opt"] = Hyperparams( epochs=10000, lr=0.0003, clip=1.0, beta1=0.9, beta2=0.999, ignore_grad_norm=0, weight_decay=0.0, eps=1e-08, lr_warmup=100.0, lr_decay=10000000000.0, lr_gamma=1.0, lr_scale=1.0, lr_use_linear_decay=False, lr_start_linear_decay=0, lr_use_cosine_decay=False, ) DEFAULTS["fp16"] = Hyperparams( fp16=False, fp16_params=False, fp16_loss_scale=None, fp16_scale_window=1000.0, fp16_opt=False, ) DEFAULTS["train_test_eval"] = Hyperparams( labels=True, labels_v3=False, dump=False, ema=True, ema_fused=True, cpu_ema=False, cpu_ema_freq=100, reset_best_loss=False, reset_step=False, reset_opt=False, reset_shd=False, train=False, test=False, sample=False, sampler='ancestral', codes_logdir='', date=None, labeller='top_genres', label_line=0, iters_before_update=1, grad_accum_iters=0, mu=None, piped=False, pipe_depth=8, break_train=1e10, break_test=1e10, exit_train=1e10, ) DEFAULTS["audio"] = Hyperparams( n_fft=1024, hop_length=256, window_size=1024, sr=44100, channels=2, wav='', n_inps=1, n_hops=2, n_segment=1, n_total_segment=1, n_segment_each=1, prime_chunks=4, sample_length=0, sample_hop_length=30000, max_silence_pad_length=0, ignore_boundaries=False, use_nonrelative_specloss=True, multispec_loss_n_fft=(2048,1024,512), multispec_loss_hop_length=(240,120,50), multispec_loss_window_size=(1200,600,240), ) DEFAULTS["distributed"] = Hyperparams( bucket=128 )