tensorgirl
commited on
Commit
·
93c26b6
1
Parent(s):
30c4649
Update augvit_model.py
Browse files- augvit_model.py +31 -4
augvit_model.py
CHANGED
@@ -98,6 +98,33 @@ class Transformer(Layer):
|
|
98 |
x = attn(x, training=training) + x + aug_attn(x, training=training)
|
99 |
x = mlp(x, training=training) + x + augs(x, training=training)
|
100 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
class AUGViT(Model):
|
103 |
def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim,name='augvit',
|
@@ -117,11 +144,11 @@ class AUGViT(Model):
|
|
117 |
self.patch_den= nn.Dense(units=dim,name='patchden')
|
118 |
|
119 |
|
120 |
-
|
121 |
self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True)
|
122 |
self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
|
123 |
-
self.pos_embedding = tf.Variable(initial_value=tf.random_normal_initializer(stddev=0.06)(
|
124 |
-
|
125 |
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=dropout,name='trans')
|
126 |
|
127 |
self.pool = pool
|
@@ -142,7 +169,7 @@ class AUGViT(Model):
|
|
142 |
)
|
143 |
x = tf.concat([cls_tokens, x], axis=1)
|
144 |
# print(x.shape,cls_tokens.shape )
|
145 |
-
x=
|
146 |
|
147 |
# print(x.shape,pos.shape,self.pos_embedding.shape)
|
148 |
x = self.dropout(x, training=training)
|
|
|
98 |
x = attn(x, training=training) + x + aug_attn(x, training=training)
|
99 |
x = mlp(x, training=training) + x + augs(x, training=training)
|
100 |
return x
|
101 |
+
|
102 |
+
@tf.keras.utils.register_keras_serializable()
|
103 |
+
class AddPositionEmbs(tf.keras.layers.Layer):
|
104 |
+
|
105 |
+
def build(self, input_shape):
|
106 |
+
assert (
|
107 |
+
len(input_shape) == 3
|
108 |
+
), f"Number of dimensions should be 3, got {len(input_shape)}"
|
109 |
+
self.pe = tf.Variable(
|
110 |
+
name="pos_embedding",
|
111 |
+
initial_value=tf.random_normal_initializer(stddev=0.06)(
|
112 |
+
shape=(1, input_shape[1], input_shape[2])
|
113 |
+
),
|
114 |
+
dtype="float32",
|
115 |
+
trainable=True,
|
116 |
+
)
|
117 |
+
|
118 |
+
def call(self, inputs):
|
119 |
+
return inputs + tf.cast(self.pe, dtype=inputs.dtype)
|
120 |
+
|
121 |
+
def get_config(self):
|
122 |
+
config = super().get_config()
|
123 |
+
return config
|
124 |
+
|
125 |
+
@classmethod
|
126 |
+
def from_config(cls, config):
|
127 |
+
return cls(**config)
|
128 |
|
129 |
class AUGViT(Model):
|
130 |
def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim,name='augvit',
|
|
|
144 |
self.patch_den= nn.Dense(units=dim,name='patchden')
|
145 |
|
146 |
|
147 |
+
self.pos_embedding = AddPositionEmbs(name="Transformer/posembed_input")
|
148 |
self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True)
|
149 |
self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
|
150 |
+
# self.pos_embedding = tf.Variable(initial_value=tf.random_normal_initializer(stddev=0.06)(
|
151 |
+
# shape=(1, num_patches + 1, dim)),name='pos_emb',trainable=True)
|
152 |
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=dropout,name='trans')
|
153 |
|
154 |
self.pool = pool
|
|
|
169 |
)
|
170 |
x = tf.concat([cls_tokens, x], axis=1)
|
171 |
# print(x.shape,cls_tokens.shape )
|
172 |
+
x= self.pos_embedding(x)
|
173 |
|
174 |
# print(x.shape,pos.shape,self.pos_embedding.shape)
|
175 |
x = self.dropout(x, training=training)
|