import gradio as gr from transformers import AutoFeatureExtractor, AutoModelForImageClassification from PIL import Image # Load model and feature extractor model_name = "google/vit-base-patch16-224" model = AutoModelForImageClassification.from_pretrained(model_name) feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) # Define the prediction function def classify_image(image): inputs = feature_extractor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() label = model.config.id2label[predicted_class_idx] return f"Predicted Class: {label}" # Create Gradio interface interface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs="text", title="Image Classification App", description="Upload an image to classify it using the Vision Transformer model.", ) # Launch the app if __name__ == "__main__": interface.launch()