ItsNotRohit commited on
Commit
7cbc13f
1 Parent(s): c903d33

Update app.py

Browse files

Previous


import gradio as gr
import os
import torch

from model import create_ViT
from timeit import default_timer as timer
from typing import Tuple, Dict

# Setup class names
with open("class_names.txt", "r") as f:
class_names = [food_name.strip() for food_name in f.readlines()]


# Create model
ViT_model, ViT_transforms = create_ViT(
num_classes=126,
)

# Load saved weights
ViT_model.load_state_dict(
torch.load(
f="ViT.pth",
map_location=torch.device("cpu"),
)
)


# Create predict function
def predict(img) -> Tuple[Dict, float]:

start_time = timer()

# Transform the target image and add a batch dimension
img = ViT_transforms(img).unsqueeze(0)

# Put model into evaluation mode and turn on inference mode
ViT_model.eval()
with torch.inference_mode():
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
pred_probs = torch.softmax(ViT_model(img), dim=1)

# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}

# Calculate the prediction time
pred_time = round(timer() - start_time, 5)

# Return the prediction dictionary and prediction time
return pred_labels_and_probs, pred_time


##GRADIO APP
# Create title, description and article strings
title = "FoodVision🍔🍟🍦"
description = "A Vision Transformer feature extractor computer vision model to classify images of food into 126 different classes."
article = "Created by [Rohit](https://github.com/ItsNotRohit02)."

# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]

# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=5, label="Predictions"),
gr.Number(label="Prediction time (s)"),
],
examples=example_list,
title=title,
description=description,
article=article,
)

# Launch the app!
demo.launch()

Files changed (1) hide show
  1. app.py +12 -15
app.py CHANGED
@@ -12,40 +12,37 @@ with open("class_names.txt", "r") as f:
12
 
13
 
14
  # Create model
15
- ViT_model, ViT_transforms = create_ViT(
16
- num_classes=126,
17
- )
18
 
19
  # Load saved weights
20
- ViT_model.load_state_dict(
21
  torch.load(
22
- f="ViT.pth",
23
  map_location=torch.device("cpu"),
24
  )
25
  )
26
 
27
 
28
- # Create predict function
29
  def predict(img) -> Tuple[Dict, float]:
30
 
31
  start_time = timer()
32
 
33
- # Transform the target image and add a batch dimension
34
- img = ViT_transforms(img).unsqueeze(0)
 
 
 
 
 
35
 
36
- # Put model into evaluation mode and turn on inference mode
37
- ViT_model.eval()
38
  with torch.inference_mode():
39
- # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
40
- pred_probs = torch.softmax(ViT_model(img), dim=1)
41
 
42
- # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
43
  pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
44
 
45
- # Calculate the prediction time
46
  pred_time = round(timer() - start_time, 5)
47
 
48
- # Return the prediction dictionary and prediction time
49
  return pred_labels_and_probs, pred_time
50
 
51
 
 
12
 
13
 
14
  # Create model
15
+ model = create_ViT()
 
 
16
 
17
  # Load saved weights
18
+ model.load_state_dict(
19
  torch.load(
20
+ f="ViTHg.pth",
21
  map_location=torch.device("cpu"),
22
  )
23
  )
24
 
25
 
 
26
  def predict(img) -> Tuple[Dict, float]:
27
 
28
  start_time = timer()
29
 
30
+ preprocess = transforms.Compose([
31
+ transforms.Resize((224, 224)),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
34
+ ])
35
+
36
+ img = preprocess(img).unsqueeze(0) # Add batch dimension
37
 
38
+ model.eval()
 
39
  with torch.inference_mode():
40
+ pred_probs = torch.softmax(model(img), dim=1)
 
41
 
 
42
  pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
43
 
 
44
  pred_time = round(timer() - start_time, 5)
45
 
 
46
  return pred_labels_and_probs, pred_time
47
 
48