Ajibola commited on
Commit
5d27bb5
·
1 Parent(s): 1f8160d

Upload PaViTs.py

Browse files
Files changed (1) hide show
  1. Model/PaViTs.py +121 -0
Model/PaViTs.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[2]:
5
+
6
+
7
+ from tensorflow.keras.layers import *
8
+ import pandas as pd
9
+ import numpy as np
10
+ from tensorflow.keras.models import *
11
+ from keras.optimizers import Adam
12
+ import cv2
13
+ import tensorflow as tf
14
+ from keras.callbacks import *
15
+ from tensorflow.keras.utils import to_categorical
16
+ from keras.preprocessing.text import Tokenizer
17
+ from sklearn.preprocessing import LabelBinarizer, MultiLabelBinarizer
18
+ import matplotlib.pyplot as plt
19
+ from keras.activations import swish
20
+ from keras.preprocessing.image import *
21
+ from tensorflow.image import extract_patches
22
+
23
+
24
+ # In[3]:
25
+
26
+
27
+ num_patches=224//3
28
+
29
+
30
+ # In[3]:
31
+
32
+
33
+ class patches(Layer):
34
+ def __init__(self,patch_size ):
35
+ self.patch_size=patch_size
36
+ def __call__(self, x):
37
+ assert x.shape[1]%self.patch_size==0, 'Patch_size should be divisible'
38
+ if len(list(tf.shape(x)))==2:
39
+ x=tf.expand_dims(x, axis=-1)
40
+ if len(list(tf.shape(x)))==3:
41
+ x=tf.expand_dims(x, axis=0)
42
+ patch=extract_patches(images=x,strides=[1, self.patch_size, self.patch_size, 1] ,sizes=[1, self.patch_size, self.patch_size, 1],rates=[1, 1, 1,1], padding='VALID')
43
+ return patch
44
+
45
+
46
+
47
+
48
+ # In[9]:
49
+
50
+
51
+ def encoder(x, dim=32,pos:bool=True):
52
+ lin_proj=Dense(dim, activation='relu')
53
+ if pos:
54
+ pos_Emb=Embedding(x.shape[1], dim)
55
+ position=tf.range(0, x.shape[1])
56
+ return lin_proj(x)+pos_Emb(position)
57
+ else:
58
+ return lin_proj(x)
59
+
60
+ def Mlp(x, n:int=8, dim=32):
61
+ x=GlobalAveragePooling1D()(x)
62
+ for i in range(n): #7
63
+ x=Dense(dim, activation='relu')(x)
64
+ x=Dense(dim, activation='relu')(x)
65
+
66
+ return x
67
+
68
+
69
+ # In[79]:
70
+
71
+
72
+ class PaViT:
73
+ def __init__(self, shape=(224, 224, 3),num_heads=12, patch_size=32, dim=126, pos_emb:bool =False,
74
+ mlp_it=8, attn_drop:int= .3, dropout:bool=True):
75
+ self.dropout=dropout
76
+ self.shape=shape
77
+ self.num_heads=num_heads
78
+ self.patch_size=patch_size
79
+ self.dim=dim
80
+ self.attn_drop=attn_drop
81
+ self.pos_emb=pos_emb
82
+ self.mlp_it=mlp_it
83
+
84
+ def model(self, output_class=None, output=15, activation='softmax'):
85
+ inp=Input(shape=self.shape, name='Input')
86
+ patch=patches(patch_size=self.patch_size)(inp)
87
+ reshape=Reshape((-1, patch.shape[-1]))(patch)
88
+ encode=encoder(reshape, dim=self.dim, pos=True)
89
+ x=BatchNormalization()(encode)
90
+ drop=None
91
+ if self.attn_drop:
92
+ drop=self.attn_drop
93
+ attn=MultiHeadAttention(num_heads=self.num_heads, key_dim=self.dim, dropout=drop)(x,x) #12
94
+ mlp=Mlp(x,n=self.mlp_it, dim=self.dim)
95
+ add=Add()([mlp, attn])
96
+ norm=BatchNormalization()(add)
97
+ if self.dropout:
98
+ norm=Dropout(.3)(norm)
99
+
100
+ flat=Flatten()(norm)
101
+ if not output_class:
102
+ out=Dense(output, activation=activation)(flat)
103
+ else:
104
+ out=output_class(flat)
105
+
106
+
107
+ self.without_head=Model(inp, norm)
108
+ return Model(inp, out)
109
+
110
+
111
+ def remove_head(self):
112
+ try:
113
+ return self.without_head
114
+ except:
115
+ print('Cant load model without last layer. \nInitialize model first')
116
+
117
+ model=PaViT()
118
+ mox=model.model()
119
+ mox.summary()
120
+ #mox.load_weights('C:\\Users\\Emmanuel\\Downloads\\PAVIT_weights.h5')
121
+