verypro commited on
Commit
2c0f741
·
1 Parent(s): 9d28229

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +58 -0
README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+ # Model Card for Model ViT fine tuning on CiFAR10
5
+
6
+ <!-- Provide a quick summary of what the model is/does. -->
7
+
8
+ It's a toy experiemnt of fine tuning ViT by using huggingface transformers.
9
+
10
+ ## Model Details
11
+
12
+ It's fine tuned on CiFAR10 for 1000 steps, and achieved accuracy of 98.7% on test split.
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** verypro
21
+ - **Model type:** Vision Transformer
22
+ - **License:** MIT
23
+ - **Finetuned from model [optional]:** google/vit-base-patch16-224
24
+
25
+ ## Uses
26
+
27
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
28
+
29
+ ```
30
+ from transformers import ViTImageProcessor, ViTForImageClassification
31
+ from torchvision import datasets
32
+
33
+ # # 初始化模型和特征提取器
34
+ image_processor = ViTImageProcessor.from_pretrained('verypro/vit-base-patch16-224-cifar10')
35
+ model = ViTForImageClassification.from_pretrained('verypro/vit-base-patch16-224-cifar10')
36
+
37
+
38
+ # 加载 CIFAR10 数据集
39
+ test_dataset = datasets.CIFAR10(root='./data', train=False, download=True)
40
+
41
+ sample = test_dataset[0]
42
+ image = sample[0]
43
+ gt_label = sample[1]
44
+
45
+ # 保存原始图像,并打印其标签
46
+ image.save("original.png")
47
+ print(f"Ground truth class: '{test_dataset.classes[gt_label]}'")
48
+
49
+ inputs = image_processor(image, return_tensors="pt")
50
+ outputs = model(**inputs)
51
+
52
+ logits = outputs.logits
53
+ print(logits)
54
+
55
+ predicted_class_idx = logits.argmax(-1).item()
56
+ predicted_class_label = test_dataset.classes[predicted_class_idx]
57
+ print(f"Predicted class: '{predicted_class_label}', confidence: {logits[0, predicted_class_idx]:.2f}")
58
+ ```