File size: 4,414 Bytes
a506b66
 
 
 
 
 
0b8cddd
1e6d8c9
 
 
 
 
d958429
1e6d8c9
 
 
 
d958429
1e6d8c9
 
 
 
 
75893a3
1e6d8c9
 
 
 
d958429
1e6d8c9
60ce7c6
 
 
 
 
 
 
 
 
d958429
1e6d8c9
d958429
 
 
1e6d8c9
60ce7c6
d958429
75893a3
d958429
 
 
60ce7c6
d958429
60ce7c6
 
 
 
d958429
60ce7c6
 
 
 
 
d958429
60ce7c6
 
d958429
60ce7c6
f0fbcbd
60ce7c6
f0fbcbd
 
 
 
 
 
60ce7c6
d958429
 
 
 
 
 
 
60ce7c6
d958429
60ce7c6
d958429
60ce7c6
d958429
60ce7c6
d958429
60ce7c6
d958429
60ce7c6
d958429
 
 
 
 
 
 
60ce7c6
d958429
 
60ce7c6
d958429
 
 
f0fbcbd
 
 
 
 
 
 
 
d958429
 
 
 
 
 
 
60ce7c6
 
1e6d8c9
 
 
0b8cddd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
---
license: other
license_name: nvclv1
license_link: LICENSE
datasets:
- ILSVRC/imagenet-1k
pipeline_tag: image-feature-extraction
---


[**MambaVision: A Hybrid Mamba-Transformer Vision Backbone**](https://arxiv.org/abs/2407.08083).

## Model Overview

We introduce a novel mixer block by creating a symmetric path without SSM to enhance the modeling of global context. MambaVision has a hierarchical architecture that employs both self-attention and mixer blocks. 


## Model Performance

MambaVision demonstrates a strong performance by achieving a new SOTA Pareto-front in
terms of Top-1 accuracy and throughput. 

<p align="center">
<img src="https://github.com/NVlabs/MambaVision/assets/26806394/79dcf841-3966-4b77-883d-76cd5e1d4320" width=50% height=50% 
class="center">
</p>


## Model Usage

It is highly recommended to install the requirements for MambaVision by running the following:


```Bash
pip install mambavision
```

For each model, we offer two variants for image classification and feature extraction that can be imported with 1 line of code. 

### Image Classification

In the following example, we demonstrate how MambaVision can be used for image classification. 

Given the following image from [COCO dataset](https://cocodataset.org/#home)  val set as an input:


<p align="center">
<img src="https://cdn-uploads.huggingface.co/production/uploads/64414b62603214724ebd2636/4duSnqLf4lrNiAHczSmAN.jpeg" width=50% height=50% 
class="center">
</p>


The following snippet can be used for image classification:

```Python
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests

model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)

# eval mode for inference
model.cuda().eval()

# prepare image for the model
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224)  # MambaVision supports any input resolutions

transform = create_transform(input_size=input_resolution,
                             is_training=False,
                             mean=model.config.mean,
                             std=model.config.std,
                             crop_mode=model.config.crop_mode,
                             crop_pct=model.config.crop_pct)

inputs = transform(image).unsqueeze(0).cuda()
# model inference
outputs = model(inputs)
logits = outputs['logits'] 
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
```

The predicted label is brown bear, bruin, Ursus arctos.

### Feature Extraction

MambaVision can also be used as a generic feature extractor. 

Specifically, we can extract the outputs of each stage of model (4 stages) as well as the final averaged-pool features that are flattened. 

The following snippet can be used for feature extraction:

```Python
from transformers import AutoModel
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests

model = AutoModel.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)

# eval mode for inference
model.cuda().eval()

# prepare image for the model
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224)  # MambaVision supports any input resolutions

transform = create_transform(input_size=input_resolution,
                             is_training=False,
                             mean=model.config.mean,
                             std=model.config.std,
                             crop_mode=model.config.crop_mode,
                             crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
# model inference
out_avg_pool, features = model(inputs)
print("Size of the averaged pool features:", out_avg_pool.size())  # torch.Size([1, 640])
print("Number of stages in extracted features:", len(features)) # 4 stages
print("Size of extracted features in stage 1:", features[0].size()) # torch.Size([1, 80, 56, 56])
print("Size of extracted features in stage 4:", features[3].size()) # torch.Size([1, 640, 7, 7])
```


### License: 

[NVIDIA Source Code License-NC](https://huggingface.co/nvidia/MambaVision-T-1K/blob/main/LICENSE)