FluxMusicGUI / audioldm2 /__main__.py
flosstradamus's picture
Upload 194 files
afe1a07 verified
raw
history blame
No virus
4.76 kB
#!D:\GitDownload\SupThirdParty\audioldm2\venv\Scripts\python.exe
import os
import torch
import logging
from audioldm2 import text_to_audio, build_model, save_wave, get_time, read_list
import argparse
os.environ["TOKENIZERS_PARALLELISM"] = "true"
matplotlib_logger = logging.getLogger('matplotlib')
matplotlib_logger.setLevel(logging.WARNING)
parser = argparse.ArgumentParser()
parser.add_argument(
"-t",
"--text",
type=str,
required=False,
default="",
help="Text prompt to the model for audio generation",
)
parser.add_argument(
"--transcription",
type=str,
required=False,
default="",
help="Transcription for Text-to-Speech",
)
parser.add_argument(
"-tl",
"--text_list",
type=str,
required=False,
default="",
help="A file that contains text prompt to the model for audio generation",
)
parser.add_argument(
"-s",
"--save_path",
type=str,
required=False,
help="The path to save model output",
default="./output",
)
parser.add_argument(
"--model_name",
type=str,
required=False,
help="The checkpoint you gonna use",
default="audioldm_48k",
choices=["audioldm_48k", "audioldm_16k_crossattn_t5", "audioldm2-full", "audioldm2-music-665k",
"audioldm2-full-large-1150k", "audioldm2-speech-ljspeech", "audioldm2-speech-gigaspeech"]
)
parser.add_argument(
"-d",
"--device",
type=str,
required=False,
help="The device for computation. If not specified, the script will automatically choose the device based on your environment.",
default="auto",
)
parser.add_argument(
"-b",
"--batchsize",
type=int,
required=False,
default=1,
help="Generate how many samples at the same time",
)
parser.add_argument(
"--ddim_steps",
type=int,
required=False,
default=200,
help="The sampling step for DDIM",
)
parser.add_argument(
"-gs",
"--guidance_scale",
type=float,
required=False,
default=3.5,
help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
)
parser.add_argument(
"-dur",
"--duration",
type=float,
required=False,
default=10.0,
help="The duration of the samples",
)
parser.add_argument(
"-n",
"--n_candidate_gen_per_text",
type=int,
required=False,
default=3,
help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
)
parser.add_argument(
"--seed",
type=int,
required=False,
default=0,
help="Change this value (any integer number) will lead to a different generation result.",
)
args = parser.parse_args()
torch.set_float32_matmul_precision("high")
save_path = os.path.join(args.save_path, get_time())
text = args.text
random_seed = args.seed
duration = args.duration
sample_rate = 16000
if ("audioldm2" in args.model_name):
print(
"Warning: For AudioLDM2 we currently only support 10s of generation. Please use audioldm_48k or audioldm_16k_crossattn_t5 if you want a different duration.")
duration = 10
if ("48k" in args.model_name):
sample_rate = 48000
guidance_scale = args.guidance_scale
n_candidate_gen_per_text = args.n_candidate_gen_per_text
transcription = args.transcription
if (transcription):
if "speech" not in args.model_name:
print(
"Warning: You choose to perform Text-to-Speech by providing the transcription.However you do not choose the correct model name (audioldm2-speech-gigaspeech or audioldm2-speech-ljspeech).")
print("Warning: We will use audioldm2-speech-gigaspeech by default")
args.model_name = "audioldm2-speech-gigaspeech"
if (not text):
print(
"Warning: You should provide text as a input to describe the speaker. Use default (A male reporter is speaking)")
text = "A female reporter is speaking full of emotion"
os.makedirs(save_path, exist_ok=True)
audioldm2 = build_model(model_name=args.model_name, device=args.device)
if (args.text_list):
print("Generate audio based on the text prompts in %s" % args.text_list)
prompt_todo = read_list(args.text_list)
else:
prompt_todo = [text]
for text in prompt_todo:
if ("|" in text):
text, name = text.split("|")
else:
name = text[:128]
if (transcription):
name += "-TTS-%s" % transcription
waveform = text_to_audio(
audioldm2,
text,
transcription=transcription, # To avoid the model to ignore the last vocab
seed=random_seed,
duration=duration,
guidance_scale=guidance_scale,
ddim_steps=args.ddim_steps,
n_candidate_gen_per_text=n_candidate_gen_per_text,
batchsize=args.batchsize,
)
save_wave(waveform, save_path, name=name, samplerate=sample_rate)