Update README.md
Browse files
README.md
CHANGED
@@ -10,12 +10,12 @@ pipeline_tag: image-classification
|
|
10 |
|
11 |
[**MambaVision: A Hybrid Mamba-Transformer Vision Backbone**](https://arxiv.org/abs/2407.08083).
|
12 |
|
13 |
-
|
14 |
|
15 |
We introduce a novel mixer block by creating a symmetric path without SSM to enhance the modeling of global context. MambaVision has a hierarchical architecture that employs both self-attention and mixer blocks.
|
16 |
|
17 |
|
18 |
-
|
19 |
|
20 |
MambaVision demonstrates a strong performance by achieving a new SOTA Pareto-front in
|
21 |
terms of Top-1 accuracy and throughput.
|
@@ -26,13 +26,7 @@ class="center">
|
|
26 |
</p>
|
27 |
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
You must first login into HuggingFace to pull the model:
|
32 |
-
|
33 |
-
```Bash
|
34 |
-
huggingface-cli login
|
35 |
-
```
|
36 |
|
37 |
It is highly recommended to install the requirements for MambaVision by running the following:
|
38 |
|
@@ -43,59 +37,82 @@ pip install mambavision
|
|
43 |
|
44 |
For each model, we offer two variants for image classification and feature extraction that can be imported with 1 line of code.
|
45 |
|
46 |
-
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
```
|
52 |
|
53 |
-
The model outputs logits when an image is passed. If label is additionally provided, cross entropy loss between the output prediction and label is computed.
|
54 |
|
55 |
-
|
|
|
|
|
|
|
|
|
56 |
|
|
|
57 |
|
58 |
```Python
|
59 |
from transformers import AutoModelForImageClassification
|
60 |
from PIL import Image
|
|
|
61 |
import requests
|
62 |
-
import torch
|
63 |
-
import timm
|
64 |
-
|
65 |
-
# import mambavision model
|
66 |
|
67 |
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)
|
68 |
|
69 |
# eval mode for inference
|
70 |
-
model.eval()
|
71 |
|
72 |
# prepare image for the model
|
73 |
-
url = 'http://images.cocodataset.org/val2017/
|
74 |
image = Image.open(requests.get(url, stream=True).raw)
|
|
|
75 |
|
76 |
-
|
77 |
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
|
81 |
|
82 |
-
|
83 |
|
84 |
-
|
85 |
-
image = image.cuda()
|
86 |
|
87 |
-
|
88 |
-
outputs = model(image)
|
89 |
|
90 |
-
|
91 |
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
-
#
|
|
|
95 |
|
96 |
-
|
|
|
|
|
|
|
97 |
|
|
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
```
|
100 |
|
101 |
|
|
|
10 |
|
11 |
[**MambaVision: A Hybrid Mamba-Transformer Vision Backbone**](https://arxiv.org/abs/2407.08083).
|
12 |
|
13 |
+
## Model Overview
|
14 |
|
15 |
We introduce a novel mixer block by creating a symmetric path without SSM to enhance the modeling of global context. MambaVision has a hierarchical architecture that employs both self-attention and mixer blocks.
|
16 |
|
17 |
|
18 |
+
## Model Performance
|
19 |
|
20 |
MambaVision demonstrates a strong performance by achieving a new SOTA Pareto-front in
|
21 |
terms of Top-1 accuracy and throughput.
|
|
|
26 |
</p>
|
27 |
|
28 |
|
29 |
+
## Model Usage
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
It is highly recommended to install the requirements for MambaVision by running the following:
|
32 |
|
|
|
37 |
|
38 |
For each model, we offer two variants for image classification and feature extraction that can be imported with 1 line of code.
|
39 |
|
40 |
+
### Image Classification
|
41 |
|
42 |
+
In the following example, we demonstrate how MambaVision can be used for image classification.
|
43 |
+
|
44 |
+
Given the following image from [COCO dataset](https://cocodataset.org/#home) val set as an input:
|
|
|
45 |
|
|
|
46 |
|
47 |
+
<p align="center">
|
48 |
+
<img src="https://cdn-uploads.huggingface.co/production/uploads/64414b62603214724ebd2636/4duSnqLf4lrNiAHczSmAN.jpeg" width=42% height=42%
|
49 |
+
class="center">
|
50 |
+
</p>
|
51 |
+
|
52 |
|
53 |
+
The following snippet can be used for image classification:
|
54 |
|
55 |
```Python
|
56 |
from transformers import AutoModelForImageClassification
|
57 |
from PIL import Image
|
58 |
+
from timm.data.transforms_factory import create_transform
|
59 |
import requests
|
|
|
|
|
|
|
|
|
60 |
|
61 |
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)
|
62 |
|
63 |
# eval mode for inference
|
64 |
+
model.cuda().eval()
|
65 |
|
66 |
# prepare image for the model
|
67 |
+
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
|
68 |
image = Image.open(requests.get(url, stream=True).raw)
|
69 |
+
input_resolution = (3, 224, 224) # MambaVision supports any input resolutions but has been originally trained on (3, 224, 224)
|
70 |
|
71 |
+
transform = create_transform(input_size=input_resolution, is_training=False, mean=model.config.mean, std=model.config.std, crop_mode=model.config.crop_mode, crop_pct=model.config.crop_pct)
|
72 |
|
73 |
+
inputs = transform(image).unsqueeze(0).cuda()
|
74 |
+
# model inference
|
75 |
+
outputs = model(inputs)
|
76 |
+
logits = outputs['logits']
|
77 |
+
predicted_class_idx = logits.argmax(-1).item()
|
78 |
+
print("Predicted class:", model.config.id2label[predicted_class_idx])
|
79 |
+
```
|
80 |
|
81 |
+
The predicted label is brown bear, bruin, Ursus arctos.
|
82 |
|
83 |
+
### Feature Extraction
|
84 |
|
85 |
+
MambaVision can also be used as a generic feature extractor.
|
|
|
86 |
|
87 |
+
Specifically, we can extract the outputs of each stage of model (4 stages) as well as the final averaged-pool features that are flattened.
|
|
|
88 |
|
89 |
+
The following snippet can be used for feature extraction:
|
90 |
|
91 |
+
```Python
|
92 |
+
from transformers import AutoModel
|
93 |
+
from PIL import Image
|
94 |
+
from timm.data.transforms_factory import create_transform
|
95 |
+
import requests
|
96 |
+
|
97 |
+
model = AutoModel.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)
|
98 |
|
99 |
+
# eval mode for inference
|
100 |
+
model.cuda().eval()
|
101 |
|
102 |
+
# prepare image for the model
|
103 |
+
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
|
104 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
105 |
+
input_resolution = (3, 224, 224) # MambaVision supports any input resolutions but has been originally trained on (3, 224, 224)
|
106 |
|
107 |
+
transform = create_transform(input_size=input_resolution, is_training=False, mean=model.config.mean, std=model.config.std, crop_mode=model.config.crop_mode, crop_pct=model.config.crop_pct)
|
108 |
|
109 |
+
inputs = transform(image).unsqueeze(0).cuda()
|
110 |
+
# model inference
|
111 |
+
out_avg_pool, features = model(inputs)
|
112 |
+
print("Size of the averaged pool features:", out_avg_pool.size()) # torch.Size([1, 640])
|
113 |
+
print("Number of stages in extracted features:", len(features)) # 4 stages
|
114 |
+
print("Size of extracted features in stage 1:", features[0].size()) # torch.Size([1, 80, 56, 56])
|
115 |
+
print("Size of extracted features in stage 4:", features[3].size()) # torch.Size([1, 640, 7, 7])
|
116 |
```
|
117 |
|
118 |
|