import torch import torchvision.transforms as transforms from PIL import Image import gradio as gr from resnet_model import ResNet50 from utils import load_checkpoint import ast # Load the model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = ResNet50() model = torch.nn.DataParallel(model) model = model.to(device) # Load the checkpoint checkpoint_path = "checkpoint.pth" model, _, _, _ = load_checkpoint(model, None, checkpoint_path) model.eval() # Define the image transformation transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Load class labels from the file with open("imagenet1000_clsidx_to_labels.txt") as f: class_labels = ast.literal_eval(f.read()) # Define the prediction function def predict(image): image = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(image) probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] top5_prob, top5_catid = torch.topk(probabilities, 5) results = "