Update README.md
Browse files
README.md
CHANGED
@@ -34,13 +34,69 @@ You must first login into HuggingFace to pull the model:
|
|
34 |
huggingface-cli login
|
35 |
```
|
36 |
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
```Python
|
40 |
from transformers import AutoModelForImageClassification
|
41 |
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)
|
42 |
```
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
### License:
|
46 |
|
|
|
34 |
huggingface-cli login
|
35 |
```
|
36 |
|
37 |
+
It is highly recommended to install the requirements for MambaVision by running the following:
|
38 |
+
|
39 |
+
|
40 |
+
```Bash
|
41 |
+
pip install mambavision
|
42 |
+
```
|
43 |
+
|
44 |
+
For each model, we offer two variants for image classification and feature extraction that can be imported with 1 line of code.
|
45 |
+
|
46 |
+
The model can be simply imported according to:
|
47 |
|
48 |
```Python
|
49 |
from transformers import AutoModelForImageClassification
|
50 |
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)
|
51 |
```
|
52 |
|
53 |
+
The model outputs logits when an image is passed. If label is additionally provided, cross entropy loss between the output prediction and label is computed.
|
54 |
+
|
55 |
+
The following demonstrates a minimal example of how to use the model:
|
56 |
+
|
57 |
+
|
58 |
+
```Python
|
59 |
+
from transformers import AutoModelForImageClassification
|
60 |
+
from PIL import Image
|
61 |
+
import requests
|
62 |
+
import torch
|
63 |
+
|
64 |
+
# import mambavision model
|
65 |
+
|
66 |
+
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)
|
67 |
+
|
68 |
+
# eval mode for inference
|
69 |
+
model.eval()
|
70 |
+
|
71 |
+
# prepare image for the model
|
72 |
+
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
73 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
74 |
+
|
75 |
+
# define a transform
|
76 |
+
|
77 |
+
transforms = timm.data.create_transform((3, 224, 224))
|
78 |
+
|
79 |
+
image = transforms(image).unsqueeze(0)
|
80 |
+
|
81 |
+
# put both model and image on cuda
|
82 |
+
|
83 |
+
model = model.cuda()
|
84 |
+
image = image.cuda()
|
85 |
+
|
86 |
+
# forward pass
|
87 |
+
outputs = model(image)
|
88 |
+
|
89 |
+
# You can then extract the predicted probabilities by applying softmax:
|
90 |
+
|
91 |
+
probabilities = torch.nn.functional.softmax(outputs['logits'], dim=0)
|
92 |
+
|
93 |
+
# In order to find the top 5 predicted class indexes and their corresponding values:
|
94 |
+
|
95 |
+
values, indices = torch.topk(probabilities, 5)
|
96 |
+
|
97 |
+
|
98 |
+
```
|
99 |
+
|
100 |
|
101 |
### License:
|
102 |
|