wuhp commited on
Commit
b9b01da
·
verified ·
1 Parent(s): 6ff60a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -12
app.py CHANGED
@@ -117,13 +117,14 @@ def get_model_info(model_info):
117
  )
118
  return info_text
119
 
120
- def predict_image(model_name, image, models):
121
  """
122
  Perform prediction on an uploaded image using the selected YOLO model.
123
 
124
  Args:
125
  model_name (str): The name of the selected model.
126
  image (PIL.Image.Image): The uploaded image.
 
127
  models (dict): The dictionary containing models and their info.
128
 
129
  Returns:
@@ -142,8 +143,8 @@ def predict_image(model_name, image, models):
142
  input_image_path = os.path.join(TEMP_DIR, f"{model_name}_input_image.jpg")
143
  image.save(input_image_path)
144
 
145
- # Perform prediction
146
- results = model(input_image_path, save=True, save_txt=False, conf=0.25)
147
 
148
  # Determine the output path
149
  # Ultralytics YOLO saves the results in 'runs/detect/predict' by default
@@ -165,13 +166,14 @@ def predict_image(model_name, image, models):
165
  except Exception as e:
166
  return f"❌ Error during prediction: {str(e)}", None, None
167
 
168
- def predict_video(model_name, video, models):
169
  """
170
  Perform prediction on an uploaded video using the selected YOLO model.
171
 
172
  Args:
173
  model_name (str): The name of the selected model.
174
  video (str): Path to the uploaded video file.
 
175
  models (dict): The dictionary containing models and their info.
176
 
177
  Returns:
@@ -190,8 +192,8 @@ def predict_video(model_name, video, models):
190
  input_video_path = os.path.join(TEMP_DIR, f"{model_name}_input_video.mp4")
191
  shutil.copy(video, input_video_path)
192
 
193
- # Perform prediction
194
- results = model(input_video_path, save=True, save_txt=False, conf=0.25)
195
 
196
  # Determine the output path
197
  # Ultralytics YOLO saves the results in 'runs/detect/predict' by default
@@ -253,6 +255,17 @@ def main():
253
  outputs=model_info
254
  )
255
 
 
 
 
 
 
 
 
 
 
 
 
256
  # Tabs for different input types
257
  with gr.Tabs():
258
  # Image Prediction Tab
@@ -269,16 +282,16 @@ def main():
269
  image_download_btn = gr.File(label="⬇️ Download Predicted Image")
270
 
271
  # Define the image prediction function
272
- def process_image(selected_display_name, image):
273
  if not selected_display_name:
274
  return "❌ Please select a model.", None, None
275
  model_name = display_to_name.get(selected_display_name)
276
- return predict_image(model_name, image, models)
277
 
278
  # Connect the predict button
279
  image_predict_btn.click(
280
  fn=process_image,
281
- inputs=[model_dropdown, image_input],
282
  outputs=[image_status, image_output, image_download_btn]
283
  )
284
 
@@ -294,16 +307,16 @@ def main():
294
  video_download_btn = gr.File(label="⬇️ Download Predicted Video")
295
 
296
  # Define the video prediction function
297
- def process_video(selected_display_name, video):
298
  if not selected_display_name:
299
  return "❌ Please select a model.", None, None
300
  model_name = display_to_name.get(selected_display_name)
301
- return predict_video(model_name, video, models)
302
 
303
  # Connect the predict button
304
  video_predict_btn.click(
305
  fn=process_video,
306
- inputs=[model_dropdown, video_input],
307
  outputs=[video_status, video_output, video_download_btn]
308
  )
309
 
 
117
  )
118
  return info_text
119
 
120
+ def predict_image(model_name, image, confidence, models):
121
  """
122
  Perform prediction on an uploaded image using the selected YOLO model.
123
 
124
  Args:
125
  model_name (str): The name of the selected model.
126
  image (PIL.Image.Image): The uploaded image.
127
+ confidence (float): The confidence threshold for detections.
128
  models (dict): The dictionary containing models and their info.
129
 
130
  Returns:
 
143
  input_image_path = os.path.join(TEMP_DIR, f"{model_name}_input_image.jpg")
144
  image.save(input_image_path)
145
 
146
+ # Perform prediction with user-specified confidence
147
+ results = model(input_image_path, save=True, save_txt=False, conf=confidence)
148
 
149
  # Determine the output path
150
  # Ultralytics YOLO saves the results in 'runs/detect/predict' by default
 
166
  except Exception as e:
167
  return f"❌ Error during prediction: {str(e)}", None, None
168
 
169
+ def predict_video(model_name, video, confidence, models):
170
  """
171
  Perform prediction on an uploaded video using the selected YOLO model.
172
 
173
  Args:
174
  model_name (str): The name of the selected model.
175
  video (str): Path to the uploaded video file.
176
+ confidence (float): The confidence threshold for detections.
177
  models (dict): The dictionary containing models and their info.
178
 
179
  Returns:
 
192
  input_video_path = os.path.join(TEMP_DIR, f"{model_name}_input_video.mp4")
193
  shutil.copy(video, input_video_path)
194
 
195
+ # Perform prediction with user-specified confidence
196
+ results = model(input_video_path, save=True, save_txt=False, conf=confidence)
197
 
198
  # Determine the output path
199
  # Ultralytics YOLO saves the results in 'runs/detect/predict' by default
 
255
  outputs=model_info
256
  )
257
 
258
+ # Confidence Threshold Slider
259
+ with gr.Row():
260
+ confidence_slider = gr.Slider(
261
+ minimum=0.0,
262
+ maximum=1.0,
263
+ step=0.01,
264
+ value=0.25,
265
+ label="Confidence Threshold",
266
+ info="Adjust the minimum confidence required for detections to be displayed."
267
+ )
268
+
269
  # Tabs for different input types
270
  with gr.Tabs():
271
  # Image Prediction Tab
 
282
  image_download_btn = gr.File(label="⬇️ Download Predicted Image")
283
 
284
  # Define the image prediction function
285
+ def process_image(selected_display_name, image, confidence):
286
  if not selected_display_name:
287
  return "❌ Please select a model.", None, None
288
  model_name = display_to_name.get(selected_display_name)
289
+ return predict_image(model_name, image, confidence, models)
290
 
291
  # Connect the predict button
292
  image_predict_btn.click(
293
  fn=process_image,
294
+ inputs=[model_dropdown, image_input, confidence_slider],
295
  outputs=[image_status, image_output, image_download_btn]
296
  )
297
 
 
307
  video_download_btn = gr.File(label="⬇️ Download Predicted Video")
308
 
309
  # Define the video prediction function
310
+ def process_video(selected_display_name, video, confidence):
311
  if not selected_display_name:
312
  return "❌ Please select a model.", None, None
313
  model_name = display_to_name.get(selected_display_name)
314
+ return predict_video(model_name, video, confidence, models)
315
 
316
  # Connect the predict button
317
  video_predict_btn.click(
318
  fn=process_video,
319
+ inputs=[model_dropdown, video_input, confidence_slider],
320
  outputs=[video_status, video_output, video_download_btn]
321
  )
322