allispaul commited on
Commit
5f37d56
·
1 Parent(s): 3db2b38

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from timeit import default_timer as timer
3
+ from typing import Tuple
4
+ from pathlib import Path
5
+
6
+ import gradio as gr
7
+ import torch
8
+ from torch import nn
9
+ from torchvision import transforms
10
+
11
+ from model import create_effnetb2_model
12
+
13
+ class_names = ["pizza", "steak", "sushi"]
14
+ device = "cpu"
15
+
16
+ # Create model
17
+ effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=len(class_names))
18
+
19
+ # Load saved weights
20
+ effnetb2.load_state_dict(torch.load("effnetb2.pth"),
21
+ map_location=torch.device(device))
22
+
23
+ # Define predict function
24
+ def predict(img: Image) -> Tuple[dict, float]:
25
+ """Uses EffnetB2 model to transform and predict on img. Returns prediction
26
+ probabilities and time taken.
27
+
28
+ Args:
29
+ img (PIL.Image): Image to predict on.
30
+
31
+ Returns:
32
+ A tuple (pred_labels_and_probs, pred_time), where pred_labels_and_probs
33
+ is a dict mapping each class name to the probability the model assigns to
34
+ it, and pred_time is the time taken to predict (in seconds).
35
+ """
36
+ start_time = timer()
37
+ img = effnetb2_transforms(img).unsqueeze(0)
38
+ effnetb2.eval()
39
+ with torch.inference_mode():
40
+ pred_probs = torch.softmax(effnetb2(img), dim=1)
41
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i])
42
+ for i in range(len(class_names))}
43
+ pred_time = round(timer() - start_time, 4)
44
+ return pred_labels_and_probs, pred_time
45
+
46
+ # Initialize Gradio app
47
+ title = "FoodVision Mini"
48
+ description = "EfficientNetB2 feature extractor to classify images of food as pizza, steak, or sushi."
49
+ article = "From the [Zero to Mastery PyTorch tutorial](https://www.learnpytorch.io/09_pytorch_model_deployment/)"
50
+ examples = [list(example) for example in Path("examples").glob("*.jpg")]
51
+
52
+ demo = gr.Interface(
53
+ fn=predict,
54
+ inputs=gr.Image(type="pil"),
55
+ outputs=[gr.Label(num_top_classes=3, label="Predictions"),
56
+ gr.Number(label="Prediction time (s)")],
57
+ examples=example_list,
58
+ title=title,
59
+ description=description,
60
+ article=article,
61
+ )
62
+
63
+ demo.lauch()
effnetb2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:796022ec640571b749d822bb03ffaac90c49bded116726793cf9cc35e6b7109d
3
+ size 31294149
examples/.ipynb_checkpoints/1180001-checkpoint.jpg ADDED
examples/.ipynb_checkpoints/1280320-checkpoint.jpg ADDED
examples/.ipynb_checkpoints/705150-checkpoint.jpg ADDED
examples/1180001.jpg ADDED
examples/1280320.jpg ADDED
examples/705150.jpg ADDED
model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torchvision
6
+
7
+ def create_effnetb2_model(num_classes: int = 3,
8
+ seed: int = 4,
9
+ ) -> Tuple[nn.Module, torchvision.Transforms]:
10
+ """Create an EfficientNetB2 feature extractor model and transforms.
11
+
12
+ Args:
13
+ num_classes: Number of classes to use for classification (default 3).
14
+ seed: Random seed for reproducibility (default 4).
15
+
16
+ Returns:
17
+ A tuple (model, transforms) of the model and its image transforms.
18
+ """
19
+ weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
20
+ transforms = effnet_b2_weights.transforms()
21
+ model = torchvision.models.efficientnet_b2(weights=weights)
22
+
23
+ # Freeze parameters below the head
24
+ for param in model.parameters():
25
+ param.requires_grad = False
26
+ # Replace the classifier head with one of appropriate size for the problem
27
+ torch.manual_seed(seed)
28
+ model.classifier = nn.Sequential(
29
+ nn.Dropout(p=0.3, inplace=True),
30
+ nn.Linear(in_features=1408, out_features=len(class_names))
31
+ )
32
+ return model, transforms
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio==3.37.0
2
+ torch==2.0.1
3
+ torchvision==0.15.2