Spaces:
Runtime error
Runtime error
File size: 1,337 Bytes
428957e 0a2ae0a 428957e 0a2ae0a 428957e 0a2ae0a 428957e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import gradio as gr
import timm
import torch
import torchvision.transforms as transforms
inference_model = timm.create_model('swin_large_patch4_window7_224', pretrained=False, num_classes=101)
inference_model.load_state_dict(torch.load('model.pth'))
inference_model.eval()
with open('labels.txt', 'r') as f:
idx_to_class = [s.strip() for s in f.readlines()]
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def inference(input_image):
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
inference_model.to('cuda')
with torch.inference_mode():
output = inference_model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_prob, top5_catid = torch.topk(probabilities, 5)
# Label:probability
result = {idx_to_class[int(idx)]:val.item() for val, idx in zip(top5_prob.cpu(), top5_catid.cpu())}
return result
iface = gr.Interface(fn=inference,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5))
iface.launch() |