Chaitanya Garg commited on
Commit
9bb0389
·
1 Parent(s): 87886d5
ViT.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from MyModel.partsViT import patchNPositionalEmbeddingMaker,transformerEncoderBlock
4
+
5
+ class ViT(nn.Module):
6
+ def __init__(self,inChannels,outChannels,patchSize,imgSize, hiddenLayer,numHeads,MLPdropOut,numTransformLayers,numClasses,embeddingDropOut=0.1,attnDropOut=0):
7
+ super().__init__()
8
+ self.EmbeddingMaker = patchNPositionalEmbeddingMaker(inChannels,outChannels,patchSize,imgSize)
9
+ # self.transformerEncodingBlock = transformerEncoderBlock(outChannels,hiddenLayer,numHeads,MLPdropOut,attnDropOut)
10
+ self.embeddingDrop = nn.Dropout(embeddingDropOut)
11
+ self.TransformEncoder = nn.Sequential(*[transformerEncoderBlock(outChannels,hiddenLayer,numHeads,MLPdropOut,attnDropOut) for _ in range(numTransformLayers)])
12
+ self.Classifier = nn.Sequential(nn.LayerNorm(normalized_shape=outChannels),
13
+ nn.Linear(outChannels,numClasses))
14
+ def forward(self,x):
15
+ x = self.EmbeddingMaker(x)
16
+ x = self.embeddingDrop(x)
17
+ x = self.TransformEncoder(x)
18
+ x = self.Classifier(x[:,0])
19
+ return x
ViTModel.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf9a58f8286e2d46f20877ab7bea7b38be359b1aa175273df92ac9150e8257d7
3
+ size 343564994
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Imports for Modules ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+ from typing import Tuple, Dict
6
+ from timeit import default_timer as timer
7
+
8
+ ### Functional Imports
9
+ from predictor import predictionMaker
10
+
11
+ exampleList = [["examples/" + example] for example in os.listdir("examples")]
12
+
13
+ title = "Food Vision👀 on Food101 Using ViT"
14
+ description = "Trained a ViT to classify images of food based on [Food101](https://pytorch.org/vision/main/generated/torchvision.datasets.Food101.html)."
15
+ article = "Created by [Eternal Bliassard](https://github.com/EternalBlissard)."
16
+
17
+ # Create the Gradio demo
18
+ demo = gr.Interface(fn=predictionMaker,
19
+ inputs=[gr.Image(type="pil")],
20
+ outputs=[gr.Label(num_top_classes=3, label="Predictions"),
21
+ gr.Number(label="Prediction time (s)")],
22
+ examples=exampleList,
23
+ title=title,
24
+ description=description,
25
+ article=article)
26
+
27
+ # Launch the demo!
28
+ demo.launch()
29
+
30
+
31
+
32
+
examples/example1.jpg ADDED
examples/example2.jpg ADDED
examples/example3.jpg ADDED
examples/example4.jpg ADDED
examples/example5.jpg ADDED
helper.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ import random
7
+ import zipfile
8
+ from pathlib import Path
9
+ import requests
10
+
11
+ def setAllSeeds(seed):
12
+ os.environ['MY_GLOBAL_SEED'] = str(seed)
13
+ random.seed(seed)
14
+ np.random.seed(seed)
15
+ torch.manual_seed(seed)
16
+ torch.cuda.manual_seed_all(seed)
17
+
18
+
model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torchvision
4
+ from torch import nn
5
+ from helper import setAllSeeds
6
+ from ViT import ViT
7
+
8
+ def getViT(seed,classNames,DEVICE):
9
+ setAllSeeds(seed)
10
+ ViTModel = ViT(3,768,16,224,3072,12,0.1,12,len(classNames)).to(DEVICE)
11
+ vitWeights = torchvision.models.ViT_B_16_Weights.DEFAULT
12
+ vitTransforms = vitWeights.transforms()
13
+ vit = torchvision.models.vit_b_16(weights=vitWeights).to(DEVICE)
14
+ for param in vit.parameters():
15
+ param.requires_grad = False
16
+ vit.heads = nn.Linear(in_features=768, out_features=len(classNames)).to(DEVICE)
17
+ return vit,vitTransforms
partViT.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+
4
+ class multiHeadSelfAttentionBlock(nn.Module):
5
+ def __init__(self,embeddingDim=768,numHeads=12,attnDropOut=0):
6
+ super().__init__()
7
+ self.layerNorm = nn.LayerNorm(normalized_shape=embeddingDim)
8
+ self.multiheadAttn = nn.MultiheadAttention(embed_dim=embeddingDim,num_heads=numHeads,dropout=attnDropOut,batch_first=True)
9
+
10
+ def forward(self,x):
11
+ layNorm = self.layerNorm(x)
12
+ attnOutPut, _ = self.multiheadAttn(query=layNorm,key=layNorm,value=layNorm)
13
+ return attnOutPut
14
+
15
+ class MLPBlock(nn.Module):
16
+ def __init__(self,embeddingDim,hiddenLayer,dropOut=0.1):
17
+ super().__init__()
18
+ self.MLP = nn.Sequential(
19
+ nn.LayerNorm(normalized_shape = embeddingDim),
20
+ nn.Linear(embeddingDim, hiddenLayer),
21
+ nn.GELU(),
22
+ nn.Dropout(dropOut),
23
+ nn.Linear(hiddenLayer,embeddingDim),
24
+ nn.Dropout(dropOut)
25
+ )
26
+
27
+ def forward(self,x):
28
+ return self.MLP(x)
29
+
30
+ class transformerEncoderBlock(nn.Module):
31
+ def __init__(self, embeddingDim, hiddenLayer,numHeads,MLPdropOut,attnDropOut=0):
32
+ super().__init__()
33
+ self.MSABlock = multiHeadSelfAttentionBlock(embeddingDim,numHeads,attnDropOut)
34
+ self.MLPBlock = MLPBlock(embeddingDim,hiddenLayer,MLPdropOut)
35
+
36
+ def forward(self,x):
37
+ x = self.MSABlock(x) + x
38
+ x = self.MLPBlock(x) + x
39
+ return x
40
+
41
+ class patchNPositionalEmbeddingMaker(nn.Module):
42
+ def __init__(self,inChannels,outChannels,patchSize,imgSize):
43
+ super().__init__()
44
+ self.outChannels = outChannels
45
+
46
+ # outChannels is the same as embeddingDim
47
+ self.patchSize = patchSize
48
+ self.numPatches = int(imgSize**2/patchSize**2)
49
+ self.patchMaker = nn.Conv2d(inChannels,outChannels, kernel_size=patchSize,stride=patchSize,padding=0)
50
+ self.flattener = nn.Flatten(start_dim=2,end_dim=3)
51
+ self.classEmbedding = nn.Parameter(torch.randn(1,1,self.outChannels),requires_grad=True)
52
+ self.PositionalEmbedding = nn.Parameter(torch.randn(1,self.numPatches+1,self.outChannels), requires_grad=True)
53
+
54
+ def forward(self,x):
55
+ batchSize = x.shape[0]
56
+ imgRes = x.shape[-1]
57
+ if(imgRes % self.patchSize ==0):
58
+ pass
59
+ else:
60
+ assert imgRes % self.patchSize ==0, 'Input size must be div by patchSize'
61
+ x = self.patchMaker(x)
62
+ x = self.flattener(x)
63
+ x = x.permute(0,2,1)
64
+ classToken = self.classEmbedding.expand(batchSize,-1,-1)
65
+ x = torch.cat((classToken,x),dim=1)
66
+ x = x + self.PositionalEmbedding
67
+ # batchSize = x.shape[0]
68
+ # embeddingDim = x.shape[-1]
69
+ return x
70
+
predictor.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Imports for Modules ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+ from typing import Tuple, Dict
6
+ from timeit import default_timer as timer
7
+
8
+ ### Functional Imports
9
+ from model import getViT
10
+
11
+ classNames = ['Apple Pie', 'Baby Back Ribs', 'Baklava', 'Beef Carpaccio', 'Beef Tartare', 'Beet Salad', 'Beignets', 'Bibimbap', 'Bread Pudding', 'Breakfast Burrito', 'Bruschetta', 'Caesar Salad', 'Cannoli', 'Caprese Salad', 'Carrot Cake', 'Ceviche', 'Cheese Plate', 'Cheesecake', 'Chicken Curry', 'Chicken Quesadilla', 'Chicken Wings', 'Chocolate Cake', 'Chocolate Mousse', 'Churros', 'Clam Chowder', 'Club Sandwich', 'Crab Cakes', 'Creme Brulee', 'Croque Madame', 'Cup Cakes', 'Deviled Eggs', 'Donuts', 'Dumplings', 'Edamame', 'Eggs Benedict', 'Escargots', 'Falafel', 'Filet Mignon', 'Fish And Chips', 'Foie Gras', 'French Fries', 'French Onion Soup', 'French Toast', 'Fried Calamari', 'Fried Rice', 'Frozen Yogurt', 'Garlic Bread', 'Gnocchi', 'Greek Salad', 'Grilled Cheese Sandwich', 'Grilled Salmon', 'Guacamole', 'Gyoza', 'Hamburger', 'Hot And Sour Soup', 'Hot Dog', 'Huevos Rancheros', 'Hummus', 'Ice Cream', 'Lasagna', 'Lobster Bisque', 'Lobster Roll Sandwich', 'Macaroni And Cheese', 'Macarons', 'Miso Soup', 'Mussels', 'Nachos', 'Omelette', 'Onion Rings', 'Oysters', 'Pad Thai', 'Paella', 'Pancakes', 'Panna Cotta', 'Peking Duck', 'Pho', 'Pizza', 'Pork Chop', 'Poutine', 'Prime Rib', 'Pulled Pork Sandwich', 'Ramen', 'Ravioli', 'Red Velvet Cake', 'Risotto', 'Samosa', 'Sashimi', 'Scallops', 'Seaweed Salad', 'Shrimp And Grits', 'Spaghetti Bolognese', 'Spaghetti Carbonara', 'Spring Rolls', 'Steak', 'Strawberry Shortcake', 'Sushi', 'Tacos', 'Takoyaki', 'Tiramisu', 'Tuna Tartare', 'Waffles']
12
+ ViTModel, VitTransforms = getViT(42,len(classNames))
13
+ ViTModel.load_state_dict(torch.load(f="ViTModel.pt",map_location=torch.device("cpu")))
14
+
15
+ def predictionMaker(img):
16
+ startTime = timer()
17
+ img = VitTransforms(img).unsqueeze(0)
18
+ ViTModel.eval()
19
+ with torch.inference_mode():
20
+ predProbs = torch.softmax(ViTModel(img),dim=1)
21
+ predDict = {classNames[i]: float(predProbs[0][i]) for i in range(len(classNames))}
22
+ endTime = timer()
23
+ predTime = round(endTime-startTime,4)
24
+ return predDict,predTime
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==2.2.0
2
+ torchvision==0.17.0
3
+ gradio==4.20.0