wuhp commited on
Commit
b86b824
·
verified ·
1 Parent(s): 798cc07

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +217 -0
main.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import os
4
+ from pathlib import Path
5
+ from PIL import Image
6
+ import shutil
7
+ from ultralytics import YOLO
8
+
9
+ def load_models(models_dir='models', info_file='models_info.json'):
10
+ """
11
+ Load YOLO models and their information from the specified directory and JSON file.
12
+
13
+ Args:
14
+ models_dir (str): Path to the models directory.
15
+ info_file (str): Path to the JSON file containing model info.
16
+
17
+ Returns:
18
+ dict: A dictionary of models and their associated information.
19
+ """
20
+ with open(info_file, 'r') as f:
21
+ models_info = json.load(f)
22
+
23
+ models = {}
24
+ for model_info in models_info:
25
+ model_name = model_info['model_name']
26
+ model_path = os.path.join(models_dir, model_name, 'best.pt') # Assuming 'best.pt' as the weight file
27
+ if os.path.isfile(model_path):
28
+ try:
29
+ # Load the YOLO model
30
+ model = YOLO(model_path)
31
+ models[model_name] = {
32
+ 'model': model,
33
+ 'mAP': model_info.get('mAP_score', 'N/A'),
34
+ 'num_images': model_info.get('num_images', 'N/A')
35
+ }
36
+ print(f"Loaded model '{model_name}' from '{model_path}'.")
37
+ except Exception as e:
38
+ print(f"Error loading model '{model_name}': {e}")
39
+ else:
40
+ print(f"Model weight file for '{model_name}' not found at '{model_path}'. Skipping.")
41
+ return models
42
+
43
+ def get_model_info(model_name, models):
44
+ """
45
+ Retrieve model information for the selected model.
46
+
47
+ Args:
48
+ model_name (str): The name of the model.
49
+ models (dict): The dictionary containing models and their info.
50
+
51
+ Returns:
52
+ str: A formatted string containing model information.
53
+ """
54
+ model_info = models.get(model_name, {})
55
+ if not model_info:
56
+ return "Model information not available."
57
+ info_text = (
58
+ f"**Model Name:** {model_name}\n\n"
59
+ f"**mAP Score:** {model_info.get('mAP', 'N/A')}\n\n"
60
+ f"**Number of Images Trained On:** {model_info.get('num_images', 'N/A')}"
61
+ )
62
+ return info_text
63
+
64
+ def predict_image(model_name, image, models):
65
+ """
66
+ Perform prediction on an uploaded image using the selected YOLO model.
67
+
68
+ Args:
69
+ model_name (str): The name of the selected model.
70
+ image (PIL.Image.Image): The uploaded image.
71
+ models (dict): The dictionary containing models and their info.
72
+
73
+ Returns:
74
+ tuple: A status message, the processed image, and the path to the output image.
75
+ """
76
+ model = models.get(model_name, {}).get('model', None)
77
+ if not model:
78
+ return "Error: Model not found.", None, None
79
+ try:
80
+ # Save the uploaded image to a temporary path
81
+ input_image_path = f"temp/{model_name}_input_image.jpg"
82
+ os.makedirs(os.path.dirname(input_image_path), exist_ok=True)
83
+ image.save(input_image_path)
84
+
85
+ # Perform prediction
86
+ results = model(input_image_path, save=True, save_txt=False, conf=0.25)
87
+ # Ultralytics saves the result images in 'runs/detect/predict'
88
+ output_image_path = results[0].save()[0] # Get the path to the saved image
89
+
90
+ # Open the output image
91
+ output_image = Image.open(output_image_path)
92
+
93
+ return "Prediction completed successfully.", output_image, output_image_path
94
+ except Exception as e:
95
+ return f"Error during prediction: {str(e)}", None, None
96
+
97
+ def predict_video(model_name, video, models):
98
+ """
99
+ Perform prediction on an uploaded video using the selected YOLO model.
100
+
101
+ Args:
102
+ model_name (str): The name of the selected model.
103
+ video (str): Path to the uploaded video file.
104
+ models (dict): The dictionary containing models and their info.
105
+
106
+ Returns:
107
+ tuple: A status message, the processed video, and the path to the output video.
108
+ """
109
+ model = models.get(model_name, {}).get('model', None)
110
+ if not model:
111
+ return "Error: Model not found.", None, None
112
+ try:
113
+ # Ensure the video is saved in a temporary location
114
+ input_video_path = video.name
115
+ if not os.path.isfile(input_video_path):
116
+ # If the video is a temp file provided by Gradio
117
+ shutil.copy(video.name, input_video_path)
118
+
119
+ # Perform prediction
120
+ results = model(input_video_path, save=True, save_txt=False, conf=0.25)
121
+ # Ultralytics saves the result videos in 'runs/detect/predict'
122
+ output_video_path = results[0].save()[0] # Get the path to the saved video
123
+
124
+ return "Prediction completed successfully.", output_video_path, output_video_path
125
+ except Exception as e:
126
+ return f"Error during prediction: {str(e)}", None, None
127
+
128
+ def main():
129
+ # Load the models and their information
130
+ models = load_models()
131
+
132
+ # Initialize Gradio Blocks interface
133
+ with gr.Blocks() as demo:
134
+ gr.Markdown("# 🧪 YOLO Model Tester")
135
+
136
+ gr.Markdown(
137
+ """
138
+ Upload images or videos to test different YOLO models. Select a model from the dropdown to see its details.
139
+ """
140
+ )
141
+
142
+ # Model selection and info
143
+ with gr.Row():
144
+ model_dropdown = gr.Dropdown(
145
+ choices=list(models.keys()),
146
+ label="Select Model",
147
+ value=None
148
+ )
149
+ model_info = gr.Markdown("**Model Information will appear here.**")
150
+
151
+ # Update model_info when a model is selected
152
+ model_dropdown.change(
153
+ fn=lambda model_name: get_model_info(model_name, models) if model_name else "Please select a model.",
154
+ inputs=model_dropdown,
155
+ outputs=model_info
156
+ )
157
+
158
+ # Tabs for different input types
159
+ with gr.Tabs():
160
+ # Image Prediction Tab
161
+ with gr.Tab("🖼️ Image"):
162
+ with gr.Column():
163
+ image_input = gr.Image(
164
+ type='pil',
165
+ label="Upload Image for Prediction",
166
+ tool="editor"
167
+ )
168
+ image_predict_btn = gr.Button("🔍 Predict on Image")
169
+ image_status = gr.Markdown("**Status will appear here.**")
170
+ image_output = gr.Image(label="Predicted Image")
171
+ image_download_btn = gr.File(label="⬇️ Download Predicted Image")
172
+
173
+ # Define the image prediction function
174
+ def process_image(model_name, image):
175
+ return predict_image(model_name, image, models)
176
+
177
+ # Connect the predict button
178
+ image_predict_btn.click(
179
+ fn=process_image,
180
+ inputs=[model_dropdown, image_input],
181
+ outputs=[image_status, image_output, image_download_btn]
182
+ )
183
+
184
+ # Video Prediction Tab
185
+ with gr.Tab("🎥 Video"):
186
+ with gr.Column():
187
+ video_input = gr.Video(
188
+ label="Upload Video for Prediction"
189
+ )
190
+ video_predict_btn = gr.Button("🔍 Predict on Video")
191
+ video_status = gr.Markdown("**Status will appear here.**")
192
+ video_output = gr.Video(label="Predicted Video")
193
+ video_download_btn = gr.File(label="⬇️ Download Predicted Video")
194
+
195
+ # Define the video prediction function
196
+ def process_video(model_name, video):
197
+ return predict_video(model_name, video, models)
198
+
199
+ # Connect the predict button
200
+ video_predict_btn.click(
201
+ fn=process_video,
202
+ inputs=[model_dropdown, video_input],
203
+ outputs=[video_status, video_output, video_download_btn]
204
+ )
205
+
206
+ gr.Markdown(
207
+ """
208
+ ---
209
+ **Note:** Ensure that the YOLO models are correctly placed in the `models/` directory and that `models_info.json` is properly configured.
210
+ """
211
+ )
212
+
213
+ # Launch the Gradio app
214
+ demo.launch()
215
+
216
+ if __name__ == "__main__":
217
+ main()