File size: 4,284 Bytes
699342a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
import gradio as gr
import numpy as np
import torch
from typing import Tuple, Optional, Dict, List
import glob
from collections import defaultdict
from transformers import (AutoImageProcessor,
ResNetForImageClassification)
from labelmap import DR_LABELMAP
class App:
def __init__(self) -> None:
ckpt_name = "2023-12-24_20-02-18_30345221_V100_x4_resnet34/"
path = f"release_ckpts/{ckpt_name}/inference/"
self.image_processor = AutoImageProcessor.from_pretrained(path)
self.model = ResNetForImageClassification.from_pretrained(path)
example_lists = self._load_example_lists()
device = 'GPU' if torch.cuda.is_available() else 'CPU'
css = ".output-image, .input-image, .image-preview {height: 600px !important}"
with gr.Blocks(css=css) as ui:
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
predict_btn = gr.Button("Predict", size="lg")
with gr.Row():
gr.Markdown(f"Running on {device}")
with gr.Column(scale=4):
# output = gr.Textbox(label="Retinopathy level prediction")
output = gr.Label(num_top_classes=len(DR_LABELMAP),
label="Retinopathy level prediction")
with gr.Column(scale=4):
gr.Markdown("![](https://media.githubusercontent.com/media/Obs01ete/retinopathy/master/media/logo1.png)")
with gr.Row():
with gr.Column(scale=9, min_width=100):
image = gr.Image(label="Retina scan")
with gr.Column(scale=1, min_width=150):
for cls_id in range(len(example_lists)):
label = DR_LABELMAP[cls_id]
with gr.Tab(f"{cls_id} : {label}"):
gr.Examples(
example_lists[cls_id],
inputs=[image],
outputs=[output],
fn=self.predict,
examples_per_page=10,
run_on_click=True)
predict_btn.click(
fn=self.predict,
inputs=image,
outputs=output,
api_name="predict")
self.ui = ui
def launch(self) -> None:
self.ui.queue().launch(share=True)
def predict(self, image: Optional[np.ndarray]):
if image is None:
return dict()
cls_name, prob, probs = self._infer(image)
message = f"Predicted class={cls_name}, prob={prob:.3f}"
print(message)
probs_dict = {f"{i} - {DR_LABELMAP[i]}": float(v)
for i, v in enumerate(probs)}
return probs_dict
def _infer(self, image_chw: np.ndarray) -> Tuple[str, float, np.ndarray]:
assert isinstance(self.model, ResNetForImageClassification)
inputs = self.image_processor(image_chw, return_tensors="pt")
with torch.no_grad():
output = self.model(**inputs)
logits_batch = output.logits
assert len(logits_batch.shape) == 2
assert logits_batch.shape[0] == 1
logits = logits_batch[0]
probs = torch.softmax(logits, dim=-1)
predicted_label = int(probs.argmax(-1).item())
prob = probs[predicted_label].item()
cls_name = self.model.config.id2label[predicted_label]
return cls_name, prob, probs.numpy()
@staticmethod
def _load_example_lists() -> Dict[int, List[str]]:
example_flat_list = glob.glob("demo_data/train/**/*.jpeg")
example_lists: Dict[int, List[str]] = defaultdict(list)
for path in example_flat_list:
dir, _ = os.path.split(path)
_, subdir = os.path.split(dir)
try:
cls_id = int(subdir)
except ValueError:
print(f"Cannot parse path {path}")
continue
example_lists[cls_id].append(path)
return example_lists
def main():
app = App()
app.launch()
if __name__ == "__main__":
main()
|