File size: 1,958 Bytes
2c0f741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87fd603
2c0f741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a97fe9
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
---
license: mit
---
# Model Card for Model ViT fine tuning on CiFAR10

<!-- Provide a quick summary of what the model is/does. -->

It's a toy experiemnt of fine tuning ViT by using huggingface transformers.

## Model Details

It's fine tuned on CiFAR10 for 1000 steps, and achieved accuracy of 98.7% on test split.

### Model Description

<!-- Provide a longer summary of what this model is. -->



- **Developed by:** verypro
- **Model type:** Vision Transformer
- **License:** MIT
- **Finetuned from model [optional]:** google/vit-base-patch16-224

## Uses

<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->

```python
from transformers import ViTImageProcessor, ViTForImageClassification
from torchvision import datasets

# # 初始化模型和特征提取器
image_processor = ViTImageProcessor.from_pretrained('verypro/vit-base-patch16-224-cifar10')
model = ViTForImageClassification.from_pretrained('verypro/vit-base-patch16-224-cifar10')


# 加载 CIFAR10 数据集
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True)

sample = test_dataset[0]
image = sample[0]
gt_label = sample[1]

# 保存原始图像,并打印其标签
image.save("original.png")
print(f"Ground truth class: '{test_dataset.classes[gt_label]}'")

inputs = image_processor(image, return_tensors="pt")
outputs = model(**inputs)

logits = outputs.logits
print(logits)

predicted_class_idx = logits.argmax(-1).item()
predicted_class_label = test_dataset.classes[predicted_class_idx]
print(f"Predicted class: '{predicted_class_label}', confidence: {logits[0, predicted_class_idx]:.2f}")
```

The output of above code snippets should be like:

```bash
Ground truth class: 'cat'
tensor([[-1.1497, -0.1080, -0.7349,  9.2517, -1.3094,  0.5403, -0.9521, -1.0223,
         -1.4102, -1.5389]], grad_fn=<AddmmBackward0>)
Predicted class: 'cat', confidence: 9.25
```