brxerq commited on
Commit
38d944b
1 Parent(s): 7c3c7c7

Update model_1.py

Browse files
Files changed (1) hide show
  1. model_1.py +74 -7
model_1.py CHANGED
@@ -1,12 +1,11 @@
1
- # model_1.py
2
  import os
3
  import cv2
4
  import numpy as np
5
  import importlib.util
6
- from PIL import Image
7
  import gradio as gr
8
- from common_detection import perform_detection
9
 
 
10
  MODEL_DIR = 'model'
11
  GRAPH_NAME = 'detect.tflite'
12
  LABELMAP_NAME = 'labelmap.txt'
@@ -14,18 +13,22 @@ LABELMAP_NAME = 'labelmap.txt'
14
  pkg = importlib.util.find_spec('tflite_runtime')
15
  if pkg:
16
  from tflite_runtime.interpreter import Interpreter
 
17
  else:
18
  from tensorflow.lite.python.interpreter import Interpreter
 
19
 
20
  PATH_TO_CKPT = os.path.join(MODEL_DIR, GRAPH_NAME)
21
  PATH_TO_LABELS = os.path.join(MODEL_DIR, LABELMAP_NAME)
22
 
 
23
  with open(PATH_TO_LABELS, 'r') as f:
24
  labels = [line.strip() for line in f.readlines()]
25
 
26
  if labels[0] == '???':
27
  del(labels[0])
28
 
 
29
  interpreter = Interpreter(model_path=PATH_TO_CKPT)
30
  interpreter.allocate_tensors()
31
 
@@ -35,10 +38,57 @@ height = input_details[0]['shape'][1]
35
  width = input_details[0]['shape'][2]
36
  floating_model = (input_details[0]['dtype'] == np.float32)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def detect_image(input_image):
39
  image = np.array(input_image)
40
- resized_image = cv2.resize(image, (640, 640))
41
- result_image = perform_detection(resized_image, interpreter, labels, input_details, output_details, height, width, floating_model)
42
  return Image.fromarray(result_image)
43
 
44
  def detect_video(input_video):
@@ -50,8 +100,8 @@ def detect_video(input_video):
50
  if not ret:
51
  break
52
 
53
- resized_frame = cv2.resize(frame, (640, 640))
54
- result_frame = perform_detection(resized_frame, interpreter, labels, input_details, output_details, height, width, floating_model)
55
  frames.append(result_frame)
56
 
57
  cap.release()
@@ -70,3 +120,20 @@ def detect_video(input_video):
70
  out.release()
71
 
72
  return output_video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import cv2
3
  import numpy as np
4
  import importlib.util
 
5
  import gradio as gr
6
+ from PIL import Image
7
 
8
+ # Load the TensorFlow Lite model
9
  MODEL_DIR = 'model'
10
  GRAPH_NAME = 'detect.tflite'
11
  LABELMAP_NAME = 'labelmap.txt'
 
13
  pkg = importlib.util.find_spec('tflite_runtime')
14
  if pkg:
15
  from tflite_runtime.interpreter import Interpreter
16
+ from tflite_runtime.interpreter import load_delegate
17
  else:
18
  from tensorflow.lite.python.interpreter import Interpreter
19
+ from tensorflow.lite.python.interpreter import load_delegate
20
 
21
  PATH_TO_CKPT = os.path.join(MODEL_DIR, GRAPH_NAME)
22
  PATH_TO_LABELS = os.path.join(MODEL_DIR, LABELMAP_NAME)
23
 
24
+ # Load the label map
25
  with open(PATH_TO_LABELS, 'r') as f:
26
  labels = [line.strip() for line in f.readlines()]
27
 
28
  if labels[0] == '???':
29
  del(labels[0])
30
 
31
+ # Load the TensorFlow Lite model
32
  interpreter = Interpreter(model_path=PATH_TO_CKPT)
33
  interpreter.allocate_tensors()
34
 
 
38
  width = input_details[0]['shape'][2]
39
  floating_model = (input_details[0]['dtype'] == np.float32)
40
 
41
+ input_mean = 127.5
42
+ input_std = 127.5
43
+
44
+ outname = output_details[0]['name']
45
+ if ('StatefulPartitionedCall' in outname):
46
+ boxes_idx, classes_idx, scores_idx = 1, 3, 0
47
+ else:
48
+ boxes_idx, classes_idx, scores_idx = 0, 1, 2
49
+
50
+ def perform_detection(image, interpreter, labels):
51
+ imH, imW, _ = image.shape
52
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
53
+ image_resized = cv2.resize(image_rgb, (width, height))
54
+ input_data = np.expand_dims(image_resized, axis=0)
55
+
56
+ if floating_model:
57
+ input_data = (np.float32(input_data) - input_mean) / input_std
58
+
59
+ interpreter.set_tensor(input_details[0]['index'], input_data)
60
+ interpreter.invoke()
61
+
62
+ boxes = interpreter.get_tensor(output_details[boxes_idx]['index'])[0]
63
+ classes = interpreter.get_tensor(output_details[classes_idx]['index'])[0]
64
+ scores = interpreter.get_tensor(output_details[scores_idx]['index'])[0]
65
+
66
+ detections = []
67
+ for i in range(len(scores)):
68
+ if ((scores[i] > 0.5) and (scores[i] <= 1.0)):
69
+ ymin = int(max(1, (boxes[i][0] * imH)))
70
+ xmin = int(max(1, (boxes[i][1] * imW)))
71
+ ymax = int(min(imH, (boxes[i][2] * imH)))
72
+ xmax = int(min(imW, (boxes[i][3] * imW)))
73
+
74
+ cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (10, 255, 0), 2)
75
+ object_name = labels[int(classes[i])]
76
+ label = '%s: %d%%' % (object_name, int(scores[i] * 100))
77
+ labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)
78
+ label_ymin = max(ymin, labelSize[1] + 10)
79
+ cv2.rectangle(image, (xmin, label_ymin - labelSize[1] - 10), (xmin + labelSize[0], label_ymin + baseLine - 10), (255, 255, 255), cv2.FILLED)
80
+ cv2.putText(image, label, (xmin, label_ymin - 7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
81
+
82
+ detections.append([object_name, scores[i], xmin, ymin, xmax, ymax])
83
+ return image
84
+
85
+ def resize_image(image, size=640):
86
+ return cv2.resize(image, (size, size))
87
+
88
  def detect_image(input_image):
89
  image = np.array(input_image)
90
+ resized_image = resize_image(image, size=640) # Resize input image
91
+ result_image = perform_detection(resized_image, interpreter, labels)
92
  return Image.fromarray(result_image)
93
 
94
  def detect_video(input_video):
 
100
  if not ret:
101
  break
102
 
103
+ resized_frame = resize_image(frame, size=640) # Resize each frame
104
+ result_frame = perform_detection(resized_frame, interpreter, labels)
105
  frames.append(result_frame)
106
 
107
  cap.release()
 
120
  out.release()
121
 
122
  return output_video_path
123
+
124
+ app = gr.Blocks()
125
+
126
+ with app:
127
+ with gr.Tab("Image Detection"):
128
+ gr.Markdown("Upload an image for object detection")
129
+ image_input = gr.Image(type="pil", label="Upload an image")
130
+ image_output = gr.Image(type="pil", label="Detection Result")
131
+ gr.Button("Submit").click(fn=detect_image, inputs=image_input, outputs=image_output)
132
+
133
+ with gr.Tab("Video Detection"):
134
+ gr.Markdown("Upload a video for object detection")
135
+ video_input = gr.Video(label="Upload a video")
136
+ video_output = gr.Video(label="Detection Result")
137
+ gr.Button("Submit").click(fn=detect_video, inputs=video_input, outputs=video_output)
138
+
139
+ app.launch()