Spaces:
Runtime error
Runtime error
# Importing libraries for gradio app | |
import gradio as gr | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import models | |
import torchvision.transforms as tt | |
from PIL import Image | |
# Moving data to CPU | |
def to_device(data, device): | |
"""Move tensor(s) to chosen device""" | |
if isinstance(data, (list,tuple)): | |
return [to_device(x, device) for x in data] | |
return data.to(device, non_blocking=True) | |
# Defining our Class for just prediction | |
def accuracy(outputs, labels): | |
_, preds = torch.max(outputs, dim=1) | |
return torch.tensor(torch.sum(preds == labels).item() / len(preds)) | |
class ImageClassificationBase(nn.Module): | |
def validation_step(self, batch): | |
images, labels = batch | |
out = self(images) # Generate predictions | |
loss = F.cross_entropy(out, labels) # Calculate loss | |
acc = accuracy(out, labels) # Calculate accuracy | |
return {'val_loss': loss.detach(), 'val_acc': acc} | |
def validation_epoch_end(self, outputs): | |
batch_losses = [x['val_loss'] for x in outputs] | |
epoch_loss = torch.stack(batch_losses).mean() # Combine losses | |
batch_accs = [x['val_acc'] for x in outputs] | |
epoch_acc = torch.stack(batch_accs).mean() # Combine accuracies | |
return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()} | |
# Defining our finetuned Resnet50 Architecture with our Classification layer | |
class IndianFoodModelResnet50(ImageClassificationBase): | |
def __init__(self, num_classes, pretrained=True): | |
super().__init__() | |
self.network = models.resnet50(pretrained=pretrained) | |
self.network.fc = nn.Linear(self.network.fc.in_features, num_classes) | |
def forward(self, xb): | |
return self.network(xb) | |
# Prediction method | |
def evaluate(model, val_loader): | |
model.eval() | |
outputs = [model.validation_step(batch) for batch in val_loader] | |
return model.validation_epoch_end(outputs) | |
# Initialising our model and moving it to CPU | |
classes = ['burger', 'butter_naan', 'chai', 'chapati', 'chole_bhature', | |
'dal_makhani', 'dhokla', 'fried_rice', 'idli', 'jalebi', | |
'kaathi_rolls', 'kadai_paneer', 'kulfi', 'masala_dosa', 'momos', | |
'paani_puri', 'pakode', 'pav_bhaji', 'pizza', 'samosa'] | |
model = IndianFoodModelResnet50(len(classes), pretrained=True) | |
device = 'cpu' | |
to_device(model, device); | |
# Loading the model | |
ckp_path = 'indianFood-resnet50.pth' | |
model.load_state_dict(torch.load(ckp_path, map_location=torch.device('cpu'))) | |
model.eval() | |
# Image preprocessing before prediction | |
stats = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
img_tfms = tt.Compose([tt.Resize((224, 224)), | |
tt.ToTensor(), | |
tt.Normalize(*stats, inplace = True)]) | |
def predict_image(image, model): | |
xb = to_device(image.unsqueeze(0), device) | |
yb = model(xb) | |
_, preds = torch.max(yb, dim=1) | |
return classes[preds[0].item()] | |
# Function handling input, processing and output | |
def classify_image(path): | |
img = Image.open(path) | |
img = img_tfms(img) | |
label = predict_image(img, model) | |
return label | |
# Defining gradio interface functions | |
image = gr.inputs.Image(shape=(224, 224), type="filepath") | |
label = gr.outputs.Label(num_top_classes=1) | |
article = "<p style='text-align: center'><a href='https://' target='_blank'>DesiVisionNet</a> | <a href='https://github.com/kunal-bhadra/DesiVisionNet' target='_blank'>GitHub Repo</a></p>" | |
gr.Interface( | |
fn=classify_image, | |
inputs=image, | |
outputs=label, | |
examples = [["idli.jpg"], ["naan.jpg"]], | |
theme = "huggingface", | |
title = "DesiVisionNet: Indian Food Vision with ResNet", | |
description = "This is a Gradio demo for multi-class image classification of Indian food amongst 20 classes. The DesiVisionNet achieved 90% accuracy on our test dataset, performing well for a relatively efficient model. See the GitHub project page for detailed information below. Here, we provide a demo for real-world food classification. To use it, simply upload your image, or click one of the examples to load them.", | |
article = article | |
).launch() |