File size: 3,323 Bytes
81a6137
933911c
81a6137
56fd351
933911c
81a6137
933911c
 
 
81a6137
d1d4583
 
81a6137
d1d4583
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
933911c
81a6137
933911c
 
41a6271
d1d4583
02729d4
933911c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1d4583
933911c
 
 
 
 
 
81a6137
d1d4583
 
 
933911c
d1d4583
933911c
81a6137
933911c
 
 
 
 
 
81a6137
933911c
 
 
 
81a6137
 
933911c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
from gradio import Brush
import torch
import numpy as np
import cv2  # For Gaussian blur
from PIL import Image
from torch import nn, save, load
from torchvision.transforms import Compose, ToTensor, Normalize, Resize


# Load the model using PyTorch
model_path = "https://huggingface.co/immartian/improved_digits_recognition/resolve/main/pytorch_model.bin"

# Define your ImageClassifier model architecture (same as used during training)
class ImageClassifier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.AdaptiveAvgPool2d((1, 1)),
            torch.nn.Flatten(),
            torch.nn.Linear(64, 10)
        )
    
    def forward(self, x):
        return self.model(x)

# Instantiate the model and load weights
model = ImageClassifier()
model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
# model.eval()

# with open('pytorch_model.bin', 'rb') as f: 
#     model.load_state_dict(load(f))  

# Gradio preprocessing and prediction pipeline
def predict_digit(image):
    # Extract the 'composite' key, which contains the drawn image
    if isinstance(image, dict):
        image = image.get('composite', None)  # Use the composite image

    if image is None:
        raise ValueError("No image data found in the input!")

    #print("Unique pixel values in the image array:", np.unique(image))

    # Ensure the input is a numpy array
    image = np.array(image, dtype=np.uint8)


    # Apply Gaussian blur to reduce noise
    image = cv2.GaussianBlur(image, (5, 5), 0)

    # If the image has multiple channels (e.g., BGR), convert it to grayscale
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    image = cv2.resize(image, (28, 28))
    # Optional: Apply adaptive histogram equalization to improve contrast
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    image = clahe.apply(image)



    # Convert the numpy array back to a PIL Image for torchvision compatibility
    img_pil = Image.fromarray(image)
    img_pil.show()

    transform = Compose([
        Resize((28, 28)),
        ToTensor(),
        Normalize((0.5,), (0.5,))
    ])

    img_tensor = transform(img_pil).unsqueeze(0)  # Add batch dimension

    # Debugging: Print tensor shape and some pixel values
    print(f"Input Tensor Shape: {img_tensor.shape}")
    print(f"First 5 pixels of the tensor: {img_tensor[0, 0, :5, :5]}")

    # Pass through the model
    with torch.no_grad():
        output = model(img_tensor)
        prediction = torch.argmax(output)
    
    return f"Predicted Label: {prediction}"

# Create Gradio Interface using ImageEditor
with gr.Blocks() as demo:
    with gr.Row():
        im = gr.ImageEditor(type="numpy", crop_size="1:1")
        # im_preview = gr.Image()
        prediction_box = gr.Textbox(label="Predicted Digit")

    im.change(predict_digit, outputs=prediction_box, inputs=im, show_progress="hidden")

    #im.change(predict_digit, outputs="text", inputs=im, show_progress=True)
    
# Launch the app
if __name__ == "__main__":
    demo.launch()