File size: 3,353 Bytes
5d27bb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python
# coding: utf-8

# In[2]:


from tensorflow.keras.layers import *
import pandas as pd
import numpy as np
from tensorflow.keras.models import *
from keras.optimizers import Adam
import cv2
import tensorflow as tf
from keras.callbacks import *
from tensorflow.keras.utils import to_categorical
from keras.preprocessing.text import Tokenizer
from sklearn.preprocessing import LabelBinarizer, MultiLabelBinarizer
import matplotlib.pyplot as plt
from keras.activations import swish
from keras.preprocessing.image import *
from tensorflow.image import extract_patches


# In[3]:


num_patches=224//3


# In[3]:


class patches(Layer):
    def __init__(self,patch_size ):
        self.patch_size=patch_size
    def __call__(self, x):
        assert x.shape[1]%self.patch_size==0, 'Patch_size should be divisible'
        if len(list(tf.shape(x)))==2:
            x=tf.expand_dims(x, axis=-1)
        if len(list(tf.shape(x)))==3:
            x=tf.expand_dims(x, axis=0)
        patch=extract_patches(images=x,strides=[1, self.patch_size, self.patch_size, 1] ,sizes=[1, self.patch_size, self.patch_size, 1],rates=[1, 1, 1,1], padding='VALID')
        return patch
    



# In[9]:


def encoder(x, dim=32,pos:bool=True):
    lin_proj=Dense(dim, activation='relu')
    if pos:
        pos_Emb=Embedding(x.shape[1], dim)
        position=tf.range(0, x.shape[1])
        return lin_proj(x)+pos_Emb(position)
    else:
        return lin_proj(x)
    
def Mlp(x, n:int=8, dim=32):
    x=GlobalAveragePooling1D()(x)
    for i in range(n): #7
        x=Dense(dim, activation='relu')(x)
        x=Dense(dim, activation='relu')(x)   
    
    return x


# In[79]:


class PaViT:
    def __init__(self, shape=(224, 224, 3),num_heads=12, patch_size=32, dim=126, pos_emb:bool =False, 
                 mlp_it=8, attn_drop:int= .3, dropout:bool=True):
        self.dropout=dropout
        self.shape=shape
        self.num_heads=num_heads
        self.patch_size=patch_size
        self.dim=dim
        self.attn_drop=attn_drop
        self.pos_emb=pos_emb
        self.mlp_it=mlp_it
        
    def model(self, output_class=None, output=15, activation='softmax'):
        inp=Input(shape=self.shape, name='Input')
        patch=patches(patch_size=self.patch_size)(inp)
        reshape=Reshape((-1, patch.shape[-1]))(patch)
        encode=encoder(reshape, dim=self.dim, pos=True)
        x=BatchNormalization()(encode)
        drop=None
        if self.attn_drop:
            drop=self.attn_drop
        attn=MultiHeadAttention(num_heads=self.num_heads, key_dim=self.dim, dropout=drop)(x,x)   #12
        mlp=Mlp(x,n=self.mlp_it, dim=self.dim)
        add=Add()([mlp, attn])
        norm=BatchNormalization()(add)
        if self.dropout:
            norm=Dropout(.3)(norm)
        
        flat=Flatten()(norm)
        if not output_class:
            out=Dense(output, activation=activation)(flat)
        else: 
            out=output_class(flat)
        
        
        self.without_head=Model(inp, norm)
        return Model(inp, out)
    
    
    def remove_head(self):
        try:
            return self.without_head 
        except: 
            print('Cant load model without last layer. \nInitialize model first')
    
model=PaViT()
mox=model.model()
mox.summary()
#mox.load_weights('C:\\Users\\Emmanuel\\Downloads\\PAVIT_weights.h5')