Spaces:
Runtime error
Runtime error
from typing import Any | |
import torch | |
from huggingface_hub import hf_hub_download | |
from PIL import Image | |
import torchvision.transforms as transforms | |
import numpy as np | |
class DeepLabv3Pipeline: | |
def __init__(self): | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.transforms = transforms.Compose( | |
[ | |
transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.NEAREST), | |
transforms.ToTensor() | |
] | |
) | |
self.model = None | |
def from_pretrained(cls, model_path_hf: str = None, filename_hf: str = "weights.pt", local_model_path: str = None): | |
dl = cls() | |
if model_path_hf is not None and filename_hf is not None: | |
dl.model = torch.load(hf_hub_download(model_path_hf, filename=filename_hf), map_location='cpu') | |
dl.model.to(dl.device) | |
dl.model.eval() | |
elif local_model_path is not None: | |
dl.model = torch.load(local_model_path, map_location='cpu') | |
dl.model.to(dl.device) | |
dl.model.eval() | |
return dl | |
def __call__(self, image: Image.Image, threshold: float = 0) -> Image.Image: | |
image = image.convert("RGB") | |
output = self.model(self.transforms(image).unsqueeze(0).to(self.device)) | |
return Image.fromarray((255 * np.where(output['out'][0].permute(1, 2, 0).detach().cpu() > threshold, | |
self.transforms(image).permute(1, 2, 0), 1)).astype(np.uint8)) | |