File size: 4,476 Bytes
d9da3b4
 
9225658
d9da3b4
 
96bc74a
d9da3b4
 
 
9225658
 
d9da3b4
 
 
 
 
 
 
 
 
9e1a19f
 
 
 
 
 
 
 
d9da3b4
 
 
 
 
9225658
d9da3b4
 
 
 
 
 
96bc74a
 
 
 
d9da3b4
 
 
 
 
9225658
72cb885
 
d9da3b4
 
9225658
 
 
 
d9da3b4
9225658
d9da3b4
 
9225658
d9da3b4
 
 
 
 
 
 
9225658
96bc74a
d9da3b4
 
 
 
 
 
 
 
 
96bc74a
d9da3b4
 
 
 
9225658
96bc74a
9225658
 
 
96bc74a
 
d9da3b4
9225658
 
 
 
d9da3b4
9225658
 
 
 
d9da3b4
9225658
 
 
 
 
 
 
 
d9da3b4
 
 
 
9225658
d9da3b4
 
 
 
 
 
9225658
9e1a19f
96bc74a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from __future__ import annotations

import shlex
import subprocess
from pathlib import Path
from tempfile import TemporaryDirectory
from textwrap import dedent

import numpy as np
import streamlit as st
import torch
from PIL import Image
from transformers import CLIPTokenizer


def hex_to_rgb(s: str) -> tuple[int, int, int]:
    value = s.lstrip("#")
    return (int(value[:2], 16), int(value[2:4], 16), int(value[4:6], 16))


st.header("Color Textual Inversion")
with st.expander(label="info"):
    with open("info.txt", "r", encoding="utf-8") as f:
        st.markdown(f.read())

duplicate_button = """<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Bingsu/color_textual_inversion?duplicate=true"><img style="margin: 0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a>"""
st.markdown(duplicate_button, unsafe_allow_html=True)

col1, col2 = st.columns([15, 85])
color = col1.color_picker("Pick a color", "#00f900")
col2.text_input("", color, disabled=True)

emb_name = st.text_input("Embedding name", color.lstrip("#").upper())
init_token = st.text_input("Initializer token", "init token name")
rgb = hex_to_rgb(color)

img_array = np.zeros((128, 128, 3), dtype=np.uint8)
for i in range(3):
    img_array[..., i] = rgb[i]

dataset_temp = TemporaryDirectory(prefix="dataset_", dir=".")
dataset_path = Path(dataset_temp.name)
output_temp = TemporaryDirectory(prefix="output_", dir=".")
output_path = Path(output_temp.name)

img_path = dataset_path / f"{emb_name}.png"
Image.fromarray(img_array).save(img_path)

with st.sidebar:
    model_name = st.text_input("Model name", "Linaqruf/anything-v3.0")
    steps = st.slider("Steps", 1, 100, value=1, step=1)
    learning_rate = st.text_input("Learning rate", "0.001")
    learning_rate = float(learning_rate)

tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")

# case 1: init_token is not a single token
token = tokenizer.tokenize(init_token)
if len(token) > 1:
    st.warning("Initializer token must be a single token")
    st.stop()

# case 2: init_token already exists in the tokenizer
num_added_tokens = tokenizer.add_tokens(emb_name)
if num_added_tokens == 0:
    st.warning(f"The tokenizer already contains the token {emb_name}")
    st.stop()

cmd = """
accelerate launch textual_inversion.py \
  --pretrained_model_name_or_path={model_name} \
  --train_data_dir={dataset_path} \
  --learnable_property="style" \
  --placeholder_token="{emb_name}" \
  --initializer_token="{init}" \
  --resolution=128 \
  --train_batch_size=1 \
  --repeats=1 \
  --gradient_accumulation_steps=1 \
  --max_train_steps={steps} \
  --learning_rate={lr} \
  --output_dir={output_path} \
  --only_save_embeds
""".strip()

cmd = dedent(cmd).format(
    model_name=model_name,
    dataset_path=dataset_path.as_posix(),
    emb_name=emb_name,
    init=init_token,
    steps=steps,
    lr=learning_rate,
    output_path=output_path.as_posix(),
)
cmd = shlex.split(cmd)

result_path = output_path / "learned_embeds.bin"
captured = ""

start_button = st.button("Start")
download_button = st.empty()

if start_button:
    with st.spinner("Training..."):
        placeholder = st.empty()
        p = subprocess.Popen(
            cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8"
        )

        while line := p.stderr.readline():
            captured += line
            placeholder.code(captured, language="bash")

if not result_path.exists():
    st.stop()

# fix unknown file volume bug
trained_emb = torch.load(result_path, map_location="cpu")
for k, v in trained_emb.items():
    trained_emb[k] = torch.from_numpy(v.numpy())
torch.save(trained_emb, result_path)

file = result_path.read_bytes()
download_button.download_button(f"Download {emb_name}.pt", file, f"{emb_name}.pt")
st.download_button(f"Download {emb_name}.pt ", file, f"{emb_name}.pt")

dataset_temp.cleanup()
output_temp.cleanup()