import gradio as gr import torch from torchvision import transforms from PIL import Image from transformers import AutoModelForImageClassification, AutoFeatureExtractor # Load the model and feature extractor from Hugging Face model_name = "immartian/improved_digits_recognition" model = AutoModelForImageClassification.from_pretrained(model_name) feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) # Preprocessing function to transform the drawn image into a format the model can recognize def preprocess_image(image): # Convert the image into a format suitable for the model image = Image.fromarray(image).convert('L') # Convert to grayscale image = image.resize((28, 28)) # Resize to 28x28 pixels image = image.convert('RGB') # Model expects 3-channel images, so convert to RGB inputs = feature_extractor(images=image, return_tensors="pt") return inputs['pixel_values'] # Prediction function to classify the drawn digit def predict_digit(image): # Preprocess the input image inputs = preprocess_image(image) # Make the prediction with torch.no_grad(): outputs = model(inputs) predicted_label = outputs.logits.argmax(-1).item() return f"Predicted Digit: {predicted_label}" # Gradio interface for drawing the digit and displaying the prediction demo = gr.Interface( fn=predict_digit, inputs="sketchpad", # Allow users to draw a digit outputs="text", title="MNIST Digit Recognition", description="Draw a digit (0-9) and let the model recognize it!", live=True # The prediction updates while the user draws ) # Launch the app if __name__ == "__main__": demo.launch()