Bingsu commited on
Commit
96bc74a
1 Parent(s): f9f9581

fix: use tempdir

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -1,9 +1,9 @@
1
  from __future__ import annotations
2
 
3
  import shlex
4
- import shutil
5
  import subprocess
6
  from pathlib import Path
 
7
  from textwrap import dedent
8
 
9
  import numpy as np
@@ -38,15 +38,11 @@ img_array = np.zeros((128, 128, 3), dtype=np.uint8)
38
  for i in range(3):
39
  img_array[..., i] = rgb[i]
40
 
41
- dataset_path = Path("dataset")
42
- output_path = Path("output")
43
- if dataset_path.exists():
44
- shutil.rmtree(dataset_path)
45
- if output_path.exists():
46
- shutil.rmtree(output_path)
47
 
48
- dataset_path.mkdir()
49
- output_path.mkdir()
50
  img_path = dataset_path / f"{emb_name}.png"
51
  Image.fromarray(img_array).save(img_path)
52
 
@@ -73,7 +69,7 @@ if num_added_tokens == 0:
73
  cmd = """
74
  accelerate launch textual_inversion.py \
75
  --pretrained_model_name_or_path={model_name} \
76
- --train_data_dir="dataset" \
77
  --learnable_property="style" \
78
  --placeholder_token="{emb_name}" \
79
  --initializer_token="{init}" \
@@ -83,16 +79,18 @@ accelerate launch textual_inversion.py \
83
  --gradient_accumulation_steps=1 \
84
  --max_train_steps={steps} \
85
  --learning_rate={lr} \
86
- --output_dir="output" \
87
  --only_save_embeds
88
  """.strip()
89
 
90
  cmd = dedent(cmd).format(
91
  model_name=model_name,
 
92
  emb_name=emb_name,
93
  init=init_token,
94
- lr=learning_rate,
95
  steps=steps,
 
 
96
  )
97
  cmd = shlex.split(cmd)
98
 
@@ -125,3 +123,6 @@ torch.save(trained_emb, result_path)
125
  file = result_path.read_bytes()
126
  download_button.download_button(f"Download {emb_name}.pt", file, f"{emb_name}.pt")
127
  st.download_button(f"Download {emb_name}.pt ", file, f"{emb_name}.pt")
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import shlex
 
4
  import subprocess
5
  from pathlib import Path
6
+ from tempfile import TemporaryDirectory
7
  from textwrap import dedent
8
 
9
  import numpy as np
 
38
  for i in range(3):
39
  img_array[..., i] = rgb[i]
40
 
41
+ dataset_temp = TemporaryDirectory(prefix="dataset_", dir=".")
42
+ dataset_path = Path(dataset_temp.name)
43
+ output_temp = TemporaryDirectory(prefix="output_", dir=".")
44
+ output_path = Path(output_temp.name)
 
 
45
 
 
 
46
  img_path = dataset_path / f"{emb_name}.png"
47
  Image.fromarray(img_array).save(img_path)
48
 
 
69
  cmd = """
70
  accelerate launch textual_inversion.py \
71
  --pretrained_model_name_or_path={model_name} \
72
+ --train_data_dir={dataset_path} \
73
  --learnable_property="style" \
74
  --placeholder_token="{emb_name}" \
75
  --initializer_token="{init}" \
 
79
  --gradient_accumulation_steps=1 \
80
  --max_train_steps={steps} \
81
  --learning_rate={lr} \
82
+ --output_dir={output_path} \
83
  --only_save_embeds
84
  """.strip()
85
 
86
  cmd = dedent(cmd).format(
87
  model_name=model_name,
88
+ dataset_path=dataset_path.as_posix(),
89
  emb_name=emb_name,
90
  init=init_token,
 
91
  steps=steps,
92
+ lr=learning_rate,
93
+ output_path=output_path.as_posix(),
94
  )
95
  cmd = shlex.split(cmd)
96
 
 
123
  file = result_path.read_bytes()
124
  download_button.download_button(f"Download {emb_name}.pt", file, f"{emb_name}.pt")
125
  st.download_button(f"Download {emb_name}.pt ", file, f"{emb_name}.pt")
126
+
127
+ dataset_temp.cleanup()
128
+ output_temp.cleanup()