import torch from torch import nn import torchvision import gradio as gr # Define and load my resnet50 model model = torchvision.models.resnet50() num_ftrs = model.fc.in_features model.fc = nn.Sequential( # Add dropout layer with 50% probability nn.Dropout(0.5), # Add a linear layer in order to deal with 5 classes nn.Linear(num_ftrs, 5), ) model.load_state_dict( torch.load("model/final_model_state_dict.pth", map_location=torch.device("cpu")) ) model.eval() # Define the labels labels = ["bird", "cat", "dog", "horse", "sheep"] # Define the predict function def predict(inp): inp = torchvision.transforms.ToTensor()(inp).unsqueeze(0) with torch.no_grad(): prediction = model(inp) # Map prediction to label prediction = labels[prediction.argmax()] return prediction # Define the gradio interface interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), examples=[ ["input_imgs/bird.jpeg"], ["input_imgs/cat.jpeg"], ["input_imgs/dog.jpeg"], ["input_imgs/horse.jpeg"], ["input_imgs/sheep.jpeg"], ], title="Image Object Classifier", description="This is a demo of a resnet50 model trained on COCO dataset, which can classify 5 classes: bird, cat, dog, horse, sheep.", ) if __name__ == "__main__": interface.launch()