Spaces:
Sleeping
Sleeping
Create main.py
Browse files
main.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
from PIL import Image
|
6 |
+
import shutil
|
7 |
+
from ultralytics import YOLO
|
8 |
+
|
9 |
+
def load_models(models_dir='models', info_file='models_info.json'):
|
10 |
+
"""
|
11 |
+
Load YOLO models and their information from the specified directory and JSON file.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
models_dir (str): Path to the models directory.
|
15 |
+
info_file (str): Path to the JSON file containing model info.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
dict: A dictionary of models and their associated information.
|
19 |
+
"""
|
20 |
+
with open(info_file, 'r') as f:
|
21 |
+
models_info = json.load(f)
|
22 |
+
|
23 |
+
models = {}
|
24 |
+
for model_info in models_info:
|
25 |
+
model_name = model_info['model_name']
|
26 |
+
model_path = os.path.join(models_dir, model_name, 'best.pt') # Assuming 'best.pt' as the weight file
|
27 |
+
if os.path.isfile(model_path):
|
28 |
+
try:
|
29 |
+
# Load the YOLO model
|
30 |
+
model = YOLO(model_path)
|
31 |
+
models[model_name] = {
|
32 |
+
'model': model,
|
33 |
+
'mAP': model_info.get('mAP_score', 'N/A'),
|
34 |
+
'num_images': model_info.get('num_images', 'N/A')
|
35 |
+
}
|
36 |
+
print(f"Loaded model '{model_name}' from '{model_path}'.")
|
37 |
+
except Exception as e:
|
38 |
+
print(f"Error loading model '{model_name}': {e}")
|
39 |
+
else:
|
40 |
+
print(f"Model weight file for '{model_name}' not found at '{model_path}'. Skipping.")
|
41 |
+
return models
|
42 |
+
|
43 |
+
def get_model_info(model_name, models):
|
44 |
+
"""
|
45 |
+
Retrieve model information for the selected model.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
model_name (str): The name of the model.
|
49 |
+
models (dict): The dictionary containing models and their info.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
str: A formatted string containing model information.
|
53 |
+
"""
|
54 |
+
model_info = models.get(model_name, {})
|
55 |
+
if not model_info:
|
56 |
+
return "Model information not available."
|
57 |
+
info_text = (
|
58 |
+
f"**Model Name:** {model_name}\n\n"
|
59 |
+
f"**mAP Score:** {model_info.get('mAP', 'N/A')}\n\n"
|
60 |
+
f"**Number of Images Trained On:** {model_info.get('num_images', 'N/A')}"
|
61 |
+
)
|
62 |
+
return info_text
|
63 |
+
|
64 |
+
def predict_image(model_name, image, models):
|
65 |
+
"""
|
66 |
+
Perform prediction on an uploaded image using the selected YOLO model.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
model_name (str): The name of the selected model.
|
70 |
+
image (PIL.Image.Image): The uploaded image.
|
71 |
+
models (dict): The dictionary containing models and their info.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
tuple: A status message, the processed image, and the path to the output image.
|
75 |
+
"""
|
76 |
+
model = models.get(model_name, {}).get('model', None)
|
77 |
+
if not model:
|
78 |
+
return "Error: Model not found.", None, None
|
79 |
+
try:
|
80 |
+
# Save the uploaded image to a temporary path
|
81 |
+
input_image_path = f"temp/{model_name}_input_image.jpg"
|
82 |
+
os.makedirs(os.path.dirname(input_image_path), exist_ok=True)
|
83 |
+
image.save(input_image_path)
|
84 |
+
|
85 |
+
# Perform prediction
|
86 |
+
results = model(input_image_path, save=True, save_txt=False, conf=0.25)
|
87 |
+
# Ultralytics saves the result images in 'runs/detect/predict'
|
88 |
+
output_image_path = results[0].save()[0] # Get the path to the saved image
|
89 |
+
|
90 |
+
# Open the output image
|
91 |
+
output_image = Image.open(output_image_path)
|
92 |
+
|
93 |
+
return "Prediction completed successfully.", output_image, output_image_path
|
94 |
+
except Exception as e:
|
95 |
+
return f"Error during prediction: {str(e)}", None, None
|
96 |
+
|
97 |
+
def predict_video(model_name, video, models):
|
98 |
+
"""
|
99 |
+
Perform prediction on an uploaded video using the selected YOLO model.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
model_name (str): The name of the selected model.
|
103 |
+
video (str): Path to the uploaded video file.
|
104 |
+
models (dict): The dictionary containing models and their info.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
tuple: A status message, the processed video, and the path to the output video.
|
108 |
+
"""
|
109 |
+
model = models.get(model_name, {}).get('model', None)
|
110 |
+
if not model:
|
111 |
+
return "Error: Model not found.", None, None
|
112 |
+
try:
|
113 |
+
# Ensure the video is saved in a temporary location
|
114 |
+
input_video_path = video.name
|
115 |
+
if not os.path.isfile(input_video_path):
|
116 |
+
# If the video is a temp file provided by Gradio
|
117 |
+
shutil.copy(video.name, input_video_path)
|
118 |
+
|
119 |
+
# Perform prediction
|
120 |
+
results = model(input_video_path, save=True, save_txt=False, conf=0.25)
|
121 |
+
# Ultralytics saves the result videos in 'runs/detect/predict'
|
122 |
+
output_video_path = results[0].save()[0] # Get the path to the saved video
|
123 |
+
|
124 |
+
return "Prediction completed successfully.", output_video_path, output_video_path
|
125 |
+
except Exception as e:
|
126 |
+
return f"Error during prediction: {str(e)}", None, None
|
127 |
+
|
128 |
+
def main():
|
129 |
+
# Load the models and their information
|
130 |
+
models = load_models()
|
131 |
+
|
132 |
+
# Initialize Gradio Blocks interface
|
133 |
+
with gr.Blocks() as demo:
|
134 |
+
gr.Markdown("# 🧪 YOLO Model Tester")
|
135 |
+
|
136 |
+
gr.Markdown(
|
137 |
+
"""
|
138 |
+
Upload images or videos to test different YOLO models. Select a model from the dropdown to see its details.
|
139 |
+
"""
|
140 |
+
)
|
141 |
+
|
142 |
+
# Model selection and info
|
143 |
+
with gr.Row():
|
144 |
+
model_dropdown = gr.Dropdown(
|
145 |
+
choices=list(models.keys()),
|
146 |
+
label="Select Model",
|
147 |
+
value=None
|
148 |
+
)
|
149 |
+
model_info = gr.Markdown("**Model Information will appear here.**")
|
150 |
+
|
151 |
+
# Update model_info when a model is selected
|
152 |
+
model_dropdown.change(
|
153 |
+
fn=lambda model_name: get_model_info(model_name, models) if model_name else "Please select a model.",
|
154 |
+
inputs=model_dropdown,
|
155 |
+
outputs=model_info
|
156 |
+
)
|
157 |
+
|
158 |
+
# Tabs for different input types
|
159 |
+
with gr.Tabs():
|
160 |
+
# Image Prediction Tab
|
161 |
+
with gr.Tab("🖼️ Image"):
|
162 |
+
with gr.Column():
|
163 |
+
image_input = gr.Image(
|
164 |
+
type='pil',
|
165 |
+
label="Upload Image for Prediction",
|
166 |
+
tool="editor"
|
167 |
+
)
|
168 |
+
image_predict_btn = gr.Button("🔍 Predict on Image")
|
169 |
+
image_status = gr.Markdown("**Status will appear here.**")
|
170 |
+
image_output = gr.Image(label="Predicted Image")
|
171 |
+
image_download_btn = gr.File(label="⬇️ Download Predicted Image")
|
172 |
+
|
173 |
+
# Define the image prediction function
|
174 |
+
def process_image(model_name, image):
|
175 |
+
return predict_image(model_name, image, models)
|
176 |
+
|
177 |
+
# Connect the predict button
|
178 |
+
image_predict_btn.click(
|
179 |
+
fn=process_image,
|
180 |
+
inputs=[model_dropdown, image_input],
|
181 |
+
outputs=[image_status, image_output, image_download_btn]
|
182 |
+
)
|
183 |
+
|
184 |
+
# Video Prediction Tab
|
185 |
+
with gr.Tab("🎥 Video"):
|
186 |
+
with gr.Column():
|
187 |
+
video_input = gr.Video(
|
188 |
+
label="Upload Video for Prediction"
|
189 |
+
)
|
190 |
+
video_predict_btn = gr.Button("🔍 Predict on Video")
|
191 |
+
video_status = gr.Markdown("**Status will appear here.**")
|
192 |
+
video_output = gr.Video(label="Predicted Video")
|
193 |
+
video_download_btn = gr.File(label="⬇️ Download Predicted Video")
|
194 |
+
|
195 |
+
# Define the video prediction function
|
196 |
+
def process_video(model_name, video):
|
197 |
+
return predict_video(model_name, video, models)
|
198 |
+
|
199 |
+
# Connect the predict button
|
200 |
+
video_predict_btn.click(
|
201 |
+
fn=process_video,
|
202 |
+
inputs=[model_dropdown, video_input],
|
203 |
+
outputs=[video_status, video_output, video_download_btn]
|
204 |
+
)
|
205 |
+
|
206 |
+
gr.Markdown(
|
207 |
+
"""
|
208 |
+
---
|
209 |
+
**Note:** Ensure that the YOLO models are correctly placed in the `models/` directory and that `models_info.json` is properly configured.
|
210 |
+
"""
|
211 |
+
)
|
212 |
+
|
213 |
+
# Launch the Gradio app
|
214 |
+
demo.launch()
|
215 |
+
|
216 |
+
if __name__ == "__main__":
|
217 |
+
main()
|