File size: 6,359 Bytes
06257c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16c60f0
 
 
 
06257c8
 
 
 
 
 
 
 
 
c932a4f
06257c8
 
 
16c60f0
06257c8
 
 
 
 
16c60f0
 
06257c8
 
 
 
16c60f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from pathlib import Path

import gradio as gr
import numpy as np
from matplotlib import pyplot as plt

from descriptors import disease_descriptors_chexpert, disease_descriptors_chestxray14
from model import InferenceModel


def plot_bars(model_output):
    # sort model_output by overall_probability
    model_output = {k: v for k, v in sorted(model_output.items(), key=lambda item: item[1]['overall_probability'], reverse=True)}

    # Create a figure with as many subplots as there are diseases, arranged vertically
    fig, axs = plt.subplots(len(model_output), 1, figsize=(10, 5 * len(model_output)))
    # axs is not iterable if only one subplot is created, so make it a list
    if len(model_output) == 1:
        axs = [axs]

    for ax, (disease, data) in zip(axs, model_output.items()):
        desc_probs = list(data['descriptor_probabilities'].items())
        # sort descending
        desc_probs = sorted(desc_probs, key=lambda item: item[1], reverse=True)

        my_probs = [p[1] for p in desc_probs]
        min_prob = min(my_probs)
        max_prob = max(my_probs)
        my_labels = [p[0] for p in desc_probs]

        # Convert probabilities to differences from 0.5
        diffs = np.abs(np.array(my_probs) - 0.5)

        # Set colors based on sign of difference
        colors = ['red' if p < 0.5 else 'forestgreen' for p in my_probs]

        # Plot bars with appropriate colors and left offsets
        left = [p if p < 0.5 else 0.5 for p in my_probs]
        bars = ax.barh(my_labels, diffs, left=left, color=colors, alpha=0.3)

        for i, bar in enumerate(bars):
            ax.text(min_prob - 0.04, bar.get_y() + bar.get_height() / 2, my_labels[i], ha='left', va='center', color='black', fontsize=15)

        ax.set_xlim(min(min_prob - 0.05, 0.49), max(max_prob + 0.05, 0.51))

        # Invert the y-axis to show bars with values less than 0.5 to the left of the center
        ax.invert_yaxis()

        ax.set_yticks([])

        # Add a title for the disease
        if data['overall_probability'] >= 0.5:
            ax.set_title(f"{disease} : score of {data['overall_probability']:.2f}")
        else:
            ax.set_title(f"No {disease} : score of {data['overall_probability']:.2f}")

        # make title larger and bold
        ax.title.set_fontsize(15)
        ax.title.set_fontweight(600)

    # Save the plot
    plt.tight_layout()  # Adjust subplot parameters to give specified padding
    file_path = 'plot.png'
    plt.savefig(file_path)
    plt.close(fig)

    return file_path


def classify_image(inference_model, image_path, diseases_to_predict):
    descriptors_with_indication = [d + " indicating " + disease for disease, descriptors in diseases_to_predict.items() for d in descriptors]
    probs, negative_probs = inference_model.get_descriptor_probs(image_path=Path(image_path), descriptors=descriptors_with_indication,
                                                                 do_negative_prompting=True, demo=True)

    disease_probs, negative_disease_probs = inference_model.get_diseases_probs(diseases_to_predict, pos_probs=probs, negative_probs=negative_probs)

    model_output = {}
    for idx, disease in enumerate(diseases_to_predict.keys()):
        model_output[disease] = {
            'overall_probability': disease_probs[disease],
            'descriptor_probabilities': {descriptor: probs[f'{descriptor} indicating {disease}'].item() for descriptor in
                                         diseases_to_predict[disease]}
        }

    file_path = plot_bars(model_output)
    return file_path


# Define the function you want to wrap
def process_input(image_path, prompt_names: list, disease_name: str, descriptors: str):
    diseases_to_predict = {}

    for prompt in prompt_names:
        if prompt == 'Custom':
            diseases_to_predict[disease_name] = descriptors.split('\n')
        else:
            if prompt in disease_descriptors_chexpert:
                diseases_to_predict[prompt] = disease_descriptors_chexpert[prompt]
            else:  # only chestxray14
                diseases_to_predict[prompt] = disease_descriptors_chestxray14[prompt]

    # classify
    model = InferenceModel()
    output = classify_image(model, image_path, diseases_to_predict)

    return output

with open("article.md", "r") as f:
    article = f.read()
with open("description.md", "r") as f:
    description = f.read()

# Define the Gradio interface
iface = gr.Interface(
    fn=process_input,
    examples = [['examples/enlarged_cardiomediastinum.jpg', ['Enlarged Cardiomediastinum'], '', ''],['examples/edema.jpg', ['Edema'], '', ''],
                ['examples/support_devices.jpg', ['Custom'], 'Pacemaker', 'metalic object\nimplant on the left side of the chest\nimplanted cardiac device']],
    inputs=[gr.inputs.Image(type="filepath"), gr.inputs.CheckboxGroup(
        choices=['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
                 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices',
                 'Infiltration', 'Mass', 'Nodule', 'Emphysema', 'Fibrosis', 'Pleural Thickening', 'Hernia',
                 'Custom'],
        default=['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
                 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices'],
        label='Select to use predefined disease descriptors. Select "Custom" to define your own observations.'),
            gr.inputs.Textbox(lines=2, placeholder="Name of pathology for which you want to define custom observations", label='Pathology:'),
            gr.inputs.Textbox(lines=2, placeholder="Add your custom (positive) observations separated by a new line"
                                                   "\n Note: Each descriptor will automatically be embedded into our prompt format: There is/are (no) <observation> indicating <pathology>"
                                                   "\n Example:\n\n Opacity\nPleural Effusion\nConsolidation"
                              , label='Custom Observations:')],
    article=article,
    description=description,
    outputs=gr.outputs.Image(type="filepath")
)

# Launch the interface
iface.launch()