File size: 1,844 Bytes
12f775f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import time
import gradio as gr
from pathlib import Path
import pickle as pkl
from models import *


examples = [[str(path)] for path in Path(r"examples").glob("*")]

with open('class_names_to_idx.pkl', 'rb') as fp:
    class_idx_to_names = pkl.load(fp)

def predict_one(model, transforms, image, device, class_idx_to_names):
    model.eval()
    model = model.to(device)
    with torch.inference_mode():
            
        start_time = time.perf_counter()
        image_transformed = transforms(image).unsqueeze(dim = 0).to(device)
        
        y_logits = model(image_transformed)
        y_preds = torch.softmax(y_logits, dim = 1)
        
        end_time = time.perf_counter()

        predictions = {class_idx_to_names[index]: x.item() for index, x in enumerate(y_preds[0])}

    return predictions, end_time - start_time

def predict(image, model_choice):

    if model_choice is None or model_choice == "effnet_b2":
        model, transforms = get_effnet_b2()
    else:
        model, transforms = get_vit_16_base_transformer()
    
    predictions, time_taken = predict_one(model, transforms, image, "cpu", class_idx_to_names)
    return predictions, time_taken


title = "Food Recognition πŸ•πŸ•"
desc = "A dual model app ft. EfficientNetB2 Feature Extractor and VisionTransformer. Now, bigger than ever. featuring 101 classes"

demo = gr.Interface(fn = predict, 
                    inputs = [gr.Image(type = "pil", label = "upload an Jpeg or Png"), gr.Radio(["effnet_b2", "ViT (Vision Transformer)"], label = "choose model (default on effnet)")],
                    outputs = [gr.Label(num_top_classes = 5, label = "predictions"), gr.Number(label = "Prediction Time in seconds")], 
                    examples = examples, 
                    title = title,
                    description=desc)

demo.launch(debug = False)