Spaces:
Runtime error
Update app.py
Browse filesPrevious
import gradio as gr
import os
import torch
from model import create_ViT
from timeit import default_timer as timer
from typing import Tuple, Dict
# Setup class names
with open("class_names.txt", "r") as f:
class_names = [food_name.strip() for food_name in f.readlines()]
# Create model
ViT_model, ViT_transforms = create_ViT(
num_classes=126,
)
# Load saved weights
ViT_model.load_state_dict(
torch.load(
f="ViT.pth",
map_location=torch.device("cpu"),
)
)
# Create predict function
def predict(img) -> Tuple[Dict, float]:
start_time = timer()
# Transform the target image and add a batch dimension
img = ViT_transforms(img).unsqueeze(0)
# Put model into evaluation mode and turn on inference mode
ViT_model.eval()
with torch.inference_mode():
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
pred_probs = torch.softmax(ViT_model(img), dim=1)
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
# Calculate the prediction time
pred_time = round(timer() - start_time, 5)
# Return the prediction dictionary and prediction time
return pred_labels_and_probs, pred_time
##GRADIO APP
# Create title, description and article strings
title = "FoodVision🍔🍟🍦"
description = "A Vision Transformer feature extractor computer vision model to classify images of food into 126 different classes."
article = "Created by [Rohit](https://github.com/ItsNotRohit02)."
# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=5, label="Predictions"),
gr.Number(label="Prediction time (s)"),
],
examples=example_list,
title=title,
description=description,
article=article,
)
# Launch the app!
demo.launch()
@@ -12,40 +12,37 @@ with open("class_names.txt", "r") as f:
|
|
12 |
|
13 |
|
14 |
# Create model
|
15 |
-
|
16 |
-
num_classes=126,
|
17 |
-
)
|
18 |
|
19 |
# Load saved weights
|
20 |
-
|
21 |
torch.load(
|
22 |
-
f="
|
23 |
map_location=torch.device("cpu"),
|
24 |
)
|
25 |
)
|
26 |
|
27 |
|
28 |
-
# Create predict function
|
29 |
def predict(img) -> Tuple[Dict, float]:
|
30 |
|
31 |
start_time = timer()
|
32 |
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
|
37 |
-
ViT_model.eval()
|
38 |
with torch.inference_mode():
|
39 |
-
|
40 |
-
pred_probs = torch.softmax(ViT_model(img), dim=1)
|
41 |
|
42 |
-
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
|
43 |
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
|
44 |
|
45 |
-
# Calculate the prediction time
|
46 |
pred_time = round(timer() - start_time, 5)
|
47 |
|
48 |
-
# Return the prediction dictionary and prediction time
|
49 |
return pred_labels_and_probs, pred_time
|
50 |
|
51 |
|
|
|
12 |
|
13 |
|
14 |
# Create model
|
15 |
+
model = create_ViT()
|
|
|
|
|
16 |
|
17 |
# Load saved weights
|
18 |
+
model.load_state_dict(
|
19 |
torch.load(
|
20 |
+
f="ViTHg.pth",
|
21 |
map_location=torch.device("cpu"),
|
22 |
)
|
23 |
)
|
24 |
|
25 |
|
|
|
26 |
def predict(img) -> Tuple[Dict, float]:
|
27 |
|
28 |
start_time = timer()
|
29 |
|
30 |
+
preprocess = transforms.Compose([
|
31 |
+
transforms.Resize((224, 224)),
|
32 |
+
transforms.ToTensor(),
|
33 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
34 |
+
])
|
35 |
+
|
36 |
+
img = preprocess(img).unsqueeze(0) # Add batch dimension
|
37 |
|
38 |
+
model.eval()
|
|
|
39 |
with torch.inference_mode():
|
40 |
+
pred_probs = torch.softmax(model(img), dim=1)
|
|
|
41 |
|
|
|
42 |
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
|
43 |
|
|
|
44 |
pred_time = round(timer() - start_time, 5)
|
45 |
|
|
|
46 |
return pred_labels_and_probs, pred_time
|
47 |
|
48 |
|