capradeepgujaran's picture
Create app.py
7b04d4e verified
raw
history blame
5.48 kB
import gradio as gr
import cv2
import numpy as np
from groq import Groq
import time
from PIL import Image
import io
from typing import Optional
class SafetyMonitor:
def __init__(self, api_key: str, model_name: str = "mixtral-8x7b-vision"):
"""
Initialize the safety monitor with configurable model
Args:
api_key (str): Groq API key
model_name (str): Name of the vision model to use
"""
self.client = Groq(api_key=api_key)
self.model_name = model_name
self.analysis_interval = 2 # seconds
def analyze_frame(self, frame: np.ndarray) -> str:
"""
Analyze a single frame using specified vision model
"""
# Convert frame to PIL Image
frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# Convert image to bytes
img_byte_arr = io.BytesIO()
frame_pil.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
# Safety analysis prompt
prompt = """Please analyze this image for workplace safety issues. Focus on:
1. Required PPE usage (hard hats, safety glasses, reflective vests)
2. Unsafe behaviors or positions
3. Equipment and machinery safety
4. Environmental hazards (spills, obstacles, poor lighting)
5. Emergency exit accessibility
Provide specific observations and any immediate safety concerns."""
try:
completion = self.client.chat.completions.create(
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image", "image": img_byte_arr}
]
}
],
model=self.model_name,
max_tokens=200,
temperature=0.2 # Lower temperature for more focused safety analysis
)
return completion.choices[0].message.content
except Exception as e:
return f"Analysis Error: {str(e)}"
def process_video_stream(self):
"""
Process video stream and yield analyzed frames
"""
cap = cv2.VideoCapture(0) # Use 0 for webcam
last_analysis_time = 0
latest_analysis = "Initializing safety analysis..."
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
current_time = time.time()
# Perform analysis at specified intervals
if current_time - last_analysis_time >= self.analysis_interval:
latest_analysis = self.analyze_frame(frame)
last_analysis_time = current_time
# Create a copy of frame for visualization
display_frame = frame.copy()
# Add semi-transparent overlay for text background
overlay = display_frame.copy()
cv2.rectangle(overlay, (5, 5), (640, 200), (0, 0, 0), -1)
cv2.addWeighted(overlay, 0.3, display_frame, 0.7, 0, display_frame)
# Add analysis text
cv2.putText(display_frame, "Safety Analysis:", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# Split and display analysis text
y_position = 60
for line in latest_analysis.split('\n'):
cv2.putText(display_frame, line[:80], (10, y_position),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
y_position += 30
yield display_frame
cap.release()
def create_gradio_interface(monitor: SafetyMonitor):
"""
Create and launch the Gradio interface
"""
with gr.Blocks() as demo:
gr.Markdown(f"""
# Real-time Safety Monitoring System
Using model: {monitor.model_name}
""")
with gr.Row():
video_output = gr.Image(label="Live Feed with Safety Analysis")
with gr.Row():
start_button = gr.Button("Start Monitoring", variant="primary")
stop_button = gr.Button("Stop")
with gr.Row():
interval_slider = gr.Slider(
minimum=1,
maximum=10,
value=monitor.analysis_interval,
step=0.5,
label="Analysis Interval (seconds)"
)
def update_interval(value):
monitor.analysis_interval = value
return gr.update()
def start_monitoring():
return gr.Image.update(value=monitor.process_video_stream())
start_button.click(fn=start_monitoring, outputs=[video_output])
stop_button.click(fn=lambda: None, outputs=[video_output])
interval_slider.change(fn=update_interval, inputs=[interval_slider])
demo.launch(share=True)
def main():
# Replace with your actual API key
GROQ_API_KEY = "YOUR_GROQ_API_KEY"
# Initialize the safety monitor with desired model
monitor = SafetyMonitor(
api_key=GROQ_API_KEY,
model_name="mixtral-8x7b-vision" # Replace with your preferred model
)
# Launch the Gradio interface
create_gradio_interface(monitor)
if __name__ == "__main__":
main()