Spaces:
Sleeping
Sleeping
import gradio as gr | |
from gradio import Brush | |
import torch | |
import numpy as np | |
import cv2 # For Gaussian blur | |
from PIL import Image | |
from torch import nn, save, load | |
from torchvision.transforms import Compose, ToTensor, Normalize, Resize | |
# Load the model using PyTorch | |
model_path = "https://huggingface.co/immartian/improved_digits_recognition/resolve/main/pytorch_model.bin" | |
# Define your ImageClassifier model architecture (same as used during training) | |
class ImageClassifier(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = torch.nn.Sequential( | |
torch.nn.Conv2d(1, 32, (3, 3)), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d(32, 64, (3, 3)), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d(64, 64, (3, 3)), | |
torch.nn.ReLU(), | |
torch.nn.AdaptiveAvgPool2d((1, 1)), | |
torch.nn.Flatten(), | |
torch.nn.Linear(64, 10) | |
) | |
def forward(self, x): | |
return self.model(x) | |
# Instantiate the model and load weights | |
model = ImageClassifier() | |
model.load_state_dict(torch.hub.load_state_dict_from_url(model_path)) | |
# model.eval() | |
# with open('pytorch_model.bin', 'rb') as f: | |
# model.load_state_dict(load(f)) | |
# Gradio preprocessing and prediction pipeline | |
def predict_digit(image): | |
# Extract the 'composite' key, which contains the drawn image | |
if isinstance(image, dict): | |
image = image.get('composite', None) # Use the composite image | |
if image is None: | |
raise ValueError("No image data found in the input!") | |
#print("Unique pixel values in the image array:", np.unique(image)) | |
# Ensure the input is a numpy array | |
image = np.array(image, dtype=np.uint8) | |
# Apply Gaussian blur to reduce noise | |
image = cv2.GaussianBlur(image, (5, 5), 0) | |
# If the image has multiple channels (e.g., BGR), convert it to grayscale | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
image = cv2.resize(image, (28, 28)) | |
# Optional: Apply adaptive histogram equalization to improve contrast | |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
image = clahe.apply(image) | |
# Convert the numpy array back to a PIL Image for torchvision compatibility | |
img_pil = Image.fromarray(image) | |
img_pil.show() | |
transform = Compose([ | |
Resize((28, 28)), | |
ToTensor(), | |
Normalize((0.5,), (0.5,)) | |
]) | |
img_tensor = transform(img_pil).unsqueeze(0) # Add batch dimension | |
# Debugging: Print tensor shape and some pixel values | |
print(f"Input Tensor Shape: {img_tensor.shape}") | |
print(f"First 5 pixels of the tensor: {img_tensor[0, 0, :5, :5]}") | |
# Pass through the model | |
with torch.no_grad(): | |
output = model(img_tensor) | |
prediction = torch.argmax(output) | |
return f"Predicted Label: {prediction}" | |
# Create Gradio Interface using ImageEditor | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
im = gr.ImageEditor(type="numpy", crop_size="1:1") | |
# im_preview = gr.Image() | |
prediction_box = gr.Textbox(label="Predicted Digit") | |
im.change(predict_digit, outputs=prediction_box, inputs=im, show_progress="hidden") | |
#im.change(predict_digit, outputs="text", inputs=im, show_progress=True) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() | |