Daryl Lim commited on
Commit
9637988
Β·
1 Parent(s): e37a2ae

Add application file

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module provides an interface for classifying images using the ResNet-18 model.
3
+ The interface allows users to upload an image and receive the top 3 predicted labels.
4
+ """
5
+
6
+ import gradio as gr
7
+ import spaces
8
+ import torch
9
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
10
+ from datasets import load_dataset
11
+ from PIL import Image
12
+
13
+ # Load dataset and get test image
14
+ dataset = load_dataset("huggingface/cats-image", trust_remote_code=True)
15
+ test_image = dataset["test"]["image"][0]
16
+
17
+ # Initialize the image processor and model
18
+ image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-18")
19
+ model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-18").to("cuda")
20
+
21
+ @spaces.GPU
22
+ def predict(image: Image, top_k: int = 3) -> dict:
23
+ """
24
+ Predicts the top 'top_k' labels for an image using the ResNet-18 model.
25
+
26
+ Args:
27
+ image (Image): The input image as a PIL Image object.
28
+ top_k (int): The number of top predictions to return.
29
+
30
+ Returns:
31
+ dict: A dictionary with the top 'top_k' labels and their probabilities.
32
+ """
33
+ inputs = image_processor(image, return_tensors="pt").to("cuda")
34
+ with torch.no_grad():
35
+ logits = model(**inputs).logits
36
+
37
+ # Apply softmax to logits to get probabilities
38
+ probabilities = torch.softmax(logits, dim=-1)
39
+
40
+ # Get the top 'top_k' probabilities and their corresponding indices
41
+ top_k_probs, top_k_indices = torch.topk(input=probabilities, k=top_k, dim=-1)
42
+
43
+ # Map the indices to labels and probabilities
44
+ predicted_labels = {
45
+ model.config.id2label[idx.item()]: prob.item()
46
+ for idx, prob in zip(top_k_indices[0], top_k_probs[0])
47
+ }
48
+
49
+ return predicted_labels
50
+
51
+ # Define the Gradio interface
52
+ demo = gr.Interface(
53
+ fn=predict,
54
+ inputs=gr.Image(type="pil"),
55
+ outputs=gr.Label(num_top_classes=3),
56
+ title="Classifying Images with ResNet-18",
57
+ description="Upload an image to predict the top 3 labels.",
58
+ examples=[test_image]
59
+ )
60
+
61
+ # Launch the Gradio interface
62
+ demo.launch()