mstftmk commited on
Commit
2f3d409
·
1 Parent(s): dc28ae7

Add all files

Browse files
Files changed (3) hide show
  1. README.md +16 -3
  2. app.py +30 -0
  3. requirements.txt +4 -0
README.md CHANGED
@@ -1,13 +1,26 @@
1
  ---
2
- title: ImageClassification With ViT
3
- emoji: 🦀
4
  colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.11.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Image Classification with ViT
3
+ emoji: 🖼️
4
  colorFrom: yellow
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.11.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ tags:
12
+ - image-classification
13
+ - vision
14
+ - transformers
15
+ - vit
16
+ - deep-learning
17
+ - gradio
18
+ datasets:
19
+ - imagenet-1k
20
+ - cifar10
21
+ - cifar100
22
+ models:
23
+ - google/vit-base-patch16-224
24
  ---
25
 
26
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
3
+ from PIL import Image
4
+
5
+ # Load model and feature extractor
6
+ model_name = "google/vit-base-patch16-224"
7
+ model = AutoModelForImageClassification.from_pretrained(model_name)
8
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
9
+
10
+ # Define the prediction function
11
+ def classify_image(image):
12
+ inputs = feature_extractor(images=image, return_tensors="pt")
13
+ outputs = model(**inputs)
14
+ logits = outputs.logits
15
+ predicted_class_idx = logits.argmax(-1).item()
16
+ label = model.config.id2label[predicted_class_idx]
17
+ return f"Predicted Class: {label}"
18
+
19
+ # Create Gradio interface
20
+ interface = gr.Interface(
21
+ fn=classify_image,
22
+ inputs=gr.Image(type="pil"),
23
+ outputs="text",
24
+ title="Image Classification App",
25
+ description="Upload an image to classify it using the Vision Transformer model.",
26
+ )
27
+
28
+ # Launch the app
29
+ if __name__ == "__main__":
30
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ transformers
4
+ pillow