Spaces:
Runtime error
Runtime error
deleted unnecessary methods
Browse files
app.py
CHANGED
@@ -10,42 +10,7 @@ from PIL import Image
|
|
10 |
|
11 |
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
# Moving both Data and Model into GPU
|
16 |
-
|
17 |
-
def get_default_device():
|
18 |
-
"""Pick GPU if available, else CPU"""
|
19 |
-
if torch.cuda.is_available():
|
20 |
-
return torch.device('cuda')
|
21 |
-
else:
|
22 |
-
return torch.device('cpu')
|
23 |
-
|
24 |
-
def to_device(data, device):
|
25 |
-
"""Move tensor(s) to chosen device"""
|
26 |
-
if isinstance(data, (list,tuple)):
|
27 |
-
return [to_device(x, device) for x in data]
|
28 |
-
return data.to(device, non_blocking=True)
|
29 |
-
|
30 |
-
class DeviceDataLoader():
|
31 |
-
"""Wrap a dataloader to move data to a device"""
|
32 |
-
def __init__(self, dl, device):
|
33 |
-
self.dl = dl
|
34 |
-
self.device = device
|
35 |
-
|
36 |
-
def __iter__(self):
|
37 |
-
"""Yield a batch of data after moving it to device"""
|
38 |
-
for b in self.dl:
|
39 |
-
yield to_device(b, self.device)
|
40 |
-
|
41 |
-
def __len__(self):
|
42 |
-
"""Number of batches"""
|
43 |
-
return len(self.dl)
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
# Defining our Class for just prediction
|
48 |
-
|
49 |
def accuracy(outputs, labels):
|
50 |
_, preds = torch.max(outputs, dim=1)
|
51 |
return torch.tensor(torch.sum(preds == labels).item() / len(preds))
|
@@ -69,13 +34,10 @@ class ImageClassificationBase(nn.Module):
|
|
69 |
|
70 |
|
71 |
# Defining our finetuned Resnet50 Architecture with our Classification layer
|
72 |
-
|
73 |
class IndianFoodModelResnet50(ImageClassificationBase):
|
74 |
def __init__(self, num_classes, pretrained=True):
|
75 |
super().__init__()
|
76 |
-
# Use a pretrained model
|
77 |
self.network = models.resnet50(pretrained=pretrained)
|
78 |
-
# Replace last layer
|
79 |
self.network.fc = nn.Linear(self.network.fc.in_features, num_classes)
|
80 |
|
81 |
def forward(self, xb):
|
@@ -83,7 +45,7 @@ class IndianFoodModelResnet50(ImageClassificationBase):
|
|
83 |
|
84 |
|
85 |
|
86 |
-
#
|
87 |
@torch.no_grad()
|
88 |
def evaluate(model, val_loader):
|
89 |
model.eval()
|
@@ -92,7 +54,7 @@ def evaluate(model, val_loader):
|
|
92 |
|
93 |
|
94 |
|
95 |
-
#
|
96 |
classes = ['burger', 'butter_naan', 'chai', 'chapati', 'chole_bhature',
|
97 |
'dal_makhani', 'dhokla', 'fried_rice', 'idli', 'jalebi',
|
98 |
'kaathi_rolls', 'kadai_paneer', 'kulfi', 'masala_dosa', 'momos',
|
@@ -103,40 +65,37 @@ to_device(model, device);
|
|
103 |
|
104 |
|
105 |
|
106 |
-
#
|
107 |
ckp_path = 'indianFood-resnet50.pth'
|
108 |
model.load_state_dict(torch.load(ckp_path, map_location=torch.device('cpu')))
|
109 |
model.eval()
|
110 |
|
111 |
|
112 |
|
113 |
-
#
|
114 |
stats = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
115 |
img_tfms = tt.Compose([tt.Resize((224, 224)),
|
116 |
tt.ToTensor(),
|
117 |
tt.Normalize(*stats, inplace = True)])
|
118 |
|
119 |
def predict_image(image, model):
|
120 |
-
# Convert to a batch of 1
|
121 |
xb = to_device(image.unsqueeze(0), device)
|
122 |
-
# Get predictions from model
|
123 |
yb = model(xb)
|
124 |
-
# Pick index with highest probability
|
125 |
_, preds = torch.max(yb, dim=1)
|
126 |
-
# Retrieve the class label
|
127 |
return classes[preds[0].item()]
|
128 |
|
129 |
|
130 |
|
|
|
131 |
def classify_image(path):
|
132 |
img = Image.open(path)
|
133 |
img = img_tfms(img)
|
134 |
-
#img = img.permute(2, 0, 1)
|
135 |
label = predict_image(img, model)
|
136 |
-
|
137 |
return label
|
138 |
|
139 |
|
|
|
|
|
140 |
image = gr.inputs.Image(shape=(224, 224), type="filepath")
|
141 |
label = gr.outputs.Label(num_top_classes=1)
|
142 |
|
@@ -149,7 +108,6 @@ gr.Interface(
|
|
149 |
outputs=label,
|
150 |
examples = [["idli.jpg"], ["naan.jpg"]],
|
151 |
theme = "huggingface",
|
152 |
-
layout = "horizontal",
|
153 |
title = "DesiVisionNet: Desi Food Vision with ResNet",
|
154 |
description = "This is a Gradio demo for multi-class image classification of Indian food amongst 20 classes. The DesiVisionNet achieved 90% accuracy on our test dataset, performing well for a relatively efficient model. See the GitHub project page for detailed information below. Here, we provide a demo for real-world food classification. To use it, simply upload your image, or click one of the examples to load them.",
|
155 |
article = article
|
|
|
10 |
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# Defining our Class for just prediction
|
|
|
14 |
def accuracy(outputs, labels):
|
15 |
_, preds = torch.max(outputs, dim=1)
|
16 |
return torch.tensor(torch.sum(preds == labels).item() / len(preds))
|
|
|
34 |
|
35 |
|
36 |
# Defining our finetuned Resnet50 Architecture with our Classification layer
|
|
|
37 |
class IndianFoodModelResnet50(ImageClassificationBase):
|
38 |
def __init__(self, num_classes, pretrained=True):
|
39 |
super().__init__()
|
|
|
40 |
self.network = models.resnet50(pretrained=pretrained)
|
|
|
41 |
self.network.fc = nn.Linear(self.network.fc.in_features, num_classes)
|
42 |
|
43 |
def forward(self, xb):
|
|
|
45 |
|
46 |
|
47 |
|
48 |
+
# Prediction method
|
49 |
@torch.no_grad()
|
50 |
def evaluate(model, val_loader):
|
51 |
model.eval()
|
|
|
54 |
|
55 |
|
56 |
|
57 |
+
# Initialising our model and moving it to CPU
|
58 |
classes = ['burger', 'butter_naan', 'chai', 'chapati', 'chole_bhature',
|
59 |
'dal_makhani', 'dhokla', 'fried_rice', 'idli', 'jalebi',
|
60 |
'kaathi_rolls', 'kadai_paneer', 'kulfi', 'masala_dosa', 'momos',
|
|
|
65 |
|
66 |
|
67 |
|
68 |
+
# Loading the model
|
69 |
ckp_path = 'indianFood-resnet50.pth'
|
70 |
model.load_state_dict(torch.load(ckp_path, map_location=torch.device('cpu')))
|
71 |
model.eval()
|
72 |
|
73 |
|
74 |
|
75 |
+
# Image preprocessing before prediction
|
76 |
stats = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
77 |
img_tfms = tt.Compose([tt.Resize((224, 224)),
|
78 |
tt.ToTensor(),
|
79 |
tt.Normalize(*stats, inplace = True)])
|
80 |
|
81 |
def predict_image(image, model):
|
|
|
82 |
xb = to_device(image.unsqueeze(0), device)
|
|
|
83 |
yb = model(xb)
|
|
|
84 |
_, preds = torch.max(yb, dim=1)
|
|
|
85 |
return classes[preds[0].item()]
|
86 |
|
87 |
|
88 |
|
89 |
+
# Function handling input, processing and output
|
90 |
def classify_image(path):
|
91 |
img = Image.open(path)
|
92 |
img = img_tfms(img)
|
|
|
93 |
label = predict_image(img, model)
|
|
|
94 |
return label
|
95 |
|
96 |
|
97 |
+
|
98 |
+
# Defining gradio interface functions
|
99 |
image = gr.inputs.Image(shape=(224, 224), type="filepath")
|
100 |
label = gr.outputs.Label(num_top_classes=1)
|
101 |
|
|
|
108 |
outputs=label,
|
109 |
examples = [["idli.jpg"], ["naan.jpg"]],
|
110 |
theme = "huggingface",
|
|
|
111 |
title = "DesiVisionNet: Desi Food Vision with ResNet",
|
112 |
description = "This is a Gradio demo for multi-class image classification of Indian food amongst 20 classes. The DesiVisionNet achieved 90% accuracy on our test dataset, performing well for a relatively efficient model. See the GitHub project page for detailed information below. Here, we provide a demo for real-world food classification. To use it, simply upload your image, or click one of the examples to load them.",
|
113 |
article = article
|