Spaces:
Runtime error
Runtime error
Chaitanya Garg
commited on
Commit
·
9bb0389
1
Parent(s):
87886d5
all files
Browse files- ViT.py +19 -0
- ViTModel.pt +3 -0
- app.py +32 -0
- examples/example1.jpg +0 -0
- examples/example2.jpg +0 -0
- examples/example3.jpg +0 -0
- examples/example4.jpg +0 -0
- examples/example5.jpg +0 -0
- helper.py +18 -0
- model.py +17 -0
- partViT.py +70 -0
- predictor.py +24 -0
- requirements.txt +3 -0
ViT.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from MyModel.partsViT import patchNPositionalEmbeddingMaker,transformerEncoderBlock
|
4 |
+
|
5 |
+
class ViT(nn.Module):
|
6 |
+
def __init__(self,inChannels,outChannels,patchSize,imgSize, hiddenLayer,numHeads,MLPdropOut,numTransformLayers,numClasses,embeddingDropOut=0.1,attnDropOut=0):
|
7 |
+
super().__init__()
|
8 |
+
self.EmbeddingMaker = patchNPositionalEmbeddingMaker(inChannels,outChannels,patchSize,imgSize)
|
9 |
+
# self.transformerEncodingBlock = transformerEncoderBlock(outChannels,hiddenLayer,numHeads,MLPdropOut,attnDropOut)
|
10 |
+
self.embeddingDrop = nn.Dropout(embeddingDropOut)
|
11 |
+
self.TransformEncoder = nn.Sequential(*[transformerEncoderBlock(outChannels,hiddenLayer,numHeads,MLPdropOut,attnDropOut) for _ in range(numTransformLayers)])
|
12 |
+
self.Classifier = nn.Sequential(nn.LayerNorm(normalized_shape=outChannels),
|
13 |
+
nn.Linear(outChannels,numClasses))
|
14 |
+
def forward(self,x):
|
15 |
+
x = self.EmbeddingMaker(x)
|
16 |
+
x = self.embeddingDrop(x)
|
17 |
+
x = self.TransformEncoder(x)
|
18 |
+
x = self.Classifier(x[:,0])
|
19 |
+
return x
|
ViTModel.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bf9a58f8286e2d46f20877ab7bea7b38be359b1aa175273df92ac9150e8257d7
|
3 |
+
size 343564994
|
app.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Imports for Modules ###
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
from typing import Tuple, Dict
|
6 |
+
from timeit import default_timer as timer
|
7 |
+
|
8 |
+
### Functional Imports
|
9 |
+
from predictor import predictionMaker
|
10 |
+
|
11 |
+
exampleList = [["examples/" + example] for example in os.listdir("examples")]
|
12 |
+
|
13 |
+
title = "Food Vision👀 on Food101 Using ViT"
|
14 |
+
description = "Trained a ViT to classify images of food based on [Food101](https://pytorch.org/vision/main/generated/torchvision.datasets.Food101.html)."
|
15 |
+
article = "Created by [Eternal Bliassard](https://github.com/EternalBlissard)."
|
16 |
+
|
17 |
+
# Create the Gradio demo
|
18 |
+
demo = gr.Interface(fn=predictionMaker,
|
19 |
+
inputs=[gr.Image(type="pil")],
|
20 |
+
outputs=[gr.Label(num_top_classes=3, label="Predictions"),
|
21 |
+
gr.Number(label="Prediction time (s)")],
|
22 |
+
examples=exampleList,
|
23 |
+
title=title,
|
24 |
+
description=description,
|
25 |
+
article=article)
|
26 |
+
|
27 |
+
# Launch the demo!
|
28 |
+
demo.launch()
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
examples/example1.jpg
ADDED
![]() |
examples/example2.jpg
ADDED
![]() |
examples/example3.jpg
ADDED
![]() |
examples/example4.jpg
ADDED
![]() |
examples/example5.jpg
ADDED
![]() |
helper.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
import zipfile
|
8 |
+
from pathlib import Path
|
9 |
+
import requests
|
10 |
+
|
11 |
+
def setAllSeeds(seed):
|
12 |
+
os.environ['MY_GLOBAL_SEED'] = str(seed)
|
13 |
+
random.seed(seed)
|
14 |
+
np.random.seed(seed)
|
15 |
+
torch.manual_seed(seed)
|
16 |
+
torch.cuda.manual_seed_all(seed)
|
17 |
+
|
18 |
+
|
model.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torchvision
|
4 |
+
from torch import nn
|
5 |
+
from helper import setAllSeeds
|
6 |
+
from ViT import ViT
|
7 |
+
|
8 |
+
def getViT(seed,classNames,DEVICE):
|
9 |
+
setAllSeeds(seed)
|
10 |
+
ViTModel = ViT(3,768,16,224,3072,12,0.1,12,len(classNames)).to(DEVICE)
|
11 |
+
vitWeights = torchvision.models.ViT_B_16_Weights.DEFAULT
|
12 |
+
vitTransforms = vitWeights.transforms()
|
13 |
+
vit = torchvision.models.vit_b_16(weights=vitWeights).to(DEVICE)
|
14 |
+
for param in vit.parameters():
|
15 |
+
param.requires_grad = False
|
16 |
+
vit.heads = nn.Linear(in_features=768, out_features=len(classNames)).to(DEVICE)
|
17 |
+
return vit,vitTransforms
|
partViT.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class multiHeadSelfAttentionBlock(nn.Module):
|
5 |
+
def __init__(self,embeddingDim=768,numHeads=12,attnDropOut=0):
|
6 |
+
super().__init__()
|
7 |
+
self.layerNorm = nn.LayerNorm(normalized_shape=embeddingDim)
|
8 |
+
self.multiheadAttn = nn.MultiheadAttention(embed_dim=embeddingDim,num_heads=numHeads,dropout=attnDropOut,batch_first=True)
|
9 |
+
|
10 |
+
def forward(self,x):
|
11 |
+
layNorm = self.layerNorm(x)
|
12 |
+
attnOutPut, _ = self.multiheadAttn(query=layNorm,key=layNorm,value=layNorm)
|
13 |
+
return attnOutPut
|
14 |
+
|
15 |
+
class MLPBlock(nn.Module):
|
16 |
+
def __init__(self,embeddingDim,hiddenLayer,dropOut=0.1):
|
17 |
+
super().__init__()
|
18 |
+
self.MLP = nn.Sequential(
|
19 |
+
nn.LayerNorm(normalized_shape = embeddingDim),
|
20 |
+
nn.Linear(embeddingDim, hiddenLayer),
|
21 |
+
nn.GELU(),
|
22 |
+
nn.Dropout(dropOut),
|
23 |
+
nn.Linear(hiddenLayer,embeddingDim),
|
24 |
+
nn.Dropout(dropOut)
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self,x):
|
28 |
+
return self.MLP(x)
|
29 |
+
|
30 |
+
class transformerEncoderBlock(nn.Module):
|
31 |
+
def __init__(self, embeddingDim, hiddenLayer,numHeads,MLPdropOut,attnDropOut=0):
|
32 |
+
super().__init__()
|
33 |
+
self.MSABlock = multiHeadSelfAttentionBlock(embeddingDim,numHeads,attnDropOut)
|
34 |
+
self.MLPBlock = MLPBlock(embeddingDim,hiddenLayer,MLPdropOut)
|
35 |
+
|
36 |
+
def forward(self,x):
|
37 |
+
x = self.MSABlock(x) + x
|
38 |
+
x = self.MLPBlock(x) + x
|
39 |
+
return x
|
40 |
+
|
41 |
+
class patchNPositionalEmbeddingMaker(nn.Module):
|
42 |
+
def __init__(self,inChannels,outChannels,patchSize,imgSize):
|
43 |
+
super().__init__()
|
44 |
+
self.outChannels = outChannels
|
45 |
+
|
46 |
+
# outChannels is the same as embeddingDim
|
47 |
+
self.patchSize = patchSize
|
48 |
+
self.numPatches = int(imgSize**2/patchSize**2)
|
49 |
+
self.patchMaker = nn.Conv2d(inChannels,outChannels, kernel_size=patchSize,stride=patchSize,padding=0)
|
50 |
+
self.flattener = nn.Flatten(start_dim=2,end_dim=3)
|
51 |
+
self.classEmbedding = nn.Parameter(torch.randn(1,1,self.outChannels),requires_grad=True)
|
52 |
+
self.PositionalEmbedding = nn.Parameter(torch.randn(1,self.numPatches+1,self.outChannels), requires_grad=True)
|
53 |
+
|
54 |
+
def forward(self,x):
|
55 |
+
batchSize = x.shape[0]
|
56 |
+
imgRes = x.shape[-1]
|
57 |
+
if(imgRes % self.patchSize ==0):
|
58 |
+
pass
|
59 |
+
else:
|
60 |
+
assert imgRes % self.patchSize ==0, 'Input size must be div by patchSize'
|
61 |
+
x = self.patchMaker(x)
|
62 |
+
x = self.flattener(x)
|
63 |
+
x = x.permute(0,2,1)
|
64 |
+
classToken = self.classEmbedding.expand(batchSize,-1,-1)
|
65 |
+
x = torch.cat((classToken,x),dim=1)
|
66 |
+
x = x + self.PositionalEmbedding
|
67 |
+
# batchSize = x.shape[0]
|
68 |
+
# embeddingDim = x.shape[-1]
|
69 |
+
return x
|
70 |
+
|
predictor.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Imports for Modules ###
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
from typing import Tuple, Dict
|
6 |
+
from timeit import default_timer as timer
|
7 |
+
|
8 |
+
### Functional Imports
|
9 |
+
from model import getViT
|
10 |
+
|
11 |
+
classNames = ['Apple Pie', 'Baby Back Ribs', 'Baklava', 'Beef Carpaccio', 'Beef Tartare', 'Beet Salad', 'Beignets', 'Bibimbap', 'Bread Pudding', 'Breakfast Burrito', 'Bruschetta', 'Caesar Salad', 'Cannoli', 'Caprese Salad', 'Carrot Cake', 'Ceviche', 'Cheese Plate', 'Cheesecake', 'Chicken Curry', 'Chicken Quesadilla', 'Chicken Wings', 'Chocolate Cake', 'Chocolate Mousse', 'Churros', 'Clam Chowder', 'Club Sandwich', 'Crab Cakes', 'Creme Brulee', 'Croque Madame', 'Cup Cakes', 'Deviled Eggs', 'Donuts', 'Dumplings', 'Edamame', 'Eggs Benedict', 'Escargots', 'Falafel', 'Filet Mignon', 'Fish And Chips', 'Foie Gras', 'French Fries', 'French Onion Soup', 'French Toast', 'Fried Calamari', 'Fried Rice', 'Frozen Yogurt', 'Garlic Bread', 'Gnocchi', 'Greek Salad', 'Grilled Cheese Sandwich', 'Grilled Salmon', 'Guacamole', 'Gyoza', 'Hamburger', 'Hot And Sour Soup', 'Hot Dog', 'Huevos Rancheros', 'Hummus', 'Ice Cream', 'Lasagna', 'Lobster Bisque', 'Lobster Roll Sandwich', 'Macaroni And Cheese', 'Macarons', 'Miso Soup', 'Mussels', 'Nachos', 'Omelette', 'Onion Rings', 'Oysters', 'Pad Thai', 'Paella', 'Pancakes', 'Panna Cotta', 'Peking Duck', 'Pho', 'Pizza', 'Pork Chop', 'Poutine', 'Prime Rib', 'Pulled Pork Sandwich', 'Ramen', 'Ravioli', 'Red Velvet Cake', 'Risotto', 'Samosa', 'Sashimi', 'Scallops', 'Seaweed Salad', 'Shrimp And Grits', 'Spaghetti Bolognese', 'Spaghetti Carbonara', 'Spring Rolls', 'Steak', 'Strawberry Shortcake', 'Sushi', 'Tacos', 'Takoyaki', 'Tiramisu', 'Tuna Tartare', 'Waffles']
|
12 |
+
ViTModel, VitTransforms = getViT(42,len(classNames))
|
13 |
+
ViTModel.load_state_dict(torch.load(f="ViTModel.pt",map_location=torch.device("cpu")))
|
14 |
+
|
15 |
+
def predictionMaker(img):
|
16 |
+
startTime = timer()
|
17 |
+
img = VitTransforms(img).unsqueeze(0)
|
18 |
+
ViTModel.eval()
|
19 |
+
with torch.inference_mode():
|
20 |
+
predProbs = torch.softmax(ViTModel(img),dim=1)
|
21 |
+
predDict = {classNames[i]: float(predProbs[0][i]) for i in range(len(classNames))}
|
22 |
+
endTime = timer()
|
23 |
+
predTime = round(endTime-startTime,4)
|
24 |
+
return predDict,predTime
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.2.0
|
2 |
+
torchvision==0.17.0
|
3 |
+
gradio==4.20.0
|