ben1 / index.py
qinzheng_wang
fix: bug
3771205
from huggingface_hub import hf_hub_download
import gradio as gr
import model
import torch
gr.Info(f"⏳ Downloading model from huggingface hub...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.BEN_Base().to(device).eval() #init pipeline
# 从 Hugging Face 下载模型
model_path = hf_hub_download(
repo_id="PramaLLC/BEN",
filename="BEN_Base.pth",
cache_dir="./models" # 缓存目录,避免重复下载
)
gr.Info(f"✅ Model downloaded successfully to {model_path}")
model.loadcheckpoints(model_path)
def handler(input_image):
gr.Info("🚀 Processing image...")
mask, foreground = model.inference(input_image)
gr.Info("✅ Image processing completed!")
return [mask, foreground]
# 创建 Gradio 界面
def create_interface():
with gr.Blocks() as demo:
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
with gr.Column():
output_mask = gr.Image(type="pil", label="Mask")
output_foreground = gr.Image(type="pil", label="Foreground")
submit_btn = gr.Button("Process Image")
submit_btn.click(
fn=handler,
inputs=[input_image],
outputs=[output_mask, output_foreground]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()