Spaces:
Runtime error
Runtime error
File size: 1,544 Bytes
fd52b7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from typing import Any
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
class DeepLabv3Pipeline:
def __init__(self):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.transforms = transforms.Compose(
[
transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.NEAREST),
transforms.ToTensor()
]
)
self.model = None
@classmethod
def from_pretrained(cls, model_path_hf: str = None, filename_hf: str = "weights.pt", local_model_path: str = None):
dl = cls()
if model_path_hf is not None and filename_hf is not None:
dl.model = torch.load(hf_hub_download(model_path_hf, filename=filename_hf), map_location='cpu')
dl.model.to(dl.device)
dl.model.eval()
elif local_model_path is not None:
dl.model = torch.load(local_model_path, map_location='cpu')
dl.model.to(dl.device)
dl.model.eval()
return dl
def __call__(self, image: Image.Image, threshold: float = 0) -> Image.Image:
image = image.convert("RGB")
output = self.model(self.transforms(image).unsqueeze(0).to(self.device))
return Image.fromarray((255 * np.where(output['out'][0].permute(1, 2, 0).detach().cpu() > threshold,
self.transforms(image).permute(1, 2, 0), 1)).astype(np.uint8))
|