import re import os import json import time import torch import random import shutil import argparse import warnings import gradio as gr import soundfile as sf from transformers import GPT2Config from model import Patchilizer, TunesFormer from convert import abc2xml, xml2img, xml2, transpose_octaves_abc from utils import ( PATCH_NUM_LAYERS, PATCH_LENGTH, CHAR_NUM_LAYERS, PATCH_SIZE, SHARE_WEIGHTS, TEMP_DIR, WEIGHTS_DIR, DEVICE, ) def get_args(parser: argparse.ArgumentParser): parser.add_argument( "-num_tunes", type=int, default=1, help="the number of independently computed returned tunes", ) parser.add_argument( "-max_patch", type=int, default=128, help="integer to define the maximum length in tokens of each tune", ) parser.add_argument( "-top_p", type=float, default=0.8, help="float to define the tokens that are within the sample operation of text generation", ) parser.add_argument( "-top_k", type=int, default=8, help="integer to define the tokens that are within the sample operation of text generation", ) parser.add_argument( "-temperature", type=float, default=1.2, help="the temperature of the sampling operation", ) parser.add_argument("-seed", type=int, default=None, help="seed for randomstate") parser.add_argument( "-show_control_code", type=bool, default=False, help="whether to show control code", ) return parser.parse_args() def get_abc_key_val(text: str, key="K"): pattern = re.escape(key) + r":(.*?)\n" match = re.search(pattern, text) if match: return match.group(1).strip() else: return None def adjust_volume(in_audio: str, dB_change: int): y, sr = sf.read(in_audio) sf.write(in_audio, y * 10 ** (dB_change / 20), sr) def generate_music( args, emo: str, weights: str, outdir=TEMP_DIR, fix_tempo=None, fix_pitch=None, fix_volume=None, ): patchilizer = Patchilizer() patch_config = GPT2Config( num_hidden_layers=PATCH_NUM_LAYERS, max_length=PATCH_LENGTH, max_position_embeddings=PATCH_LENGTH, vocab_size=1, ) char_config = GPT2Config( num_hidden_layers=CHAR_NUM_LAYERS, max_length=PATCH_SIZE, max_position_embeddings=PATCH_SIZE, vocab_size=128, ) model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS) checkpoint = torch.load(weights, map_location=DEVICE) model.load_state_dict(checkpoint["model"]) model = model.to(DEVICE) model.eval() prompt = f"A:{emo}\n" tunes = "" num_tunes = args.num_tunes max_patch = args.max_patch top_p = args.top_p top_k = args.top_k temperature = args.temperature seed = args.seed show_control_code = args.show_control_code print(" Hyper parms ".center(60, "#"), "\n") args_dict: dict = vars(args) for arg in args_dict.keys(): print(f"{arg}: {str(args_dict[arg])}") print("\n", " Output tunes ".center(60, "#")) start_time = time.time() for i in range(num_tunes): title = f"T:{emo} Fragment\n" artist = f"C:Generated by AI\n" tune = f"X:{str(i + 1)}\n{title}{artist}{prompt}" lines = re.split(r"(\n)", tune) tune = "" skip = False for line in lines: if show_control_code or line[:2] not in ["S:", "B:", "E:"]: if not skip: print(line, end="") tune += line skip = False else: skip = True input_patches = torch.tensor( [patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=DEVICE, ) if tune == "": tokens = None else: prefix = patchilizer.decode(input_patches[0]) remaining_tokens = prompt[len(prefix):] tokens = torch.tensor( [patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens], device=DEVICE, ) while input_patches.shape[1] < max_patch: predicted_patch, seed = model.generate( input_patches, tokens, top_p=top_p, top_k=top_k, temperature=temperature, seed=seed, ) tokens = None if predicted_patch[0] != patchilizer.eos_token_id: next_bar = patchilizer.decode([predicted_patch]) if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]: print(next_bar, end="") tune += next_bar if next_bar == "": break next_bar = remaining_tokens + next_bar remaining_tokens = "" predicted_patch = torch.tensor( patchilizer.bar2patch(next_bar), device=DEVICE, ).unsqueeze(0) input_patches = torch.cat( [input_patches, predicted_patch.unsqueeze(0)], dim=1, ) else: break tunes += f"{tune}\n\n" print("\n") # fix tempo if fix_tempo != None: tempo = f"Q:{fix_tempo}\n" else: tempo = f"Q:{random.randint(88, 132)}\n" if emo == "Q1": tempo = f"Q:{random.randint(160, 184)}\n" elif emo == "Q2": tempo = f"Q:{random.randint(184, 228)}\n" elif emo == "Q3": tempo = f"Q:{random.randint(40, 69)}\n" elif emo == "Q4": tempo = f"Q:{random.randint(40, 69)}\n" Q_val = get_abc_key_val(tunes, "Q") if Q_val: tunes = tunes.replace(f"Q:{Q_val}\n", "") K_val = get_abc_key_val(tunes) if K_val == "none": K_val = "C" tunes = tunes.replace("K:none\n", f"K:{K_val}\n") tunes = tunes.replace(f"A:{emo}\n", tempo) # fix mode:major/minor mode = "major" if emo == "Q1" or emo == "Q4" else "minor" if (mode == "major") and ("m" in K_val): tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.split('m')[0]}\n") elif (mode == "minor") and (not "m" in K_val): tunes = tunes.replace( f"\nK:{K_val}\n", f"\nK:{K_val.replace('dor', '')}min\n") print("Generation time: {:.2f} seconds".format(time.time() - start_time)) timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime()) try: # fix avg_pitch (octave) if fix_pitch != None: if fix_pitch: tunes, xml = transpose_octaves_abc( tunes, f"{outdir}/{timestamp}.musicxml", fix_pitch, ) tunes = tunes.replace(title + title, title) os.rename(xml, f"{outdir}/[{emo}]{timestamp}.musicxml") xml = f"{outdir}/[{emo}]{timestamp}.musicxml" else: if mode == "minor": offset = -12 if emo == "Q2": offset -= 12 tunes, xml = transpose_octaves_abc( tunes, f"{outdir}/{timestamp}.musicxml", offset, ) tunes = tunes.replace(title + title, title) os.rename(xml, f"{outdir}/[{emo}]{timestamp}.musicxml") xml = f"{outdir}/[{emo}]{timestamp}.musicxml" else: xml = abc2xml(tunes, f"{outdir}/[{emo}]{timestamp}.musicxml") audio = xml2(xml, "wav") if fix_volume != None: if fix_volume: adjust_volume(audio, fix_volume) elif os.path.exists(audio): if emo == "Q1": adjust_volume(audio, 5) elif emo == "Q2": adjust_volume(audio, 10) mxl = xml2(xml, "mxl") midi = xml2(xml, "mid") pdf, jpg = xml2img(xml) return audio, midi, pdf, xml, mxl, tunes, jpg except Exception as e: print(f"{e}") return generate_music(args, emo, weights) def inference(dataset: str, v: str, a: str, add_chord: bool): if os.path.exists(TEMP_DIR): shutil.rmtree(TEMP_DIR) os.makedirs(TEMP_DIR, exist_ok=True) emotion = "Q1" if v == "Low" and a == "High": emotion = "Q2" elif v == "Low" and a == "Low": emotion = "Q3" elif v == "High" and a == "Low": emotion = "Q4" parser = argparse.ArgumentParser() args = get_args(parser) return generate_music( args, emo=emotion, weights=f"{WEIGHTS_DIR}/{dataset.lower()}/weights.pth", ) def infer( dataset: str, pitch_std: str, mode: str, tempo: int, octave: int, rms: int, add_chord: bool, ): if os.path.exists(TEMP_DIR): shutil.rmtree(TEMP_DIR) os.makedirs(TEMP_DIR, exist_ok=True) emotion = "Q1" if mode == "Minor" and pitch_std == "High": emotion = "Q2" elif mode == "Minor" and pitch_std == "Low": emotion = "Q3" elif mode == "Major" and pitch_std == "Low": emotion = "Q4" parser = argparse.ArgumentParser() args = get_args(parser) return generate_music( args, emo=emotion, weights=f"{WEIGHTS_DIR}/{dataset.lower()}/weights.pth", fix_tempo=tempo, fix_pitch=octave, fix_volume=rms, ) def feedback(fixed_emo: str, source_dir="./flagged", target_dir="./feedbacks"): if not fixed_emo: return "Please select feedback before submitting! " os.makedirs(target_dir, exist_ok=True) for root, _, files in os.walk(source_dir): for file in files: if file.endswith(".mxl"): prompt_emo = file.split("]")[0][1:] if prompt_emo != fixed_emo: file_path = os.path.join(root, file) target_path = os.path.join( target_dir, file.replace(".mxl", f"_{fixed_emo}.mxl") ) shutil.copy(file_path, target_path) return f"Copied {file_path} to {target_path}" else: return "Thanks for your feedback!" return "No .mxl files found in the source directory." def save_template( label: str, pitch_std: str, mode: str, tempo: int, octave: int, rms: int, ): if ( label and pitch_std and mode and tempo != None and octave != None and rms != None ): json_str = json.dumps( { "label": label, "pitch_std": pitch_std == "High", "mode": mode == "Major", "tempo": tempo, "octave": octave, "volume": rms, } ) with open("./feedbacks/templates.jsonl", "a", encoding="utf-8") as file: file.write(json_str + "\n") if __name__ == "__main__": warnings.filterwarnings("ignore") if os.path.exists("./flagged"): shutil.rmtree("./flagged") with gr.Blocks() as demo: gr.Markdown( "## The current CPU-based version on HuggingFace has slow inference, you can access the GPU-based mirror on [ModelScope](https://www.modelscope.cn/studios/monetjoe/EMusicGen)") with gr.Row(): with gr.Column(): dataset_option = gr.Dropdown( ["VGMIDI", "EMOPIA", "Rough4Q"], label="Dataset", value="Rough4Q", ) gr.Markdown( "# Generate by emotion condition
" ) valence_radio = gr.Radio( ["Low", "High"], label="Valence (reflects negative-positive levels of emotion)", value="High", ) arousal_radio = gr.Radio( ["Low", "High"], label="Arousal (reflects the calmness-intensity of the emotion)", value="High", ) chord_check = gr.Checkbox( label="Generate chords (Coming soon)", value=False, ) gen_btn = gr.Button("Generate") gr.Markdown("# Generate by feature control") std_option = gr.Radio( ["Low", "High"], label="Pitch SD", value="High", ) mode_option = gr.Radio( ["Minor", "Major"], label="Mode", value="Major", ) tempo_option = gr.Slider( minimum=40, maximum=228, step=1, value=120, label="Tempo (BPM)", ) octave_option = gr.Slider( minimum=-24, maximum=24, step=12, value=0, label="Octave (±12)", ) volume_option = gr.Slider( minimum=-5, maximum=10, step=5, value=0, label="Volume (dB)", ) chord_check_2 = gr.Checkbox( label="Generate chords (Coming soon)", value=False, ) gen_btn_2 = gr.Button("Generate") template_radio = gr.Radio( ["Q1", "Q2", "Q3", "Q4"], label="The emotion to which the current template belongs", ) save_btn = gr.Button("Save template") gr.Markdown( """ ## Cite ```bibtex @article{Zhou2024EMusicGen, title = {EMusicGen: Emotion-Conditioned Melody Generation in ABC Notation}, author = {Monan Zhou, Xiaobing Li, Feng Yu and Wei Li}, month = {Sep}, year = {2024}, publisher = {GitHub}, version = {0.1}, url = {https://github.com/monetjoe/EMusicGen} } ``` """ ) with gr.Column(): wav_audio = gr.Audio(label="Audio", type="filepath") midi_file = gr.File(label="Download MIDI") pdf_file = gr.File(label="Download PDF score") xml_file = gr.File(label="Download MusicXML") mxl_file = gr.File(label="Download MXL") abc_textbox = gr.Textbox( label="ABC notation", show_copy_button=True, ) staff_img = gr.Image(label="Staff", type="filepath") with gr.Row(): gr.Interface( fn=feedback, inputs=gr.Radio( ["Q1", "Q2", "Q3", "Q4"], label="Feedback: the emotion you believe the generated result should belong to", ), outputs=gr.Textbox(show_copy_button=False, show_label=False), allow_flagging="never", ) gen_btn.click( fn=inference, inputs=[dataset_option, valence_radio, arousal_radio, chord_check], outputs=[ wav_audio, midi_file, pdf_file, xml_file, mxl_file, abc_textbox, staff_img, ], ) gen_btn_2.click( fn=infer, inputs=[ dataset_option, std_option, mode_option, tempo_option, octave_option, volume_option, chord_check, ], outputs=[ wav_audio, midi_file, pdf_file, xml_file, mxl_file, abc_textbox, staff_img, ], ) save_btn.click( fn=save_template, inputs=[ template_radio, std_option, mode_option, tempo_option, octave_option, volume_option, ], ) demo.launch()