File size: 1,731 Bytes
5be6671
 
 
 
 
 
 
 
 
 
bd05801
5be6671
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import Dict, List, Any
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms

from PIL import Image



MODEL_PATH = 'website_classifier.pth'

# Function to load an image and perform the necessary transformations
def process_image(image):
    # Load Image
    img = image.convert("RGB")
    
    # Apply transformations
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    img_t = transform(img)
    
    # Convert to a batch of 1
    img_u = torch.unsqueeze(img_t, 0)
    
    return img_u
    
class PreTrainedPipeline():
    def __init__(self, path=""):
        self.model = torchvision.models.resnet18(pretrained=True)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, 3)
        self.transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        self.model.load_state_dict(torch.load(MODEL_PATH))
        self.processor = process_image
        self.classes = ['forum', 'general', 'marketplace']
        self.classe_to_idx = {'forum': 0, 'general': 1, 'marketplace': 2}
    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        image = data.pop("inputs", data)

        # process image
        image = self.processor(image)

        # run prediction
        outputs = self.model.generate(image)
        
        # decode output
        _, predicted = torch.max(outputs, 1)
        prediction = self.classes[predicted[0]]
        return {"class":prediction[0]}