File size: 4,284 Bytes
699342a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import gradio as gr
import numpy as np
import torch
from typing import Tuple, Optional, Dict, List
import glob
from collections import defaultdict

from transformers import (AutoImageProcessor,
                          ResNetForImageClassification)

from labelmap import DR_LABELMAP


class App:
    def __init__(self) -> None:

        ckpt_name = "2023-12-24_20-02-18_30345221_V100_x4_resnet34/"

        path = f"release_ckpts/{ckpt_name}/inference/"

        self.image_processor = AutoImageProcessor.from_pretrained(path)

        self.model = ResNetForImageClassification.from_pretrained(path)

        example_lists = self._load_example_lists()

        device = 'GPU' if torch.cuda.is_available() else 'CPU'

        css = ".output-image, .input-image, .image-preview {height: 600px !important}"

        with gr.Blocks(css=css) as ui:
            with gr.Row():
                with gr.Column(scale=1):
                    with gr.Row():
                        predict_btn = gr.Button("Predict", size="lg")
                    with gr.Row():
                        gr.Markdown(f"Running on {device}")
                with gr.Column(scale=4):
                    # output = gr.Textbox(label="Retinopathy level prediction")
                    output = gr.Label(num_top_classes=len(DR_LABELMAP),
                                      label="Retinopathy level prediction")
                with gr.Column(scale=4):
                    gr.Markdown("![](https://media.githubusercontent.com/media/Obs01ete/retinopathy/master/media/logo1.png)")
            with gr.Row():
                with gr.Column(scale=9, min_width=100):
                    image = gr.Image(label="Retina scan")
                with gr.Column(scale=1, min_width=150):
                    for cls_id in range(len(example_lists)):
                        label = DR_LABELMAP[cls_id]
                        with gr.Tab(f"{cls_id} : {label}"):
                            gr.Examples(
                                example_lists[cls_id],
                                inputs=[image],
                                outputs=[output],
                                fn=self.predict,
                                examples_per_page=10,
                                run_on_click=True)

            predict_btn.click(
                fn=self.predict,
                inputs=image,
                outputs=output,
                api_name="predict")
        
        self.ui = ui

    def launch(self) -> None:
        self.ui.queue().launch(share=True)

    def predict(self, image: Optional[np.ndarray]):
        if image is None:
            return dict()
        cls_name, prob, probs = self._infer(image)
        message = f"Predicted class={cls_name}, prob={prob:.3f}"
        print(message)
        probs_dict = {f"{i} - {DR_LABELMAP[i]}": float(v)
                      for i, v in enumerate(probs)}
        return probs_dict
    
    def _infer(self, image_chw: np.ndarray) -> Tuple[str, float, np.ndarray]:
        assert isinstance(self.model, ResNetForImageClassification)

        inputs = self.image_processor(image_chw, return_tensors="pt")

        with torch.no_grad():
            output = self.model(**inputs)

        logits_batch = output.logits
        assert len(logits_batch.shape) == 2
        assert logits_batch.shape[0] == 1
        logits = logits_batch[0]
        probs = torch.softmax(logits, dim=-1)
        predicted_label = int(probs.argmax(-1).item())
        prob = probs[predicted_label].item()
        cls_name = self.model.config.id2label[predicted_label]
        return cls_name, prob, probs.numpy()

    @staticmethod
    def _load_example_lists() -> Dict[int, List[str]]:

        example_flat_list = glob.glob("demo_data/train/**/*.jpeg")

        example_lists: Dict[int, List[str]] = defaultdict(list)
        for path in example_flat_list:
            dir, _ = os.path.split(path)
            _, subdir = os.path.split(dir)
            try:
                cls_id = int(subdir)
            except ValueError:
                print(f"Cannot parse path {path}")
                continue
            example_lists[cls_id].append(path)
        return example_lists


def main():
    app = App()
    app.launch()


if __name__ == "__main__":
    main()