tensorgirl commited on
Commit
93c26b6
·
1 Parent(s): 30c4649

Update augvit_model.py

Browse files
Files changed (1) hide show
  1. augvit_model.py +31 -4
augvit_model.py CHANGED
@@ -98,6 +98,33 @@ class Transformer(Layer):
98
  x = attn(x, training=training) + x + aug_attn(x, training=training)
99
  x = mlp(x, training=training) + x + augs(x, training=training)
100
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  class AUGViT(Model):
103
  def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim,name='augvit',
@@ -117,11 +144,11 @@ class AUGViT(Model):
117
  self.patch_den= nn.Dense(units=dim,name='patchden')
118
 
119
 
120
-
121
  self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True)
122
  self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
123
- self.pos_embedding = tf.Variable(initial_value=tf.random_normal_initializer(stddev=0.06)(
124
- shape=(1, num_patches + 1, dim)),name='pos_emb',trainable=True)
125
  self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=dropout,name='trans')
126
 
127
  self.pool = pool
@@ -142,7 +169,7 @@ class AUGViT(Model):
142
  )
143
  x = tf.concat([cls_tokens, x], axis=1)
144
  # print(x.shape,cls_tokens.shape )
145
- x= x+self.pos_embedding
146
 
147
  # print(x.shape,pos.shape,self.pos_embedding.shape)
148
  x = self.dropout(x, training=training)
 
98
  x = attn(x, training=training) + x + aug_attn(x, training=training)
99
  x = mlp(x, training=training) + x + augs(x, training=training)
100
  return x
101
+
102
+ @tf.keras.utils.register_keras_serializable()
103
+ class AddPositionEmbs(tf.keras.layers.Layer):
104
+
105
+ def build(self, input_shape):
106
+ assert (
107
+ len(input_shape) == 3
108
+ ), f"Number of dimensions should be 3, got {len(input_shape)}"
109
+ self.pe = tf.Variable(
110
+ name="pos_embedding",
111
+ initial_value=tf.random_normal_initializer(stddev=0.06)(
112
+ shape=(1, input_shape[1], input_shape[2])
113
+ ),
114
+ dtype="float32",
115
+ trainable=True,
116
+ )
117
+
118
+ def call(self, inputs):
119
+ return inputs + tf.cast(self.pe, dtype=inputs.dtype)
120
+
121
+ def get_config(self):
122
+ config = super().get_config()
123
+ return config
124
+
125
+ @classmethod
126
+ def from_config(cls, config):
127
+ return cls(**config)
128
 
129
  class AUGViT(Model):
130
  def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim,name='augvit',
 
144
  self.patch_den= nn.Dense(units=dim,name='patchden')
145
 
146
 
147
+ self.pos_embedding = AddPositionEmbs(name="Transformer/posembed_input")
148
  self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True)
149
  self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
150
+ # self.pos_embedding = tf.Variable(initial_value=tf.random_normal_initializer(stddev=0.06)(
151
+ # shape=(1, num_patches + 1, dim)),name='pos_emb',trainable=True)
152
  self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=dropout,name='trans')
153
 
154
  self.pool = pool
 
169
  )
170
  x = tf.concat([cls_tokens, x], axis=1)
171
  # print(x.shape,cls_tokens.shape )
172
+ x= self.pos_embedding(x)
173
 
174
  # print(x.shape,pos.shape,self.pos_embedding.shape)
175
  x = self.dropout(x, training=training)