|
import model |
|
import torch |
|
from PIL import Image |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
file = "./image.png" |
|
|
|
model = model.BEN_Base().to(device).eval() |
|
|
|
|
|
model_path = hf_hub_download( |
|
repo_id="PramaLLC/BEN", |
|
filename="BEN_Base.pth", |
|
cache_dir="./models" |
|
) |
|
model.loadcheckpoints(model_path) |
|
|
|
image = Image.open(file) |
|
mask, foreground = model.inference(image) |
|
|
|
mask.save("./mask.png") |
|
foreground.save("./foreground.png") |
|
foreground.save("./foreground.png") |