Spaces:
Sleeping
Sleeping
File size: 2,339 Bytes
9d6c5d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import time
import gradio as gr
from pathlib import Path
from models import *
class_idx_to_names = {
0: "pizza",
1: "steak",
2: "sushi"
}
examples = [[str(path)] for path in Path(r"examples").glob("*")]
def predict_one(model, transforms, image, device, class_idx_to_names):
model.eval()
model = model.to(device)
with torch.inference_mode():
start_time = time.perf_counter()
image_transformed = transforms(image).unsqueeze(dim = 0).to(device)
y_logits = model(image_transformed)
y_preds = torch.softmax(y_logits, dim = 1)
y_probs = torch.argmax(y_preds, dim = 1)
end_time = time.perf_counter()
predictions = {class_idx_to_names[index]: x.item() for index, x in enumerate(y_preds[0])}
return predictions, end_time - start_time
def predict(image, model_choice):
if model_choice is None or model_choice == "effnet_b2":
model, transforms = get_effnet_b2()
else:
model, transforms = get_vit_16_base_transformer()
predictions, time_taken = predict_one(model, transforms, image, "cpu", class_idx_to_names)
return predictions, time_taken
title = "Food Recognition ππ"
desc = "A dual model app ft. EfficientNetB2 Feature Extractor and VisionTransformer."
article = '''
## Stats on different Models
---
| Model Name | Train Loss | Test Loss | Train Accuracy | Test Accuracy | Num Parameters | Model Size |
|-----------------|------------|-----------|----------------|---------------|----------------|------------|
| EfficientNet_b2 | 0.340270 | 0.301134 | 0.906250 | 0.953409 | 7705221 | 29.91 MB |
| ViT_Base_16 | 0.040448 | 0.055140 | 0.995833 | 0.981250 | 85800963 | 327.39 MB |
'''
demo = gr.Interface(fn = predict,
inputs = [gr.Image(type = "pil", label = "upload an Jpeg or Png"), gr.Radio(["effnet_b2", "ViT (Vision Transformer)"], label = "choose model (default on effnet)")],
outputs = [gr.Label(num_top_classes=3, label = "predictions"), gr.Number(label = "Prediction Time in seconds")],
examples = examples,
title = title,
description=desc,
article=article)
demo.launch(debug = False) |