brxerq's picture
Update app.py
bb4d038 verified
raw
history blame
No virus
3.71 kB
import os
import cv2
import numpy as np
import importlib.util
from PIL import Image
import gradio as gr
from common_detection import perform_detection, resize_image
# Function to load the TensorFlow Lite model and labels
def load_model_and_labels(model_dir):
pkg = importlib.util.find_spec('tflite_runtime')
if pkg:
from tflite_runtime.interpreter import Interpreter
else:
from tensorflow.lite.python.interpreter import Interpreter
PATH_TO_CKPT = os.path.join(model_dir, 'detect.tflite')
PATH_TO_LABELS = os.path.join(model_dir, 'labelmap.txt')
with open(PATH_TO_LABELS, 'r') as f:
labels = [line.strip() for line in f.readlines()]
if labels[0] == '???':
del(labels[0])
interpreter = Interpreter(model_path=PATH_TO_CKPT)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
height = input_details[0]['shape'][1]
width = input_details[0]['shape'][2]
floating_model = (input_details[0]['dtype'] == np.float32)
return interpreter, labels, input_details, output_details, height, width, floating_model
# Load models
models = {
"Multi-class model": "model",
"Empty class": "model_2",
"Misalignment class": "model_3"
}
# Function to perform image detection
def detect_image(model_choice, input_image):
model_dir = models[model_choice]
interpreter, labels, input_details, output_details, height, width, floating_model = load_model_and_labels(model_dir)
image = np.array(input_image)
resized_image = resize_image(image, size=640)
result_image = perform_detection(resized_image, interpreter, labels, input_details, output_details, height, width, floating_model)
return Image.fromarray(result_image)
# Function to perform video detection
def detect_video(model_choice, input_video):
model_dir = models[model_choice]
interpreter, labels, input_details, output_details, height, width, floating_model = load_model_and_labels(model_dir)
cap = cv2.VideoCapture(input_video)
frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
resized_frame = resize_image(frame, size=640)
result_frame = perform_detection(resized_frame, interpreter, labels, input_details, output_details, height, width, floating_model)
frames.append(result_frame)
cap.release()
if not frames:
raise ValueError("No frames were read from the video.")
height, width, layers = frames[0].shape
size = (width, height)
output_video_path = "result_" + os.path.basename(input_video)
out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 15, size)
for frame in frames:
out.write(frame)
out.release()
return output_video_path
app = gr.Blocks()
with app:
gr.Markdown("## Object Detection using TensorFlow Lite Models")
with gr.Row():
model_choice = gr.Dropdown(label="Select Model", choices=["Multi-class model", "Empty class", "Misalignment class"])
with gr.Tab("Image Detection"):
image_input = gr.Image(type="pil", label="Upload an image")
image_output = gr.Image(type="pil", label="Detection Result")
gr.Button("Submit Image").click(fn=detect_image, inputs=[model_choice, image_input], outputs=image_output)
with gr.Tab("Video Detection"):
video_input = gr.Video(label="Upload a video")
video_output = gr.Video(label="Detection Result")
gr.Button("Submit Video").click(fn=detect_video, inputs=[model_choice, video_input], outputs=video_output)
app.launch(share=True)