SeeFood101v1 / app.py
HangenYuu
added first version of app.py
0a2ae0a
raw
history blame
1.34 kB
import gradio as gr
import timm
import torch
import torchvision.transforms as transforms
inference_model = timm.create_model('swin_large_patch4_window7_224', pretrained=False, num_classes=101)
inference_model.load_state_dict(torch.load('model.pth'))
inference_model.eval()
with open('labels.txt', 'r') as f:
idx_to_class = [s.strip() for s in f.readlines()]
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def inference(input_image):
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
inference_model.to('cuda')
with torch.inference_mode():
output = inference_model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_prob, top5_catid = torch.topk(probabilities, 5)
# Label:probability
result = {idx_to_class[int(idx)]:val.item() for val, idx in zip(top5_prob.cpu(), top5_catid.cpu())}
return result
iface = gr.Interface(fn=inference,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5))
iface.launch()