yolo-v8 / app.py
eusholli's picture
Added YouTube URL option
d815415
raw
history blame
No virus
11.1 kB
import yt_dlp
from ultralytics import YOLO
import time
import os
import logging
import av
import cv2
import numpy as np
import streamlit as st
from streamlit_webrtc import WebRtcMode, webrtc_streamer
from utils.download import download_file
from utils.turn import get_ice_servers
from PIL import Image, ImageDraw # Import PIL for image processing
from transformers import pipeline # Import Hugging Face transformers pipeline
import requests
from io import BytesIO # Import for handling byte streams
# CHANGE CODE BELOW HERE, USE TO REPLACE WITH YOUR WANTED ANALYSIS.
# Update below string to set display title of analysis
# Default title - "Facial Sentiment Analysis"
ANALYSIS_TITLE = "YOLO-8 Object Detection Analysis"
# Load the YOLOv8 model
model = YOLO("yolov8n.pt")
# CHANGE THE CONTENTS OF THIS FUNCTION, USE TO REPLACE WITH YOUR WANTED ANALYSIS.
#
# Set analysis results in img_container and result queue for display
# img_container["input"] - holds the input frame contents - of type np.ndarray
# img_container["analyzed"] - holds the analyzed frame with any added annotations - of type np.ndarray
# img_container["analysis_time"] - holds how long the analysis has taken in miliseconds
# result_queue - holds the analysis metadata results - of type dictionary
def analyze_frame(frame: np.ndarray):
start_time = time.time() # Start timing the analysis
img_container["input"] = frame # Store the input frame
frame = frame.copy() # Create a copy of the frame to modify
# Run YOLOv8 tracking on the frame, persisting tracks between frames
results = model.track(frame, persist=True)
# Initialize a list to store Detection objects
detections = []
object_counter = 1
# Iterate over the detected boxes
for box in results[0].boxes:
detection = {}
# Extract class id, label, score, and bounding box coordinates
class_id = int(box.cls)
detection["id"] = object_counter
detection["label"] = model.names[class_id]
detection["score"] = float(box.conf)
detection["box_coords"] = [round(value.item(), 2)
for value in box.xyxy.flatten()]
detections.append(detection)
object_counter += 1
# Visualize the results on the frame
frame = results[0].plot()
end_time = time.time() # End timing the analysis
execution_time_ms = round(
(end_time - start_time) * 1000, 2
) # Calculate execution time in milliseconds
# Store the execution time
img_container["analysis_time"] = execution_time_ms
# store the detections
img_container["detections"] = detections
img_container["analyzed"] = frame # Store the analyzed frame
return # End of the function
#
#
# DO NOT TOUCH THE BELOW CODE (NOT NEEDED)
#
#
# Suppress FFmpeg logs
os.environ["FFMPEG_LOG_LEVEL"] = "quiet"
# Suppress Streamlit logs using the logging module
logging.getLogger("streamlit").setLevel(logging.ERROR)
# Container to hold image data and analysis results
img_container = {"input": None, "analyzed": None,
"analysis_time": None, "detections": None}
# Logger for debugging and information
logger = logging.getLogger(__name__)
# Callback function to process video frames
# This function is called for each video frame in the WebRTC stream.
# It converts the frame to a numpy array in RGB format, analyzes the frame,
# and returns the original frame.
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
# Convert frame to numpy array in RGB format
img = frame.to_ndarray(format="rgb24")
analyze_frame(img) # Analyze the frame
return frame # Return the original frame
# Get ICE servers for WebRTC
ice_servers = get_ice_servers()
# Streamlit UI configuration
st.set_page_config(layout="wide")
# Custom CSS for the Streamlit page
st.markdown(
"""
<style>
.main {
padding: 2rem;
}
h1, h2, h3 {
font-family: 'Arial', sans-serif;
}
h1 {
font-weight: 700;
font-size: 2.5rem;
}
h2 {
font-weight: 600;
font-size: 2rem;
}
h3 {
font-weight: 500;
font-size: 1.5rem;
}
</style>
""",
unsafe_allow_html=True,
)
# Streamlit page title and subtitle
st.title("Computer Vision Playground")
# Add a link to the README file
st.markdown(
"""
<div style="text-align: left;">
<p>See the <a href="https://huggingface.co/spaces/eusholli/sentiment-analyzer/blob/main/README.md"
target="_blank">README</a> to learn how to use this code to help you start your computer vision exploration.</p>
</div>
""",
unsafe_allow_html=True,
)
st.subheader(ANALYSIS_TITLE)
# Columns for input and output streams
col1, col2 = st.columns(2)
with col1:
st.header("Input Stream")
st.subheader("input")
# WebRTC streamer to get video input from the webcam
webrtc_ctx = webrtc_streamer(
key="input-webcam",
mode=WebRtcMode.SENDRECV,
rtc_configuration=ice_servers,
video_frame_callback=video_frame_callback,
media_stream_constraints={"video": True, "audio": False},
async_processing=True,
)
# File uploader for images
st.subheader("Upload an Image")
uploaded_file = st.file_uploader(
"Choose an image...", type=["jpg", "jpeg", "png"])
# Text input for image URL
st.subheader("Or Enter Image URL")
image_url = st.text_input("Image URL")
# Text input for video URL
st.subheader("Enter a YouTube URL")
youtube_url = st.text_input("YouTube URL")
# File uploader for videos
st.subheader("Upload a Video")
uploaded_video = st.file_uploader(
"Choose a video...", type=["mp4", "avi", "mov", "mkv"]
)
# Text input for video URL
st.subheader("Or Enter Video Download URL")
video_url = st.text_input("Video URL")
# Streamlit footer
st.markdown(
"""
<div style="text-align: center; margin-top: 2rem;">
<p>If you want to set up your own computer vision playground see <a href="https://huggingface.co/spaces/eusholli/computer-vision-playground/blob/main/README.md" target="_blank">here</a>.</p>
</div>
""",
unsafe_allow_html=True
)
# Function to initialize the analysis UI
# This function sets up the placeholders and UI elements in the analysis section.
# It creates placeholders for input and output frames, analysis time, and detected labels.
def analysis_init():
global analysis_time, show_labels, labels_placeholder, input_placeholder, output_placeholder
with col2:
st.header("Analysis")
st.subheader("Input Frame")
input_placeholder = st.empty() # Placeholder for input frame
st.subheader("Output Frame")
output_placeholder = st.empty() # Placeholder for output frame
analysis_time = st.empty() # Placeholder for analysis time
show_labels = st.checkbox(
"Show the detected labels", value=True
) # Checkbox to show/hide labels
labels_placeholder = st.empty() # Placeholder for labels
# Function to publish frames and results to the Streamlit UI
# This function retrieves the latest frames and results from the global container and result queue,
# and updates the placeholders in the Streamlit UI with the current input frame, analyzed frame, analysis time, and detected labels.
def publish_frame():
img = img_container["input"]
if img is None:
return
input_placeholder.image(img, channels="RGB") # Display the input frame
analyzed = img_container["analyzed"]
if analyzed is None:
return
# Display the analyzed frame
output_placeholder.image(analyzed, channels="RGB")
time = img_container["analysis_time"]
if time is None:
return
# Display the analysis time
analysis_time.text(f"Analysis Time: {time} ms")
detections = img_container["detections"]
if detections is None:
return
if show_labels:
labels_placeholder.table(
detections
) # Display labels if the checkbox is checked
# If the WebRTC streamer is playing, initialize and publish frames
if webrtc_ctx.state.playing:
analysis_init() # Initialize the analysis UI
while True:
publish_frame() # Publish the frames and results
time.sleep(0.1) # Delay to control frame rate
# If an image is uploaded or a URL is provided, process the image
if uploaded_file is not None or image_url:
analysis_init() # Initialize the analysis UI
if uploaded_file is not None:
image = Image.open(uploaded_file) # Open the uploaded image
img = np.array(image.convert("RGB")) # Convert the image to RGB format
else:
response = requests.get(image_url) # Download the image from the URL
# Open the downloaded image
image = Image.open(BytesIO(response.content))
img = np.array(image.convert("RGB")) # Convert the image to RGB format
analyze_frame(img) # Analyze the image
publish_frame() # Publish the results
# Function to process video files
# This function reads frames from a video file, analyzes each frame for face detection and sentiment analysis,
# and updates the Streamlit UI with the current input frame, analyzed frame, and detected labels.
def process_video(video_path):
cap = cv2.VideoCapture(video_path) # Open the video file
while cap.isOpened():
ret, frame = cap.read() # Read a frame from the video
if not ret:
break # Exit the loop if no more frames are available
# Convert the frame from BGR to RGB format
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Analyze the frame for face detection and sentiment analysis
analyze_frame(rgb_frame)
publish_frame() # Publish the results
cap.release() # Release the video capture object
# Function to get the video stream URL from YouTube using yt-dlp
def get_youtube_stream_url(youtube_url):
ydl_opts = {
'format': 'best[ext=mp4]',
'quiet': True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info_dict = ydl.extract_info(youtube_url, download=False)
stream_url = info_dict['url']
return stream_url
# If a YouTube URL is provided, process the video
if youtube_url:
analysis_init() # Initialize the analysis UI
stream_url = get_youtube_stream_url(youtube_url)
process_video(stream_url) # Process the video
# If a video is uploaded or a URL is provided, process the video
if uploaded_video is not None or video_url:
analysis_init() # Initialize the analysis UI
if uploaded_video is not None:
video_path = uploaded_video.name # Get the name of the uploaded video
with open(video_path, "wb") as f:
# Save the uploaded video to a file
f.write(uploaded_video.getbuffer())
else:
# Download the video from the URL
video_path = download_file(video_url)
process_video(video_path) # Process the video