arnavkundalia's picture
Update app.py
40ccb7c
raw
history blame contribute delete
997 Bytes
import gradio as gr
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from PIL import Image
model = torch.load('entire_model.pt',map_location ='cpu')
model.eval()
#label
labels = ['Healthy','Scab']
transform = create_transform(**resolve_data_config({},model = model))
def predict_fn(img):
img = img.convert('RGB')
img = transform(img).unsqueeze(0)
with torch.no_grad():
out = model(img)
probabilites = torch.nn.functional.softmax(out[0], dim=0)
values, indices = torch.topk(probabilites, k=int(1))
return {labels[i]: v.item() for i, v in zip(indices, values)}
description = "Upload an image of an Apple and the model would predict if it is a healthy apple or scab apple."
title = "Apple scab detection"
gr.Interface(fn=predict_fn, inputs=gr.inputs.Image(type='pil'), outputs='label',description=description,
title=title, allow_flagging='never'
).launch(debug='True')