new models
Browse files- gcvit/__init__.py +1 -1
- gcvit/models/__init__.py +1 -1
- gcvit/models/gcvit.py +76 -23
gcvit/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
from .models import GCViT, GCViTTiny, GCViTSmall, GCViTBase
|
2 |
from .version import __version__
|
|
|
1 |
+
from .models import GCViT, GCViTXXTiny, GCViTXTiny, GCViTTiny, GCViTSmall, GCViTBase, GCViTLarge
|
2 |
from .version import __version__
|
gcvit/models/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
from .gcvit import GCViT, GCViTTiny, GCViTSmall, GCViTBase
|
|
|
1 |
+
from .gcvit import GCViT, GCViTXXTiny, GCViTXTiny, GCViTTiny, GCViTSmall, GCViTBase, GCViTLarge
|
gcvit/models/gcvit.py
CHANGED
@@ -1,16 +1,29 @@
|
|
1 |
import numpy as np
|
2 |
import tensorflow as tf
|
3 |
|
4 |
-
from ..layers import
|
5 |
-
|
6 |
|
|
|
7 |
BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
|
8 |
-
TAG = 'v1.
|
9 |
NAME2CONFIG = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
'gcvit_tiny': {'window_size': (7, 7, 14, 7),
|
11 |
'dim': 64,
|
12 |
'depths': (3, 4, 19, 5),
|
13 |
-
'num_heads': (2, 4, 8, 16),
|
|
|
14 |
'path_drop': 0.2,},
|
15 |
'gcvit_small': {'window_size': (7, 7, 14, 7),
|
16 |
'dim': 96,
|
@@ -26,6 +39,13 @@ NAME2CONFIG = {
|
|
26 |
'mlp_ratio': 2.,
|
27 |
'path_drop': 0.5,
|
28 |
'layer_scale': 1e-5,},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
}
|
30 |
|
31 |
@tf.keras.utils.register_keras_serializable(package='gcvit')
|
@@ -50,14 +70,14 @@ class GCViT(tf.keras.Model):
|
|
50 |
self.num_classes = num_classes
|
51 |
self.head_act = head_act
|
52 |
|
53 |
-
self.patch_embed =
|
54 |
self.pos_drop = tf.keras.layers.Dropout(drop_rate, name='pos_drop')
|
55 |
path_drops = np.linspace(0., path_drop, sum(depths))
|
56 |
keep_dims = [(False, False, False),(False, False),(True,),(True,),]
|
57 |
self.levels = []
|
58 |
for i in range(len(depths)):
|
59 |
path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
|
60 |
-
level =
|
61 |
downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
62 |
drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
|
63 |
name=f'levels/{i}')
|
@@ -71,16 +91,14 @@ class GCViT(tf.keras.Model):
|
|
71 |
self.pool = Identity(name='pool')
|
72 |
else:
|
73 |
raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
|
74 |
-
self.head =
|
75 |
-
tf.keras.layers.Activation(head_act, name='head/act')]
|
76 |
|
77 |
-
def reset_classifier(self, num_classes, head_act, global_pool=None):
|
78 |
self.num_classes = num_classes
|
79 |
if global_pool is not None:
|
80 |
self.global_pool = global_pool
|
81 |
-
self.head
|
82 |
-
|
83 |
-
super().build((1, 224, 224, 3))
|
84 |
|
85 |
def forward_features(self, inputs):
|
86 |
x = self.patch_embed(inputs)
|
@@ -96,8 +114,7 @@ class GCViT(tf.keras.Model):
|
|
96 |
if self.global_pool in ['avg', 'max']:
|
97 |
x = self.pool(x)
|
98 |
if not pre_logits:
|
99 |
-
|
100 |
-
x = layer(x)
|
101 |
return x
|
102 |
|
103 |
def call(self, inputs, **kwargs):
|
@@ -110,35 +127,71 @@ class GCViT(tf.keras.Model):
|
|
110 |
x = tf.keras.Input(shape=input_shape)
|
111 |
return tf.keras.Model(inputs=[x], outputs=self.call(x), name=self.name)
|
112 |
|
|
|
|
|
|
|
113 |
# load standard models
|
114 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
name = 'gcvit_tiny'
|
116 |
config = NAME2CONFIG[name]
|
117 |
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
118 |
-
model = GCViT(name=name, **config,
|
119 |
-
model(tf.random.uniform(shape=
|
120 |
if pretrain:
|
121 |
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
122 |
model.load_weights(ckpt_path)
|
123 |
return model
|
124 |
|
125 |
-
def GCViTSmall(pretrain=False, **kwargs):
|
126 |
name = 'gcvit_small'
|
127 |
config = NAME2CONFIG[name]
|
128 |
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
129 |
-
model = GCViT(name=name, **config, **kwargs)
|
130 |
-
model(tf.random.uniform(shape=
|
131 |
if pretrain:
|
132 |
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
133 |
model.load_weights(ckpt_path)
|
134 |
return model
|
135 |
|
136 |
-
def GCViTBase(pretrain=False, **kwargs):
|
137 |
name = 'gcvit_base'
|
138 |
config = NAME2CONFIG[name]
|
139 |
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
140 |
-
model = GCViT(name=name, **config, **kwargs)
|
141 |
-
model(tf.random.uniform(shape=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
if pretrain:
|
143 |
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
144 |
model.load_weights(ckpt_path)
|
|
|
1 |
import numpy as np
|
2 |
import tensorflow as tf
|
3 |
|
4 |
+
from ..layers import Stem, GCViTLevel, Identity
|
|
|
5 |
|
6 |
+
|
7 |
BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
|
8 |
+
TAG = 'v1.1.1'
|
9 |
NAME2CONFIG = {
|
10 |
+
'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
|
11 |
+
'dim': 64,
|
12 |
+
'depths': (2, 2, 6, 2),
|
13 |
+
'num_heads': (2, 4, 8, 16),
|
14 |
+
'mlp_ratio': 3.,
|
15 |
+
'path_drop': 0.2},
|
16 |
+
'gcvit_xtiny': {'window_size': (7, 7, 14, 7),
|
17 |
+
'dim': 64,
|
18 |
+
'depths': (3, 4, 6, 5),
|
19 |
+
'num_heads': (2, 4, 8, 16),
|
20 |
+
'mlp_ratio': 3.,
|
21 |
+
'path_drop': 0.2},
|
22 |
'gcvit_tiny': {'window_size': (7, 7, 14, 7),
|
23 |
'dim': 64,
|
24 |
'depths': (3, 4, 19, 5),
|
25 |
+
'num_heads': (2, 4, 8, 16),
|
26 |
+
'mlp_ratio': 3.,
|
27 |
'path_drop': 0.2,},
|
28 |
'gcvit_small': {'window_size': (7, 7, 14, 7),
|
29 |
'dim': 96,
|
|
|
39 |
'mlp_ratio': 2.,
|
40 |
'path_drop': 0.5,
|
41 |
'layer_scale': 1e-5,},
|
42 |
+
'gcvit_large': {'window_size': (7, 7, 14, 7),
|
43 |
+
'dim':192,
|
44 |
+
'depths': (3, 4, 19, 5),
|
45 |
+
'num_heads': (6, 12, 24, 48),
|
46 |
+
'mlp_ratio': 2.,
|
47 |
+
'path_drop': 0.5,
|
48 |
+
'layer_scale': 1e-5,},
|
49 |
}
|
50 |
|
51 |
@tf.keras.utils.register_keras_serializable(package='gcvit')
|
|
|
70 |
self.num_classes = num_classes
|
71 |
self.head_act = head_act
|
72 |
|
73 |
+
self.patch_embed = Stem(dim=dim, name='patch_embed')
|
74 |
self.pos_drop = tf.keras.layers.Dropout(drop_rate, name='pos_drop')
|
75 |
path_drops = np.linspace(0., path_drop, sum(depths))
|
76 |
keep_dims = [(False, False, False),(False, False),(True,),(True,),]
|
77 |
self.levels = []
|
78 |
for i in range(len(depths)):
|
79 |
path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
|
80 |
+
level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
|
81 |
downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
82 |
drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
|
83 |
name=f'levels/{i}')
|
|
|
91 |
self.pool = Identity(name='pool')
|
92 |
else:
|
93 |
raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
|
94 |
+
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
|
|
|
95 |
|
96 |
+
def reset_classifier(self, num_classes, head_act, global_pool=None, in_channels=3):
|
97 |
self.num_classes = num_classes
|
98 |
if global_pool is not None:
|
99 |
self.global_pool = global_pool
|
100 |
+
self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
|
101 |
+
super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
|
|
|
102 |
|
103 |
def forward_features(self, inputs):
|
104 |
x = self.patch_embed(inputs)
|
|
|
114 |
if self.global_pool in ['avg', 'max']:
|
115 |
x = self.pool(x)
|
116 |
if not pre_logits:
|
117 |
+
x = self.head(x)
|
|
|
118 |
return x
|
119 |
|
120 |
def call(self, inputs, **kwargs):
|
|
|
127 |
x = tf.keras.Input(shape=input_shape)
|
128 |
return tf.keras.Model(inputs=[x], outputs=self.call(x), name=self.name)
|
129 |
|
130 |
+
def summary(self, input_shape=(224, 224, 3)):
|
131 |
+
return self.build_graph(input_shape).summary()
|
132 |
+
|
133 |
# load standard models
|
134 |
+
def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
135 |
+
name = 'gcvit_xxtiny'
|
136 |
+
config = NAME2CONFIG[name]
|
137 |
+
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
138 |
+
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
|
139 |
+
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
|
140 |
+
if pretrain:
|
141 |
+
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
142 |
+
model.load_weights(ckpt_path)
|
143 |
+
return model
|
144 |
+
|
145 |
+
def GCViTXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
146 |
+
name = 'gcvit_xtiny'
|
147 |
+
config = NAME2CONFIG[name]
|
148 |
+
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
149 |
+
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
|
150 |
+
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
|
151 |
+
if pretrain:
|
152 |
+
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
153 |
+
model.load_weights(ckpt_path)
|
154 |
+
return model
|
155 |
+
|
156 |
+
def GCViTTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
157 |
name = 'gcvit_tiny'
|
158 |
config = NAME2CONFIG[name]
|
159 |
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
160 |
+
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
|
161 |
+
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
|
162 |
if pretrain:
|
163 |
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
164 |
model.load_weights(ckpt_path)
|
165 |
return model
|
166 |
|
167 |
+
def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
168 |
name = 'gcvit_small'
|
169 |
config = NAME2CONFIG[name]
|
170 |
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
171 |
+
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
|
172 |
+
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
|
173 |
if pretrain:
|
174 |
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
175 |
model.load_weights(ckpt_path)
|
176 |
return model
|
177 |
|
178 |
+
def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
179 |
name = 'gcvit_base'
|
180 |
config = NAME2CONFIG[name]
|
181 |
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
182 |
+
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
|
183 |
+
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
|
184 |
+
if pretrain:
|
185 |
+
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
186 |
+
model.load_weights(ckpt_path)
|
187 |
+
return model
|
188 |
+
|
189 |
+
def GCViTLarge(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
|
190 |
+
name = 'gcvit_large'
|
191 |
+
config = NAME2CONFIG[name]
|
192 |
+
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
193 |
+
model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
|
194 |
+
model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
|
195 |
if pretrain:
|
196 |
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
197 |
model.load_weights(ckpt_path)
|