""" This module provides an interface for classifying images using the ResNet-18 model. The interface allows users to upload an image and receive the top 3 predicted labels. """ import gradio as gr import spaces import torch from transformers import AutoImageProcessor, AutoModelForImageClassification from datasets import load_dataset from PIL import Image # Load dataset and get test image dataset = load_dataset("huggingface/cats-image", trust_remote_code=True) test_image = dataset["test"]["image"][0] # Initialize the image processor and model image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-18") model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-18").to("cuda") @spaces.GPU def predict(image: Image, top_k: int = 3) -> dict: """ Predicts the top 'top_k' labels for an image using the ResNet-18 model. Args: image (Image): The input image as a PIL Image object. top_k (int): The number of top predictions to return. Returns: dict: A dictionary with the top 'top_k' labels and their probabilities. """ inputs = image_processor(image, return_tensors="pt").to("cuda") with torch.no_grad(): logits = model(**inputs).logits # Apply softmax to logits to get probabilities probabilities = torch.softmax(logits, dim=-1) # Get the top 'top_k' probabilities and their corresponding indices top_k_probs, top_k_indices = torch.topk(input=probabilities, k=top_k, dim=-1) # Map the indices to labels and probabilities predicted_labels = { model.config.id2label[idx.item()]: prob.item() for idx, prob in zip(top_k_indices[0], top_k_probs[0]) } return predicted_labels # Define the Gradio interface demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3), title="Classifying Images with ResNet-18", description="Upload an image to predict the top 3 labels.", examples=[test_image] ) # Launch the Gradio interface demo.launch()