File size: 1,151 Bytes
fd52b7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from torchvision import transforms
from huggingface_hub import hf_hub_download

class Vits8Pipeline:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = None # Initialized upon loading torchscript
        self.transform = transforms.ToTensor()
    
    @classmethod
    def from_pretrained(cls, model_path_hf: str = None, filename_hf: str = "weights.pt", local_model_path: str = None):
        vit = cls()
        if model_path_hf is not None and filename_hf is not None:
            vit.model = torch.jit.load(hf_hub_download(model_path_hf, filename=filename_hf), map_location='cpu')  
            vit.model.to(vit.device)
            vit.model.eval()      
        elif local_model_path is not None:
            vit.model = torch.jit.load(local_model_path, map_location='cpu')
            vit.model.to(vit.device)
            vit.model.eval()
        return vit

    def __call__(self, image) -> torch.Tensor:
        image = image.convert("RGB")
        img_tensor = self.transform(image).to(self.device).unsqueeze(0)
        return self.model(img_tensor)[0].detach().cpu()