|
from huggingface_hub import hf_hub_download |
|
|
|
import gradio as gr |
|
import model |
|
import torch |
|
|
|
gr.Info(f"⏳ Downloading model from huggingface hub...") |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = model.BEN_Base().to(device).eval() |
|
|
|
|
|
model_path = hf_hub_download( |
|
repo_id="PramaLLC/BEN", |
|
filename="BEN_Base.pth", |
|
cache_dir="./models" |
|
) |
|
gr.Info(f"✅ Model downloaded successfully to {model_path}") |
|
model.loadcheckpoints(model_path) |
|
|
|
|
|
|
|
def handler(input_image): |
|
gr.Info("🚀 Processing image...") |
|
mask, foreground = model.inference(input_image) |
|
gr.Info("✅ Image processing completed!") |
|
|
|
return [mask, foreground] |
|
|
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
input_image = gr.Image(type="pil", label="Input Image") |
|
with gr.Column(): |
|
output_mask = gr.Image(type="pil", label="Mask") |
|
output_foreground = gr.Image(type="pil", label="Foreground") |
|
|
|
submit_btn = gr.Button("Process Image") |
|
submit_btn.click( |
|
fn=handler, |
|
inputs=[input_image], |
|
outputs=[output_mask, output_foreground] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch() |
|
|
|
|