File size: 3,742 Bytes
05fd390
4efa9a6
05fd390
 
 
62b04c4
463536c
4efa9a6
05fd390
 
 
 
 
 
 
463536c
05fd390
 
 
 
 
 
 
 
 
 
463536c
05fd390
463536c
05fd390
 
 
 
 
 
 
 
 
 
463536c
 
05fd390
62b04c4
 
05fd390
 
 
 
 
 
 
 
463536c
 
 
 
 
05fd390
463536c
05fd390
 
 
 
 
8f30316
 
 
 
05fd390
 
 
 
 
 
 
4efa9a6
05fd390
 
 
 
 
f559d19
 
463536c
05fd390
463536c
05fd390
 
 
463536c
 
 
 
 
 
05fd390
 
 
86cbf7f
 
463536c
 
 
 
 
 
 
86cbf7f
 
05fd390
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import os
import subprocess
from huggingface_hub import snapshot_download

hf_token = os.environ.get("HF_TOKEN")


def set_accelerate_default_config():
    try:
        subprocess.run(["accelerate", "config", "default"], check=True)
        print("Accelerate default config set successfully!")
    except subprocess.CalledProcessError as e:
        print(f"An error occurred: {e}")

def train_dreambooth_lora_sdxl(instance_data_dir, lora-trained-xl-folder, instance_prompt, max_train_steps, checkpoint_steps):
    
    script_filename = "train_dreambooth_lora_sdxl.py"  # Assuming it's in the same folder

    command = [
        "accelerate",
        "launch",
        script_filename,  # Use the local script
        "--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
        "--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
        f"--instance_data_dir={instance_data_dir}",
        f"--output_dir={lora-trained-xl-folder}",
        "--mixed_precision=fp16",
        f"--instance_prompt={instance_prompt}",
        "--resolution=1024",
        "--train_batch_size=2",
        "--gradient_accumulation_steps=2",
        "--gradient_checkpointing",
        "--learning_rate=1e-4",
        "--lr_scheduler=constant",
        "--lr_warmup_steps=0",
        "--enable_xformers_memory_efficient_attention",
        "--mixed_precision=fp16",
        "--use_8bit_adam",
        f"--max_train_steps={max_train_steps}",
        f"--checkpointing_steps={checkpoint_steps}",
        "--seed=0",
        "--push_to_hub",
        f"--hub_token={hf_token}"
    ]

    try:
        subprocess.run(command, check=True)
        print("Training is finished!")
    except subprocess.CalledProcessError as e:
        print(f"An error occurred: {e}")

def main(dataset_id, 
         lora-trained-xl-folder,
         instance_prompt,
         max_train_steps,
         checkpoint_steps):

    dataset_repo = dataset_id

    # Automatically set local_dir based on the last part of dataset_repo
    repo_parts = dataset_repo.split("/")
    local_dir = f"./{repo_parts[-1]}"  # Use the last part of the split

    # Check if the directory exists and create it if necessary
    if not os.path.exists(local_dir):
        os.makedirs(local_dir)

    gr.Info("Downloading dataset ...")
    
    snapshot_download(
        dataset_repo,
        local_dir=local_dir,
        repo_type="dataset",
        ignore_patterns=".gitattributes",
        token=hf_token
    )

    set_accelerate_default_config()

    gr.Info("Training begins ...")

    instance_data_dir = repo_parts[-1]
    train_dreambooth_lora_sdxl(instance_data_dir, lora-trained-xl-folder, instance_prompt, max_train_steps, checkpoint_steps)

    return f"Done, your trained model has been stored in your models library: your_user_name/{lora-trained-xl-folder}"

with gr.Blocks() as demo:
    with gr.Column():
        dataset_id = gr.Textbox(label="Dataset ID", placeholder="diffusers/dog-example")
        instance_prompt = gr.Textbox(label="Concept prompt", info="concept prompt - use a unique, made up word to avoid collisions")
        model_output_folder = gr.Textbox(label="Output model folder name", placeholder="lora-trained-xl-folder")
        with gr.Row():
            max_train_steps = gr.Number(value=500)
            checkpoint_steps = gr.Number(value=100)
        train_button = gr.Button("Train !")
        status = gr.Textbox(labe="Training status")

    train_button.click(
        fn = main,
        inputs = [
            dataset_id,
            instance_prompt,
            model_output_folder,
            max_train_steps,
            checkpoint_steps
        ],
        outputs = [status]
    )

demo.queue().launch()