Abhilashvj commited on
Commit
5be6671
1 Parent(s): 84facbd

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +58 -0
pipeline.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ import torchvision
4
+ import torch.nn as nn
5
+ import torchvision.transforms as transforms
6
+
7
+ from PIL import Image
8
+
9
+
10
+
11
+ MODEL_PATH = './website_classifier.pth'
12
+
13
+ # Function to load an image and perform the necessary transformations
14
+ def process_image(image):
15
+ # Load Image
16
+ img = image.convert("RGB")
17
+
18
+ # Apply transformations
19
+ transform = transforms.Compose([
20
+ transforms.Resize((224, 224)),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
23
+ ])
24
+
25
+ img_t = transform(img)
26
+
27
+ # Convert to a batch of 1
28
+ img_u = torch.unsqueeze(img_t, 0)
29
+
30
+ return img_u
31
+
32
+ class PreTrainedPipeline():
33
+ def __init__(self, path=""):
34
+ self.model = torchvision.models.resnet18(pretrained=True)
35
+ num_ftrs = self.model.fc.in_features
36
+ self.model.fc = nn.Linear(num_ftrs, 3)
37
+ self.transform = transforms.Compose(
38
+ [transforms.Resize((224, 224)),
39
+ transforms.ToTensor(),
40
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
41
+ self.model.load_state_dict(torch.load(MODEL_PATH))
42
+ self.processor = process_image
43
+ self.classes = ['forum', 'general', 'marketplace']
44
+ self.classe_to_idx = {'forum': 0, 'general': 1, 'marketplace': 2}
45
+
46
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
47
+ image = data.pop("inputs", data)
48
+
49
+ # process image
50
+ image = self.processor(image)
51
+
52
+ # run prediction
53
+ outputs = self.model.generate(image)
54
+
55
+ # decode output
56
+ _, predicted = torch.max(outputs, 1)
57
+ prediction = self.classes[predicted[0]]
58
+ return {"class":prediction[0]}