eusholli commited on
Commit
56a64f9
1 Parent(s): b184fdb

working webcam

Browse files
Files changed (3) hide show
  1. app.py +131 -115
  2. object_detection.py +2 -1
  3. sentiment.py +1 -1
app.py CHANGED
@@ -1,8 +1,3 @@
1
- """Object detection demo with MobileNet SSD.
2
- This model and code are based on
3
- https://github.com/robmarkcole/object-detection-app
4
- """
5
-
6
  import logging
7
  import queue
8
  from pathlib import Path
@@ -17,41 +12,24 @@ from streamlit_webrtc import WebRtcMode, webrtc_streamer
17
  from utils.download import download_file
18
  from utils.turn import get_ice_servers
19
 
20
- HERE = Path(__file__).parent
21
- ROOT = HERE.parent
 
22
 
23
- logger = logging.getLogger(__name__)
24
 
 
 
25
 
26
- MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" # noqa: E501
27
- MODEL_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.caffemodel"
28
- PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" # noqa: E501
29
- PROTOTXT_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.prototxt.txt"
30
-
31
- CLASSES = [
32
- "background",
33
- "aeroplane",
34
- "bicycle",
35
- "bird",
36
- "boat",
37
- "bottle",
38
- "bus",
39
- "car",
40
- "cat",
41
- "chair",
42
- "cow",
43
- "diningtable",
44
- "dog",
45
- "horse",
46
- "motorbike",
47
- "person",
48
- "pottedplant",
49
- "sheep",
50
- "sofa",
51
- "train",
52
- "tvmonitor",
53
- ]
54
 
 
55
 
56
  class Detection(NamedTuple):
57
  class_id: int
@@ -59,94 +37,128 @@ class Detection(NamedTuple):
59
  score: float
60
  box: np.ndarray
61
 
62
-
63
- @st.cache_resource # type: ignore
64
- def generate_label_colors():
65
- return np.random.uniform(0, 255, size=(len(CLASSES), 3))
66
-
67
-
68
- COLORS = generate_label_colors()
69
-
70
- download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
71
- download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
72
-
73
-
74
- # Session-specific caching
75
- cache_key = "object_detection_dnn"
76
- if cache_key in st.session_state:
77
- net = st.session_state[cache_key]
78
- else:
79
- net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
80
- st.session_state[cache_key] = net
81
-
82
- score_threshold = st.slider("Score threshold", 0.0, 1.0, 0.5, 0.05)
83
-
84
  # NOTE: The callback will be called in another thread,
85
  # so use a queue here for thread-safety to pass the data
86
  # from inside to outside the callback.
87
  # TODO: A general-purpose shared state object may be more useful.
88
  result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
92
- image = frame.to_ndarray(format="bgr24")
 
 
 
93
 
94
- # Run inference
95
- blob = cv2.dnn.blobFromImage(
96
- cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
97
- )
98
- net.setInput(blob)
99
- output = net.forward()
100
-
101
- h, w = image.shape[:2]
102
-
103
- # Convert the output array into a structured form.
104
- output = output.squeeze() # (1, 1, N, 7) -> (N, 7)
105
- output = output[output[:, 2] >= score_threshold]
106
- detections = [
107
- Detection(
108
- class_id=int(detection[1]),
109
- label=CLASSES[int(detection[1])],
110
- score=float(detection[2]),
111
- box=(detection[3:7] * np.array([w, h, w, h])),
112
- )
113
- for detection in output
114
- ]
115
-
116
- # Render bounding boxes and captions
117
- for detection in detections:
118
- caption = f"{detection.label}: {round(detection.score * 100, 2)}%"
119
- color = COLORS[detection.class_id]
120
- xmin, ymin, xmax, ymax = detection.box.astype("int")
121
-
122
- cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2)
123
- cv2.putText(
124
- image,
125
- caption,
126
- (xmin, ymin - 15 if ymin - 15 > 15 else ymin + 15),
127
- cv2.FONT_HERSHEY_SIMPLEX,
128
- 0.5,
129
- color,
130
- 2,
131
- )
132
-
133
- result_queue.put(detections)
134
-
135
- return av.VideoFrame.from_ndarray(image, format="bgr24")
136
 
137
  ice_servers = get_ice_servers()
138
 
139
- webrtc_ctx = webrtc_streamer(
140
- key="object-detection",
141
- mode=WebRtcMode.SENDRECV,
142
- rtc_configuration=ice_servers,
143
- video_frame_callback=video_frame_callback,
144
- media_stream_constraints={"video": True, "audio": False},
145
- async_processing=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  )
147
 
148
- if st.checkbox("Show the detected labels", value=True):
149
- if webrtc_ctx.state.playing:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  labels_placeholder = st.empty()
151
  # NOTE: The video transformation with object detection and
152
  # this loop displaying the result labels are running
@@ -157,8 +169,12 @@ if st.checkbox("Show the detected labels", value=True):
157
  result = result_queue.get()
158
  labels_placeholder.table(result)
159
 
160
- st.markdown(
161
- "This demo uses a model and code from "
162
- "https://github.com/robmarkcole/object-detection-app. "
163
- "Many thanks to the project."
164
- )
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
  import queue
3
  from pathlib import Path
 
12
  from utils.download import download_file
13
  from utils.turn import get_ice_servers
14
 
15
+ from mtcnn import MTCNN
16
+ from PIL import Image, ImageDraw
17
+ from transformers import pipeline
18
 
 
19
 
20
+ # Initialize the Hugging Face pipeline for facial emotion detection
21
+ emotion_pipeline = pipeline("image-classification", model="trpakov/vit-face-expression")
22
 
23
+ img_container = {"webcam": None,
24
+ "analyzed": None}
25
+
26
+ # Initialize MTCNN for face detection
27
+ mtcnn = MTCNN()
28
+
29
+ HERE = Path(__file__).parent
30
+ ROOT = HERE.parent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ logger = logging.getLogger(__name__)
33
 
34
  class Detection(NamedTuple):
35
  class_id: int
 
37
  score: float
38
  box: np.ndarray
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # NOTE: The callback will be called in another thread,
41
  # so use a queue here for thread-safety to pass the data
42
  # from inside to outside the callback.
43
  # TODO: A general-purpose shared state object may be more useful.
44
  result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
45
 
46
+ # Function to analyze sentiment
47
+ def analyze_sentiment(face):
48
+ # Convert face to RGB
49
+ rgb_face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
50
+ # Convert the face to a PIL image
51
+ pil_image = Image.fromarray(rgb_face)
52
+ # Analyze sentiment using the Hugging Face pipeline
53
+ results = emotion_pipeline(pil_image)
54
+ # Get the dominant emotion
55
+ dominant_emotion = max(results, key=lambda x: x['score'])['label']
56
+ return dominant_emotion
57
+
58
+ TEXT_SIZE = 1
59
+ LINE_SIZE = 2
60
+
61
+ # Function to detect faces, analyze sentiment, and draw a red box around them
62
+ def detect_and_draw_faces(frame):
63
+ # Detect faces using MTCNN
64
+ results = mtcnn.detect_faces(frame)
65
+
66
+ # Draw on the frame
67
+ for result in results:
68
+ x, y, w, h = result['box']
69
+ face = frame[y:y+h, x:x+w]
70
+ sentiment = analyze_sentiment(face)
71
+ cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 0, 255), LINE_SIZE) # Thicker red box
72
+
73
+ # Calculate position for the text background and the text itself
74
+ text_size = cv2.getTextSize(sentiment, cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, 2)[0]
75
+ text_x = x
76
+ text_y = y - 10
77
+ background_tl = (text_x, text_y - text_size[1])
78
+ background_br = (text_x + text_size[0], text_y + 5)
79
+
80
+ # Draw black rectangle as background
81
+ cv2.rectangle(frame, background_tl, background_br, (0, 0, 0), cv2.FILLED)
82
+ # Draw white text on top
83
+ cv2.putText(frame, sentiment, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, (255, 255, 255), 2)
84
+
85
+ result_queue.put(results)
86
+ return frame
87
 
88
  def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
89
+ img = frame.to_ndarray(format="bgr24")
90
+ img_container["webcam"] = img
91
+ frame_with_boxes = detect_and_draw_faces(img.copy())
92
+ img_container["analyzed"] = frame_with_boxes
93
 
94
+ return frame
95
+ # return av.VideoFrame.from_ndarray(frame_with_boxes, format="bgr24")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  ice_servers = get_ice_servers()
98
 
99
+ # Streamlit UI
100
+ st.markdown(
101
+ """
102
+ <style>
103
+ .main {
104
+ background-color: #F7F7F7;
105
+ padding: 2rem;
106
+ }
107
+ h1, h2, h3 {
108
+ color: #333333;
109
+ font-family: 'Arial', sans-serif;
110
+ }
111
+ h1 {
112
+ font-weight: 700;
113
+ font-size: 2.5rem;
114
+ }
115
+ h2 {
116
+ font-weight: 600;
117
+ font-size: 2rem;
118
+ }
119
+ h3 {
120
+ font-weight: 500;
121
+ font-size: 1.5rem;
122
+ }
123
+ .stButton button {
124
+ background-color: #E60012;
125
+ color: white;
126
+ border-radius: 5px;
127
+ font-size: 16px;
128
+ padding: 0.5rem 1rem;
129
+ }
130
+ </style>
131
+ """,
132
+ unsafe_allow_html=True
133
  )
134
 
135
+ st.title("Computer Vision Test Lab")
136
+ st.subheader("Facial Sentiment Analysis")
137
+
138
+ # Columns for input and output streams
139
+ col1, col2 = st.columns(2)
140
+
141
+ with col1:
142
+ st.header("Input Stream")
143
+ st.subheader("Webcam")
144
+ webrtc_ctx = webrtc_streamer(
145
+ key="object-detection",
146
+ mode=WebRtcMode.SENDRECV,
147
+ rtc_configuration=ice_servers,
148
+ video_frame_callback=video_frame_callback,
149
+ media_stream_constraints={"video": True, "audio": False},
150
+ async_processing=True,
151
+ )
152
+
153
+ with col2:
154
+ st.header("Analysis")
155
+ st.subheader("Input Frame")
156
+ input_placeholder = st.empty()
157
+ st.subheader("Output Frame")
158
+ output_placeholder = st.empty()
159
+
160
+ if webrtc_ctx.state.playing:
161
+ if st.checkbox("Show the detected labels", value=True):
162
  labels_placeholder = st.empty()
163
  # NOTE: The video transformation with object detection and
164
  # this loop displaying the result labels are running
 
169
  result = result_queue.get()
170
  labels_placeholder.table(result)
171
 
172
+ img = img_container["webcam"]
173
+ frame_with_boxes = img_container["analyzed"]
174
+
175
+ if img is None:
176
+ continue
177
+
178
+ input_placeholder.image(img, channels="BGR")
179
+ output_placeholder.image(frame_with_boxes, channels="BGR")
180
+
object_detection.py CHANGED
@@ -134,11 +134,12 @@ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
134
 
135
  return av.VideoFrame.from_ndarray(image, format="bgr24")
136
 
 
137
 
138
  webrtc_ctx = webrtc_streamer(
139
  key="object-detection",
140
  mode=WebRtcMode.SENDRECV,
141
- rtc_configuration={"iceServers": get_ice_servers()},
142
  video_frame_callback=video_frame_callback,
143
  media_stream_constraints={"video": True, "audio": False},
144
  async_processing=True,
 
134
 
135
  return av.VideoFrame.from_ndarray(image, format="bgr24")
136
 
137
+ ice_servers = get_ice_servers()
138
 
139
  webrtc_ctx = webrtc_streamer(
140
  key="object-detection",
141
  mode=WebRtcMode.SENDRECV,
142
+ rtc_configuration=ice_servers,
143
  video_frame_callback=video_frame_callback,
144
  media_stream_constraints={"video": True, "audio": False},
145
  async_processing=True,
sentiment.py CHANGED
@@ -14,7 +14,7 @@ logging.getLogger("transformers").setLevel(logging.ERROR)
14
 
15
  lock = threading.Lock()
16
  img_container = {"webcam": None,
17
- "analzyed": None}
18
 
19
  # Initialize the Hugging Face pipeline for facial emotion detection
20
  emotion_pipeline = pipeline("image-classification", model="trpakov/vit-face-expression")
 
14
 
15
  lock = threading.Lock()
16
  img_container = {"webcam": None,
17
+ "analyzed": None}
18
 
19
  # Initialize the Hugging Face pipeline for facial emotion detection
20
  emotion_pipeline = pipeline("image-classification", model="trpakov/vit-face-expression")