brxerq's picture
Update app.py
15b8b99 verified
raw
history blame contribute delete
No virus
2.67 kB
import importlib
import gradio as gr
from PIL import Image
import cv2
import os
# Sample media paths and thumbnails
sample_images = {
"Unorganized": "samples/unorganized.jpg",
"Organized": "samples/organized.jpg"
}
sample_videos = {
"Sample Video": "samples/sample_video.mp4"
}
def load_model(model_name):
module = importlib.import_module(model_name)
return module
models = {
"Multi-class model": "model_1",
"Empty class": "model_2",
"Misalignment class": "model_3"
}
def detect_image(model_choice, input_image=None, sample_image_choice=None):
model = load_model(models[model_choice])
if sample_image_choice:
input_image = Image.open(sample_images[sample_image_choice])
return model.detect_image(input_image)
def detect_video(model_choice, input_video=None, sample_video_choice=None):
model = load_model(models[model_choice])
if sample_video_choice:
input_video = sample_videos[sample_video_choice]
output_video_path = model.detect_video(input_video)
return output_video_path
app = gr.Blocks()
with app:
gr.Markdown("## Object Detection using TensorFlow Lite Models")
with gr.Row():
model_choice = gr.Dropdown(label="Select Model", choices=list(models.keys()), value="Multi-class model")
with gr.Tab("Image Detection"):
with gr.Column():
image_input = gr.Image(type="pil", label="Upload an image (or use sample)")
sample_image_choice = gr.Dropdown(
label="Or select a sample image",
choices=[None] + list(sample_images.keys()),
value=None
)
image_output = gr.Image(type="pil", label="Detection Result")
gr.Button("Submit Image").click(
fn=detect_image,
inputs=[model_choice, image_input, sample_image_choice],
outputs=image_output
)
with gr.Tab("Video Detection"):
with gr.Column():
video_input = gr.Video(label="Upload a video (or use sample)")
sample_video_choice = gr.Dropdown(
label="Or select a sample video",
choices=[None] + list(sample_videos.keys()),
value=None
)
video_output = gr.Video(label="Detection Result")
gr.Button("Submit Video").click(
fn=detect_video,
inputs=[model_choice, video_input, sample_video_choice],
outputs=video_output
)
app.launch(share=True)