awsaf49 commited on
Commit
ed43681
Β·
1 Parent(s): 4092407

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +146 -13
README.md CHANGED
@@ -1,13 +1,146 @@
1
- ---
2
- title: Gcvit Tf
3
- emoji: πŸ“ˆ
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 3.1.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: 'GCViT: Global Context Vision Transformer'
3
+ colorFrom: indigo
4
+ ---
5
+ <h1 align="center">
6
+ <p><a href='https://arxiv.org/pdf/2206.09959v1.pdf'>GCViT: Global Context Vision Transformer</a></p>
7
+ </h1>
8
+ <div align=center><img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_arch.PNG" width=800></div>
9
+ <p align="center">
10
+ <a href="https://github.com/awsaf49/gcvit-tf/blob/main/LICENSE.md">
11
+ <img src="https://img.shields.io/badge/License-MIT-yellow.svg">
12
+ </a>
13
+ <img alt="python" src="https://img.shields.io/badge/python-%3E%3D3.6-blue?logo=python">
14
+ <img alt="tensorflow" src="https://img.shields.io/badge/tensorflow-%3E%3D2.4.1-orange?logo=tensorflow">
15
+ <div align=center><p>
16
+ <a target="_blank" href="https://huggingface.co/spaces/awsaf49/gcvit-tf"><img src="https://img.shields.io/badge/πŸ€—%20Hugging%20Face-Spaces-yellow.svg"></a>
17
+ <a href="https://colab.research.google.com/github/awsaf49/gcvit-tf/blob/main/notebooks/GCViT_Flower_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
18
+ <a href="https://www.kaggle.com/awsaf49/flower-classification-gcvit-global-context-vit"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open In Kaggle"></a>
19
+ </p></div>
20
+ <h2 align="center">
21
+ <p>Tensorflow 2.0 Implementation of GCViT</p>
22
+ </h2>
23
+ </p>
24
+ <p align="center">
25
+ This library implements <b>GCViT</b> using Tensorflow 2.0 specifically in <code>tf.keras.Model</code> manner to get PyTorch flavor.
26
+ </p>
27
+
28
+ ## Update
29
+ * **15 Jan 2023** : `GCViTLarge` model added with ckpt.
30
+ * **3 Sept 2022** : Annotated [kaggle-notebook](https://www.kaggle.com/code/awsaf49/gcvit-global-context-vision-transformer) based on this project won [Kaggle ML Research Spotlight: August 2022](https://www.kaggle.com/discussions/general/349817).
31
+ * **19 Aug 2022** : This project got acknowledged by [Official](https://github.com/NVlabs/GCVit) repo [here](https://github.com/NVlabs/GCVit#third-party-implementations-and-resources)
32
+
33
+ ## Model
34
+ * Architecture:
35
+
36
+ <img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/arch.PNG">
37
+
38
+ * Local Vs Global Attention:
39
+
40
+ <img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_msa.PNG">
41
+
42
+ ## Result
43
+ <img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/result.PNG" width=900>
44
+
45
+ Official codebase had some issue which has been fixed recently (12 August 2022). Here's the result of ported weights on **ImageNetV2-Test** data,
46
+
47
+ | Model | Acc@1 | Acc@5 | #Params |
48
+ |--------------|-------|-------|---------|
49
+ | GCViT-XXTiny | 0.663 | 0.873 | 12M |
50
+ | GCViT-XTiny | 0.685 | 0.885 | 20M |
51
+ | GCViT-Tiny | 0.708 | 0.899 | 28M |
52
+ | GCViT-Small | 0.720 | 0.901 | 51M |
53
+ | GCViT-Base | 0.731 | 0.907 | 90M |
54
+ | GCViT-Large | 0.734 | 0.913 | 202M |
55
+
56
+ ## Installation
57
+ ```bash
58
+ pip install -U gcvit
59
+ # or
60
+ # pip install -U git+https://github.com/awsaf49/gcvit-tf
61
+ ```
62
+
63
+ ## Usage
64
+ Load model using following codes,
65
+ ```py
66
+ from gcvit import GCViTTiny
67
+ model = GCViTTiny(pretrain=True)
68
+ ```
69
+ Simple code to check model's prediction,
70
+ ```py
71
+ from skimage.data import chelsea
72
+ img = tf.keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='torch') # Chelsea the cat
73
+ img = tf.image.resize(img, (224, 224))[None,] # resize & create batch
74
+ pred = model(img).numpy()
75
+ print(tf.keras.applications.imagenet_utils.decode_predictions(pred)[0])
76
+ ```
77
+ Prediction:
78
+ ```py
79
+ [('n02124075', 'Egyptian_cat', 0.9194835),
80
+ ('n02123045', 'tabby', 0.009686623),
81
+ ('n02123159', 'tiger_cat', 0.0061576385),
82
+ ('n02127052', 'lynx', 0.0011503297),
83
+ ('n02883205', 'bow_tie', 0.00042479983)]
84
+ ```
85
+ For feature extraction:
86
+ ```py
87
+ model = GCViTTiny(pretrain=True) # when pretrain=True, num_classes must be 1000
88
+ model.reset_classifier(num_classes=0, head_act=None)
89
+ feature = model(img)
90
+ print(feature.shape)
91
+ ```
92
+ Feature:
93
+ ```py
94
+ (None, 512)
95
+ ```
96
+ For feature map:
97
+ ```py
98
+ model = GCViTTiny(pretrain=True) # when pretrain=True, num_classes must be 1000
99
+ feature = model.forward_features(img)
100
+ print(feature.shape)
101
+ ```
102
+ Feature map:
103
+ ```py
104
+ (None, 7, 7, 512)
105
+ ```
106
+
107
+ ## Live-Demo
108
+ * For live demo on Image Classification & Grad-CAM, with **ImageNet** weights, click <a target="_blank" href="https://huggingface.co/spaces/awsaf49/gcvit-tf"><img src="https://img.shields.io/badge/Try%20on-Gradio-orange"></a> powered by πŸ€— Space and Gradio. here's an example,
109
+
110
+ <a href="https://huggingface.co/spaces/awsaf49/gcvit-tf"><img src="image/gradio_demo.JPG" height=500></a>
111
+
112
+ ## Example
113
+ For working training example checkout these notebooks on **Google Colab** <a href="https://colab.research.google.com/github/awsaf49/gcvit-tf/blob/main/notebooks/GCViT_Flower_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> & **Kaggle** <a href="https://www.kaggle.com/awsaf49/flower-classification-gcvit-global-context-vit"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open In Kaggle"></a>.
114
+
115
+ Here is grad-cam result after training on Flower Classification Dataset,
116
+
117
+ <img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/flower_gradcam.PNG" height=500>
118
+
119
+
120
+
121
+ ## To Do
122
+ - [ ] Segmentation Pipeline
123
+ - [x] New updated weights have been added.
124
+ - [x] Working training example in Colab & Kaggle.
125
+ - [x] GradCAM showcase.
126
+ - [x] Gradio Demo.
127
+ - [x] Build model with `tf.keras.Model`.
128
+ - [x] Port weights from official repo.
129
+ - [x] Support for `TPU`.
130
+
131
+ ## Acknowledgement
132
+ * [GCVit](https://github.com/NVlabs/GCVit) (Official)
133
+ * [Swin-Transformer-TF](https://github.com/rishigami/Swin-Transformer-TF)
134
+ * [tfgcvit](https://github.com/shkarupa-alex/tfgcvit/tree/develop/tfgcvit)
135
+ * [keras_cv_attention_models](https://github.com/leondgarse/keras_cv_attention_model)
136
+
137
+
138
+ ## Citation
139
+ ```bibtex
140
+ @article{hatamizadeh2022global,
141
+ title={Global Context Vision Transformers},
142
+ author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
143
+ journal={arXiv preprint arXiv:2206.09959},
144
+ year={2022}
145
+ }
146
+ ```