Spaces:
Runtime error
Runtime error
import gradio as gr | |
import timm | |
import torch | |
import torchvision.transforms as transforms | |
inference_model = timm.create_model('swin_large_patch4_window7_224', pretrained=False, num_classes=101) | |
inference_model.load_state_dict(torch.load('model.pth')) | |
inference_model.eval() | |
with open('labels.txt', 'r') as f: | |
idx_to_class = [s.strip() for s in f.readlines()] | |
preprocess = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
def inference(input_image): | |
input_tensor = preprocess(input_image) | |
input_batch = input_tensor.unsqueeze(0) | |
if torch.cuda.is_available(): | |
input_batch = input_batch.to('cuda') | |
inference_model.to('cuda') | |
with torch.inference_mode(): | |
output = inference_model(input_batch) | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
top5_prob, top5_catid = torch.topk(probabilities, 5) | |
# Label:probability | |
result = {idx_to_class[int(idx)]:val.item() for val, idx in zip(top5_prob.cpu(), top5_catid.cpu())} | |
return result | |
iface = gr.Interface(fn=inference, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=5)) | |
iface.launch() |