tori29umai commited on
Commit
8c411d6
1 Parent(s): dce20cd

Delete utils/tagger.py

Browse files
Files changed (1) hide show
  1. utils/tagger.py +0 -137
utils/tagger.py DELETED
@@ -1,137 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # https://github.com/kohya-ss/sd-scripts/blob/main/finetune/tag_images_by_wd14_tagger.py
3
-
4
- import csv
5
- import os
6
- os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
7
-
8
- from PIL import Image
9
- import cv2
10
- import numpy as np
11
- from pathlib import Path
12
- import onnx
13
- import onnxruntime as ort
14
-
15
- # from wd14 tagger
16
- IMAGE_SIZE = 448
17
-
18
- model = None # Initialize model variable
19
-
20
-
21
- def convert_array_to_bgr(array):
22
- """
23
- Convert a NumPy array image to BGR format regardless of its original format.
24
-
25
- Parameters:
26
- - array: NumPy array of the image.
27
-
28
- Returns:
29
- - A NumPy array representing the image in BGR format.
30
- """
31
- # グレースケール画像(2次元配列)
32
- if array.ndim == 2:
33
- # グレースケールをBGRに変換(3チャンネルに拡張)
34
- bgr_array = np.stack((array,) * 3, axis=-1)
35
- # RGBAまたはRGB画像(3次元配列)
36
- elif array.ndim == 3:
37
- # RGBA画像の場合、アルファチャンネルを削除
38
- if array.shape[2] == 4:
39
- array = array[:, :, :3]
40
- # RGBをBGRに変換
41
- bgr_array = array[:, :, ::-1]
42
- else:
43
- raise ValueError("Unsupported array shape.")
44
-
45
- return bgr_array
46
-
47
-
48
- def preprocess_image(image):
49
- image = np.array(image)
50
- image = convert_array_to_bgr(image)
51
-
52
- size = max(image.shape[0:2])
53
- pad_x = size - image.shape[1]
54
- pad_y = size - image.shape[0]
55
- pad_l = pad_x // 2
56
- pad_t = pad_y // 2
57
- image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
58
-
59
- interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
60
- image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
61
-
62
- image = image.astype(np.float32)
63
- return image
64
-
65
- def modelLoad(model_dir):
66
- onnx_path = os.path.join(model_dir, "model.onnx")
67
- # 実行プロバイダーをCPUのみに指定
68
- providers = ['CPUExecutionProvider']
69
- # InferenceSessionの作成時にプロバイダーのリストを指定
70
- ort_session = ort.InferenceSession(onnx_path, providers=providers)
71
- input_name = ort_session.get_inputs()[0].name
72
-
73
- # 実際に使用されているプロバイダーを取得して表示
74
- actual_provider = ort_session.get_providers()[0] # 使用されているプロバイダー
75
- print(f"Using provider: {actual_provider}")
76
-
77
- return [ort_session, input_name]
78
-
79
- def analysis(image_path, model_dir, model):
80
- ort_session = model[0]
81
- input_name = model[1]
82
-
83
- with open(os.path.join(model_dir, "selected_tags.csv"), "r", encoding="utf-8") as f:
84
- reader = csv.reader(f)
85
- l = [row for row in reader]
86
- header = l[0] # tag_id,name,category,count
87
- rows = l[1:]
88
- assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
89
-
90
- general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
91
- character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
92
-
93
- tag_freq = {}
94
- undesired_tags = ["transparent background"]
95
-
96
- image_pil = Image.open(image_path)
97
- image_preprocessed = preprocess_image(image_pil)
98
- image_preprocessed = np.expand_dims(image_preprocessed, axis=0)
99
-
100
- # 推論を実行
101
- prob = ort_session.run(None, {input_name: image_preprocessed})[0][0]
102
- # タグを生成
103
- combined_tags = []
104
- general_tag_text = ""
105
- character_tag_text = ""
106
- remove_underscore = True
107
- caption_separator = ", "
108
- general_threshold = 0.35
109
- character_threshold = 0.35
110
-
111
- for i, p in enumerate(prob[4:]):
112
- if i < len(general_tags) and p >= general_threshold:
113
- tag_name = general_tags[i]
114
- if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
115
- tag_name = tag_name.replace("_", " ")
116
-
117
- if tag_name not in undesired_tags:
118
- tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
119
- general_tag_text += caption_separator + tag_name
120
- combined_tags.append(tag_name)
121
- elif i >= len(general_tags) and p >= character_threshold:
122
- tag_name = character_tags[i - len(general_tags)]
123
- if remove_underscore and len(tag_name) > 3:
124
- tag_name = tag_name.replace("_", " ")
125
-
126
- if tag_name not in undesired_tags:
127
- tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
128
- character_tag_text += caption_separator + tag_name
129
- combined_tags.append(tag_name)
130
-
131
- # 先頭のカンマを取る
132
- if len(general_tag_text) > 0:
133
- general_tag_text = general_tag_text[len(caption_separator) :]
134
- if len(character_tag_text) > 0:
135
- character_tag_text = character_tag_text[len(caption_separator) :]
136
- tag_text = caption_separator.join(combined_tags)
137
- return tag_text