import gradio as gr import torch from torchvision import transforms from PIL import Image # Load the model using PyTorch model_path = "https://huggingface.co/immartian/improved_digits_recognition/resolve/main/pytorch_model.bin" # Define your ImageClassifier model architecture (same as used during training) class ImageClassifier(torch.nn.Module): def __init__(self): super().__init__() self.model = torch.nn.Sequential( torch.nn.Conv2d(1, 32, (3, 3)), torch.nn.ReLU(), torch.nn.Conv2d(32, 64, (3, 3)), torch.nn.ReLU(), torch.nn.Conv2d(64, 64, (3, 3)), torch.nn.ReLU(), torch.nn.AdaptiveAvgPool2d((1, 1)), torch.nn.Flatten(), torch.nn.Linear(64, 10) ) def forward(self, x): return self.model(x) # Instantiate the model and load weights model = ImageClassifier() model.load_state_dict(torch.hub.load_state_dict_from_url(model_path)) model.eval() # Gradio preprocessing and prediction pipeline def predict_digit(image): # Preprocess the image: resize to 28x28, convert to grayscale, and normalize image = Image.fromarray(image).convert('L') # Convert to grayscale transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) img_tensor = transform(image).unsqueeze(0) # Add batch dimension # Pass through the model with torch.no_grad(): output = model(img_tensor) predicted_label = torch.argmax(output, dim=1).item() return f"Predicted Label: {predicted_label}" # Create Gradio Interface interface = gr.Interface( fn=predict_digit, inputs=gr.Sketchpad(), # Sketchpad for users to draw outputs="text", title="Digit Recognizer", description="Draw a digit (0-9) and the model will predict the number!" ) # Launch the app if __name__ == "__main__": interface.launch()