Spaces:
Runtime error
Runtime error
feat: capture progressbar
Browse files
app.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
|
|
3 |
import shutil
|
4 |
import subprocess
|
5 |
from pathlib import Path
|
6 |
from textwrap import dedent
|
7 |
|
8 |
-
import torch
|
9 |
-
import streamlit as st
|
10 |
import numpy as np
|
|
|
|
|
11 |
from PIL import Image
|
12 |
from transformers import CLIPTokenizer
|
13 |
|
@@ -22,6 +23,7 @@ color = col1.color_picker("Pick a color", "#00f900")
|
|
22 |
col2.text_input("", color, disabled=True)
|
23 |
|
24 |
emb_name = st.text_input("Embedding name", color.lstrip("#").upper())
|
|
|
25 |
rgb = hex_to_rgb(color)
|
26 |
|
27 |
img_array = np.zeros((128, 128, 3), dtype=np.uint8)
|
@@ -38,23 +40,22 @@ if output_path.exists():
|
|
38 |
dataset_path.mkdir()
|
39 |
img_path = dataset_path / f"{emb_name}.png"
|
40 |
Image.fromarray(img_array).save(img_path)
|
41 |
-
tokenizer = CLIPTokenizer.from_pretrained(
|
42 |
-
"Linaqruf/anything-v3.0", subfolder="tokenizer"
|
43 |
-
)
|
44 |
|
45 |
with st.sidebar:
|
46 |
-
|
47 |
steps = st.slider("Steps", 1, 100, 30, step=1)
|
48 |
learning_rate = st.text_input("Learning rate", "0.005")
|
49 |
learning_rate = float(learning_rate)
|
50 |
|
51 |
-
|
52 |
-
|
|
|
|
|
53 |
if len(token) > 1:
|
54 |
-
st.warning("
|
55 |
st.stop()
|
56 |
|
57 |
-
# case 2:
|
58 |
num_added_tokens = tokenizer.add_tokens(emb_name)
|
59 |
if num_added_tokens == 0:
|
60 |
st.warning(f"The tokenizer already contains the token {emb_name}")
|
@@ -62,7 +63,7 @@ if num_added_tokens == 0:
|
|
62 |
|
63 |
cmd = """
|
64 |
accelerate launch textual_inversion.py \
|
65 |
-
--pretrained_model_name_or_path=
|
66 |
--train_data_dir="dataset" \
|
67 |
--learnable_property="style" \
|
68 |
--placeholder_token="{emb_name}" \
|
@@ -78,22 +79,39 @@ accelerate launch textual_inversion.py \
|
|
78 |
""".strip()
|
79 |
|
80 |
cmd = dedent(cmd).format(
|
81 |
-
|
|
|
|
|
|
|
|
|
82 |
)
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
|
|
|
|
|
|
|
85 |
with st.spinner("Training..."):
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
result_path = Path("output") / "learned_embeds.bin"
|
89 |
if not result_path.exists():
|
90 |
st.stop()
|
91 |
|
92 |
-
# fix unknown
|
93 |
trained_emb = torch.load(result_path, map_location="cpu")
|
94 |
for k, v in trained_emb.items():
|
95 |
trained_emb[k] = torch.from_numpy(v.numpy())
|
96 |
torch.save(trained_emb, result_path)
|
97 |
|
98 |
file = result_path.read_bytes()
|
99 |
-
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
+
import shlex
|
4 |
import shutil
|
5 |
import subprocess
|
6 |
from pathlib import Path
|
7 |
from textwrap import dedent
|
8 |
|
|
|
|
|
9 |
import numpy as np
|
10 |
+
import streamlit as st
|
11 |
+
import torch
|
12 |
from PIL import Image
|
13 |
from transformers import CLIPTokenizer
|
14 |
|
|
|
23 |
col2.text_input("", color, disabled=True)
|
24 |
|
25 |
emb_name = st.text_input("Embedding name", color.lstrip("#").upper())
|
26 |
+
init_token = st.text_input("Initializer token", "init token name")
|
27 |
rgb = hex_to_rgb(color)
|
28 |
|
29 |
img_array = np.zeros((128, 128, 3), dtype=np.uint8)
|
|
|
40 |
dataset_path.mkdir()
|
41 |
img_path = dataset_path / f"{emb_name}.png"
|
42 |
Image.fromarray(img_array).save(img_path)
|
|
|
|
|
|
|
43 |
|
44 |
with st.sidebar:
|
45 |
+
model_name = st.text_input("Model name", "Linaqruf/anything-v3.0")
|
46 |
steps = st.slider("Steps", 1, 100, 30, step=1)
|
47 |
learning_rate = st.text_input("Learning rate", "0.005")
|
48 |
learning_rate = float(learning_rate)
|
49 |
|
50 |
+
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
|
51 |
+
|
52 |
+
# case 1: init_token is not a single token
|
53 |
+
token = tokenizer.tokenize(init_token)
|
54 |
if len(token) > 1:
|
55 |
+
st.warning("Initializer token must be a single token")
|
56 |
st.stop()
|
57 |
|
58 |
+
# case 2: init_token already exists in the tokenizer
|
59 |
num_added_tokens = tokenizer.add_tokens(emb_name)
|
60 |
if num_added_tokens == 0:
|
61 |
st.warning(f"The tokenizer already contains the token {emb_name}")
|
|
|
63 |
|
64 |
cmd = """
|
65 |
accelerate launch textual_inversion.py \
|
66 |
+
--pretrained_model_name_or_path={model_name} \
|
67 |
--train_data_dir="dataset" \
|
68 |
--learnable_property="style" \
|
69 |
--placeholder_token="{emb_name}" \
|
|
|
79 |
""".strip()
|
80 |
|
81 |
cmd = dedent(cmd).format(
|
82 |
+
model_name=model_name,
|
83 |
+
emb_name=emb_name,
|
84 |
+
init=init_token,
|
85 |
+
lr=learning_rate,
|
86 |
+
steps=steps,
|
87 |
)
|
88 |
+
cmd = shlex.split(cmd)
|
89 |
+
|
90 |
+
result_path = output_path / "learned_embeds.bin"
|
91 |
+
captured = ""
|
92 |
|
93 |
+
start_button = st.button("Start")
|
94 |
+
download_button = st.empty()
|
95 |
+
|
96 |
+
if start_button:
|
97 |
with st.spinner("Training..."):
|
98 |
+
placeholder = st.empty()
|
99 |
+
p = subprocess.Popen(
|
100 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8"
|
101 |
+
)
|
102 |
+
|
103 |
+
while line := p.stderr.readline():
|
104 |
+
captured += line
|
105 |
+
placeholder.code(captured, language="bash")
|
106 |
|
|
|
107 |
if not result_path.exists():
|
108 |
st.stop()
|
109 |
|
110 |
+
# fix unknown file volume bug
|
111 |
trained_emb = torch.load(result_path, map_location="cpu")
|
112 |
for k, v in trained_emb.items():
|
113 |
trained_emb[k] = torch.from_numpy(v.numpy())
|
114 |
torch.save(trained_emb, result_path)
|
115 |
|
116 |
file = result_path.read_bytes()
|
117 |
+
download_button.download_button(f"Download {emb_name}.pt", file, f"{emb_name}.pt")
|