File size: 2,131 Bytes
428957e
0a2ae0a
 
 
428957e
0a2ae0a
458d932
0a2ae0a
428957e
0a2ae0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfb4d20
 
 
 
 
 
 
 
 
0a2ae0a
 
8a62832
 
 
 
 
 
428957e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio as gr
import timm
import torch
import torchvision.transforms as transforms

inference_model = timm.create_model('swin_large_patch4_window7_224', pretrained=False, num_classes=101)
inference_model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
inference_model.eval()

with open('labels.txt', 'r') as f:
    idx_to_class = [s.strip() for s in f.readlines()]

preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

def inference(input_image):
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)

    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        inference_model.to('cuda')
    
    with torch.inference_mode():
        output = inference_model(input_batch)

    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    top5_prob, top5_catid = torch.topk(probabilities, 5)
    # Label:probability
    result = {idx_to_class[int(idx)]:val.item() for val, idx in zip(top5_prob.cpu(), top5_catid.cpu())}
    
    return result


title = "See Food 101"
description = "Gradio demo for See Food 101, the expansion edition of See Food from Silicon Valley. Simply upload your image, or click on the example(s) to load them. Read more at the links below for architecture used."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2103.14030'>Swin Transformer: Hierarchical Vision Transformer using Shifted Windows</a> | <a href='https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/'>Data</a></p>"

examples = [
            ['Screenshot 2023-05-05 085533.png']
]
iface = gr.Interface(fn=inference,
                     inputs=gr.Image(type="pil"),
                     outputs=gr.Label(num_top_classes=5),
                     title=title,
                     description=description,
                     article=article,
                     examples=examples,
                     analytics_enabled=False)
iface.launch()