drhead commited on
Commit
69fc921
1 Parent(s): 3b4f406

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -19
app.py CHANGED
@@ -137,7 +137,7 @@ class GatedHead(torch.nn.Module):
137
 
138
  model.head = GatedHead(min(model.head.weight.shape), 9083)
139
 
140
- safetensors.torch.load_model(model, "JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors")
141
  model.eval()
142
 
143
  with open("tagger_tags.json", "r") as file:
@@ -147,26 +147,36 @@ allowed_tags = list(tags.keys())
147
  for idx, tag in enumerate(allowed_tags):
148
  allowed_tags[idx] = tag.replace("_", " ")
149
 
 
 
150
  @spaces.GPU(duration=5)
151
- def create_tags(image, threshold):
 
152
  img = image.convert('RGB')
153
  tensor = transform(img).unsqueeze(0)
154
 
155
  with torch.no_grad():
156
- probits = model(tensor).squeeze()
157
- indices = torch.where(probits > threshold)[0]
158
- values = probits[indices]
159
 
160
- temp = []
161
  tag_score = dict()
162
  for i in range(indices.size(0)):
163
- temp.append([allowed_tags[indices[i]], values[i].item()])
164
  tag_score[allowed_tags[indices[i]]] = values[i].item()
165
- temp = [t[0] for t in temp]
166
- text_no_impl = ", ".join(temp)
167
- return text_no_impl, tag_score
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- with gr.Blocks() as demo:
170
  gr.Markdown("""
171
  ## Joint Tagger Project: JTP-PILOT² Demo **BETA**
172
  This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
@@ -175,14 +185,30 @@ with gr.Blocks() as demo:
175
 
176
  Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
177
  """)
178
- gr.Interface(
179
- create_tags,
180
- inputs=[gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")],
181
- outputs=[
182
- gr.Textbox(label="Tag String"),
183
- gr.Label(label="Tag Predictions", num_top_classes=200),
184
- ],
185
- allow_flagging="never",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  )
187
 
188
  if __name__ == "__main__":
 
137
 
138
  model.head = GatedHead(min(model.head.weight.shape), 9083)
139
 
140
+ safetensors.torch.load_model(model, "JTP_PILOT2-2-e3-vit_so400m_patch14_siglip_384.safetensors")
141
  model.eval()
142
 
143
  with open("tagger_tags.json", "r") as file:
 
147
  for idx, tag in enumerate(allowed_tags):
148
  allowed_tags[idx] = tag.replace("_", " ")
149
 
150
+ sorted_tag_score = {}
151
+
152
  @spaces.GPU(duration=5)
153
+ def run_classifier(image, threshold):
154
+ global sorted_tag_score
155
  img = image.convert('RGB')
156
  tensor = transform(img).unsqueeze(0)
157
 
158
  with torch.no_grad():
159
+ values, indices = torch.topk(model(tensor)[0], 250)
 
 
160
 
 
161
  tag_score = dict()
162
  for i in range(indices.size(0)):
 
163
  tag_score[allowed_tags[indices[i]]] = values[i].item()
164
+ sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
165
+
166
+ return create_tags(threshold)
167
+
168
+ def create_tags(threshold):
169
+ global sorted_tag_score
170
+ filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
171
+ text_no_impl = ", ".join(filtered_tag_score.keys())
172
+ return text_no_impl, filtered_tag_score
173
+
174
+ def clear_image():
175
+ global sorted_tag_score
176
+ sorted_tag_score = {}
177
+ return "", {}
178
 
179
+ with gr.Blocks(css=".output-class { display: none; }") as demo:
180
  gr.Markdown("""
181
  ## Joint Tagger Project: JTP-PILOT² Demo **BETA**
182
  This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
 
185
 
186
  Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
187
  """)
188
+ with gr.Row():
189
+ with gr.Column():
190
+ image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
191
+ threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")
192
+ with gr.Column():
193
+ tag_string = gr.Textbox(label="Tag String")
194
+ label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
195
+
196
+ image_input.upload(
197
+ fn=run_classifier,
198
+ inputs=[image_input, threshold_slider],
199
+ outputs=[tag_string, label_box]
200
+ )
201
+
202
+ image_input.clear(
203
+ fn=clear_image,
204
+ inputs=[],
205
+ outputs=[tag_string, label_box]
206
+ )
207
+
208
+ threshold_slider.input(
209
+ fn=create_tags,
210
+ inputs=[threshold_slider],
211
+ outputs=[tag_string, label_box]
212
  )
213
 
214
  if __name__ == "__main__":