File size: 1,337 Bytes
428957e
0a2ae0a
 
 
428957e
0a2ae0a
 
 
428957e
0a2ae0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428957e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import gradio as gr
import timm
import torch
import torchvision.transforms as transforms

inference_model = timm.create_model('swin_large_patch4_window7_224', pretrained=False, num_classes=101)
inference_model.load_state_dict(torch.load('model.pth'))
inference_model.eval()

with open('labels.txt', 'r') as f:
    idx_to_class = [s.strip() for s in f.readlines()]

preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

def inference(input_image):
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)

    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        inference_model.to('cuda')
    
    with torch.inference_mode():
        output = inference_model(input_batch)

    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    top5_prob, top5_catid = torch.topk(probabilities, 5)
    # Label:probability
    result = {idx_to_class[int(idx)]:val.item() for val, idx in zip(top5_prob.cpu(), top5_catid.cpu())}
    
    return result
iface = gr.Interface(fn=inference,
                     inputs=gr.Image(type="pil"),
                     outputs=gr.Label(num_top_classes=5))
iface.launch()