Spaces:
Sleeping
Sleeping
im2
commited on
Commit
•
d1d4583
1
Parent(s):
cbb8b31
remolve transformer
Browse files- app.py +46 -28
- requirements.txt +1 -2
app.py
CHANGED
@@ -2,44 +2,62 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
from torchvision import transforms
|
4 |
from PIL import Image
|
5 |
-
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
|
6 |
|
7 |
-
# Load the model
|
8 |
-
|
9 |
-
model = AutoModelForImageClassification.from_pretrained(model_name)
|
10 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
#
|
22 |
def predict_digit(image):
|
23 |
-
# Preprocess the
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
#
|
27 |
-
with torch.no_grad():
|
28 |
-
outputs = model(inputs)
|
29 |
-
predicted_label = outputs.logits.argmax(-1).item()
|
30 |
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
# Gradio
|
34 |
-
|
35 |
fn=predict_digit,
|
36 |
-
inputs="
|
37 |
outputs="text",
|
38 |
-
title="
|
39 |
-
description="Draw a digit (0-9) and
|
40 |
-
live=True # The prediction updates while the user draws
|
41 |
)
|
42 |
|
43 |
# Launch the app
|
44 |
if __name__ == "__main__":
|
45 |
-
|
|
|
2 |
import torch
|
3 |
from torchvision import transforms
|
4 |
from PIL import Image
|
|
|
5 |
|
6 |
+
# Load the model using PyTorch
|
7 |
+
model_path = "https://huggingface.co/immartian/improved_digits_recognition/resolve/main/pytorch_model.bin"
|
|
|
|
|
8 |
|
9 |
+
# Define your ImageClassifier model architecture (same as used during training)
|
10 |
+
class ImageClassifier(torch.nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
self.model = torch.nn.Sequential(
|
14 |
+
torch.nn.Conv2d(1, 32, (3, 3)),
|
15 |
+
torch.nn.ReLU(),
|
16 |
+
torch.nn.Conv2d(32, 64, (3, 3)),
|
17 |
+
torch.nn.ReLU(),
|
18 |
+
torch.nn.Conv2d(64, 64, (3, 3)),
|
19 |
+
torch.nn.ReLU(),
|
20 |
+
torch.nn.AdaptiveAvgPool2d((1, 1)),
|
21 |
+
torch.nn.Flatten(),
|
22 |
+
torch.nn.Linear(64, 10)
|
23 |
+
)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
return self.model(x)
|
27 |
+
|
28 |
+
# Instantiate the model and load weights
|
29 |
+
model = ImageClassifier()
|
30 |
+
model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
|
31 |
+
model.eval()
|
32 |
|
33 |
+
# Gradio preprocessing and prediction pipeline
|
34 |
def predict_digit(image):
|
35 |
+
# Preprocess the image: resize to 28x28, convert to grayscale, and normalize
|
36 |
+
image = image.convert('L') # Convert to grayscale
|
37 |
+
transform = transforms.Compose([
|
38 |
+
transforms.Resize((28, 28)),
|
39 |
+
transforms.ToTensor(),
|
40 |
+
transforms.Normalize((0.5,), (0.5,))
|
41 |
+
])
|
42 |
|
43 |
+
img_tensor = transform(image).unsqueeze(0) # Add batch dimension
|
|
|
|
|
|
|
44 |
|
45 |
+
# Pass through the model
|
46 |
+
with torch.no_grad():
|
47 |
+
output = model(img_tensor)
|
48 |
+
predicted_label = torch.argmax(output, dim=1).item()
|
49 |
+
|
50 |
+
return f"Predicted Label: {predicted_label}"
|
51 |
|
52 |
+
# Create Gradio Interface
|
53 |
+
interface = gr.Interface(
|
54 |
fn=predict_digit,
|
55 |
+
inputs=gr.Image(source="canvas", tool="editor", type="pil"), # User can draw on a canvas
|
56 |
outputs="text",
|
57 |
+
title="Digit Recognizer",
|
58 |
+
description="Draw a digit (0-9) and the model will predict the number!"
|
|
|
59 |
)
|
60 |
|
61 |
# Launch the app
|
62 |
if __name__ == "__main__":
|
63 |
+
interface.launch()
|
requirements.txt
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
torch
|
2 |
torchvision
|
3 |
gradio
|
4 |
-
|
5 |
-
Pillow # Required for image processing
|
|
|
1 |
torch
|
2 |
torchvision
|
3 |
gradio
|
4 |
+
Pillow
|
|