planthealthapp / app.py
Brian Burns
Add application file, pytorch saved models, requirements file
ca9b0e0
raw
history blame contribute delete
No virus
929 Bytes
import datasets
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import gradio as gr
dataset = datasets.load_dataset("beans")
extractor = AutoFeatureExtractor.from_pretrained("saved_model_files")
model = AutoModelForImageClassification.from_pretrained("saved_model_files")
labels = dataset['train'].features['labels'].names
def classify(im):
features = feature_extractor(im, return_tensors='pt')
logits = model(features["pixel_values"])[-1]
probability = torch.nn.functional.softmax(logits, dim=-1)
probs = probability[0].detach().numpy()
confidences = {label: float(probs[i]) for i, label in enumerate(labels)}
return confidences
interface = gr.Interface(fn = classify, inputs = "image", outputs = "label",
title = "Plant health classifier",
description = "Classifies plant health"
)
interface.launch()