from huggingface_hub import hf_hub_download import gradio as gr import model import torch gr.Info(f"⏳ Downloading model from huggingface hub...") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.BEN_Base().to(device).eval() #init pipeline # 从 Hugging Face 下载模型 model_path = hf_hub_download( repo_id="PramaLLC/BEN", filename="BEN_Base.pth", cache_dir="./models" # 缓存目录,避免重复下载 ) gr.Info(f"✅ Model downloaded successfully to {model_path}") model.loadcheckpoints(model_path) def handler(input_image): gr.Info("🚀 Processing image...") mask, foreground = model.inference(input_image) gr.Info("✅ Image processing completed!") return [mask, foreground] # 创建 Gradio 界面 def create_interface(): with gr.Blocks() as demo: with gr.Row(): input_image = gr.Image(type="pil", label="Input Image") with gr.Column(): output_mask = gr.Image(type="pil", label="Mask") output_foreground = gr.Image(type="pil", label="Foreground") submit_btn = gr.Button("Process Image") submit_btn.click( fn=handler, inputs=[input_image], outputs=[output_mask, output_foreground] ) return demo if __name__ == "__main__": demo = create_interface() demo.launch()