Spaces:
Runtime error
Runtime error
File size: 2,131 Bytes
428957e 0a2ae0a 428957e 0a2ae0a 458d932 0a2ae0a 428957e 0a2ae0a bfb4d20 0a2ae0a 8a62832 428957e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import gradio as gr
import timm
import torch
import torchvision.transforms as transforms
inference_model = timm.create_model('swin_large_patch4_window7_224', pretrained=False, num_classes=101)
inference_model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
inference_model.eval()
with open('labels.txt', 'r') as f:
idx_to_class = [s.strip() for s in f.readlines()]
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def inference(input_image):
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
inference_model.to('cuda')
with torch.inference_mode():
output = inference_model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_prob, top5_catid = torch.topk(probabilities, 5)
# Label:probability
result = {idx_to_class[int(idx)]:val.item() for val, idx in zip(top5_prob.cpu(), top5_catid.cpu())}
return result
title = "See Food 101"
description = "Gradio demo for See Food 101, the expansion edition of See Food from Silicon Valley. Simply upload your image, or click on the example(s) to load them. Read more at the links below for architecture used."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2103.14030'>Swin Transformer: Hierarchical Vision Transformer using Shifted Windows</a> | <a href='https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/'>Data</a></p>"
examples = [
['Screenshot 2023-05-05 085533.png']
]
iface = gr.Interface(fn=inference,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title=title,
description=description,
article=article,
examples=examples,
analytics_enabled=False)
iface.launch() |