qinzheng_wang
commited on
Commit
•
3771205
1
Parent(s):
838840e
fix: bug
Browse files
index.py
CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
|
|
4 |
import model
|
5 |
import torch
|
6 |
|
7 |
-
gr.
|
8 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
9 |
model = model.BEN_Base().to(device).eval() #init pipeline
|
10 |
|
@@ -14,15 +14,15 @@ model_path = hf_hub_download(
|
|
14 |
filename="BEN_Base.pth",
|
15 |
cache_dir="./models" # 缓存目录,避免重复下载
|
16 |
)
|
17 |
-
gr.
|
18 |
model.loadcheckpoints(model_path)
|
19 |
|
20 |
|
21 |
|
22 |
def handler(input_image):
|
23 |
-
gr.
|
24 |
mask, foreground = model.inference(input_image)
|
25 |
-
gr.
|
26 |
|
27 |
return [mask, foreground]
|
28 |
|
|
|
4 |
import model
|
5 |
import torch
|
6 |
|
7 |
+
gr.Info(f"⏳ Downloading model from huggingface hub...")
|
8 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
9 |
model = model.BEN_Base().to(device).eval() #init pipeline
|
10 |
|
|
|
14 |
filename="BEN_Base.pth",
|
15 |
cache_dir="./models" # 缓存目录,避免重复下载
|
16 |
)
|
17 |
+
gr.Info(f"✅ Model downloaded successfully to {model_path}")
|
18 |
model.loadcheckpoints(model_path)
|
19 |
|
20 |
|
21 |
|
22 |
def handler(input_image):
|
23 |
+
gr.Info("🚀 Processing image...")
|
24 |
mask, foreground = model.inference(input_image)
|
25 |
+
gr.Info("✅ Image processing completed!")
|
26 |
|
27 |
return [mask, foreground]
|
28 |
|