File size: 1,833 Bytes
a21b696 565830f a21b696 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
# import model
# from PIL import Image
# import torch
# def handler(event, context):
# # logger = logging.getLogger()
# # logger.info(event)
# # return event
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# file = "./image.png" # input image
# model = model.BEN_Base().to(device).eval() #init pipeline
# model.loadcheckpoints("./BEN_Base.pth")
# image = Image.open(file)
# mask, foreground = model.inference(image)
# mask.save("./mask.png")
# foreground.save("./foreground.png")
# return
from huggingface_hub import hf_hub_download
import os
import gradio as gr
import model
from PIL import Image
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.BEN_Base().to(device).eval() #init pipeline
# 从 Hugging Face 下载模型
model_path = hf_hub_download(
repo_id="PramaLLC/BEN",
filename="BEN_Base.pth",
cache_dir="./models" # 缓存目录,避免重复下载
)
model.loadcheckpoints(model_path)
def handler(input_image):
# 处理输入图片
mask, foreground = model.inference(input_image)
return [mask, foreground]
# 创建 Gradio 界面
def create_interface():
with gr.Blocks() as demo:
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
with gr.Column():
output_mask = gr.Image(type="pil", label="Mask")
output_foreground = gr.Image(type="pil", label="Foreground")
submit_btn = gr.Button("Process Image")
submit_btn.click(
fn=handler,
inputs=[input_image],
outputs=[output_mask, output_foreground]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=True)
|