betterdigits / app.py
im2
improved
933911c
raw
history blame contribute delete
No virus
3.32 kB
import gradio as gr
from gradio import Brush
import torch
import numpy as np
import cv2 # For Gaussian blur
from PIL import Image
from torch import nn, save, load
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
# Load the model using PyTorch
model_path = "https://huggingface.co/immartian/improved_digits_recognition/resolve/main/pytorch_model.bin"
# Define your ImageClassifier model architecture (same as used during training)
class ImageClassifier(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, (3, 3)),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 64, (3, 3)),
torch.nn.ReLU(),
torch.nn.Conv2d(64, 64, (3, 3)),
torch.nn.ReLU(),
torch.nn.AdaptiveAvgPool2d((1, 1)),
torch.nn.Flatten(),
torch.nn.Linear(64, 10)
)
def forward(self, x):
return self.model(x)
# Instantiate the model and load weights
model = ImageClassifier()
model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
# model.eval()
# with open('pytorch_model.bin', 'rb') as f:
# model.load_state_dict(load(f))
# Gradio preprocessing and prediction pipeline
def predict_digit(image):
# Extract the 'composite' key, which contains the drawn image
if isinstance(image, dict):
image = image.get('composite', None) # Use the composite image
if image is None:
raise ValueError("No image data found in the input!")
#print("Unique pixel values in the image array:", np.unique(image))
# Ensure the input is a numpy array
image = np.array(image, dtype=np.uint8)
# Apply Gaussian blur to reduce noise
image = cv2.GaussianBlur(image, (5, 5), 0)
# If the image has multiple channels (e.g., BGR), convert it to grayscale
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = cv2.resize(image, (28, 28))
# Optional: Apply adaptive histogram equalization to improve contrast
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
image = clahe.apply(image)
# Convert the numpy array back to a PIL Image for torchvision compatibility
img_pil = Image.fromarray(image)
img_pil.show()
transform = Compose([
Resize((28, 28)),
ToTensor(),
Normalize((0.5,), (0.5,))
])
img_tensor = transform(img_pil).unsqueeze(0) # Add batch dimension
# Debugging: Print tensor shape and some pixel values
print(f"Input Tensor Shape: {img_tensor.shape}")
print(f"First 5 pixels of the tensor: {img_tensor[0, 0, :5, :5]}")
# Pass through the model
with torch.no_grad():
output = model(img_tensor)
prediction = torch.argmax(output)
return f"Predicted Label: {prediction}"
# Create Gradio Interface using ImageEditor
with gr.Blocks() as demo:
with gr.Row():
im = gr.ImageEditor(type="numpy", crop_size="1:1")
# im_preview = gr.Image()
prediction_box = gr.Textbox(label="Predicted Digit")
im.change(predict_digit, outputs=prediction_box, inputs=im, show_progress="hidden")
#im.change(predict_digit, outputs="text", inputs=im, show_progress=True)
# Launch the app
if __name__ == "__main__":
demo.launch()