file added
Browse files- app.py +30 -4
- gcvit/__init__.py +2 -0
- gcvit/__pycache__/__init__.cpython-38.pyc +0 -0
- gcvit/layers/__init__.py +7 -0
- gcvit/layers/__pycache__/__init__.cpython-38.pyc +0 -0
- gcvit/layers/__pycache__/attention.cpython-38.pyc +0 -0
- gcvit/layers/__pycache__/block.cpython-38.pyc +0 -0
- gcvit/layers/__pycache__/drop.cpython-38.pyc +0 -0
- gcvit/layers/__pycache__/embedding.cpython-38.pyc +0 -0
- gcvit/layers/__pycache__/feature.cpython-38.pyc +0 -0
- gcvit/layers/__pycache__/level.cpython-38.pyc +0 -0
- gcvit/layers/__pycache__/window.cpython-38.pyc +0 -0
- gcvit/layers/attention.py +96 -0
- gcvit/layers/block.py +99 -0
- gcvit/layers/drop.py +40 -0
- gcvit/layers/embedding.py +27 -0
- gcvit/layers/feature.py +202 -0
- gcvit/layers/level.py +93 -0
- gcvit/layers/window.py +15 -0
- gcvit/models/__init__.py +1 -0
- gcvit/models/__pycache__/__init__.cpython-38.pyc +0 -0
- gcvit/models/__pycache__/gcvit.cpython-38.pyc +0 -0
- gcvit/models/gcvit.py +145 -0
- gcvit/utils/__init__.py +1 -0
- gcvit/utils/gradcam.py +69 -0
- gcvit/version.py +1 -0
- requirements.txt +5 -0
- setup.py +50 -0
app.py
CHANGED
@@ -1,7 +1,33 @@
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
|
3 |
-
def
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
import gradio as gr
|
3 |
+
import gcvit
|
4 |
+
from gcvit.utils import get_gradcam_model, get_gradcam_prediction
|
5 |
|
6 |
+
def predict_fn(image, model_name):
|
7 |
+
"""A predict function that will be invoked by gradio."""
|
8 |
+
model = getattr(gcvit, model_name)(pretrain=True)
|
9 |
+
gradcam_model = get_gradcam_model(model)
|
10 |
+
preds, overlay = get_gradcam_prediction(image, gradcam_model, cmap='jet', alpha=0.4, pred_index=None)
|
11 |
+
preds = {x[1]:x[2] for x in preds}
|
12 |
+
return [preds, overlay]
|
13 |
|
14 |
+
demo = gr.Interface(
|
15 |
+
fn=predict_fn,
|
16 |
+
inputs=[
|
17 |
+
gr.inputs.Image(label="Input Image"),
|
18 |
+
gr.Radio(['GCViTTiny', 'GCViTSmall', 'GCViTBase'], value='GCViTTiny', label='Model Size')
|
19 |
+
],
|
20 |
+
outputs=[
|
21 |
+
gr.outputs.Label(label="Prediction"),
|
22 |
+
gr.inputs.Image(label="GradCAM"),
|
23 |
+
],
|
24 |
+
title="Global Context Vision Transformer (GCViT) Demo",
|
25 |
+
description="ImageNet Pretrain.",
|
26 |
+
examples=[
|
27 |
+
["example/african_elephant.png"],
|
28 |
+
["example/chelsea.png"],
|
29 |
+
["example/german_shepherd.jpg"],
|
30 |
+
["example/panda.jpg"],
|
31 |
+
],
|
32 |
+
)
|
33 |
+
demo.launch()
|
gcvit/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .models import GCViT, GCViTTiny, GCViTSmall, GCViTBase
|
2 |
+
from .version import __version__
|
gcvit/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (228 Bytes). View file
|
|
gcvit/layers/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .window import window_partition, window_reverse
|
2 |
+
from .attention import WindowAttention
|
3 |
+
from .drop import DropPath, Identity
|
4 |
+
from .embedding import PatchEmbed
|
5 |
+
from .feature import Mlp, FeatExtract, ReduceSize, SE, Resizing
|
6 |
+
from .block import GCViTBlock
|
7 |
+
from .level import GCViTLayer
|
gcvit/layers/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (530 Bytes). View file
|
|
gcvit/layers/__pycache__/attention.cpython-38.pyc
ADDED
Binary file (3.58 kB). View file
|
|
gcvit/layers/__pycache__/block.cpython-38.pyc
ADDED
Binary file (3 kB). View file
|
|
gcvit/layers/__pycache__/drop.cpython-38.pyc
ADDED
Binary file (1.8 kB). View file
|
|
gcvit/layers/__pycache__/embedding.cpython-38.pyc
ADDED
Binary file (1.39 kB). View file
|
|
gcvit/layers/__pycache__/feature.cpython-38.pyc
ADDED
Binary file (5.5 kB). View file
|
|
gcvit/layers/__pycache__/level.cpython-38.pyc
ADDED
Binary file (3 kB). View file
|
|
gcvit/layers/__pycache__/window.cpython-38.pyc
ADDED
Binary file (801 Bytes). View file
|
|
gcvit/layers/attention.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import tensorflow_addons as tfa
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
7 |
+
class WindowAttention(tf.keras.layers.Layer):
|
8 |
+
def __init__(self, window_size, num_heads, global_query, qkv_bias=True, qk_scale=None, attn_dropout=0., proj_dropout=0.,
|
9 |
+
**kwargs):
|
10 |
+
super().__init__(**kwargs)
|
11 |
+
window_size = (window_size,window_size)
|
12 |
+
self.window_size = window_size
|
13 |
+
self.num_heads = num_heads
|
14 |
+
self.global_query = global_query
|
15 |
+
self.qkv_bias = qkv_bias
|
16 |
+
self.qk_scale = qk_scale
|
17 |
+
self.attn_dropout = attn_dropout
|
18 |
+
self.proj_dropout = proj_dropout
|
19 |
+
|
20 |
+
def build(self, input_shape):
|
21 |
+
dim = input_shape[0][-1]
|
22 |
+
head_dim = dim // self.num_heads
|
23 |
+
self.scale = self.qk_scale or head_dim ** -0.5
|
24 |
+
self.qkv_size = 3 - int(self.global_query)
|
25 |
+
self.qkv = tf.keras.layers.Dense(dim * self.qkv_size, use_bias=self.qkv_bias, name='qkv')
|
26 |
+
self.relative_position_bias_table = self.add_weight(
|
27 |
+
'relative_position_bias_table',
|
28 |
+
shape=[(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads],
|
29 |
+
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
|
30 |
+
trainable=True,
|
31 |
+
dtype=self.dtype)
|
32 |
+
self.attn_drop = tf.keras.layers.Dropout(self.attn_dropout, name='attn_drop')
|
33 |
+
self.proj = tf.keras.layers.Dense(dim, name='proj')
|
34 |
+
self.proj_drop = tf.keras.layers.Dropout(self.proj_dropout, name='proj_drop')
|
35 |
+
self.softmax = tf.keras.layers.Activation('softmax', name='softmax')
|
36 |
+
self.relative_position_index = self.get_relative_position_index()
|
37 |
+
super().build(input_shape)
|
38 |
+
|
39 |
+
def get_relative_position_index(self):
|
40 |
+
coords_h = tf.range(self.window_size[0])
|
41 |
+
coords_w = tf.range(self.window_size[1])
|
42 |
+
coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing='ij'), axis=0)
|
43 |
+
coords_flatten = tf.reshape(coords, [2, -1])
|
44 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
45 |
+
relative_coords = tf.transpose(relative_coords, perm=[1, 2, 0])
|
46 |
+
relative_coords_xx = (relative_coords[:, :, 0] + self.window_size[0] - 1)
|
47 |
+
relative_coords_yy = (relative_coords[:, :, 1] + self.window_size[1] - 1)
|
48 |
+
relative_coords_xx = relative_coords_xx * (2 * self.window_size[1] - 1)
|
49 |
+
relative_position_index = (relative_coords_xx + relative_coords_yy)
|
50 |
+
return relative_position_index
|
51 |
+
|
52 |
+
def call(self, inputs, **kwargs):
|
53 |
+
if self.global_query:
|
54 |
+
inputs, q_global = inputs
|
55 |
+
B = tf.shape(q_global)[0] # B, N, C
|
56 |
+
else:
|
57 |
+
inputs = inputs[0]
|
58 |
+
B_, N, C = tf.unstack(tf.shape(inputs), num=3) # B*num_window, num_tokens, channels
|
59 |
+
qkv = self.qkv(inputs)
|
60 |
+
qkv = tf.reshape(qkv, [B_, N, self.qkv_size, self.num_heads, C // self.num_heads])
|
61 |
+
qkv = tf.transpose(qkv, [2, 0, 3, 1, 4])
|
62 |
+
if self.global_query:
|
63 |
+
k, v = tf.unstack(qkv, num=2, axis=0) # for unknown shame num=None will throw error
|
64 |
+
q_global = tf.repeat(q_global, repeats=B_//B, axis=0) # num_windows = B_//B => q_global same for all windows in a img
|
65 |
+
q = tf.reshape(q_global, shape=[B_, N, self.num_heads, C // self.num_heads])
|
66 |
+
q = tf.transpose(q, perm=[0, 2, 1, 3])
|
67 |
+
else:
|
68 |
+
q, k, v = tf.unstack(qkv, num=3, axis=0)
|
69 |
+
q = q * self.scale
|
70 |
+
attn = (q @ tf.transpose(k, perm=[0, 1, 3, 2]))
|
71 |
+
relative_position_bias = tf.gather(self.relative_position_bias_table, tf.reshape(self.relative_position_index, shape=[-1]))
|
72 |
+
relative_position_bias = tf.reshape(relative_position_bias,
|
73 |
+
shape=[self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1])
|
74 |
+
relative_position_bias = tf.transpose(relative_position_bias, perm=[2, 0, 1])
|
75 |
+
attn = attn + relative_position_bias[tf.newaxis,]
|
76 |
+
attn = self.softmax(attn)
|
77 |
+
attn = self.attn_drop(attn)
|
78 |
+
|
79 |
+
x = tf.transpose((attn @ v), perm=[0, 2, 1, 3]) # B_, num_tokens, num_heads, channels_per_head
|
80 |
+
x = tf.reshape(x, shape=[B_, N, C])
|
81 |
+
x = self.proj(x)
|
82 |
+
x = self.proj_drop(x)
|
83 |
+
return x
|
84 |
+
|
85 |
+
def get_config(self):
|
86 |
+
config = super().get_config()
|
87 |
+
config.update({
|
88 |
+
'window_size': self.window_size,
|
89 |
+
'num_heads': self.num_heads,
|
90 |
+
'global_query': self.global_query,
|
91 |
+
'qkv_bias': self.qkv_bias,
|
92 |
+
'qk_scale': self.qk_scale,
|
93 |
+
'attn_dropout': self.attn_dropout,
|
94 |
+
'proj_dropout': self.proj_dropout
|
95 |
+
})
|
96 |
+
return config
|
gcvit/layers/block.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from .attention import WindowAttention
|
4 |
+
from .drop import DropPath
|
5 |
+
from .window import window_partition, window_reverse
|
6 |
+
from .feature import Mlp, FeatExtract
|
7 |
+
|
8 |
+
|
9 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
10 |
+
class GCViTBlock(tf.keras.layers.Layer):
|
11 |
+
def __init__(self, window_size, num_heads, global_query, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0.,
|
12 |
+
attn_drop=0., path_drop=0., act_layer='gelu', layer_scale=None, **kwargs):
|
13 |
+
super().__init__(**kwargs)
|
14 |
+
self.window_size = window_size
|
15 |
+
self.num_heads = num_heads
|
16 |
+
self.global_query = global_query
|
17 |
+
self.mlp_ratio = mlp_ratio
|
18 |
+
self.qkv_bias = qkv_bias
|
19 |
+
self.qk_scale = qk_scale
|
20 |
+
self.drop = drop
|
21 |
+
self.attn_drop = attn_drop
|
22 |
+
self.path_drop = path_drop
|
23 |
+
self.act_layer = act_layer
|
24 |
+
self.layer_scale = layer_scale
|
25 |
+
|
26 |
+
def build(self, input_shape):
|
27 |
+
B, H, W, C = input_shape[0]
|
28 |
+
self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1')
|
29 |
+
self.attn = WindowAttention(window_size=self.window_size,
|
30 |
+
num_heads=self.num_heads,
|
31 |
+
global_query=self.global_query,
|
32 |
+
qkv_bias=self.qkv_bias,
|
33 |
+
qk_scale=self.qk_scale,
|
34 |
+
attn_dropout=self.attn_drop,
|
35 |
+
proj_dropout=self.drop,
|
36 |
+
name='attn')
|
37 |
+
self.drop_path1 = DropPath(self.path_drop)
|
38 |
+
self.drop_path2 = DropPath(self.path_drop)
|
39 |
+
self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
|
40 |
+
self.mlp = Mlp(hidden_features=int(C * self.mlp_ratio), dropout=self.drop, act_layer=self.act_layer, name='mlp')
|
41 |
+
if self.layer_scale is not None:
|
42 |
+
self.gamma1 = self.add_weight(
|
43 |
+
'gamma1',
|
44 |
+
shape=[C],
|
45 |
+
initializer=tf.keras.initializers.Constant(self.layer_scale),
|
46 |
+
trainable=True,
|
47 |
+
dtype=self.dtype)
|
48 |
+
self.gamma2 = self.add_weight(
|
49 |
+
'gamma2',
|
50 |
+
shape=[C],
|
51 |
+
initializer=tf.keras.initializers.Constant(self.layer_scale),
|
52 |
+
trainable=True,
|
53 |
+
dtype=self.dtype)
|
54 |
+
else:
|
55 |
+
self.gamma1 = 1.0
|
56 |
+
self.gamma2 = 1.0
|
57 |
+
self.num_windows = int(H // self.window_size) * int(W // self.window_size)
|
58 |
+
super().build(input_shape)
|
59 |
+
|
60 |
+
def call(self, inputs, **kwargs):
|
61 |
+
if self.global_query:
|
62 |
+
inputs, q_global = inputs
|
63 |
+
else:
|
64 |
+
inputs = inputs[0]
|
65 |
+
B, H, W, C = tf.unstack(tf.shape(inputs), num=4)
|
66 |
+
x = self.norm1(inputs)
|
67 |
+
# create windows and concat them in batch axis
|
68 |
+
x = window_partition(x, self.window_size) # (B_, win_h, win_w, C)
|
69 |
+
# flatten patch
|
70 |
+
x = tf.reshape(x, shape=[-1, self.window_size * self.window_size, C]) # (B_, N, C) => (batch*num_win, num_token, feature)
|
71 |
+
# attention
|
72 |
+
if self.global_query:
|
73 |
+
x = self.attn([x, q_global])
|
74 |
+
else:
|
75 |
+
x = self.attn([x])
|
76 |
+
# reverse window partition
|
77 |
+
x = window_reverse(x, self.window_size, H, W, C)
|
78 |
+
# FFN
|
79 |
+
x = inputs + self.drop_path1(x * self.gamma1)
|
80 |
+
x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
|
81 |
+
return x
|
82 |
+
|
83 |
+
def get_config(self):
|
84 |
+
config = super().get_config()
|
85 |
+
config.update({
|
86 |
+
'window_size': self.window_size,
|
87 |
+
'num_heads': self.num_heads,
|
88 |
+
'global_query': self.global_query,
|
89 |
+
'mlp_ratio': self.mlp_ratio,
|
90 |
+
'qkv_bias': self.qkv_bias,
|
91 |
+
'qk_scale': self.qk_scale,
|
92 |
+
'drop': self.drop,
|
93 |
+
'attn_drop': self.attn_drop,
|
94 |
+
'path_drop': self.path_drop,
|
95 |
+
'act_layer': self.act_layer,
|
96 |
+
'layer_scale': self.layer_scale,
|
97 |
+
'num_windows': self.num_windows,
|
98 |
+
})
|
99 |
+
return config
|
gcvit/layers/drop.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
|
4 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
5 |
+
class Identity(tf.keras.layers.Layer):
|
6 |
+
def __init__(self, **kwargs):
|
7 |
+
super().__init__(**kwargs)
|
8 |
+
|
9 |
+
def call(self, x):
|
10 |
+
return tf.identity(x)
|
11 |
+
|
12 |
+
def get_config(self):
|
13 |
+
config = super().get_config()
|
14 |
+
return config
|
15 |
+
|
16 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
17 |
+
class DropPath(tf.keras.layers.Layer):
|
18 |
+
def __init__(self, drop_prob=0., scale_by_keep=True, **kwargs):
|
19 |
+
super().__init__(**kwargs)
|
20 |
+
self.drop_prob = drop_prob
|
21 |
+
self.scale_by_keep = scale_by_keep
|
22 |
+
|
23 |
+
def call(self, x, training=None):
|
24 |
+
if self.drop_prob==0. or not training:
|
25 |
+
return x
|
26 |
+
keep_prob = 1 - self.drop_prob
|
27 |
+
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
|
28 |
+
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
|
29 |
+
random_tensor = tf.floor(random_tensor)
|
30 |
+
if keep_prob > 0.0 and self.scale_by_keep:
|
31 |
+
x = (x / keep_prob)
|
32 |
+
return x * random_tensor
|
33 |
+
|
34 |
+
def get_config(self):
|
35 |
+
config = super().get_config()
|
36 |
+
config.update({
|
37 |
+
"drop_prob": self.drop_prob,
|
38 |
+
"scale_by_keep": self.scale_by_keep
|
39 |
+
})
|
40 |
+
return config
|
gcvit/layers/embedding.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from .feature import ReduceSize
|
4 |
+
|
5 |
+
|
6 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
7 |
+
class PatchEmbed(tf.keras.layers.Layer):
|
8 |
+
def __init__(self, dim, **kwargs):
|
9 |
+
super().__init__(**kwargs)
|
10 |
+
self.dim = dim
|
11 |
+
|
12 |
+
def build(self, input_shape):
|
13 |
+
self.pad = tf.keras.layers.ZeroPadding2D(1, name='pad')
|
14 |
+
self.proj = tf.keras.layers.Conv2D(self.dim, kernel_size=3, strides=2, name='proj')
|
15 |
+
self.conv_down = ReduceSize(keep_dim=True, name='conv_down')
|
16 |
+
super().build(input_shape)
|
17 |
+
|
18 |
+
def call(self, inputs, **kwargs):
|
19 |
+
x = self.pad(inputs)
|
20 |
+
x = self.proj(x)
|
21 |
+
x = self.conv_down(x)
|
22 |
+
return x
|
23 |
+
|
24 |
+
def get_config(self):
|
25 |
+
config = super().get_config()
|
26 |
+
config.update({'dim': self.dim})
|
27 |
+
return config
|
gcvit/layers/feature.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import tensorflow_addons as tfa
|
3 |
+
|
4 |
+
H_AXIS = -3
|
5 |
+
W_AXIS = -2
|
6 |
+
|
7 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
8 |
+
class Mlp(tf.keras.layers.Layer):
|
9 |
+
def __init__(self, hidden_features=None, out_features=None, act_layer='gelu', dropout=0., **kwargs):
|
10 |
+
super().__init__(**kwargs)
|
11 |
+
self.hidden_features = hidden_features
|
12 |
+
self.out_features = out_features
|
13 |
+
self.act_layer = act_layer
|
14 |
+
self.dropout = dropout
|
15 |
+
|
16 |
+
def build(self, input_shape):
|
17 |
+
self.in_features = input_shape[-1]
|
18 |
+
self.hidden_features = self.hidden_features or self.in_features
|
19 |
+
self.out_features = self.out_features or self.in_features
|
20 |
+
self.fc1 = tf.keras.layers.Dense(self.hidden_features, name="fc1")
|
21 |
+
self.act = tf.keras.layers.Activation(self.act_layer, name="act")
|
22 |
+
self.fc2 = tf.keras.layers.Dense(self.out_features, name="fc2")
|
23 |
+
self.drop1 = tf.keras.layers.Dropout(self.dropout, name="drop1")
|
24 |
+
self.drop2 = tf.keras.layers.Dropout(self.dropout, name="drop2")
|
25 |
+
super().build(input_shape)
|
26 |
+
|
27 |
+
def call(self, inputs, **kwargs):
|
28 |
+
x = self.fc1(inputs)
|
29 |
+
x = self.act(x)
|
30 |
+
x = self.drop1(x)
|
31 |
+
x = self.fc2(x)
|
32 |
+
x = self.drop2(x)
|
33 |
+
return x
|
34 |
+
|
35 |
+
def get_config(self):
|
36 |
+
config = super().get_config()
|
37 |
+
config.update({
|
38 |
+
"hidden_features":self.hidden_features,
|
39 |
+
"out_features":self.out_features,
|
40 |
+
"act_layer":self.act_layer,
|
41 |
+
"dropout":self.dropout
|
42 |
+
})
|
43 |
+
return config
|
44 |
+
|
45 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
46 |
+
class SE(tf.keras.layers.Layer):
|
47 |
+
def __init__(self, oup=None, expansion=0.25, **kwargs):
|
48 |
+
super().__init__(**kwargs)
|
49 |
+
self.expansion = expansion
|
50 |
+
self.oup = oup
|
51 |
+
|
52 |
+
def build(self, input_shape):
|
53 |
+
inp = input_shape[-1]
|
54 |
+
self.oup = self.oup or inp
|
55 |
+
self.avg_pool = tfa.layers.AdaptiveAveragePooling2D(1, name="avg_pool")
|
56 |
+
self.fc = [
|
57 |
+
tf.keras.layers.Dense(int(inp * self.expansion), use_bias=False, name='fc/0'),
|
58 |
+
tf.keras.layers.Activation('gelu', name='fc/1'),
|
59 |
+
tf.keras.layers.Dense(self.oup, use_bias=False, name='fc/2'),
|
60 |
+
tf.keras.layers.Activation('sigmoid', name='fc/3')
|
61 |
+
]
|
62 |
+
super().build(input_shape)
|
63 |
+
|
64 |
+
def call(self, inputs, **kwargs):
|
65 |
+
b, _, _, c = tf.unstack(tf.shape(inputs), num=4)
|
66 |
+
x = tf.reshape(self.avg_pool(inputs), (b, c))
|
67 |
+
for layer in self.fc:
|
68 |
+
x = layer(x)
|
69 |
+
x = tf.reshape(x, (b, 1, 1, c))
|
70 |
+
return x*inputs
|
71 |
+
|
72 |
+
def get_config(self):
|
73 |
+
config = super().get_config()
|
74 |
+
config.update({
|
75 |
+
'expansion': self.expansion,
|
76 |
+
'oup': self.oup,
|
77 |
+
})
|
78 |
+
return config
|
79 |
+
|
80 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
81 |
+
class ReduceSize(tf.keras.layers.Layer):
|
82 |
+
def __init__(self, keep_dim=False, **kwargs):
|
83 |
+
super().__init__(**kwargs)
|
84 |
+
self.keep_dim = keep_dim
|
85 |
+
|
86 |
+
def build(self, input_shape):
|
87 |
+
dim = input_shape[-1]
|
88 |
+
dim_out = dim if self.keep_dim else 2*dim
|
89 |
+
self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
|
90 |
+
self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
|
91 |
+
self.conv = [
|
92 |
+
tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
|
93 |
+
tf.keras.layers.Activation('gelu', name='conv/1'),
|
94 |
+
SE(name='conv/2'),
|
95 |
+
tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
|
96 |
+
]
|
97 |
+
self.reduction = tf.keras.layers.Conv2D(dim_out, kernel_size=3, strides=2, padding='valid', use_bias=False,
|
98 |
+
name='reduction')
|
99 |
+
self.norm1 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm1') # eps like PyTorch
|
100 |
+
self.norm2 = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm2')
|
101 |
+
super().build(input_shape)
|
102 |
+
|
103 |
+
def call(self, inputs, **kwargs):
|
104 |
+
x = self.norm1(inputs)
|
105 |
+
xr = self.pad1(x) # if pad had weights it would've thrown error with .save_weights()
|
106 |
+
for layer in self.conv:
|
107 |
+
xr = layer(xr)
|
108 |
+
x = x + xr
|
109 |
+
x = self.pad2(x)
|
110 |
+
x = self.reduction(x)
|
111 |
+
x = self.norm2(x)
|
112 |
+
return x
|
113 |
+
|
114 |
+
def get_config(self):
|
115 |
+
config = super().get_config()
|
116 |
+
config.update({
|
117 |
+
"keep_dim":self.keep_dim,
|
118 |
+
})
|
119 |
+
return config
|
120 |
+
|
121 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
122 |
+
class FeatExtract(tf.keras.layers.Layer):
|
123 |
+
def __init__(self, keep_dim=False, **kwargs):
|
124 |
+
super().__init__(**kwargs)
|
125 |
+
self.keep_dim = keep_dim
|
126 |
+
|
127 |
+
def build(self, input_shape):
|
128 |
+
dim = input_shape[-1]
|
129 |
+
self.pad1 = tf.keras.layers.ZeroPadding2D(1, name='pad1')
|
130 |
+
self.pad2 = tf.keras.layers.ZeroPadding2D(1, name='pad2')
|
131 |
+
self.conv = [
|
132 |
+
tf.keras.layers.DepthwiseConv2D(kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv/0'),
|
133 |
+
tf.keras.layers.Activation('gelu', name='conv/1'),
|
134 |
+
SE(name='conv/2'),
|
135 |
+
tf.keras.layers.Conv2D(dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name='conv/3')
|
136 |
+
]
|
137 |
+
if not self.keep_dim:
|
138 |
+
self.pool = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='valid', name='pool')
|
139 |
+
# else:
|
140 |
+
# self.pool = tf.keras.layers.Activation('linear', name='identity') # hack for PyTorch nn.Identity layer ;)
|
141 |
+
super().build(input_shape)
|
142 |
+
|
143 |
+
def call(self, inputs, **kwargs):
|
144 |
+
x = inputs
|
145 |
+
xr = self.pad1(x)
|
146 |
+
for layer in self.conv:
|
147 |
+
xr = layer(xr)
|
148 |
+
x = x + xr # if pad had weights it would've thrown error with .save_weights()
|
149 |
+
if not self.keep_dim:
|
150 |
+
x = self.pad2(x)
|
151 |
+
x = self.pool(x)
|
152 |
+
return x
|
153 |
+
|
154 |
+
def get_config(self):
|
155 |
+
config = super().get_config()
|
156 |
+
config.update({
|
157 |
+
"keep_dim":self.keep_dim,
|
158 |
+
})
|
159 |
+
return config
|
160 |
+
|
161 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
162 |
+
class Resizing(tf.keras.layers.Layer):
|
163 |
+
def __init__(self,
|
164 |
+
height,
|
165 |
+
width,
|
166 |
+
interpolation='bilinear',
|
167 |
+
**kwargs):
|
168 |
+
self.height = height
|
169 |
+
self.width = width
|
170 |
+
self.interpolation = interpolation
|
171 |
+
super().__init__(**kwargs)
|
172 |
+
|
173 |
+
def call(self, inputs):
|
174 |
+
# tf.image.resize will always output float32 and operate more efficiently on
|
175 |
+
# float32 unless interpolation is nearest, in which case ouput type matches
|
176 |
+
# input type.
|
177 |
+
if self.interpolation == 'nearest':
|
178 |
+
input_dtype = self.compute_dtype
|
179 |
+
else:
|
180 |
+
input_dtype = tf.float32
|
181 |
+
inputs = tf.cast(inputs, dtype=input_dtype)
|
182 |
+
size = [self.height, self.width]
|
183 |
+
outputs = tf.image.resize(
|
184 |
+
inputs,
|
185 |
+
size=size,
|
186 |
+
method=self.interpolation)
|
187 |
+
return tf.cast(outputs, self.compute_dtype)
|
188 |
+
|
189 |
+
def compute_output_shape(self, input_shape):
|
190 |
+
input_shape = tf.TensorShape(input_shape).as_list()
|
191 |
+
input_shape[H_AXIS] = self.height
|
192 |
+
input_shape[W_AXIS] = self.width
|
193 |
+
return tf.TensorShape(input_shape)
|
194 |
+
|
195 |
+
def get_config(self):
|
196 |
+
config = super().get_config()
|
197 |
+
config.update({
|
198 |
+
'height': self.height,
|
199 |
+
'width': self.width,
|
200 |
+
'interpolation': self.interpolation,
|
201 |
+
})
|
202 |
+
return config
|
gcvit/layers/level.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from .feature import FeatExtract, ReduceSize, Resizing
|
4 |
+
from .block import GCViTBlock
|
5 |
+
|
6 |
+
@tf.keras.utils.register_keras_serializable(package="gcvit")
|
7 |
+
class GCViTLayer(tf.keras.layers.Layer):
|
8 |
+
def __init__(self, depth, num_heads, window_size, keep_dims, downsample=True, mlp_ratio=4., qkv_bias=True,
|
9 |
+
qk_scale=None, drop=0., attn_drop=0., path_drop=0., layer_scale=None, resize_query=False, **kwargs):
|
10 |
+
super().__init__(**kwargs)
|
11 |
+
self.depth = depth
|
12 |
+
self.num_heads = num_heads
|
13 |
+
self.window_size = window_size
|
14 |
+
self.keep_dims = keep_dims
|
15 |
+
self.downsample = downsample
|
16 |
+
self.mlp_ratio = mlp_ratio
|
17 |
+
self.qkv_bias = qkv_bias
|
18 |
+
self.qk_scale = qk_scale
|
19 |
+
self.drop = drop
|
20 |
+
self.attn_drop = attn_drop
|
21 |
+
self.path_drop = path_drop
|
22 |
+
self.layer_scale = layer_scale
|
23 |
+
self.resize_query = resize_query
|
24 |
+
|
25 |
+
def build(self, input_shape):
|
26 |
+
path_drop = [self.path_drop] * self.depth if not isinstance(self.path_drop, list) else self.path_drop
|
27 |
+
self.blocks = [
|
28 |
+
GCViTBlock(window_size=self.window_size,
|
29 |
+
num_heads=self.num_heads,
|
30 |
+
global_query=bool(i % 2),
|
31 |
+
mlp_ratio=self.mlp_ratio,
|
32 |
+
qkv_bias=self.qkv_bias,
|
33 |
+
qk_scale=self.qk_scale,
|
34 |
+
drop=self.drop,
|
35 |
+
attn_drop=self.attn_drop,
|
36 |
+
path_drop=path_drop[i],
|
37 |
+
layer_scale=self.layer_scale,
|
38 |
+
name=f'blocks/{i}')
|
39 |
+
for i in range(self.depth)]
|
40 |
+
self.down = ReduceSize(keep_dim=False, name='downsample')
|
41 |
+
self.to_q_global = [
|
42 |
+
FeatExtract(keep_dim, name=f'to_q_global/{i}')
|
43 |
+
for i, keep_dim in enumerate(self.keep_dims)]
|
44 |
+
self.resize = Resizing(self.window_size, self.window_size, interpolation='bicubic')
|
45 |
+
super().build(input_shape)
|
46 |
+
|
47 |
+
def call(self, inputs, **kwargs):
|
48 |
+
height, width = tf.unstack(tf.shape(inputs)[1:3], num=2)
|
49 |
+
# pad to multiple of window_size
|
50 |
+
h_pad = (self.window_size - height % self.window_size) % self.window_size
|
51 |
+
w_pad = (self.window_size - width % self.window_size) % self.window_size
|
52 |
+
x = tf.pad(inputs, [[0, 0],
|
53 |
+
[h_pad//2, (h_pad//2 + h_pad%2)], # padding in both directions unlike tfgcvit
|
54 |
+
[w_pad//2, (w_pad//2 + w_pad%2)],
|
55 |
+
[0, 0]])
|
56 |
+
# generate global query
|
57 |
+
q_global = x # (B, H, W, C)
|
58 |
+
for layer in self.to_q_global:
|
59 |
+
q_global = layer(q_global) # official impl issue: https://github.com/NVlabs/GCVit/issues/13
|
60 |
+
# resize query to fit key-value, but result in poor score with official weights?
|
61 |
+
if self.resize_query:
|
62 |
+
q_global = self.resize(q_global) # to avoid mismatch between feat_map and q_global: https://github.com/NVlabs/GCVit/issues/9
|
63 |
+
# feature_map -> windows -> window_attention -> feature_map
|
64 |
+
for i, blk in enumerate(self.blocks):
|
65 |
+
if i % 2:
|
66 |
+
x = blk([x, q_global])
|
67 |
+
else:
|
68 |
+
x = blk([x])
|
69 |
+
x = x[:, :height, :width, :] # https://github.com/NVlabs/GCVit/issues/9
|
70 |
+
# set shape for [B, ?, ?, C]
|
71 |
+
x.set_shape(inputs.shape) # `tf.reshape` creates new tensor with new_shape
|
72 |
+
# downsample
|
73 |
+
if self.downsample:
|
74 |
+
x = self.down(x)
|
75 |
+
return x
|
76 |
+
|
77 |
+
def get_config(self):
|
78 |
+
config = super().get_config()
|
79 |
+
config.update({
|
80 |
+
'depth': self.depth,
|
81 |
+
'num_heads': self.num_heads,
|
82 |
+
'window_size': self.window_size,
|
83 |
+
'keep_dims': self.keep_dims,
|
84 |
+
'downsample': self.downsample,
|
85 |
+
'mlp_ratio': self.mlp_ratio,
|
86 |
+
'qkv_bias': self.qkv_bias,
|
87 |
+
'qk_scale': self.qk_scale,
|
88 |
+
'drop': self.drop,
|
89 |
+
'attn_drop': self.attn_drop,
|
90 |
+
'path_drop': self.path_drop,
|
91 |
+
'layer_scale': self.layer_scale
|
92 |
+
})
|
93 |
+
return config
|
gcvit/layers/window.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
def window_partition(x, window_size):
|
4 |
+
B, H, W, C = tf.unstack(tf.shape(x), num=4)
|
5 |
+
x = tf.reshape(x, shape=[-1, H // window_size, window_size, W // window_size, window_size, C])
|
6 |
+
x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5])
|
7 |
+
windows = tf.reshape(x, shape=[-1, window_size, window_size, C])
|
8 |
+
return windows
|
9 |
+
|
10 |
+
|
11 |
+
def window_reverse(windows, window_size, H, W, C):
|
12 |
+
x = tf.reshape(windows, shape=[-1, H // window_size, W // window_size, window_size, window_size, C])
|
13 |
+
x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5])
|
14 |
+
x = tf.reshape(x, shape=[-1, H, W, C])
|
15 |
+
return x
|
gcvit/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .gcvit import GCViT, GCViTTiny, GCViTSmall, GCViTBase
|
gcvit/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (234 Bytes). View file
|
|
gcvit/models/__pycache__/gcvit.cpython-38.pyc
ADDED
Binary file (4.08 kB). View file
|
|
gcvit/models/gcvit.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
|
4 |
+
from ..layers import PatchEmbed, GCViTLayer, Identity
|
5 |
+
|
6 |
+
|
7 |
+
BASE_URL = 'https://github.com/awsaf49/gcvit-tf/releases/download'
|
8 |
+
TAG = 'v1.0.0'
|
9 |
+
NAME2CONFIG = {
|
10 |
+
'gcvit_tiny': {'window_size': (7, 7, 14, 7),
|
11 |
+
'dim': 64,
|
12 |
+
'depths': (3, 4, 19, 5),
|
13 |
+
'num_heads': (2, 4, 8, 16),
|
14 |
+
'path_drop': 0.2,},
|
15 |
+
'gcvit_small': {'window_size': (7, 7, 14, 7),
|
16 |
+
'dim': 96,
|
17 |
+
'depths': (3, 4, 19, 5),
|
18 |
+
'num_heads': (3, 6, 12, 24),
|
19 |
+
'mlp_ratio': 2.,
|
20 |
+
'path_drop': 0.3,
|
21 |
+
'layer_scale': 1e-5,},
|
22 |
+
'gcvit_base': {'window_size': (7, 7, 14, 7),
|
23 |
+
'dim':128,
|
24 |
+
'depths': (3, 4, 19, 5),
|
25 |
+
'num_heads': (4, 8, 16, 32),
|
26 |
+
'mlp_ratio': 2.,
|
27 |
+
'path_drop': 0.5,
|
28 |
+
'layer_scale': 1e-5,},
|
29 |
+
}
|
30 |
+
|
31 |
+
@tf.keras.utils.register_keras_serializable(package='gcvit')
|
32 |
+
class GCViT(tf.keras.Model):
|
33 |
+
def __init__(self, window_size, dim, depths, num_heads,
|
34 |
+
drop_rate=0., mlp_ratio=3., qkv_bias=True, qk_scale=None, attn_drop=0., path_drop=0.1, layer_scale=None, resize_query=False,
|
35 |
+
global_pool='avg', num_classes=1000, head_act='softmax', **kwargs):
|
36 |
+
super().__init__(**kwargs)
|
37 |
+
self.window_size = window_size
|
38 |
+
self.dim = dim
|
39 |
+
self.depths = depths
|
40 |
+
self.num_heads = num_heads
|
41 |
+
self.drop_rate = drop_rate
|
42 |
+
self.mlp_ratio = mlp_ratio
|
43 |
+
self.qkv_bias = qkv_bias
|
44 |
+
self.qk_scale = qk_scale
|
45 |
+
self.attn_drop = attn_drop
|
46 |
+
self.path_drop = path_drop
|
47 |
+
self.layer_scale = layer_scale
|
48 |
+
self.resize_query = resize_query
|
49 |
+
self.global_pool = global_pool
|
50 |
+
self.num_classes = num_classes
|
51 |
+
self.head_act = head_act
|
52 |
+
|
53 |
+
self.patch_embed = PatchEmbed(dim=dim, name='patch_embed')
|
54 |
+
self.pos_drop = tf.keras.layers.Dropout(drop_rate, name='pos_drop')
|
55 |
+
path_drops = np.linspace(0., path_drop, sum(depths))
|
56 |
+
keep_dims = [(False, False, False),(False, False),(True,),(True,),]
|
57 |
+
self.levels = []
|
58 |
+
for i in range(len(depths)):
|
59 |
+
path_drop = path_drops[sum(depths[:i]):sum(depths[:i + 1])].tolist()
|
60 |
+
level = GCViTLayer(depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], keep_dims=keep_dims[i],
|
61 |
+
downsample=(i < len(depths) - 1), mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
62 |
+
drop=drop_rate, attn_drop=attn_drop, path_drop=path_drop, layer_scale=layer_scale, resize_query=resize_query,
|
63 |
+
name=f'levels/{i}')
|
64 |
+
self.levels.append(level)
|
65 |
+
self.norm = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-05, name='norm')
|
66 |
+
if global_pool == 'avg':
|
67 |
+
self.pool = tf.keras.layers.GlobalAveragePooling2D(name='pool')
|
68 |
+
elif global_pool == 'max':
|
69 |
+
self.pool = tf.keras.layers.GlobalMaxPooling2D(name='pool')
|
70 |
+
elif global_pool is None:
|
71 |
+
self.pool = Identity(name='pool')
|
72 |
+
else:
|
73 |
+
raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {global_pool}')
|
74 |
+
self.head = [tf.keras.layers.Dense(num_classes, name='head/fc'),
|
75 |
+
tf.keras.layers.Activation(head_act, name='head/act')]
|
76 |
+
|
77 |
+
def reset_classifier(self, num_classes, head_act, global_pool=None):
|
78 |
+
self.num_classes = num_classes
|
79 |
+
if global_pool is not None:
|
80 |
+
self.global_pool = global_pool
|
81 |
+
self.head[0] = tf.keras.layers.Dense(num_classes, name='head/fc') if num_classes else Identity(name='head/fc')
|
82 |
+
self.head[1] = tf.keras.layers.Activation(head_act, name='head/act') if head_act else Identity(name='head/act')
|
83 |
+
super().build((1, 224, 224, 3))
|
84 |
+
|
85 |
+
def forward_features(self, inputs):
|
86 |
+
x = self.patch_embed(inputs)
|
87 |
+
x = self.pos_drop(x)
|
88 |
+
x = tf.cast(x, dtype=tf.float32)
|
89 |
+
for level in self.levels:
|
90 |
+
x = level(x)
|
91 |
+
x = self.norm(x)
|
92 |
+
return x
|
93 |
+
|
94 |
+
def forward_head(self, inputs, pre_logits=False):
|
95 |
+
x = inputs
|
96 |
+
if self.global_pool in ['avg', 'max']:
|
97 |
+
x = self.pool(x)
|
98 |
+
if not pre_logits:
|
99 |
+
for layer in self.head:
|
100 |
+
x = layer(x)
|
101 |
+
return x
|
102 |
+
|
103 |
+
def call(self, inputs, **kwargs):
|
104 |
+
x = self.forward_features(inputs)
|
105 |
+
x = self.forward_head(x)
|
106 |
+
return x
|
107 |
+
|
108 |
+
def build_graph(self, input_shape=(224, 224, 3)):
|
109 |
+
"""https://www.kaggle.com/code/ipythonx/tf-hybrid-efficientnet-swin-transformer-gradcam"""
|
110 |
+
x = tf.keras.Input(shape=input_shape)
|
111 |
+
return tf.keras.Model(inputs=[x], outputs=self.call(x), name=self.name)
|
112 |
+
|
113 |
+
# load standard models
|
114 |
+
def GCViTTiny(pretrain=False, **kwargs):
|
115 |
+
name = 'gcvit_tiny'
|
116 |
+
config = NAME2CONFIG[name]
|
117 |
+
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
118 |
+
model = GCViT(name=name, **config, **kwargs)
|
119 |
+
model(tf.random.uniform(shape=(1, 224, 224, 3)))
|
120 |
+
if pretrain:
|
121 |
+
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
122 |
+
model.load_weights(ckpt_path)
|
123 |
+
return model
|
124 |
+
|
125 |
+
def GCViTSmall(pretrain=False, **kwargs):
|
126 |
+
name = 'gcvit_small'
|
127 |
+
config = NAME2CONFIG[name]
|
128 |
+
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
129 |
+
model = GCViT(name=name, **config, **kwargs)
|
130 |
+
model(tf.random.uniform(shape=(1, 224, 224, 3)))
|
131 |
+
if pretrain:
|
132 |
+
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
133 |
+
model.load_weights(ckpt_path)
|
134 |
+
return model
|
135 |
+
|
136 |
+
def GCViTBase(pretrain=False, **kwargs):
|
137 |
+
name = 'gcvit_base'
|
138 |
+
config = NAME2CONFIG[name]
|
139 |
+
ckpt_link = '{}/{}/{}_weights.h5'.format(BASE_URL, TAG, name)
|
140 |
+
model = GCViT(name=name, **config, **kwargs)
|
141 |
+
model(tf.random.uniform(shape=(1, 224, 224, 3)))
|
142 |
+
if pretrain:
|
143 |
+
ckpt_path = tf.keras.utils.get_file('{}_weights.h5'.format(name), ckpt_link)
|
144 |
+
model.load_weights(ckpt_path)
|
145 |
+
return model
|
gcvit/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .gradcam import process_image, get_gradcam_model, get_gradcam_prediction
|
gcvit/utils/gradcam.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import matplotlib.cm as cm
|
3 |
+
import numpy as np
|
4 |
+
try:
|
5 |
+
from tensorflow.keras.utils import array_to_img, img_to_array
|
6 |
+
except:
|
7 |
+
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array
|
8 |
+
|
9 |
+
def process_image(img, size=(224, 224)):
|
10 |
+
img_array = tf.keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
|
11 |
+
img_array = tf.image.resize(img_array, size,)[None,]
|
12 |
+
return img_array
|
13 |
+
|
14 |
+
def get_gradcam_model(model):
|
15 |
+
inp = tf.keras.Input(shape=(224, 224, 3))
|
16 |
+
feats = model.forward_features(inp)
|
17 |
+
preds = model.forward_head(feats)
|
18 |
+
return tf.keras.models.Model(inp, [preds, feats])
|
19 |
+
|
20 |
+
def get_gradcam_prediction(img, grad_model, process=True, decode=True, pred_index=None, cmap='jet', alpha=0.4):
|
21 |
+
"""Grad-CAM for a single image
|
22 |
+
|
23 |
+
Args:
|
24 |
+
img (np.ndarray): process or raw image without batch_shape e.g. (224, 224, 3)
|
25 |
+
grad_model (tf.keras.Model): model with feature map and prediction
|
26 |
+
process (bool, optional): imagenet pre-processing. Defaults to True.
|
27 |
+
pred_index (int, optional): for particular calss. Defaults to None.
|
28 |
+
cmap (str, optional): colormap. Defaults to 'jet'.
|
29 |
+
alpha (float, optional): opacity. Defaults to 0.4.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
preds_decode: top5 predictions
|
33 |
+
heatmap: gradcam heatmap
|
34 |
+
"""
|
35 |
+
# process image for inference
|
36 |
+
if process:
|
37 |
+
img_array = process_image(img)
|
38 |
+
else:
|
39 |
+
img_array = tf.convert_to_tensor(img)[None,]
|
40 |
+
if img.min()!=img.max():
|
41 |
+
img = (img - img.min())/(img.max() - img.min())
|
42 |
+
img = np.uint8(img*255.0)
|
43 |
+
# get prediction
|
44 |
+
with tf.GradientTape(persistent=True) as tape:
|
45 |
+
preds, feats = grad_model(img_array)
|
46 |
+
if pred_index is None:
|
47 |
+
pred_index = tf.argmax(preds[0])
|
48 |
+
class_channel = preds[:, pred_index]
|
49 |
+
# compute heatmap
|
50 |
+
grads = tape.gradient(class_channel, feats)
|
51 |
+
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
|
52 |
+
feats = feats[0]
|
53 |
+
heatmap = feats @ pooled_grads[..., tf.newaxis]
|
54 |
+
heatmap = tf.squeeze(heatmap)
|
55 |
+
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
|
56 |
+
heatmap = heatmap.numpy()
|
57 |
+
heatmap = np.uint8(255 * heatmap)
|
58 |
+
# colorize heatmap
|
59 |
+
cmap = cm.get_cmap(cmap)
|
60 |
+
colors = cmap(np.arange(256))[:, :3]
|
61 |
+
heatmap = colors[heatmap]
|
62 |
+
heatmap = array_to_img(heatmap)
|
63 |
+
heatmap = heatmap.resize((img.shape[1], img.shape[0]))
|
64 |
+
heatmap = img_to_array(heatmap)
|
65 |
+
overlay = img + heatmap * alpha
|
66 |
+
overlay = array_to_img(overlay)
|
67 |
+
# decode prediction
|
68 |
+
preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
|
69 |
+
return preds_decode, overlay
|
gcvit/version.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "1.0.3"
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tensorflow==2.4.1
|
2 |
+
tensorflow_addons==0.14.0
|
3 |
+
gradio==3.1.0
|
4 |
+
numpy
|
5 |
+
matplotlib
|
setup.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
from codecs import open
|
3 |
+
from os import path
|
4 |
+
|
5 |
+
here = path.abspath(path.dirname(__file__))
|
6 |
+
|
7 |
+
# Get the long description from the README file
|
8 |
+
with open(path.join(here, "README.md"), encoding="utf-8") as f:
|
9 |
+
long_description = f.read()
|
10 |
+
|
11 |
+
with open(path.join(here, 'requirements.txt')) as f:
|
12 |
+
install_requires = [x for x in f.read().splitlines() if len(x)]
|
13 |
+
|
14 |
+
exec(open("gcvit/version.py").read())
|
15 |
+
|
16 |
+
setup(
|
17 |
+
name="gcvit",
|
18 |
+
version=__version__,
|
19 |
+
description="Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer. https://github.com/awsaf49/gcvit-tf",
|
20 |
+
long_description=long_description,
|
21 |
+
long_description_content_type="text/markdown",
|
22 |
+
url="https://github.com/awsaf49/gcvit-tf",
|
23 |
+
author="Awsaf",
|
24 |
+
author_email="[email protected]",
|
25 |
+
classifiers=[
|
26 |
+
# How mature is this project? Common values are
|
27 |
+
# 3 - Alpha
|
28 |
+
# 4 - Beta
|
29 |
+
# 5 - Production/Stable
|
30 |
+
"Development Status :: 3 - Alpha",
|
31 |
+
"Intended Audience :: Developers",
|
32 |
+
"Intended Audience :: Science/Research",
|
33 |
+
"License :: OSI Approved :: Apache Software License",
|
34 |
+
"Programming Language :: Python :: 3.6",
|
35 |
+
"Programming Language :: Python :: 3.7",
|
36 |
+
"Programming Language :: Python :: 3.8",
|
37 |
+
"Topic :: Scientific/Engineering",
|
38 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
39 |
+
"Topic :: Software Development",
|
40 |
+
"Topic :: Software Development :: Libraries",
|
41 |
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
42 |
+
],
|
43 |
+
# Note that this is a string of words separated by whitespace, not a list.
|
44 |
+
keywords="tensorflow computer_vision image classification transformer",
|
45 |
+
packages=find_packages(exclude=["tests"]),
|
46 |
+
include_package_data=True,
|
47 |
+
install_requires=install_requires,
|
48 |
+
python_requires=">=3.6",
|
49 |
+
license="MIT",
|
50 |
+
)
|