from typing import Dict, List, Any import torch import torchvision import torch.nn as nn import torchvision.transforms as transforms from PIL import Image MODEL_PATH = 'website_classifier.pth' # Function to load an image and perform the necessary transformations def process_image(image): # Load Image img = image.convert("RGB") # Apply transformations transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) img_t = transform(img) # Convert to a batch of 1 img_u = torch.unsqueeze(img_t, 0) return img_u class PreTrainedPipeline(): def __init__(self, path=""): self.model = torchvision.models.resnet18(pretrained=True) num_ftrs = self.model.fc.in_features self.model.fc = nn.Linear(num_ftrs, 3) self.transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) self.model.load_state_dict(torch.load(MODEL_PATH)) self.processor = process_image self.classes = ['forum', 'general', 'marketplace'] self.classe_to_idx = {'forum': 0, 'general': 1, 'marketplace': 2} def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: image = data.pop("inputs", data) # process image image = self.processor(image) # run prediction outputs = self.model.generate(image) # decode output _, predicted = torch.max(outputs, 1) prediction = self.classes[predicted[0]] return {"class":prediction[0]}