{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "VjYy0F2gZIPR" }, "outputs": [], "source": [ "%cd /content\n", "!git clone -b dev https://github.com/camenduru/FluxMusic\n", "%cd C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter/FluxMusic\n", "\n", "!apt -y install -qq aria2\n", "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/FluxMusic/resolve/main/musicflow_b.pt -d C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter/FluxMusic -o musicflow_b.pt\n", "\n", "!pip install transformers diffusers accelerate einops soundfile progressbar unidecode phonemizer torchlibrosa ftfy pandas timm matplotlib numpy==1.26.4 thop flash-attn==2.6.3 sentencepiece" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NoTEt9Wto70D" }, "outputs": [], "source": [ "%cd C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter\n", "\n", "import os\n", "import torch\n", "import argparse\n", "import math\n", "from einops import rearrange, repeat\n", "from PIL import Image\n", "from diffusers import AutoencoderKL\n", "from transformers import SpeechT5HifiGan\n", "\n", "from utils import load_t5, load_clap, load_ae\n", "from train import RF\n", "from constants import build_model\n", "\n", "def prepare(t5, clip, img, prompt):\n", " bs, c, h, w = img.shape\n", " if bs == 1 and not isinstance(prompt, str):\n", " bs = len(prompt)\n", "\n", " img = rearrange(img, \"b c (h ph) (w pw) -> b (h w) (c ph pw)\", ph=2, pw=2)\n", " if img.shape[0] == 1 and bs > 1:\n", " img = repeat(img, \"1 ... -> bs ...\", bs=bs)\n", "\n", " img_ids = torch.zeros(h // 2, w // 2, 3)\n", " img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]\n", " img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]\n", " img_ids = repeat(img_ids, \"h w c -> b (h w) c\", b=bs)\n", "\n", " if isinstance(prompt, str):\n", " prompt = [prompt]\n", " txt = t5(prompt)\n", " if txt.shape[0] == 1 and bs > 1:\n", " txt = repeat(txt, \"1 ... -> bs ...\", bs=bs)\n", " txt_ids = torch.zeros(bs, txt.shape[1], 3)\n", "\n", " vec = clip(prompt)\n", " if vec.shape[0] == 1 and bs > 1:\n", " vec = repeat(vec, \"1 ... -> bs ...\", bs=bs)\n", "\n", " print(img_ids.size(), txt.size(), vec.size())\n", " return img, {\n", " \"img_ids\": img_ids.to(img.device),\n", " \"txt\": txt.to(img.device),\n", " \"txt_ids\": txt_ids.to(img.device),\n", " \"y\": vec.to(img.device),\n", " }\n", "\n", "version=\"base\"\n", "seed=2024\n", "prompt_file=\"C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter/config/example.txt\"\n", "\n", "print('generate with MusicFlux')\n", "torch.manual_seed(seed)\n", "torch.set_grad_enabled(False)\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "latent_size = (256, 16)\n", "\n", "model = build_model(version).to(device)\n", "local_path = 'C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter/musicflow_b.pt'\n", "state_dict = torch.load(local_path, map_location=lambda storage, loc: storage, weights_only=True)\n", "model.load_state_dict(state_dict['ema'])\n", "model.eval() # important!\n", "diffusion = RF()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B5ebyTmto70D" }, "outputs": [], "source": [ "t5 = load_t5(device, max_length=256)\n", "clap = load_clap(device, max_length=256)\n", "\n", "vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder=\"vae\").to(device)\n", "vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder=\"vocoder\").to(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xqG8Px6xo70D" }, "outputs": [], "source": [ "prompt_file=\"C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter/config/example.txt\"\n", "\n", "with open(prompt_file, 'r') as f:\n", " conds_txt = f.readlines()\n", "L = len(conds_txt)\n", "unconds_txt = [\"low quality, gentle\"] * L\n", "print(L, conds_txt, unconds_txt)\n", "\n", "init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).cuda()\n", "\n", "STEPSIZE = 50\n", "img, conds = prepare(t5, clap, init_noise, conds_txt)\n", "_, unconds = prepare(t5, clap, init_noise, unconds_txt)\n", "with torch.autocast(device_type='cuda'):\n", " images = diffusion.sample_with_xps(model, img, conds=conds, null_cond=unconds, sample_steps = STEPSIZE, cfg = 7.0)\n", "\n", "print(images[-1].size(), )\n", "\n", "images = rearrange(\n", " images[-1],\n", " \"b (h w) (c ph pw) -> b c (h ph) (w pw)\",\n", " h=128,\n", " w=8,\n", " ph=2,\n", " pw=2,)\n", "# print(images.size())\n", "latents = 1 / vae.config.scaling_factor * images\n", "mel_spectrogram = vae.decode(latents).sample\n", "print(mel_spectrogram.size())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ytAXlAEdo70D" }, "outputs": [], "source": [ "!mkdir C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter/FluxMusic/b_output\n", "\n", "for i in range(L):\n", " x_i = mel_spectrogram[i]\n", " if x_i.dim() == 4:\n", " x_i = x_i.squeeze(1)\n", " waveform = vocoder(x_i)\n", " waveform = waveform[0].cpu().float().detach().numpy()\n", " print(waveform.shape)\n", " # import soundfile as sf\n", " # sf.write('reconstruct.wav', waveform, samplerate=16000)\n", " from scipy.io import wavfile\n", " wavfile.write('C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter/FluxMusic/b_output/sample_' + str(i) + '.wav', 16000, waveform)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }