brxerq commited on
Commit
20ad1d4
1 Parent(s): 7623bfa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -29
app.py CHANGED
@@ -5,7 +5,10 @@ import importlib.util
5
  import gradio as gr
6
  from PIL import Image
7
 
8
- # Function to load the TensorFlow Lite model and labels
 
 
 
9
  def load_model(model_dir):
10
  GRAPH_NAME = 'detect.tflite'
11
  LABELMAP_NAME = 'labelmap.txt'
@@ -30,18 +33,18 @@ def load_model(model_dir):
30
  width = input_details[0]['shape'][2]
31
  floating_model = (input_details[0]['dtype'] == np.float32)
32
 
33
- return interpreter, labels, input_details, output_details, height, width, floating_model
 
 
 
 
34
 
35
- MODEL_DIRS = {
36
- "Multi-class model": 'model',
37
- "Empty class": 'model_2',
38
- "Misalignment class": 'model_3'
39
- }
40
 
41
- input_mean = 127.5
42
- input_std = 127.5
43
 
44
- def perform_detection(image, interpreter, labels, input_details, output_details, height, width, floating_model):
45
  imH, imW, _ = image.shape
46
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
47
  image_resized = cv2.resize(image_rgb, (width, height))
@@ -53,9 +56,9 @@ def perform_detection(image, interpreter, labels, input_details, output_details,
53
  interpreter.set_tensor(input_details[0]['index'], input_data)
54
  interpreter.invoke()
55
 
56
- boxes = interpreter.get_tensor(output_details[0]['index'])[0]
57
- classes = interpreter.get_tensor(output_details[1]['index'])[0]
58
- scores = interpreter.get_tensor(output_details[2]['index'])[0]
59
 
60
  detections = []
61
  for i in range(len(scores)):
@@ -79,16 +82,19 @@ def perform_detection(image, interpreter, labels, input_details, output_details,
79
  def resize_image(image, size=640):
80
  return cv2.resize(image, (size, size))
81
 
82
- def detect_image(input_image, model_selection):
83
- interpreter, labels, input_details, output_details, height, width, floating_model = load_model(MODEL_DIRS[model_selection])
 
 
84
  image = np.array(input_image)
85
  resized_image = resize_image(image, size=640) # Resize input image
86
- result_image = perform_detection(resized_image, interpreter, labels, input_details, output_details, height, width, floating_model)
87
  return Image.fromarray(result_image)
88
 
89
- def detect_video(input_video, model_selection):
90
- interpreter, labels, input_details, output_details, height, width, floating_model = load_model(MODEL_DIRS[model_selection])
91
-
 
92
  cap = cv2.VideoCapture(input_video)
93
  frames = []
94
 
@@ -98,7 +104,7 @@ def detect_video(input_video, model_selection):
98
  break
99
 
100
  resized_frame = resize_image(frame, size=640) # Resize each frame
101
- result_frame = perform_detection(resized_frame, interpreter, labels, input_details, output_details, height, width, floating_model)
102
  frames.append(result_frame)
103
 
104
  cap.release()
@@ -118,21 +124,29 @@ def detect_video(input_video, model_selection):
118
 
119
  return output_video_path
120
 
121
- app = gr.Blocks()
 
 
 
 
 
 
 
 
122
 
123
  with app:
124
  with gr.Tab("Image Detection"):
125
  gr.Markdown("Upload an image for object detection")
126
- model_selection = gr.Dropdown(choices=["Multi-class model", "Empty class", "Misalignment class"], label="Select Model")
127
- image_input = gr.Image(type="pil", label="Upload an image")
128
- image_output = gr.Image(type="pil", label="Detection Result")
129
- gr.Button("Submit").click(fn=detect_image, inputs=[image_input, model_selection], outputs=image_output)
130
 
131
  with gr.Tab("Video Detection"):
132
  gr.Markdown("Upload a video for object detection")
133
- model_selection = gr.Dropdown(choices=["Multi-class model", "Empty class", "Misalignment class"], label="Select Model")
134
- video_input = gr.Video(label="Upload a video")
135
- video_output = gr.Video(label="Detection Result")
136
- gr.Button("Submit").click(fn=detect_video, inputs=[video_input, model_selection], outputs=video_output)
137
 
138
  app.launch()
 
5
  import gradio as gr
6
  from PIL import Image
7
 
8
+ # Load the TensorFlow Lite models
9
+ MODEL_DIRS = ['model', 'model_2', 'model_3']
10
+ MODEL_NAMES = ['Multi-class model', 'One Empty class', 'Misalignment class']
11
+
12
  def load_model(model_dir):
13
  GRAPH_NAME = 'detect.tflite'
14
  LABELMAP_NAME = 'labelmap.txt'
 
33
  width = input_details[0]['shape'][2]
34
  floating_model = (input_details[0]['dtype'] == np.float32)
35
 
36
+ outname = output_details[0]['name']
37
+ if ('StatefulPartitionedCall' in outname):
38
+ boxes_idx, classes_idx, scores_idx = 1, 3, 0
39
+ else:
40
+ boxes_idx, classes_idx, scores_idx = 0, 1, 2
41
 
42
+ return interpreter, labels, input_details, output_details, height, width, floating_model, boxes_idx, classes_idx, scores_idx
 
 
 
 
43
 
44
+ # Load default model
45
+ interpreter, labels, input_details, output_details, height, width, floating_model, boxes_idx, classes_idx, scores_idx = load_model(MODEL_DIRS[0])
46
 
47
+ def perform_detection(image, interpreter, labels):
48
  imH, imW, _ = image.shape
49
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50
  image_resized = cv2.resize(image_rgb, (width, height))
 
56
  interpreter.set_tensor(input_details[0]['index'], input_data)
57
  interpreter.invoke()
58
 
59
+ boxes = interpreter.get_tensor(output_details[boxes_idx]['index'])[0]
60
+ classes = interpreter.get_tensor(output_details[classes_idx]['index'])[0]
61
+ scores = interpreter.get_tensor(output_details[scores_idx]['index'])[0]
62
 
63
  detections = []
64
  for i in range(len(scores)):
 
82
  def resize_image(image, size=640):
83
  return cv2.resize(image, (size, size))
84
 
85
+ def detect_image(input_image, model_index=0):
86
+ global interpreter, labels, input_details, output_details, height, width, floating_model, boxes_idx, classes_idx, scores_idx
87
+ interpreter, labels, input_details, output_details, height, width, floating_model, boxes_idx, classes_idx, scores_idx = load_model(MODEL_DIRS[model_index])
88
+
89
  image = np.array(input_image)
90
  resized_image = resize_image(image, size=640) # Resize input image
91
+ result_image = perform_detection(resized_image, interpreter, labels)
92
  return Image.fromarray(result_image)
93
 
94
+ def detect_video(input_video, model_index=0):
95
+ global interpreter, labels, input_details, output_details, height, width, floating_model, boxes_idx, classes_idx, scores_idx
96
+ interpreter, labels, input_details, output_details, height, width, floating_model, boxes_idx, classes_idx, scores_idx = load_model(MODEL_DIRS[model_index])
97
+
98
  cap = cv2.VideoCapture(input_video)
99
  frames = []
100
 
 
104
  break
105
 
106
  resized_frame = resize_image(frame, size=640) # Resize each frame
107
+ result_frame = perform_detection(resized_frame, interpreter, labels)
108
  frames.append(result_frame)
109
 
110
  cap.release()
 
124
 
125
  return output_video_path
126
 
127
+ app = gr.Interface(
128
+ fn=None,
129
+ inputs=None,
130
+ outputs=None,
131
+ title="Object Detection",
132
+ description="Detect objects in images and videos.",
133
+ layout="blocks",
134
+ theme="compact",
135
+ )
136
 
137
  with app:
138
  with gr.Tab("Image Detection"):
139
  gr.Markdown("Upload an image for object detection")
140
+ image_input = gr.inputs.Image(type="pil", label="Upload an image")
141
+ image_output = gr.outputs.Image(type="pil", label="Detection Result")
142
+ model_dropdown = gr.inputs.Dropdown(choices=MODEL_NAMES, label="Select Model")
143
+ gr.Button("Submit").on_click(detect_image, inputs=[image_input, model_dropdown], outputs=image_output)
144
 
145
  with gr.Tab("Video Detection"):
146
  gr.Markdown("Upload a video for object detection")
147
+ video_input = gr.inputs.Video(label="Upload a video")
148
+ video_output = gr.outputs.Video(label="Detection Result")
149
+ model_dropdown = gr.inputs.Dropdown(choices=MODEL_NAMES, label="Select Model")
150
+ gr.Button("Submit").on_click(detect_video, inputs=[video_input, model_dropdown], outputs=video_output)
151
 
152
  app.launch()