im2 commited on
Commit
933911c
1 Parent(s): 41a6271
Files changed (1) hide show
  1. app.py +60 -28
app.py CHANGED
@@ -1,8 +1,12 @@
1
  import gradio as gr
 
2
  import torch
3
  import numpy as np
4
- from torchvision import transforms
5
  from PIL import Image
 
 
 
6
 
7
  # Load the model using PyTorch
8
  model_path = "https://huggingface.co/immartian/improved_digits_recognition/resolve/main/pytorch_model.bin"
@@ -29,44 +33,72 @@ class ImageClassifier(torch.nn.Module):
29
  # Instantiate the model and load weights
30
  model = ImageClassifier()
31
  model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
32
- model.eval()
33
 
34
- # Function to process sketchpad input
35
- def sketchToNumpy(image):
36
- # Extract the 'composite' key from the sketchpad input dictionary
37
- imArray = image['composite'] # 'composite' contains the drawn image
38
- return imArray
39
 
40
  # Gradio preprocessing and prediction pipeline
41
  def predict_digit(image):
42
- # Convert the sketchpad input into a PIL Image
43
- image = Image.fromarray(image).convert('L') # Convert to grayscale
44
-
45
- # Preprocess: resize to 28x28 and normalize
46
- transform = transforms.Compose([
47
- transforms.Resize((28, 28)),
48
- transforms.ToTensor(),
49
- transforms.Normalize((0.5,), (0.5,))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  ])
51
-
52
- img_tensor = transform(image).unsqueeze(0) # Add batch dimension
 
 
 
 
53
 
54
  # Pass through the model
55
  with torch.no_grad():
56
  output = model(img_tensor)
57
- predicted_label = torch.argmax(output, dim=1).item()
58
 
59
- return f"Predicted Label: {predicted_label}"
60
 
61
- # Create Gradio Interface
62
- interface = gr.Interface(
63
- fn=lambda x: predict_digit(sketchToNumpy(x)),
64
- inputs=gr.Sketchpad(crop_size=(256,256), type='numpy', image_mode='L'),
65
- outputs="text",
66
- title="Digit Recognizer",
67
- description="Draw a digit (0-9) and the model will predict the number!"
68
- )
69
 
 
 
 
 
70
  # Launch the app
71
  if __name__ == "__main__":
72
- interface.launch()
 
1
  import gradio as gr
2
+ from gradio import Brush
3
  import torch
4
  import numpy as np
5
+ import cv2 # For Gaussian blur
6
  from PIL import Image
7
+ from torch import nn, save, load
8
+ from torchvision.transforms import Compose, ToTensor, Normalize, Resize
9
+
10
 
11
  # Load the model using PyTorch
12
  model_path = "https://huggingface.co/immartian/improved_digits_recognition/resolve/main/pytorch_model.bin"
 
33
  # Instantiate the model and load weights
34
  model = ImageClassifier()
35
  model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
36
+ # model.eval()
37
 
38
+ # with open('pytorch_model.bin', 'rb') as f:
39
+ # model.load_state_dict(load(f))
 
 
 
40
 
41
  # Gradio preprocessing and prediction pipeline
42
  def predict_digit(image):
43
+ # Extract the 'composite' key, which contains the drawn image
44
+ if isinstance(image, dict):
45
+ image = image.get('composite', None) # Use the composite image
46
+
47
+ if image is None:
48
+ raise ValueError("No image data found in the input!")
49
+
50
+ #print("Unique pixel values in the image array:", np.unique(image))
51
+
52
+ # Ensure the input is a numpy array
53
+ image = np.array(image, dtype=np.uint8)
54
+
55
+
56
+ # Apply Gaussian blur to reduce noise
57
+ image = cv2.GaussianBlur(image, (5, 5), 0)
58
+
59
+ # If the image has multiple channels (e.g., BGR), convert it to grayscale
60
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
61
+ image = cv2.resize(image, (28, 28))
62
+ # Optional: Apply adaptive histogram equalization to improve contrast
63
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
64
+ image = clahe.apply(image)
65
+
66
+
67
+
68
+ # Convert the numpy array back to a PIL Image for torchvision compatibility
69
+ img_pil = Image.fromarray(image)
70
+ img_pil.show()
71
+
72
+ transform = Compose([
73
+ Resize((28, 28)),
74
+ ToTensor(),
75
+ Normalize((0.5,), (0.5,))
76
  ])
77
+
78
+ img_tensor = transform(img_pil).unsqueeze(0) # Add batch dimension
79
+
80
+ # Debugging: Print tensor shape and some pixel values
81
+ print(f"Input Tensor Shape: {img_tensor.shape}")
82
+ print(f"First 5 pixels of the tensor: {img_tensor[0, 0, :5, :5]}")
83
 
84
  # Pass through the model
85
  with torch.no_grad():
86
  output = model(img_tensor)
87
+ prediction = torch.argmax(output)
88
 
89
+ return f"Predicted Label: {prediction}"
90
 
91
+ # Create Gradio Interface using ImageEditor
92
+ with gr.Blocks() as demo:
93
+ with gr.Row():
94
+ im = gr.ImageEditor(type="numpy", crop_size="1:1")
95
+ # im_preview = gr.Image()
96
+ prediction_box = gr.Textbox(label="Predicted Digit")
 
 
97
 
98
+ im.change(predict_digit, outputs=prediction_box, inputs=im, show_progress="hidden")
99
+
100
+ #im.change(predict_digit, outputs="text", inputs=im, show_progress=True)
101
+
102
  # Launch the app
103
  if __name__ == "__main__":
104
+ demo.launch()