betterdigits / app.py
im2
draw and test
81a6137
raw
history blame
1.69 kB
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
# Load the model and feature extractor from Hugging Face
model_name = "immartian/improved_digits_recognition"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# Preprocessing function to transform the drawn image into a format the model can recognize
def preprocess_image(image):
# Convert the image into a format suitable for the model
image = Image.fromarray(image).convert('L') # Convert to grayscale
image = image.resize((28, 28)) # Resize to 28x28 pixels
image = image.convert('RGB') # Model expects 3-channel images, so convert to RGB
inputs = feature_extractor(images=image, return_tensors="pt")
return inputs['pixel_values']
# Prediction function to classify the drawn digit
def predict_digit(image):
# Preprocess the input image
inputs = preprocess_image(image)
# Make the prediction
with torch.no_grad():
outputs = model(inputs)
predicted_label = outputs.logits.argmax(-1).item()
return f"Predicted Digit: {predicted_label}"
# Gradio interface for drawing the digit and displaying the prediction
demo = gr.Interface(
fn=predict_digit,
inputs="sketchpad", # Allow users to draw a digit
outputs="text",
title="MNIST Digit Recognition",
description="Draw a digit (0-9) and let the model recognize it!",
live=True # The prediction updates while the user draws
)
# Launch the app
if __name__ == "__main__":
demo.launch()