File size: 1,722 Bytes
181e20c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad12b10
181e20c
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from PIL import Image
import numpy as np

from transformers import ViTImageProcessor, TFViTModel
import keras
import argparse

BASE_MODEL = "google/vit-base-patch16-224"
IMAGE_SIZE = 224

class Inference:
    def __init__(self):
        self.vit_model = self._load_vit_model()
        self.image_preprocessor = self._load_image_preprocessor()
    
    def predict_rotation(self, image_path):
        X = self._preprocess(image_path)
        y = self.vit_model.predict(X)[0][0]
        return y

    def _preprocess(self, image_path):
        img = Image.open(image_path)
        img = img.resize((IMAGE_SIZE, IMAGE_SIZE))
        img = np.array(img)

        X_vit = self.image_preprocessor.preprocess(images=[img], return_tensors="pt")["pixel_values"]
        return np.array(X_vit)

    def _load_image_preprocessor(self):
        print("Loading Image Preprocessor")
        return ViTImageProcessor.from_pretrained(BASE_MODEL)

    def _load_vit_model(self):
        print("Loading Model")
        vit_base = TFViTModel.from_pretrained(BASE_MODEL)

        img_input = keras.layers.Input(shape=(3,IMAGE_SIZE, IMAGE_SIZE))
        x = vit_base.vit(img_input)
        y = keras.layers.Dense(1, activation="linear")(x[-1])

        model = keras.Model(inputs=img_input, outputs=y)
        print(model.summary())

        print("Loading Weights")
        model.load_weights("weights.h5")

        return model

if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--image-path", type=str, required=True)
    args = parser.parse_args()

    model = Inference()
    expected_angle = model.predict_rotation(args.image_path)
    print(f"Predicted angle is about '{expected_angle}' degrees")