tori29umai commited on
Commit
dce20cd
1 Parent(s): 7f60683

Add application file

Browse files
Files changed (7) hide show
  1. app.py +98 -0
  2. config.json +57 -0
  3. requirements.txt +24 -0
  4. utils/dl_utils.py +72 -0
  5. utils/image_utils.py +64 -0
  6. utils/prompt_utils.py +28 -0
  7. utils/tagger.py +137 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, ControlNetModel, AutoencoderKL
5
+ from PIL import Image
6
+ import os
7
+ import time
8
+
9
+ from utils.dl_utils import dl_cn_model, dl_cn_config, dl_lora_model
10
+ from utils.image_utils import resize_image_aspect_ratio, base_generation
11
+ from utils.prompt_utils import remove_duplicates
12
+
13
+ path = os.getcwd()
14
+ cn_dir = f"{path}/controlnet"
15
+ lora_dir = f"{path}/lora"
16
+ os.makedirs(cn_dir, exist_ok=True)
17
+ os.makedirs(lora_dir, exist_ok=True)
18
+
19
+ dl_cn_model(cn_dir)
20
+ dl_cn_config(cn_dir)
21
+ dl_lora_model(lora_dir)
22
+
23
+ def load_model(lora_dir, cn_dir):
24
+ dtype = torch.float16
25
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
26
+ controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
27
+
28
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
29
+ "cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
30
+ )
31
+ pipe.enable_model_cpu_offload()
32
+ pipe.load_lora_weights(lora_dir, weight_name="Fixhands_anime_bdsqlsz_V1.safetensors")
33
+ return pipe
34
+
35
+ @spaces.GPU(duration=120)
36
+ def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
37
+ pipe = load_model(lora_dir, cn_dir)
38
+ input_image = Image.open(input_image_path)
39
+ base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
40
+ resize_image = resize_image_aspect_ratio(input_image)
41
+ resize_base_image = resize_image_aspect_ratio(base_image)
42
+ generator = torch.manual_seed(0)
43
+ last_time = time.time()
44
+ prompt = "masterpiece, best quality, simple background, white background, bald, nude, " + prompt
45
+ prompt = remove_duplicates(prompt)
46
+ print(prompt)
47
+
48
+ output_image = pipe(
49
+ image=resize_base_image,
50
+ control_image=resize_image,
51
+ strength=1.0,
52
+ prompt=prompt,
53
+ negative_prompt = negative_prompt,
54
+ controlnet_conditioning_scale=float(controlnet_scale),
55
+ generator=generator,
56
+ num_inference_steps=30,
57
+ eta=1.0,
58
+ ).images[0]
59
+ print(f"Time taken: {time.time() - last_time}")
60
+ output_image = output_image.resize(input_image.size, Image.LANCZOS)
61
+ return output_image
62
+
63
+ class Img2Img:
64
+ def __init__(self):
65
+ self.demo = self.layout()
66
+ self.tagger_model = None
67
+ self.input_image_path = None
68
+ self.canny_image = None
69
+
70
+
71
+ def layout(self):
72
+ css = """
73
+ #intro{
74
+ max-width: 32rem;
75
+ text-align: center;
76
+ margin: 0 auto;
77
+ }
78
+ """
79
+ with gr.Blocks(css=css) as demo:
80
+ with gr.Row():
81
+ with gr.Column():
82
+ self.input_image_path = gr.Image(label="input_image", type='filepath')
83
+ self.prompt = gr.Textbox(label="prompt", lines=3)
84
+ self.negative_prompt = gr.Textbox(label="negative_prompt", lines=3, value="nsfw, nipples, bad anatomy, liquid fingers, low quality, worst quality, out of focus, ugly, error, jpeg artifacts, lowers, blurry, bokeh")
85
+ self.controlnet_scale = gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.01, label="Stick_fidelity")
86
+ generate_button = gr.Button("generate")
87
+ with gr.Column():
88
+ self.output_image = gr.Image(type="pil", label="output_image")
89
+
90
+ generate_button.click(
91
+ fn=predict,
92
+ inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
93
+ outputs=self.output_image
94
+ )
95
+ return demo
96
+
97
+ img2img = Img2Img()
98
+ img2img.demo.launch(share=True)
config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.27.2",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": "text_time",
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": 256,
8
+ "attention_head_dim": [
9
+ 5,
10
+ 10,
11
+ 20
12
+ ],
13
+ "block_out_channels": [
14
+ 320,
15
+ 640,
16
+ 1280
17
+ ],
18
+ "class_embed_type": null,
19
+ "conditioning_channels": 3,
20
+ "conditioning_embedding_out_channels": [
21
+ 16,
22
+ 32,
23
+ 96,
24
+ 256
25
+ ],
26
+ "controlnet_conditioning_channel_order": "rgb",
27
+ "cross_attention_dim": 2048,
28
+ "down_block_types": [
29
+ "DownBlock2D",
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D"
32
+ ],
33
+ "downsample_padding": 1,
34
+ "encoder_hid_dim": null,
35
+ "encoder_hid_dim_type": null,
36
+ "flip_sin_to_cos": true,
37
+ "freq_shift": 0,
38
+ "global_pool_conditions": false,
39
+ "in_channels": 4,
40
+ "layers_per_block": 2,
41
+ "mid_block_scale_factor": 1,
42
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
43
+ "norm_eps": 1e-05,
44
+ "norm_num_groups": 32,
45
+ "num_attention_heads": null,
46
+ "num_class_embeds": null,
47
+ "only_cross_attention": false,
48
+ "projection_class_embeddings_input_dim": 2816,
49
+ "resnet_time_scale_shift": "default",
50
+ "transformer_layers_per_block": [
51
+ 1,
52
+ 2,
53
+ 10
54
+ ],
55
+ "upcast_attention": null,
56
+ "use_linear_projection": true
57
+ }
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.29.0
2
+ accelerate
3
+ transformers
4
+ torchvision
5
+ xformers
6
+ accelerate
7
+ invisible-watermark
8
+ huggingface-hub
9
+ hf-transfer
10
+ compel
11
+ opencv-python
12
+ numpy
13
+ diffusers==0.27.0
14
+ transformers
15
+ accelerate
16
+ safetensors
17
+ hidiffusion==0.1.8
18
+ spaces
19
+ torch==2.2
20
+ controlnet-aux==0.0.9
21
+ onnx==1.16.1
22
+ onnxruntime==1.18.0
23
+ mediapipe==0.10.14
24
+ peft==0.11.1
utils/dl_utils.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import requests
4
+ from tqdm import tqdm
5
+ import shutil
6
+
7
+ from PIL import Image, ImageOps
8
+ import numpy as np
9
+ import cv2
10
+
11
+ def dl_cn_model(model_dir):
12
+ folder = model_dir
13
+ file_name = 'diffusion_pytorch_model.safetensors'
14
+ url = "https://huggingface.co/tori29umai/CN_pose3D_V7/resolve/main/CN_pose3D_V7_marged/CN_pose3D_V7_marged.safetensors"
15
+ file_path = os.path.join(folder, file_name)
16
+ if not os.path.exists(file_path):
17
+ response = requests.get(url, allow_redirects=True)
18
+ if response.status_code == 200:
19
+ with open(file_path, 'wb') as f:
20
+ f.write(response.content)
21
+ print(f'Downloaded {file_name}')
22
+ else:
23
+ print(f'Failed to download {file_name}')
24
+ else:
25
+ print(f'{file_name} already exists.')
26
+
27
+ def dl_cn_config(model_dir):
28
+ folder = model_dir
29
+ file_name = 'config.json'
30
+ file_path = os.path.join(folder, file_name)
31
+ if not os.path.exists(file_path):
32
+ config_path = os.path.join(os.getcwd(), file_name)
33
+ shutil.copy(config_path, file_path)
34
+
35
+ def dl_tagger_model(model_dir):
36
+ model_id = 'SmilingWolf/wd-vit-tagger-v3'
37
+ files = [
38
+ 'config.json', 'model.onnx', 'selected_tags.csv', 'sw_jax_cv_config.json'
39
+ ]
40
+
41
+ if not os.path.exists(model_dir):
42
+ os.makedirs(model_dir)
43
+
44
+ for file in files:
45
+ file_path = os.path.join(model_dir, file)
46
+ if not os.path.exists(file_path):
47
+ url = f'https://huggingface.co/{model_id}/resolve/main/{file}'
48
+ response = requests.get(url, allow_redirects=True)
49
+ if response.status_code == 200:
50
+ with open(file_path, 'wb') as f:
51
+ f.write(response.content)
52
+ print(f'Downloaded {file}')
53
+ else:
54
+ print(f'Failed to download {file}')
55
+ else:
56
+ print(f'{file} already exists.')
57
+
58
+
59
+ def dl_lora_model(model_dir):
60
+ file_name = 'Fixhands_anime_bdsqlsz_V1.safetensors'
61
+ file_path = os.path.join(model_dir, file_name)
62
+ if not os.path.exists(file_path):
63
+ url = "https://huggingface.co/bdsqlsz/stable-diffusion-xl-anime-V5/resolve/main/Fixhands_anime_bdsqlsz_V1.safetensors"
64
+ response = requests.get(url, allow_redirects=True)
65
+ if response.status_code == 200:
66
+ with open(file_path, 'wb') as f:
67
+ f.write(response.content)
68
+ print(f'Downloaded {file_name}')
69
+ else:
70
+ print(f'Failed to download {file_name}')
71
+ else:
72
+ print(f'{file_name} already exists.')
utils/image_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageOps
2
+ import numpy as np
3
+ import cv2
4
+
5
+ def canny_process(image_path, threshold1, threshold2):
6
+ # 画像を開き、RGBA形式に変換して透過情報を保持
7
+ img = Image.open(image_path)
8
+ img = img.convert("RGBA")
9
+
10
+ canvas_image = Image.new('RGBA', img.size, (255, 255, 255, 255))
11
+
12
+ # 画像をキャンバスにペーストし、透過部分が白色になるように設定
13
+ canvas_image.paste(img, (0, 0), img)
14
+
15
+ # RGBAからRGBに変換し、透過部分を白色にする
16
+ image_pil = canvas_image.convert("RGB")
17
+ image_np = np.array(image_pil)
18
+
19
+ # グレースケール変換
20
+ gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
21
+ # Cannyエッジ検出
22
+ edges = cv2.Canny(gray, threshold1, threshold2)
23
+
24
+ canny = Image.fromarray(edges)
25
+
26
+
27
+ return canny
28
+
29
+
30
+ def resize_image_aspect_ratio(image):
31
+ # 元の画像サイズを取得
32
+ original_width, original_height = image.size
33
+
34
+ # アスペクト比を計算
35
+ aspect_ratio = original_width / original_height
36
+
37
+ # 標準のアスペクト比サイズを定義
38
+ sizes = {
39
+ 1: (1024, 1024), # 正方形
40
+ 4/3: (1152, 896), # 横長画像
41
+ 3/2: (1216, 832),
42
+ 16/9: (1344, 768),
43
+ 21/9: (1568, 672),
44
+ 3/1: (1728, 576),
45
+ 1/4: (512, 2048), # 縦長画像
46
+ 1/3: (576, 1728),
47
+ 9/16: (768, 1344),
48
+ 2/3: (832, 1216),
49
+ 3/4: (896, 1152)
50
+ }
51
+
52
+ # 最も近いアスペクト比を見つける
53
+ closest_aspect_ratio = min(sizes.keys(), key=lambda x: abs(x - aspect_ratio))
54
+ target_width, target_height = sizes[closest_aspect_ratio]
55
+
56
+ # リサイズ処理
57
+ resized_image = image.resize((target_width, target_height), Image.LANCZOS)
58
+
59
+ return resized_image
60
+
61
+
62
+ def base_generation(size, color):
63
+ canvas = Image.new("RGBA", size, color)
64
+ return canvas
utils/prompt_utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def remove_duplicates(base_prompt):
2
+ # タグの重複を取り除く
3
+ prompt_list = base_prompt.split(", ")
4
+ seen = set()
5
+ unique_tags = []
6
+ for tag in prompt_list :
7
+ tag_clean = tag.lower().strip()
8
+ if tag_clean not in seen and tag_clean != "":
9
+ unique_tags.append(tag)
10
+ seen.add(tag_clean)
11
+ return ", ".join(unique_tags)
12
+
13
+
14
+ def remove_color(base_prompt):
15
+ # タグの色情報を取り除く
16
+ prompt_list = base_prompt.split(", ")
17
+ color_list = ["pink", "red", "orange", "brown", "yellow", "green", "blue", "purple", "blonde", "colored skin", "white hair"]
18
+ # カラータグを除去します。
19
+ cleaned_tags = [tag for tag in prompt_list if all(color.lower() not in tag.lower() for color in color_list)]
20
+ return ", ".join(cleaned_tags)
21
+
22
+
23
+ def execute_prompt(execute_tags, base_prompt):
24
+ prompt_list = base_prompt.split(", ")
25
+ # execute_tagsを除去
26
+ filtered_tags = [tag for tag in prompt_list if tag not in execute_tags]
27
+ # 最終的なプロンプトを生成
28
+ return ", ".join(filtered_tags)
utils/tagger.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # https://github.com/kohya-ss/sd-scripts/blob/main/finetune/tag_images_by_wd14_tagger.py
3
+
4
+ import csv
5
+ import os
6
+ os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
7
+
8
+ from PIL import Image
9
+ import cv2
10
+ import numpy as np
11
+ from pathlib import Path
12
+ import onnx
13
+ import onnxruntime as ort
14
+
15
+ # from wd14 tagger
16
+ IMAGE_SIZE = 448
17
+
18
+ model = None # Initialize model variable
19
+
20
+
21
+ def convert_array_to_bgr(array):
22
+ """
23
+ Convert a NumPy array image to BGR format regardless of its original format.
24
+
25
+ Parameters:
26
+ - array: NumPy array of the image.
27
+
28
+ Returns:
29
+ - A NumPy array representing the image in BGR format.
30
+ """
31
+ # グレースケール画像(2次元配列)
32
+ if array.ndim == 2:
33
+ # グレースケールをBGRに変換(3チャンネルに拡張)
34
+ bgr_array = np.stack((array,) * 3, axis=-1)
35
+ # RGBAまたはRGB画像(3次元配列)
36
+ elif array.ndim == 3:
37
+ # RGBA画像の場合、アルファチャンネルを削除
38
+ if array.shape[2] == 4:
39
+ array = array[:, :, :3]
40
+ # RGBをBGRに変換
41
+ bgr_array = array[:, :, ::-1]
42
+ else:
43
+ raise ValueError("Unsupported array shape.")
44
+
45
+ return bgr_array
46
+
47
+
48
+ def preprocess_image(image):
49
+ image = np.array(image)
50
+ image = convert_array_to_bgr(image)
51
+
52
+ size = max(image.shape[0:2])
53
+ pad_x = size - image.shape[1]
54
+ pad_y = size - image.shape[0]
55
+ pad_l = pad_x // 2
56
+ pad_t = pad_y // 2
57
+ image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
58
+
59
+ interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
60
+ image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
61
+
62
+ image = image.astype(np.float32)
63
+ return image
64
+
65
+ def modelLoad(model_dir):
66
+ onnx_path = os.path.join(model_dir, "model.onnx")
67
+ # 実行プロバイダーをCPUのみに指定
68
+ providers = ['CPUExecutionProvider']
69
+ # InferenceSessionの作成時にプロバイダーのリストを指定
70
+ ort_session = ort.InferenceSession(onnx_path, providers=providers)
71
+ input_name = ort_session.get_inputs()[0].name
72
+
73
+ # 実際に使用されているプロバイダーを取得して表示
74
+ actual_provider = ort_session.get_providers()[0] # 使用されているプロバイダー
75
+ print(f"Using provider: {actual_provider}")
76
+
77
+ return [ort_session, input_name]
78
+
79
+ def analysis(image_path, model_dir, model):
80
+ ort_session = model[0]
81
+ input_name = model[1]
82
+
83
+ with open(os.path.join(model_dir, "selected_tags.csv"), "r", encoding="utf-8") as f:
84
+ reader = csv.reader(f)
85
+ l = [row for row in reader]
86
+ header = l[0] # tag_id,name,category,count
87
+ rows = l[1:]
88
+ assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
89
+
90
+ general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
91
+ character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
92
+
93
+ tag_freq = {}
94
+ undesired_tags = ["transparent background"]
95
+
96
+ image_pil = Image.open(image_path)
97
+ image_preprocessed = preprocess_image(image_pil)
98
+ image_preprocessed = np.expand_dims(image_preprocessed, axis=0)
99
+
100
+ # 推論を実行
101
+ prob = ort_session.run(None, {input_name: image_preprocessed})[0][0]
102
+ # タグを生成
103
+ combined_tags = []
104
+ general_tag_text = ""
105
+ character_tag_text = ""
106
+ remove_underscore = True
107
+ caption_separator = ", "
108
+ general_threshold = 0.35
109
+ character_threshold = 0.35
110
+
111
+ for i, p in enumerate(prob[4:]):
112
+ if i < len(general_tags) and p >= general_threshold:
113
+ tag_name = general_tags[i]
114
+ if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
115
+ tag_name = tag_name.replace("_", " ")
116
+
117
+ if tag_name not in undesired_tags:
118
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
119
+ general_tag_text += caption_separator + tag_name
120
+ combined_tags.append(tag_name)
121
+ elif i >= len(general_tags) and p >= character_threshold:
122
+ tag_name = character_tags[i - len(general_tags)]
123
+ if remove_underscore and len(tag_name) > 3:
124
+ tag_name = tag_name.replace("_", " ")
125
+
126
+ if tag_name not in undesired_tags:
127
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
128
+ character_tag_text += caption_separator + tag_name
129
+ combined_tags.append(tag_name)
130
+
131
+ # 先頭のカンマを取る
132
+ if len(general_tag_text) > 0:
133
+ general_tag_text = general_tag_text[len(caption_separator) :]
134
+ if len(character_tag_text) > 0:
135
+ character_tag_text = character_tag_text[len(caption_separator) :]
136
+ tag_text = caption_separator.join(combined_tags)
137
+ return tag_text