im2 commited on
Commit
d1d4583
1 Parent(s): cbb8b31

remolve transformer

Browse files
Files changed (2) hide show
  1. app.py +46 -28
  2. requirements.txt +1 -2
app.py CHANGED
@@ -2,44 +2,62 @@ import gradio as gr
2
  import torch
3
  from torchvision import transforms
4
  from PIL import Image
5
- from transformers import AutoModelForImageClassification, AutoFeatureExtractor
6
 
7
- # Load the model and feature extractor from Hugging Face
8
- model_name = "immartian/improved_digits_recognition"
9
- model = AutoModelForImageClassification.from_pretrained(model_name)
10
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
11
 
12
- # Preprocessing function to transform the drawn image into a format the model can recognize
13
- def preprocess_image(image):
14
- # Convert the image into a format suitable for the model
15
- image = Image.fromarray(image).convert('L') # Convert to grayscale
16
- image = image.resize((28, 28)) # Resize to 28x28 pixels
17
- image = image.convert('RGB') # Model expects 3-channel images, so convert to RGB
18
- inputs = feature_extractor(images=image, return_tensors="pt")
19
- return inputs['pixel_values']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Prediction function to classify the drawn digit
22
  def predict_digit(image):
23
- # Preprocess the input image
24
- inputs = preprocess_image(image)
 
 
 
 
 
25
 
26
- # Make the prediction
27
- with torch.no_grad():
28
- outputs = model(inputs)
29
- predicted_label = outputs.logits.argmax(-1).item()
30
 
31
- return f"Predicted Digit: {predicted_label}"
 
 
 
 
 
32
 
33
- # Gradio interface for drawing the digit and displaying the prediction
34
- demo = gr.Interface(
35
  fn=predict_digit,
36
- inputs="sketchpad", # Allow users to draw a digit
37
  outputs="text",
38
- title="MNIST Digit Recognition",
39
- description="Draw a digit (0-9) and let the model recognize it!",
40
- live=True # The prediction updates while the user draws
41
  )
42
 
43
  # Launch the app
44
  if __name__ == "__main__":
45
- demo.launch()
 
2
  import torch
3
  from torchvision import transforms
4
  from PIL import Image
 
5
 
6
+ # Load the model using PyTorch
7
+ model_path = "https://huggingface.co/immartian/improved_digits_recognition/resolve/main/pytorch_model.bin"
 
 
8
 
9
+ # Define your ImageClassifier model architecture (same as used during training)
10
+ class ImageClassifier(torch.nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.model = torch.nn.Sequential(
14
+ torch.nn.Conv2d(1, 32, (3, 3)),
15
+ torch.nn.ReLU(),
16
+ torch.nn.Conv2d(32, 64, (3, 3)),
17
+ torch.nn.ReLU(),
18
+ torch.nn.Conv2d(64, 64, (3, 3)),
19
+ torch.nn.ReLU(),
20
+ torch.nn.AdaptiveAvgPool2d((1, 1)),
21
+ torch.nn.Flatten(),
22
+ torch.nn.Linear(64, 10)
23
+ )
24
+
25
+ def forward(self, x):
26
+ return self.model(x)
27
+
28
+ # Instantiate the model and load weights
29
+ model = ImageClassifier()
30
+ model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
31
+ model.eval()
32
 
33
+ # Gradio preprocessing and prediction pipeline
34
  def predict_digit(image):
35
+ # Preprocess the image: resize to 28x28, convert to grayscale, and normalize
36
+ image = image.convert('L') # Convert to grayscale
37
+ transform = transforms.Compose([
38
+ transforms.Resize((28, 28)),
39
+ transforms.ToTensor(),
40
+ transforms.Normalize((0.5,), (0.5,))
41
+ ])
42
 
43
+ img_tensor = transform(image).unsqueeze(0) # Add batch dimension
 
 
 
44
 
45
+ # Pass through the model
46
+ with torch.no_grad():
47
+ output = model(img_tensor)
48
+ predicted_label = torch.argmax(output, dim=1).item()
49
+
50
+ return f"Predicted Label: {predicted_label}"
51
 
52
+ # Create Gradio Interface
53
+ interface = gr.Interface(
54
  fn=predict_digit,
55
+ inputs=gr.Image(source="canvas", tool="editor", type="pil"), # User can draw on a canvas
56
  outputs="text",
57
+ title="Digit Recognizer",
58
+ description="Draw a digit (0-9) and the model will predict the number!"
 
59
  )
60
 
61
  # Launch the app
62
  if __name__ == "__main__":
63
+ interface.launch()
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  torch
2
  torchvision
3
  gradio
4
- transformers
5
- Pillow # Required for image processing
 
1
  torch
2
  torchvision
3
  gradio
4
+ Pillow