File size: 1,403 Bytes
a21b696 3771205 2c92f73 35b45c1 2c92f73 3771205 2c92f73 3771205 2c92f73 3771205 35b45c1 a21b696 565830f a21b696 e089c37 a21b696 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from huggingface_hub import hf_hub_download
import gradio as gr
import model
import torch
gr.Info(f"⏳ Downloading model from huggingface hub...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.BEN_Base().to(device).eval() #init pipeline
# 从 Hugging Face 下载模型
model_path = hf_hub_download(
repo_id="PramaLLC/BEN",
filename="BEN_Base.pth",
cache_dir="./models" # 缓存目录,避免重复下载
)
gr.Info(f"✅ Model downloaded successfully to {model_path}")
model.loadcheckpoints(model_path)
def handler(input_image):
gr.Info("🚀 Processing image...")
mask, foreground = model.inference(input_image)
gr.Info("✅ Image processing completed!")
return [mask, foreground]
# 创建 Gradio 界面
def create_interface():
with gr.Blocks() as demo:
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
with gr.Column():
output_mask = gr.Image(type="pil", label="Mask")
output_foreground = gr.Image(type="pil", label="Foreground")
submit_btn = gr.Button("Process Image")
submit_btn.click(
fn=handler,
inputs=[input_image],
outputs=[output_mask, output_foreground]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()
|