import gradio as gr from gradio import Brush import torch import numpy as np import cv2 # For Gaussian blur from PIL import Image from torch import nn, save, load from torchvision.transforms import Compose, ToTensor, Normalize, Resize # Load the model using PyTorch model_path = "https://huggingface.co/immartian/improved_digits_recognition/resolve/main/pytorch_model.bin" # Define your ImageClassifier model architecture (same as used during training) class ImageClassifier(torch.nn.Module): def __init__(self): super().__init__() self.model = torch.nn.Sequential( torch.nn.Conv2d(1, 32, (3, 3)), torch.nn.ReLU(), torch.nn.Conv2d(32, 64, (3, 3)), torch.nn.ReLU(), torch.nn.Conv2d(64, 64, (3, 3)), torch.nn.ReLU(), torch.nn.AdaptiveAvgPool2d((1, 1)), torch.nn.Flatten(), torch.nn.Linear(64, 10) ) def forward(self, x): return self.model(x) # Instantiate the model and load weights model = ImageClassifier() model.load_state_dict(torch.hub.load_state_dict_from_url(model_path)) # model.eval() # with open('pytorch_model.bin', 'rb') as f: # model.load_state_dict(load(f)) # Gradio preprocessing and prediction pipeline def predict_digit(image): # Extract the 'composite' key, which contains the drawn image if isinstance(image, dict): image = image.get('composite', None) # Use the composite image if image is None: raise ValueError("No image data found in the input!") #print("Unique pixel values in the image array:", np.unique(image)) # Ensure the input is a numpy array image = np.array(image, dtype=np.uint8) # Apply Gaussian blur to reduce noise image = cv2.GaussianBlur(image, (5, 5), 0) # If the image has multiple channels (e.g., BGR), convert it to grayscale image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) image = cv2.resize(image, (28, 28)) # Optional: Apply adaptive histogram equalization to improve contrast clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) image = clahe.apply(image) # Convert the numpy array back to a PIL Image for torchvision compatibility img_pil = Image.fromarray(image) img_pil.show() transform = Compose([ Resize((28, 28)), ToTensor(), Normalize((0.5,), (0.5,)) ]) img_tensor = transform(img_pil).unsqueeze(0) # Add batch dimension # Debugging: Print tensor shape and some pixel values print(f"Input Tensor Shape: {img_tensor.shape}") print(f"First 5 pixels of the tensor: {img_tensor[0, 0, :5, :5]}") # Pass through the model with torch.no_grad(): output = model(img_tensor) prediction = torch.argmax(output) return f"Predicted Label: {prediction}" # Create Gradio Interface using ImageEditor with gr.Blocks() as demo: with gr.Row(): im = gr.ImageEditor(type="numpy", crop_size="1:1") # im_preview = gr.Image() prediction_box = gr.Textbox(label="Predicted Digit") im.change(predict_digit, outputs=prediction_box, inputs=im, show_progress="hidden") #im.change(predict_digit, outputs="text", inputs=im, show_progress=True) # Launch the app if __name__ == "__main__": demo.launch()