import os import cv2 import numpy as np import importlib.util from PIL import Image import gradio as gr from common_detection import perform_detection, resize_image # Function to load the TensorFlow Lite model and labels def load_model_and_labels(model_dir): pkg = importlib.util.find_spec('tflite_runtime') if pkg: from tflite_runtime.interpreter import Interpreter else: from tensorflow.lite.python.interpreter import Interpreter PATH_TO_CKPT = os.path.join(model_dir, 'detect.tflite') PATH_TO_LABELS = os.path.join(model_dir, 'labelmap.txt') with open(PATH_TO_LABELS, 'r') as f: labels = [line.strip() for line in f.readlines()] if labels[0] == '???': del(labels[0]) interpreter = Interpreter(model_path=PATH_TO_CKPT) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() height = input_details[0]['shape'][1] width = input_details[0]['shape'][2] floating_model = (input_details[0]['dtype'] == np.float32) return interpreter, labels, input_details, output_details, height, width, floating_model # Load models models = { "Multi-class model": "model", "Empty class": "model_2", "Misalignment class": "model_3" } # Function to perform image detection def detect_image(model_choice, input_image): model_dir = models[model_choice] interpreter, labels, input_details, output_details, height, width, floating_model = load_model_and_labels(model_dir) image = np.array(input_image) resized_image = resize_image(image, size=640) result_image = perform_detection(resized_image, interpreter, labels, input_details, output_details, height, width, floating_model) return Image.fromarray(result_image) # Function to perform video detection def detect_video(model_choice, input_video): model_dir = models[model_choice] interpreter, labels, input_details, output_details, height, width, floating_model = load_model_and_labels(model_dir) cap = cv2.VideoCapture(input_video) frames = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break resized_frame = resize_image(frame, size=640) result_frame = perform_detection(resized_frame, interpreter, labels, input_details, output_details, height, width, floating_model) frames.append(result_frame) cap.release() if not frames: raise ValueError("No frames were read from the video.") height, width, layers = frames[0].shape size = (width, height) output_video_path = "result_" + os.path.basename(input_video) out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 15, size) for frame in frames: out.write(frame) out.release() return output_video_path app = gr.Blocks() with app: gr.Markdown("## Object Detection using TensorFlow Lite Models") with gr.Row(): model_choice = gr.Dropdown(label="Select Model", choices=["Multi-class model", "Empty class", "Misalignment class"]) with gr.Tab("Image Detection"): image_input = gr.Image(type="pil", label="Upload an image") image_output = gr.Image(type="pil", label="Detection Result") gr.Button("Submit Image").click(fn=detect_image, inputs=[model_choice, image_input], outputs=image_output) with gr.Tab("Video Detection"): video_input = gr.Video(label="Upload a video") video_output = gr.Video(label="Detection Result") gr.Button("Submit Video").click(fn=detect_video, inputs=[model_choice, video_input], outputs=video_output) app.launch(share=True)