vicellst-att / app.py
polejowska's picture
Update app.py
f5a0872
raw
history blame
5.11 kB
import pathlib
from constants import MODELS_REPO, MODELS_NAMES
import gradio as gr
import torch
from transformers import (AutoFeatureExtractor, DetrForObjectDetection,)
from visualization import visualize_attention_map, visualize_prediction
from style import css, description, title
from PIL import Image
def make_prediction(img, feature_extractor, model):
inputs = feature_extractor(img, return_tensors="pt")
outputs = model(**inputs)
img_size = torch.tensor([tuple(reversed(img.size))])
processed_outputs = feature_extractor.post_process(outputs, img_size)
print(outputs.keys())
return (
processed_outputs[0],
outputs["decoder_attentions"],
outputs["encoder_attentions"],
outputs["cross_attentions"],
)
def detect_objects(model_name, image_input, image_path, threshold, display_mask=False):
feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name])
if "DETR" in model_name:
model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name])
model_details = "DETR details"
(
processed_outputs,
decoder_attention_map,
encoder_attention_map,
cross_attention_map,
) = make_prediction(image_input, feature_extractor, model)
mask_pil_image = None
if display_mask:
mask_path = image_path.parent / (
image_path.stem.replace("_HE", "_mask") + image_path.suffix
)
mask_pil_image = Image.open(mask_path)
viz_img = visualize_prediction(
image_input, processed_outputs, threshold, model.config.id2label, display_mask, mask_pil_image
)
decoder_attention_map_img = visualize_attention_map(
image_input, decoder_attention_map
)
encoder_attention_map_img = visualize_attention_map(
image_input, encoder_attention_map
)
cross_attention_map_img = visualize_attention_map(image_input, cross_attention_map)
return (
viz_img,
decoder_attention_map_img,
encoder_attention_map_img,
cross_attention_map_img,
model_details
)
def set_example_image(example: list):
img_path = example[0]
return gr.Image.update(value=example[0]), img_path
with gr.Blocks(css=css) as app:
gr.Markdown(title)
gr.Markdown(description)
with gr.Tabs():
with gr.TabItem("Image upload and detections visualization"):
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil")
with gr.Column():
options = gr.Dropdown(
value=MODELS_NAMES[0],
choices=MODELS_NAMES,
label="Select an object detection model",
show_label=True,
)
with gr.Row():
with gr.Column():
slider_input = gr.Slider(
minimum=0.2, maximum=1, value=0.7, label="Prediction threshold"
)
with gr.Column():
display_mask = gr.Checkbox(
label="Display masks", default=False
)
detect_button = gr.Button("Detect leukocytes")
with gr.Row():
example_images = gr.Dataset(
components=[img_input],
samples=[
[path.as_posix()]
for path in sorted(
pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png")
)
],
)
with gr.Row():
with gr.Column():
gr.Markdown(
"""The selected image with detected bounding boxes by the model"""
)
img_output_from_upload = gr.Image(shape=(850, 850))
with gr.TabItem("Attentions visualization"):
gr.Markdown("""Encoder attentions""")
with gr.Row():
encoder_att_map_output = gr.Image(shape=(850, 850))
gr.Markdown("""Decoder attentions""")
with gr.Row():
decoder_att_map_output = gr.Image(shape=(850, 850))
gr.Markdown("""Cross attentions""")
with gr.Row():
cross_att_map_output = gr.Image(shape=(850, 850))
with gr.TabItem("Model details"):
with gr.Row():
model_details = gr.Markdown(""" """)
detect_button.click(
detect_objects,
inputs=[options, img_input, img_path, slider_input, display_mask],
outputs=[
img_output_from_upload,
decoder_att_map_output,
encoder_att_map_output,
cross_att_map_output,
model_details,
],
queue=True,
)
example_images.click(
fn=set_example_image, inputs=[example_images], outputs=[img_input, img_path]
)
app.launch(enable_queue=True)