brxerq commited on
Commit
9728877
1 Parent(s): 2230f78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -59
app.py CHANGED
@@ -1,32 +1,34 @@
1
- import streamlit as st
2
  import os
3
- import numpy as np
4
  import cv2
 
 
 
5
  from PIL import Image
6
- import tempfile
7
-
8
- # TensorFlow imports
9
- from tensorflow.lite.python.interpreter import Interpreter
10
- if use_TPU:
11
- from tensorflow.lite.python.interpreter import load_delegate
12
 
13
- # Setup the model and labels
14
- MODEL_NAME = 'model'
15
  GRAPH_NAME = 'detect.tflite'
16
  LABELMAP_NAME = 'labelmap.txt'
17
- min_conf_threshold = 0.5
18
- use_TPU = False # Change this based on your needs
19
 
20
- PATH_TO_CKPT = os.path.join('model', GRAPH_NAME)
21
- PATH_TO_LABELS = os.path.join('model', LABELMAP_NAME)
 
 
 
 
 
 
 
 
22
 
23
- # Load labels
24
  with open(PATH_TO_LABELS, 'r') as f:
25
  labels = [line.strip() for line in f.readlines()]
 
26
  if labels[0] == '???':
27
  del(labels[0])
28
 
29
- # Load model
30
  interpreter = Interpreter(model_path=PATH_TO_CKPT)
31
  interpreter.allocate_tensors()
32
 
@@ -34,61 +36,103 @@ input_details = interpreter.get_input_details()
34
  output_details = interpreter.get_output_details()
35
  height = input_details[0]['shape'][1]
36
  width = input_details[0]['shape'][2]
 
 
 
 
37
 
38
- # Streamlit interface
39
- st.title('Object Detection System')
40
- st.sidebar.title('Settings')
41
- uploaded_file = st.sidebar.file_uploader("Choose an image or video file", type=['jpg', 'png', 'jpeg', 'mp4'])
 
42
 
43
- def detect_objects(image):
44
- # Prepare image for detection
45
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
46
  image_resized = cv2.resize(image_rgb, (width, height))
47
  input_data = np.expand_dims(image_resized, axis=0)
48
- input_data = (np.float32(input_data) - 127.5) / 127.5 # Normalize
49
 
50
- # Perform detection
 
 
51
  interpreter.set_tensor(input_details[0]['index'], input_data)
52
  interpreter.invoke()
53
 
54
- # Retrieve detection results
55
- boxes = interpreter.get_tensor(output_details[0]['index'])[0] # Bounding box coordinates of detected objects
56
- classes = interpreter.get_tensor(output_details[1]['index'])[0] # Class index of detected objects
57
- scores = interpreter.get_tensor(output_details[2]['index'])[0] # Confidence of detected objects
58
 
 
59
  for i in range(len(scores)):
60
- if scores[i] > min_conf_threshold and scores[i] <= 1.0:
61
- # Draw bounding boxes and labels on the image
62
- ymin, xmin, ymax, xmax = boxes[i]
63
- (left, right, top, bottom) = (xmin * imW, xmax * imW, ymin * imH, ymax * imH)
64
- cv2.rectangle(image, (int(left), int(top)), (int(right), int(bottom)), (10, 255, 0), 4)
 
 
65
  object_name = labels[int(classes[i])]
66
- label = '%s: %d%%' % (object_name, int(scores[i]*100))
67
- cv2.putText(image, label, (int(left), int(top)-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
 
 
 
 
 
68
  return image
69
 
70
- if uploaded_file is not None:
71
- file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
72
- if uploaded_file.type == "video/mp4":
73
- # Handle video upload
74
- tfile = tempfile.NamedTemporaryFile(delete=False)
75
- tfile.write(uploaded_file.read())
76
-
77
- cap = cv2.VideoCapture(tfile.name)
78
-
79
- stframe = st.empty()
80
-
81
- while cap.isOpened():
82
- ret, frame = cap.read()
83
- if not ret:
84
- break
85
- frame = detect_objects(frame)
86
- frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
87
- stframe.image(frame)
88
- else:
89
- # Handle image upload
90
- image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
91
- image = detect_objects(image)
92
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
93
- st.image(image, use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
 
 
 
1
  import os
 
2
  import cv2
3
+ import numpy as np
4
+ import importlib.util
5
+ import gradio as gr
6
  from PIL import Image
 
 
 
 
 
 
7
 
8
+ # Load the TensorFlow Lite model
9
+ MODEL_DIR = 'model'
10
  GRAPH_NAME = 'detect.tflite'
11
  LABELMAP_NAME = 'labelmap.txt'
 
 
12
 
13
+ pkg = importlib.util.find_spec('tflite_runtime')
14
+ if pkg:
15
+ from tflite_runtime.interpreter import Interpreter
16
+ from tflite_runtime.interpreter import load_delegate
17
+ else:
18
+ from tensorflow.lite.python.interpreter import Interpreter
19
+ from tensorflow.lite.python.interpreter import load_delegate
20
+
21
+ PATH_TO_CKPT = os.path.join(MODEL_DIR, GRAPH_NAME)
22
+ PATH_TO_LABELS = os.path.join(MODEL_DIR, LABELMAP_NAME)
23
 
24
+ # Load the label map
25
  with open(PATH_TO_LABELS, 'r') as f:
26
  labels = [line.strip() for line in f.readlines()]
27
+
28
  if labels[0] == '???':
29
  del(labels[0])
30
 
31
+ # Load the TensorFlow Lite model
32
  interpreter = Interpreter(model_path=PATH_TO_CKPT)
33
  interpreter.allocate_tensors()
34
 
 
36
  output_details = interpreter.get_output_details()
37
  height = input_details[0]['shape'][1]
38
  width = input_details[0]['shape'][2]
39
+ floating_model = (input_details[0]['dtype'] == np.float32)
40
+
41
+ input_mean = 127.5
42
+ input_std = 127.5
43
 
44
+ outname = output_details[0]['name']
45
+ if ('StatefulPartitionedCall' in outname):
46
+ boxes_idx, classes_idx, scores_idx = 1, 3, 0
47
+ else:
48
+ boxes_idx, classes_idx, scores_idx = 0, 1, 2
49
 
50
+ def perform_detection(image, interpreter, labels):
51
+ imH, imW, _ = image.shape
52
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
53
  image_resized = cv2.resize(image_rgb, (width, height))
54
  input_data = np.expand_dims(image_resized, axis=0)
 
55
 
56
+ if floating_model:
57
+ input_data = (np.float32(input_data) - input_mean) / input_std
58
+
59
  interpreter.set_tensor(input_details[0]['index'], input_data)
60
  interpreter.invoke()
61
 
62
+ boxes = interpreter.get_tensor(output_details[boxes_idx]['index'])[0]
63
+ classes = interpreter.get_tensor(output_details[classes_idx]['index'])[0]
64
+ scores = interpreter.get_tensor(output_details[scores_idx]['index'])[0]
 
65
 
66
+ detections = []
67
  for i in range(len(scores)):
68
+ if ((scores[i] > 0.5) and (scores[i] <= 1.0)):
69
+ ymin = int(max(1, (boxes[i][0] * imH)))
70
+ xmin = int(max(1, (boxes[i][1] * imW)))
71
+ ymax = int(min(imH, (boxes[i][2] * imH)))
72
+ xmax = int(min(imW, (boxes[i][3] * imW)))
73
+
74
+ cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (10, 255, 0), 2)
75
  object_name = labels[int(classes[i])]
76
+ label = '%s: %d%%' % (object_name, int(scores[i] * 100))
77
+ labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)
78
+ label_ymin = max(ymin, labelSize[1] + 10)
79
+ cv2.rectangle(image, (xmin, label_ymin - labelSize[1] - 10), (xmin + labelSize[0], label_ymin + baseLine - 10), (255, 255, 255), cv2.FILLED)
80
+ cv2.putText(image, label, (xmin, label_ymin - 7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
81
+
82
+ detections.append([object_name, scores[i], xmin, ymin, xmax, ymax])
83
  return image
84
 
85
+ def detect_image(input_image):
86
+ image = np.array(input_image)
87
+ result_image = perform_detection(image, interpreter, labels)
88
+ return Image.fromarray(result_image)
89
+
90
+ def detect_video(input_video):
91
+ cap = cv2.VideoCapture(input_video.name)
92
+ frames = []
93
+
94
+ while cap.isOpened():
95
+ ret, frame = cap.read()
96
+ if not ret:
97
+ break
98
+
99
+ result_frame = perform_detection(frame, interpreter, labels)
100
+ frames.append(result_frame)
101
+
102
+ cap.release()
103
+
104
+ height, width, layers = frames[0].shape
105
+ size = (width, height)
106
+ output_video_path = "result_" + input_video.name
107
+ out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'DIVX'), 15, size)
108
+
109
+ for frame in frames:
110
+ out.write(frame)
111
+
112
+ out.release()
113
+
114
+ return output_video_path
115
+
116
+ image_input = gr.inputs.Image(type="pil", label="Upload an image")
117
+ image_output = gr.outputs.Image(type="pil", label="Detection Result")
118
+
119
+ video_input = gr.inputs.Video(type="file", label="Upload a video")
120
+ video_output = gr.outputs.Video(label="Detection Result")
121
+
122
+ app = gr.Interface(
123
+ fn=detect_image,
124
+ inputs=image_input,
125
+ outputs=image_output,
126
+ live=True,
127
+ description="Object Detection on Images"
128
+ )
129
+
130
+ app_video = gr.Interface(
131
+ fn=detect_video,
132
+ inputs=video_input,
133
+ outputs=video_output,
134
+ live=True,
135
+ description="Object Detection on Videos"
136
+ )
137
 
138
+ gr.TabbedInterface([app, app_video], ["Image Detection", "Video Detection"]).launch()