awsaf49 commited on
Commit
8778bc5
·
1 Parent(s): 094461a

new models

Browse files
gcvit/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from .models import GCViT, GCViTTiny, GCViTSmall, GCViTBase
2
  from .version import __version__
 
1
+ from .models import GCViT, GCViTXXTiny, GCViTXTiny, GCViTTiny, GCViTSmall, GCViTBase, GCViTLarge
2
  from .version import __version__
gcvit/models/__init__.py CHANGED
@@ -1 +1 @@
1
- from .gcvit import GCViT, GCViTTiny, GCViTSmall, GCViTBase
 
1
+ from .gcvit import GCViT, GCViTXXTiny, GCViTXTiny, GCViTTiny, GCViTSmall, GCViTBase, GCViTLarge
gcvit/models/gcvit.py CHANGED
@@ -1,16 +1,29 @@
1
  import numpy as np
2
  import tensorflow as tf
3
 
4
- from ..layers import PatchEmbed, GCViTLayer, Identity
5
-
6
 
 
7
  BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
8
- TAG = 'v1.0.4'
9
  NAME2CONFIG = {
 
 
 
 
 
 
 
 
 
 
 
 
10
  'gcvit_tiny': {'window_size': (7, 7, 14, 7),
11
  'dim': 64,
12
  'depths': (3, 4, 19, 5),
13
- 'num_heads': (2, 4, 8, 16),
 
14
  'path_drop': 0.2,},
15
  'gcvit_small': {'window_size': (7, 7, 14, 7),
16
  'dim': 96,
@@ -26,6 +39,13 @@ NAME2CONFIG = {
26
  'mlp_ratio': 2.,
27
  'path_drop': 0.5,
28
  'layer_scale': 1e-5,},
 
 
 
 
 
 
 
29
  }
30
 
31
  @tf.keras.utils.register_keras_serializable(package='gcvit')
@@ -50,14 +70,14 @@ class GCViT(tf.keras.Model):
50
  self.num_classes = num_classes
51
  self.head_act = head_act
52
 
53
- self.patch_embed = PatchEmbed(dim=dim, name='patch_embed')
54
  self.pos_drop = tf.keras.layers.Dropout(drop_rate, name='pos_drop')
55
  path_drops = np.linspace(0., path_drop, sum(depths))
56
  keep_dims = [(False, False, False),(False, False),(True,),(True,),]
57
  self.levels = []
58
  for i in range(len(depths)):
59
  path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
60
- level = GCViTLayer(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
61
  downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
62
  drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
63
  name=f'levels/{i}')
@@ -71,16 +91,14 @@ class GCViT(tf.keras.Model):
71
  self.pool = Identity(name='pool')
72
  else:
73
  raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
74
- self.head = [tf.keras.layers.Dense(num_classes, name='head/fc'),
75
- tf.keras.layers.Activation(head_act, name='head/act')]
76
 
77
- def reset_classifier(self, num_classes, head_act, global_pool=None):
78
  self.num_classes = num_classes
79
  if global_pool is not None:
80
  self.global_pool = global_pool
81
- self.head[0] = tf.keras.layers.Dense(num_classes, name='head/fc') if num_classes else Identity(name='head/fc')
82
- self.head[1] = tf.keras.layers.Activation(head_act, name='head/act') if head_act else Identity(name='head/act')
83
- super().build((1, 224, 224, 3))
84
 
85
  def forward_features(self, inputs):
86
  x = self.patch_embed(inputs)
@@ -96,8 +114,7 @@ class GCViT(tf.keras.Model):
96
  if self.global_pool in ['avg', 'max']:
97
  x = self.pool(x)
98
  if not pre_logits:
99
- for layer in self.head:
100
- x = layer(x)
101
  return x
102
 
103
  def call(self, inputs, **kwargs):
@@ -110,35 +127,71 @@ class GCViT(tf.keras.Model):
110
  x = tf.keras.Input(shape=input_shape)
111
  return tf.keras.Model(inputs=[x], outputs=self.call(x), name=self.name)
112
 
 
 
 
113
  # load standard models
114
- def GCViTTiny(pretrain=False, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  name = 'gcvit_tiny'
116
  config = NAME2CONFIG[name]
117
  ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
118
- model = GCViT(name=name, **config, **kwargs)
119
- model(tf.random.uniform(shape=(1, 224, 224, 3)))
120
  if pretrain:
121
  ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
122
  model.load_weights(ckpt_path)
123
  return model
124
 
125
- def GCViTSmall(pretrain=False, **kwargs):
126
  name = 'gcvit_small'
127
  config = NAME2CONFIG[name]
128
  ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
129
- model = GCViT(name=name, **config, **kwargs)
130
- model(tf.random.uniform(shape=(1, 224, 224, 3)))
131
  if pretrain:
132
  ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
133
  model.load_weights(ckpt_path)
134
  return model
135
 
136
- def GCViTBase(pretrain=False, **kwargs):
137
  name = 'gcvit_base'
138
  config = NAME2CONFIG[name]
139
  ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
140
- model = GCViT(name=name, **config, **kwargs)
141
- model(tf.random.uniform(shape=(1, 224, 224, 3)))
 
 
 
 
 
 
 
 
 
 
 
142
  if pretrain:
143
  ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
144
  model.load_weights(ckpt_path)
 
1
  import numpy as np
2
  import tensorflow as tf
3
 
4
+ from ..layers import Stem, GCViTLevel, Identity
 
5
 
6
+
7
  BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
8
+ TAG = 'v1.1.1'
9
  NAME2CONFIG = {
10
+ 'gcvit_xxtiny': {'window_size': (7, 7, 14, 7),
11
+ 'dim': 64,
12
+ 'depths': (2, 2, 6, 2),
13
+ 'num_heads': (2, 4, 8, 16),
14
+ 'mlp_ratio': 3.,
15
+ 'path_drop': 0.2},
16
+ 'gcvit_xtiny': {'window_size': (7, 7, 14, 7),
17
+ 'dim': 64,
18
+ 'depths': (3, 4, 6, 5),
19
+ 'num_heads': (2, 4, 8, 16),
20
+ 'mlp_ratio': 3.,
21
+ 'path_drop': 0.2},
22
  'gcvit_tiny': {'window_size': (7, 7, 14, 7),
23
  'dim': 64,
24
  'depths': (3, 4, 19, 5),
25
+ 'num_heads': (2, 4, 8, 16),
26
+ 'mlp_ratio': 3.,
27
  'path_drop': 0.2,},
28
  'gcvit_small': {'window_size': (7, 7, 14, 7),
29
  'dim': 96,
 
39
  'mlp_ratio': 2.,
40
  'path_drop': 0.5,
41
  'layer_scale': 1e-5,},
42
+ 'gcvit_large': {'window_size': (7, 7, 14, 7),
43
+ 'dim':192,
44
+ 'depths': (3, 4, 19, 5),
45
+ 'num_heads': (6, 12, 24, 48),
46
+ 'mlp_ratio': 2.,
47
+ 'path_drop': 0.5,
48
+ 'layer_scale': 1e-5,},
49
  }
50
 
51
  @tf.keras.utils.register_keras_serializable(package='gcvit')
 
70
  self.num_classes = num_classes
71
  self.head_act = head_act
72
 
73
+ self.patch_embed = Stem(dim=dim, name='patch_embed')
74
  self.pos_drop = tf.keras.layers.Dropout(drop_rate, name='pos_drop')
75
  path_drops = np.linspace(0., path_drop, sum(depths))
76
  keep_dims = [(False, False, False),(False, False),(True,),(True,),]
77
  self.levels = []
78
  for i in range(len(depths)):
79
  path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
80
+ level = GCViTLevel(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
81
  downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
82
  drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
83
  name=f'levels/{i}')
 
91
  self.pool = Identity(name='pool')
92
  else:
93
  raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
94
+ self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act)
 
95
 
96
+ def reset_classifier(self, num_classes, head_act, global_pool=None, in_channels=3):
97
  self.num_classes = num_classes
98
  if global_pool is not None:
99
  self.global_pool = global_pool
100
+ self.head = tf.keras.layers.Dense(num_classes, name='head', activation=head_act) if num_classes else Identity(name='head')
101
+ super().build((1, 224, 224, in_channels)) # for head we only need info from the input channel
 
102
 
103
  def forward_features(self, inputs):
104
  x = self.patch_embed(inputs)
 
114
  if self.global_pool in ['avg', 'max']:
115
  x = self.pool(x)
116
  if not pre_logits:
117
+ x = self.head(x)
 
118
  return x
119
 
120
  def call(self, inputs, **kwargs):
 
127
  x = tf.keras.Input(shape=input_shape)
128
  return tf.keras.Model(inputs=[x], outputs=self.call(x), name=self.name)
129
 
130
+ def summary(self, input_shape=(224, 224, 3)):
131
+ return self.build_graph(input_shape).summary()
132
+
133
  # load standard models
134
+ def GCViTXXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
135
+ name = 'gcvit_xxtiny'
136
+ config = NAME2CONFIG[name]
137
+ ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
138
+ model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
139
+ model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
140
+ if pretrain:
141
+ ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
142
+ model.load_weights(ckpt_path)
143
+ return model
144
+
145
+ def GCViTXTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
146
+ name = 'gcvit_xtiny'
147
+ config = NAME2CONFIG[name]
148
+ ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
149
+ model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
150
+ model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
151
+ if pretrain:
152
+ ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
153
+ model.load_weights(ckpt_path)
154
+ return model
155
+
156
+ def GCViTTiny(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
157
  name = 'gcvit_tiny'
158
  config = NAME2CONFIG[name]
159
  ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
160
+ model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
161
+ model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
162
  if pretrain:
163
  ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
164
  model.load_weights(ckpt_path)
165
  return model
166
 
167
+ def GCViTSmall(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
168
  name = 'gcvit_small'
169
  config = NAME2CONFIG[name]
170
  ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
171
+ model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
172
+ model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
173
  if pretrain:
174
  ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
175
  model.load_weights(ckpt_path)
176
  return model
177
 
178
+ def GCViTBase(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
179
  name = 'gcvit_base'
180
  config = NAME2CONFIG[name]
181
  ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
182
+ model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
183
+ model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
184
+ if pretrain:
185
+ ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
186
+ model.load_weights(ckpt_path)
187
+ return model
188
+
189
+ def GCViTLarge(input_shape=(224, 224, 3), pretrain=False, resize_query=False, **kwargs):
190
+ name = 'gcvit_large'
191
+ config = NAME2CONFIG[name]
192
+ ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
193
+ model = GCViT(name=name, resize_query=resize_query, **config, **kwargs)
194
+ model(tf.random.uniform(shape=input_shape)[tf.newaxis,])
195
  if pretrain:
196
  ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
197
  model.load_weights(ckpt_path)