ahatamiz commited on
Commit
d958429
·
verified ·
1 Parent(s): c846d36

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +51 -34
README.md CHANGED
@@ -10,12 +10,12 @@ pipeline_tag: image-classification
10
 
11
  [**MambaVision: A Hybrid Mamba-Transformer Vision Backbone**](https://arxiv.org/abs/2407.08083).
12
 
13
- ### Model Overview
14
 
15
  We introduce a novel mixer block by creating a symmetric path without SSM to enhance the modeling of global context. MambaVision has a hierarchical architecture that employs both self-attention and mixer blocks.
16
 
17
 
18
- ### Model Performance
19
 
20
  MambaVision demonstrates a strong performance by achieving a new SOTA Pareto-front in
21
  terms of Top-1 accuracy and throughput.
@@ -26,13 +26,7 @@ class="center">
26
  </p>
27
 
28
 
29
- ### Model Usage
30
-
31
- You must first login into HuggingFace to pull the model:
32
-
33
- ```Bash
34
- huggingface-cli login
35
- ```
36
 
37
  It is highly recommended to install the requirements for MambaVision by running the following:
38
 
@@ -43,59 +37,82 @@ pip install mambavision
43
 
44
  For each model, we offer two variants for image classification and feature extraction that can be imported with 1 line of code.
45
 
46
- The model can be simply imported according to:
47
 
48
- ```Python
49
- from transformers import AutoModelForImageClassification
50
- model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)
51
- ```
52
 
53
- The model outputs logits when an image is passed. If label is additionally provided, cross entropy loss between the output prediction and label is computed.
54
 
55
- The following demonstrates a minimal example of how to use the model:
 
 
 
 
56
 
 
57
 
58
  ```Python
59
  from transformers import AutoModelForImageClassification
60
  from PIL import Image
 
61
  import requests
62
- import torch
63
- import timm
64
-
65
- # import mambavision model
66
 
67
  model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)
68
 
69
  # eval mode for inference
70
- model.eval()
71
 
72
  # prepare image for the model
73
- url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
74
  image = Image.open(requests.get(url, stream=True).raw)
 
75
 
76
- # define a transform
77
 
78
- transforms = timm.data.create_transform((3, 224, 224))
 
 
 
 
 
 
79
 
80
- image = transforms(image).unsqueeze(0)
81
 
82
- # put both model and image on cuda
83
 
84
- model = model.cuda()
85
- image = image.cuda()
86
 
87
- # forward pass
88
- outputs = model(image)
89
 
90
- # You can then extract the predicted probabilities by applying softmax:
91
 
92
- probabilities = torch.nn.functional.softmax(outputs['logits'], dim=0)
 
 
 
 
 
 
93
 
94
- # In order to find the top 5 predicted class indexes and their corresponding values:
 
95
 
96
- values, indices = torch.topk(probabilities, 5)
 
 
 
97
 
 
98
 
 
 
 
 
 
 
 
99
  ```
100
 
101
 
 
10
 
11
  [**MambaVision: A Hybrid Mamba-Transformer Vision Backbone**](https://arxiv.org/abs/2407.08083).
12
 
13
+ ## Model Overview
14
 
15
  We introduce a novel mixer block by creating a symmetric path without SSM to enhance the modeling of global context. MambaVision has a hierarchical architecture that employs both self-attention and mixer blocks.
16
 
17
 
18
+ ## Model Performance
19
 
20
  MambaVision demonstrates a strong performance by achieving a new SOTA Pareto-front in
21
  terms of Top-1 accuracy and throughput.
 
26
  </p>
27
 
28
 
29
+ ## Model Usage
 
 
 
 
 
 
30
 
31
  It is highly recommended to install the requirements for MambaVision by running the following:
32
 
 
37
 
38
  For each model, we offer two variants for image classification and feature extraction that can be imported with 1 line of code.
39
 
40
+ ### Image Classification
41
 
42
+ In the following example, we demonstrate how MambaVision can be used for image classification.
43
+
44
+ Given the following image from [COCO dataset](https://cocodataset.org/#home) val set as an input:
 
45
 
 
46
 
47
+ <p align="center">
48
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64414b62603214724ebd2636/4duSnqLf4lrNiAHczSmAN.jpeg" width=42% height=42%
49
+ class="center">
50
+ </p>
51
+
52
 
53
+ The following snippet can be used for image classification:
54
 
55
  ```Python
56
  from transformers import AutoModelForImageClassification
57
  from PIL import Image
58
+ from timm.data.transforms_factory import create_transform
59
  import requests
 
 
 
 
60
 
61
  model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)
62
 
63
  # eval mode for inference
64
+ model.cuda().eval()
65
 
66
  # prepare image for the model
67
+ url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
68
  image = Image.open(requests.get(url, stream=True).raw)
69
+ input_resolution = (3, 224, 224) # MambaVision supports any input resolutions but has been originally trained on (3, 224, 224)
70
 
71
+ transform = create_transform(input_size=input_resolution, is_training=False, mean=model.config.mean, std=model.config.std, crop_mode=model.config.crop_mode, crop_pct=model.config.crop_pct)
72
 
73
+ inputs = transform(image).unsqueeze(0).cuda()
74
+ # model inference
75
+ outputs = model(inputs)
76
+ logits = outputs['logits']
77
+ predicted_class_idx = logits.argmax(-1).item()
78
+ print("Predicted class:", model.config.id2label[predicted_class_idx])
79
+ ```
80
 
81
+ The predicted label is brown bear, bruin, Ursus arctos.
82
 
83
+ ### Feature Extraction
84
 
85
+ MambaVision can also be used as a generic feature extractor.
 
86
 
87
+ Specifically, we can extract the outputs of each stage of model (4 stages) as well as the final averaged-pool features that are flattened.
 
88
 
89
+ The following snippet can be used for feature extraction:
90
 
91
+ ```Python
92
+ from transformers import AutoModel
93
+ from PIL import Image
94
+ from timm.data.transforms_factory import create_transform
95
+ import requests
96
+
97
+ model = AutoModel.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)
98
 
99
+ # eval mode for inference
100
+ model.cuda().eval()
101
 
102
+ # prepare image for the model
103
+ url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
104
+ image = Image.open(requests.get(url, stream=True).raw)
105
+ input_resolution = (3, 224, 224) # MambaVision supports any input resolutions but has been originally trained on (3, 224, 224)
106
 
107
+ transform = create_transform(input_size=input_resolution, is_training=False, mean=model.config.mean, std=model.config.std, crop_mode=model.config.crop_mode, crop_pct=model.config.crop_pct)
108
 
109
+ inputs = transform(image).unsqueeze(0).cuda()
110
+ # model inference
111
+ out_avg_pool, features = model(inputs)
112
+ print("Size of the averaged pool features:", out_avg_pool.size()) # torch.Size([1, 640])
113
+ print("Number of stages in extracted features:", len(features)) # 4 stages
114
+ print("Size of extracted features in stage 1:", features[0].size()) # torch.Size([1, 80, 56, 56])
115
+ print("Size of extracted features in stage 4:", features[3].size()) # torch.Size([1, 640, 7, 7])
116
  ```
117
 
118