#======================================================================================= # https://huggingface.co/spaces/asigalov61/Imagen-POP-Music-Medley-Diffusion-Transformer #======================================================================================= import os import time as reqtime import datetime from pytz import timezone import torch from imagen_pytorch import Unet, Imagen, ImagenTrainer from imagen_pytorch.data import Dataset import spaces import gradio as gr import numpy as np import random import tqdm import TMIDIX import TPLOTS from midi_to_colab_audio import midi_to_colab_audio # ================================================================================================= @spaces.GPU def Generate_POP_Medley(input_num_medley_comps, input_melody_patch): print('=' * 70) print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) start_time = reqtime.time() print('=' * 70) print('Req number of medley compositions:', input_num_medley_comps) print('Req melody MIDI patch number:', input_melody_patch) print('=' * 70) #=============================================================================== # MIDI Images generation #=============================================================================== print('Loading model...') DIM = 64 CHANS = 1 TSTEPS = 1000 DEVICE = 'cuda' # 'cpu' unet = Unet( dim = DIM, dim_mults = (1, 2, 4, 8), num_resnet_blocks = 1, channels=CHANS, layer_attns = (False, False, False, True), layer_cross_attns = False ) imagen = Imagen( condition_on_text = False, # this must be set to False for unconditional Imagen unets = unet, channels=CHANS, image_sizes = 128, timesteps = TSTEPS ) trainer = ImagenTrainer( imagen = imagen, split_valid_from_train = True # whether to split the validation dataset from the training ).to(DEVICE) print('=' * 70) print('Loading model checkpoint...') print('=' * 70) trainer.load('Imagen_POP909_64_dim_12638_steps_0.00983_loss.ckpt') print('=' * 70) print('Done!') print('=' * 70) print('Generating...') print('=' * 70) images = trainer.sample(batch_size = input_num_medley_comps, return_pil_images = True) print('=' * 70) print('Done!') print('=' * 70) print('Processing...') threshold = 128 imgs_array = [] for i in images: arr = np.array(i) farr = np.where(arr < threshold, 0, 1) imgs_array.append(farr) print('Done!') #=============================================================================== print('=' * 70) print('Converting images to scores...') medley_compositions_escores = [] for i in imgs_array: bmatrix = TPLOTS.images_to_binary_matrix([i]) score = TMIDIX.binary_matrix_to_original_escore_notes(bmatrix) if input_melody_patch > -1: score = TMIDIX.add_melody_to_enhanced_score_notes(score, melody_patch=input_melody_patch) medley_compositions_escores.append(score) print('Done!') print('=' * 70) print('Creating medley score...') medley_labels = ['Imagen POP Medley Composition #' + str(i+1) for i in range(len(medley_compositions_escores))] medley_escore = TMIDIX.escore_notes_medley(medley_compositions_escores, medley_labels, pause_time_value=16) #=============================================================================== print('Rendering results...') print('=' * 70) print('Sample INTs', medley_escore[:15]) print('=' * 70) fn1 = "Imagen-POP-Music-Medley-Diffusion-Transformer-Composition" output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(medley_escore) detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score, output_signature = 'Imagen POP Music Medley', output_file_name = fn1, track_name='Project Los Angeles', list_of_MIDI_patches=patches, timings_multiplier=256 ) new_fn = fn1+'.mid' audio = midi_to_colab_audio(new_fn, soundfont_path=soundfont, sample_rate=16000, volume_scale=10, output_for_gradio=True ) print('Done!') print('=' * 70) #======================================================== output_midi_title = str(fn1) output_midi_summary = str(output_score[:3]) output_midi = str(new_fn) output_audio = (16000, audio) output_plot = TMIDIX.plot_ms_SONG(output_score, plot_title=output_midi, return_plt=True, timings_multiplier=256) print('Output MIDI file name:', output_midi) print('Output MIDI title:', output_midi_title) print('Output MIDI summary:', output_midi_summary) print('=' * 70) #======================================================== print('-' * 70) print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('-' * 70) print('Req execution time:', (reqtime.time() - start_time), 'sec') return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot # ================================================================================================= if __name__ == "__main__": PDT = timezone('US/Pacific') print('=' * 70) print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('=' * 70) soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" app = gr.Blocks() with app: gr.Markdown("

Imagen POP Music Medley Diffusion Transformer

") gr.Markdown("

Generate unique POP music medleys with Imagen diffusion transformer

") gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Imagen-POP-Music-Medley-Diffusion-Transformer&style=flat)\n\n" "This is a demo for MIDI Images dataset\n\n" "Please see [MIDI Images](https://huggingface.co/datasets/asigalov61/MIDI-Images) Hugging Face repo for more information\n\n" ) gr.Markdown("## Choose medley settings") input_num_medley_comps = gr.Slider(1, 10, value=5, step=1, label="Number of medley compositions") input_melody_patch = gr.Slider(-1, 127, value=40, step=1, label="Medley melody MIDI patch number") run_btn = gr.Button("Generate POP Medley", variant="primary") gr.Markdown("## Generation results") output_midi_title = gr.Textbox(label="Output MIDI title") output_midi_summary = gr.Textbox(label="Output MIDI summary") output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio") output_plot = gr.Plot(label="Output MIDI score plot") output_midi = gr.File(label="Output MIDI file", file_types=[".mid"]) run_event = run_btn.click(Generate_POP_Medley, [input_num_medley_comps, input_melody_patch], [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot]) app.queue().launch()