Create README.md
Browse files
README.md
CHANGED
@@ -1,13 +1,146 @@
|
|
1 |
-
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: 'GCViT: Global Context Vision Transformer'
|
3 |
+
colorFrom: indigo
|
4 |
+
---
|
5 |
+
<h1 align="center">
|
6 |
+
<p><a href='https://arxiv.org/pdf/2206.09959v1.pdf'>GCViT: Global Context Vision Transformer</a></p>
|
7 |
+
</h1>
|
8 |
+
<div align=center><img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_arch.PNG" width=800></div>
|
9 |
+
<p align="center">
|
10 |
+
<a href="https://github.com/awsaf49/gcvit-tf/blob/main/LICENSE.md">
|
11 |
+
<img src="https://img.shields.io/badge/License-MIT-yellow.svg">
|
12 |
+
</a>
|
13 |
+
<img alt="python" src="https://img.shields.io/badge/python-%3E%3D3.6-blue?logo=python">
|
14 |
+
<img alt="tensorflow" src="https://img.shields.io/badge/tensorflow-%3E%3D2.4.1-orange?logo=tensorflow">
|
15 |
+
<div align=center><p>
|
16 |
+
<a target="_blank" href="https://huggingface.co/spaces/awsaf49/gcvit-tf"><img src="https://img.shields.io/badge/π€%20Hugging%20Face-Spaces-yellow.svg"></a>
|
17 |
+
<a href="https://colab.research.google.com/github/awsaf49/gcvit-tf/blob/main/notebooks/GCViT_Flower_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
|
18 |
+
<a href="https://www.kaggle.com/awsaf49/flower-classification-gcvit-global-context-vit"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open In Kaggle"></a>
|
19 |
+
</p></div>
|
20 |
+
<h2 align="center">
|
21 |
+
<p>Tensorflow 2.0 Implementation of GCViT</p>
|
22 |
+
</h2>
|
23 |
+
</p>
|
24 |
+
<p align="center">
|
25 |
+
This library implements <b>GCViT</b> using Tensorflow 2.0 specifically in <code>tf.keras.Model</code> manner to get PyTorch flavor.
|
26 |
+
</p>
|
27 |
+
|
28 |
+
## Update
|
29 |
+
* **15 Jan 2023** : `GCViTLarge` model added with ckpt.
|
30 |
+
* **3 Sept 2022** : Annotated [kaggle-notebook](https://www.kaggle.com/code/awsaf49/gcvit-global-context-vision-transformer) based on this project won [Kaggle ML Research Spotlight: August 2022](https://www.kaggle.com/discussions/general/349817).
|
31 |
+
* **19 Aug 2022** : This project got acknowledged by [Official](https://github.com/NVlabs/GCVit) repo [here](https://github.com/NVlabs/GCVit#third-party-implementations-and-resources)
|
32 |
+
|
33 |
+
## Model
|
34 |
+
* Architecture:
|
35 |
+
|
36 |
+
<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/arch.PNG">
|
37 |
+
|
38 |
+
* Local Vs Global Attention:
|
39 |
+
|
40 |
+
<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_msa.PNG">
|
41 |
+
|
42 |
+
## Result
|
43 |
+
<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/result.PNG" width=900>
|
44 |
+
|
45 |
+
Official codebase had some issue which has been fixed recently (12 August 2022). Here's the result of ported weights on **ImageNetV2-Test** data,
|
46 |
+
|
47 |
+
| Model | Acc@1 | Acc@5 | #Params |
|
48 |
+
|--------------|-------|-------|---------|
|
49 |
+
| GCViT-XXTiny | 0.663 | 0.873 | 12M |
|
50 |
+
| GCViT-XTiny | 0.685 | 0.885 | 20M |
|
51 |
+
| GCViT-Tiny | 0.708 | 0.899 | 28M |
|
52 |
+
| GCViT-Small | 0.720 | 0.901 | 51M |
|
53 |
+
| GCViT-Base | 0.731 | 0.907 | 90M |
|
54 |
+
| GCViT-Large | 0.734 | 0.913 | 202M |
|
55 |
+
|
56 |
+
## Installation
|
57 |
+
```bash
|
58 |
+
pip install -U gcvit
|
59 |
+
# or
|
60 |
+
# pip install -U git+https://github.com/awsaf49/gcvit-tf
|
61 |
+
```
|
62 |
+
|
63 |
+
## Usage
|
64 |
+
Load model using following codes,
|
65 |
+
```py
|
66 |
+
from gcvit import GCViTTiny
|
67 |
+
model = GCViTTiny(pretrain=True)
|
68 |
+
```
|
69 |
+
Simple code to check model's prediction,
|
70 |
+
```py
|
71 |
+
from skimage.data import chelsea
|
72 |
+
img = tf.keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='torch') # Chelsea the cat
|
73 |
+
img = tf.image.resize(img, (224, 224))[None,] # resize & create batch
|
74 |
+
pred = model(img).numpy()
|
75 |
+
print(tf.keras.applications.imagenet_utils.decode_predictions(pred)[0])
|
76 |
+
```
|
77 |
+
Prediction:
|
78 |
+
```py
|
79 |
+
[('n02124075', 'Egyptian_cat', 0.9194835),
|
80 |
+
('n02123045', 'tabby', 0.009686623),
|
81 |
+
('n02123159', 'tiger_cat', 0.0061576385),
|
82 |
+
('n02127052', 'lynx', 0.0011503297),
|
83 |
+
('n02883205', 'bow_tie', 0.00042479983)]
|
84 |
+
```
|
85 |
+
For feature extraction:
|
86 |
+
```py
|
87 |
+
model = GCViTTiny(pretrain=True) # when pretrain=True, num_classes must be 1000
|
88 |
+
model.reset_classifier(num_classes=0, head_act=None)
|
89 |
+
feature = model(img)
|
90 |
+
print(feature.shape)
|
91 |
+
```
|
92 |
+
Feature:
|
93 |
+
```py
|
94 |
+
(None, 512)
|
95 |
+
```
|
96 |
+
For feature map:
|
97 |
+
```py
|
98 |
+
model = GCViTTiny(pretrain=True) # when pretrain=True, num_classes must be 1000
|
99 |
+
feature = model.forward_features(img)
|
100 |
+
print(feature.shape)
|
101 |
+
```
|
102 |
+
Feature map:
|
103 |
+
```py
|
104 |
+
(None, 7, 7, 512)
|
105 |
+
```
|
106 |
+
|
107 |
+
## Live-Demo
|
108 |
+
* For live demo on Image Classification & Grad-CAM, with **ImageNet** weights, click <a target="_blank" href="https://huggingface.co/spaces/awsaf49/gcvit-tf"><img src="https://img.shields.io/badge/Try%20on-Gradio-orange"></a> powered by π€ Space and Gradio. here's an example,
|
109 |
+
|
110 |
+
<a href="https://huggingface.co/spaces/awsaf49/gcvit-tf"><img src="image/gradio_demo.JPG" height=500></a>
|
111 |
+
|
112 |
+
## Example
|
113 |
+
For working training example checkout these notebooks on **Google Colab** <a href="https://colab.research.google.com/github/awsaf49/gcvit-tf/blob/main/notebooks/GCViT_Flower_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> & **Kaggle** <a href="https://www.kaggle.com/awsaf49/flower-classification-gcvit-global-context-vit"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open In Kaggle"></a>.
|
114 |
+
|
115 |
+
Here is grad-cam result after training on Flower Classification Dataset,
|
116 |
+
|
117 |
+
<img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/flower_gradcam.PNG" height=500>
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
## To Do
|
122 |
+
- [ ] Segmentation Pipeline
|
123 |
+
- [x] New updated weights have been added.
|
124 |
+
- [x] Working training example in Colab & Kaggle.
|
125 |
+
- [x] GradCAM showcase.
|
126 |
+
- [x] Gradio Demo.
|
127 |
+
- [x] Build model with `tf.keras.Model`.
|
128 |
+
- [x] Port weights from official repo.
|
129 |
+
- [x] Support for `TPU`.
|
130 |
+
|
131 |
+
## Acknowledgement
|
132 |
+
* [GCVit](https://github.com/NVlabs/GCVit) (Official)
|
133 |
+
* [Swin-Transformer-TF](https://github.com/rishigami/Swin-Transformer-TF)
|
134 |
+
* [tfgcvit](https://github.com/shkarupa-alex/tfgcvit/tree/develop/tfgcvit)
|
135 |
+
* [keras_cv_attention_models](https://github.com/leondgarse/keras_cv_attention_model)
|
136 |
+
|
137 |
+
|
138 |
+
## Citation
|
139 |
+
```bibtex
|
140 |
+
@article{hatamizadeh2022global,
|
141 |
+
title={Global Context Vision Transformers},
|
142 |
+
author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
|
143 |
+
journal={arXiv preprint arXiv:2206.09959},
|
144 |
+
year={2022}
|
145 |
+
}
|
146 |
+
```
|