im2 commited on
Commit
56fd351
1 Parent(s): a0962d2

image data purified

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torch
 
3
  from torchvision import transforms
4
  from PIL import Image
5
 
@@ -31,9 +32,14 @@ model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
31
  model.eval()
32
 
33
  # Gradio preprocessing and prediction pipeline
34
- def predict_digit(image):
35
- # Preprocess the image: resize to 28x28, convert to grayscale, and normalize
36
- image = Image.fromarray(image).convert('L') # Convert to grayscale
 
 
 
 
 
37
  transform = transforms.Compose([
38
  transforms.Resize((28, 28)),
39
  transforms.ToTensor(),
@@ -60,4 +66,4 @@ interface = gr.Interface(
60
 
61
  # Launch the app
62
  if __name__ == "__main__":
63
- interface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import numpy as np
4
  from torchvision import transforms
5
  from PIL import Image
6
 
 
32
  model.eval()
33
 
34
  # Gradio preprocessing and prediction pipeline
35
+ def predict_digit(image_dict):
36
+ # Extract the image array from the dict
37
+ image = image_dict["image"] # Extract the image data from the dictionary
38
+
39
+ # Convert the image to a numpy array, then to a PIL image, and preprocess
40
+ image = Image.fromarray(np.array(image)).convert('L') # Convert to grayscale
41
+
42
+ # Preprocess: resize to 28x28 and normalize
43
  transform = transforms.Compose([
44
  transforms.Resize((28, 28)),
45
  transforms.ToTensor(),
 
66
 
67
  # Launch the app
68
  if __name__ == "__main__":
69
+ interface.launch(share=True)