Ad / app.py
aach456's picture
Update app.py
0ea9241 verified
raw
history blame contribute delete
No virus
2.21 kB
import gradio as gr
import torch
import numpy as np
from transformers import MusicgenForConditionalGeneration, AutoProcessor
import scipy.io.wavfile
def generate_music(prompt, unconditional=False):
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
# Generate music
if unconditional:
unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)
else:
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
inputs = processor(text=prompt, padding=True, return_tensors="pt")
audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)
sampling_rate = model.config.audio_encoder.sampling_rate
audio_file = "musicgen_out.wav"
# Ensure audio_values is 1D and scale if necessary
audio_data = audio_values[0].cpu().numpy()
# Check if audio_data is in the correct format
if audio_data.ndim > 1:
audio_data = audio_data[0] # Take the first channel if stereo
# Scale audio data to 16-bit PCM format
audio_data = np.clip(audio_data, -1.0, 1.0) # Ensure values are in the range [-1, 1]
audio_data = (audio_data * 32767).astype(np.int16) # Scale to int16
# Save the generated audio
scipy.io.wavfile.write(audio_file, sampling_rate, audio_data)
return audio_file
def interface(prompt, unconditional):
audio_file = generate_music(prompt, unconditional)
return audio_file
with gr.Blocks() as demo:
gr.Markdown("# AI-Powered Music Generation")
with gr.Row():
prompt_input = gr.Textbox(label="Enter the Music Prompt")
unconditional_checkbox = gr.Checkbox(label="Generate Unconditional Music")
generate_button = gr.Button("Generate Music")
output_audio = gr.Audio(label="Output Music")
generate_button.click(
interface,
inputs=[prompt_input, unconditional_checkbox],
outputs=output_audio,
show_progress=True
)
demo.launch(share=True)