import pathlib from constants import MODELS_REPO, MODELS_NAMES import gradio as gr import torch from transformers import (AutoFeatureExtractor, DetrForObjectDetection,) from visualization import visualize_attention_map, visualize_prediction from style import css, description, title from PIL import Image def make_prediction(img, feature_extractor, model): inputs = feature_extractor(img, return_tensors="pt") outputs = model(**inputs) img_size = torch.tensor([tuple(reversed(img.size))]) processed_outputs = feature_extractor.post_process(outputs, img_size) print(outputs.keys()) return ( processed_outputs[0], outputs["decoder_attentions"], outputs["encoder_attentions"], outputs["cross_attentions"], ) def detect_objects(model_name, image_input, threshold, display_mask=False): feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS_REPO[model_name]) if "DETR" in model_name: model = DetrForObjectDetection.from_pretrained(MODELS_REPO[model_name]) model_details = "DETR details" ( processed_outputs, decoder_attention_map, encoder_attention_map, cross_attention_map, ) = make_prediction(image_input, feature_extractor, model) mask_pil_image = None if display_mask: # get image path image_path = pathlib.Path(image_input.name) mask_path = image_path.parent / ( image_path.stem.replace("_HE", "_mask") + image_path.suffix ) mask_pil_image = Image.open(mask_path) viz_img = visualize_prediction( image_input, processed_outputs, threshold, model.config.id2label, display_mask, mask_pil_image ) decoder_attention_map_img = visualize_attention_map( image_input, decoder_attention_map ) encoder_attention_map_img = visualize_attention_map( image_input, encoder_attention_map ) cross_attention_map_img = visualize_attention_map(image_input, cross_attention_map) return ( viz_img, decoder_attention_map_img, encoder_attention_map_img, cross_attention_map_img, model_details ) def set_example_image(example: list) -> dict: return gr.Image.update(value=example[0]) with gr.Blocks(css=css) as app: gr.Markdown(title) gr.Markdown(description) with gr.Tabs(): with gr.TabItem("Image upload and detections visualization"): with gr.Row(): with gr.Column(): img_input = gr.Image(type="pil") with gr.Column(): options = gr.Dropdown( value=MODELS_NAMES[0], choices=MODELS_NAMES, label="Select an object detection model", show_label=True, ) with gr.Row(): with gr.Column(): slider_input = gr.Slider( minimum=0.2, maximum=1, value=0.7, label="Prediction threshold" ) with gr.Column(): display_mask = gr.Checkbox( label="Display masks", default=False ) detect_button = gr.Button("Detect leukocytes") with gr.Row(): example_images = gr.Dataset( components=[img_input], samples=[ [path.as_posix()] for path in sorted( pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png") ) ], ) with gr.Row(): with gr.Column(): gr.Markdown( """The selected image with detected bounding boxes by the model""" ) img_output_from_upload = gr.Image(shape=(850, 850)) with gr.TabItem("Attentions visualization"): gr.Markdown("""Encoder attentions""") with gr.Row(): encoder_att_map_output = gr.Image(shape=(850, 850)) gr.Markdown("""Decoder attentions""") with gr.Row(): decoder_att_map_output = gr.Image(shape=(850, 850)) gr.Markdown("""Cross attentions""") with gr.Row(): cross_att_map_output = gr.Image(shape=(850, 850)) with gr.TabItem("Model details"): with gr.Row(): model_details = gr.Markdown(""" """) detect_button.click( detect_objects, inputs=[options, img_input, slider_input, display_mask], outputs=[ img_output_from_upload, decoder_att_map_output, encoder_att_map_output, cross_att_map_output, model_details, ], queue=True, ) example_images.click( fn=set_example_image, inputs=[example_images], outputs=[img_input] ) app.launch(enable_queue=True)