ben1 / index.py
showhype's picture
Update index.py
565830f verified
raw
history blame
1.83 kB
# import model
# from PIL import Image
# import torch
# def handler(event, context):
# # logger = logging.getLogger()
# # logger.info(event)
# # return event
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# file = "./image.png" # input image
# model = model.BEN_Base().to(device).eval() #init pipeline
# model.loadcheckpoints("./BEN_Base.pth")
# image = Image.open(file)
# mask, foreground = model.inference(image)
# mask.save("./mask.png")
# foreground.save("./foreground.png")
# return
from huggingface_hub import hf_hub_download
import os
import gradio as gr
import model
from PIL import Image
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.BEN_Base().to(device).eval() #init pipeline
# 从 Hugging Face 下载模型
model_path = hf_hub_download(
repo_id="PramaLLC/BEN",
filename="BEN_Base.pth",
cache_dir="./models" # 缓存目录,避免重复下载
)
model.loadcheckpoints(model_path)
def handler(input_image):
# 处理输入图片
mask, foreground = model.inference(input_image)
return [mask, foreground]
# 创建 Gradio 界面
def create_interface():
with gr.Blocks() as demo:
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
with gr.Column():
output_mask = gr.Image(type="pil", label="Mask")
output_foreground = gr.Image(type="pil", label="Foreground")
submit_btn = gr.Button("Process Image")
submit_btn.click(
fn=handler,
inputs=[input_image],
outputs=[output_mask, output_foreground]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=True)