File size: 2,339 Bytes
9d6c5d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import time
import gradio as gr
from pathlib import Path
from models import *

class_idx_to_names = {
    0: "pizza", 
    1: "steak", 
    2: "sushi"
}

examples = [[str(path)] for path in Path(r"examples").glob("*")]

def predict_one(model, transforms, image, device, class_idx_to_names):
    model.eval()
    model = model.to(device)
    with torch.inference_mode():
            
        start_time = time.perf_counter()
        image_transformed = transforms(image).unsqueeze(dim = 0).to(device)
        
        y_logits = model(image_transformed)
        
        y_preds = torch.softmax(y_logits, dim = 1)
        y_probs = torch.argmax(y_preds, dim = 1)
        
        end_time = time.perf_counter()

        predictions = {class_idx_to_names[index]: x.item() for index, x in enumerate(y_preds[0])}

    return predictions, end_time - start_time

def predict(image, model_choice):

    if model_choice is None or model_choice == "effnet_b2":
        model, transforms = get_effnet_b2()
    else:
        model, transforms = get_vit_16_base_transformer()
    
    predictions, time_taken = predict_one(model, transforms, image, "cpu", class_idx_to_names)
    return predictions, time_taken


title = "Food Recognition πŸ•πŸ•"
desc = "A dual model app ft. EfficientNetB2 Feature Extractor and VisionTransformer."
article = '''
## Stats on different Models
---
| Model Name      | Train Loss | Test Loss | Train Accuracy | Test Accuracy | Num Parameters | Model Size |
|-----------------|------------|-----------|----------------|---------------|----------------|------------|
| EfficientNet_b2 | 0.340270   | 0.301134  | 0.906250       | 0.953409      | 7705221        | 29.91 MB   |
| ViT_Base_16     | 0.040448   | 0.055140  | 0.995833       | 0.981250      | 85800963       | 327.39 MB  |
'''
demo = gr.Interface(fn = predict, 
                    inputs = [gr.Image(type = "pil", label = "upload an Jpeg or Png"), gr.Radio(["effnet_b2", "ViT (Vision Transformer)"], label = "choose model (default on effnet)")],
                    outputs = [gr.Label(num_top_classes=3, label = "predictions"), gr.Number(label = "Prediction Time in seconds")], 
                    examples = examples, 
                    title = title,
                    description=desc,
                    article=article)

demo.launch(debug = False)