Tirath5504 commited on
Commit
9a88ef0
1 Parent(s): 62e71de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -14,14 +14,21 @@ def predict(image):
14
  with torch.no_grad():
15
  outputs = model(**inputs).logits
16
 
17
- predicted_class_idx = outputs.argmax(-1).item()
 
 
 
 
 
 
 
18
  predicted_class = class_names[predicted_class_idx]
19
 
20
- return predicted_class
21
 
22
  iface = gr.Interface(fn=predict,
23
  inputs=gr.Image(type="pil"),
24
- outputs=gr.Label(num_top_classes=1),
25
  title="Hateful Content Detection",
26
  description="Upload an image to classify hateful gestures or symbols")
27
 
 
14
  with torch.no_grad():
15
  outputs = model(**inputs).logits
16
 
17
+ # predicted_class_idx = outputs.argmax(-1).item()
18
+ # predicted_class = class_names[predicted_class_idx]
19
+
20
+ # return predicted_class
21
+
22
+ probabilities = F.softmax(outputs, dim=1)
23
+ confidence_score = probabilities[0][predicted_class_idx].item()
24
+ predicted_class_idx = probabilities.argmax(-1).item()
25
  predicted_class = class_names[predicted_class_idx]
26
 
27
+ return predicted_class, confidence_score
28
 
29
  iface = gr.Interface(fn=predict,
30
  inputs=gr.Image(type="pil"),
31
+ outputs=[gr.Label(num_top_classes=1),gr.Label()]
32
  title="Hateful Content Detection",
33
  description="Upload an image to classify hateful gestures or symbols")
34