Spaces:
Runtime error
Runtime error
File size: 1,092 Bytes
fd52b7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from huggingface_hub import hf_hub_download
import torch
from torchvision.transforms.functional import to_tensor
class VaePipeline:
def __init__(self):
self.encoder = None
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
@classmethod
def from_pretrained(cls, model_path_hf: str = None, filename_hf: str = "weights.pt", local_model_path: str = None):
vae = cls()
if model_path_hf is not None and filename_hf is not None:
vae.encoder = torch.load(hf_hub_download(model_path_hf, filename_hf), map_location='cpu')
vae.encoder.to(vae.device)
vae.encoder.eval()
elif local_model_path is not None:
vae.encoder = torch.load(local_model_path, map_location='cpu')
vae.encoder.to(vae.device)
vae.encoder.eval()
return vae
def __call__(self, image) -> torch.Tensor:
image = image.convert("RGB")
img_tensor = to_tensor(image.resize((118, 118)))
return self.encoder(img_tensor.unsqueeze(0).to(self.device))[0][0].detach().cpu() |