File size: 6,359 Bytes
06257c8 16c60f0 06257c8 c932a4f 06257c8 16c60f0 06257c8 16c60f0 06257c8 16c60f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from pathlib import Path
import gradio as gr
import numpy as np
from matplotlib import pyplot as plt
from descriptors import disease_descriptors_chexpert, disease_descriptors_chestxray14
from model import InferenceModel
def plot_bars(model_output):
# sort model_output by overall_probability
model_output = {k: v for k, v in sorted(model_output.items(), key=lambda item: item[1]['overall_probability'], reverse=True)}
# Create a figure with as many subplots as there are diseases, arranged vertically
fig, axs = plt.subplots(len(model_output), 1, figsize=(10, 5 * len(model_output)))
# axs is not iterable if only one subplot is created, so make it a list
if len(model_output) == 1:
axs = [axs]
for ax, (disease, data) in zip(axs, model_output.items()):
desc_probs = list(data['descriptor_probabilities'].items())
# sort descending
desc_probs = sorted(desc_probs, key=lambda item: item[1], reverse=True)
my_probs = [p[1] for p in desc_probs]
min_prob = min(my_probs)
max_prob = max(my_probs)
my_labels = [p[0] for p in desc_probs]
# Convert probabilities to differences from 0.5
diffs = np.abs(np.array(my_probs) - 0.5)
# Set colors based on sign of difference
colors = ['red' if p < 0.5 else 'forestgreen' for p in my_probs]
# Plot bars with appropriate colors and left offsets
left = [p if p < 0.5 else 0.5 for p in my_probs]
bars = ax.barh(my_labels, diffs, left=left, color=colors, alpha=0.3)
for i, bar in enumerate(bars):
ax.text(min_prob - 0.04, bar.get_y() + bar.get_height() / 2, my_labels[i], ha='left', va='center', color='black', fontsize=15)
ax.set_xlim(min(min_prob - 0.05, 0.49), max(max_prob + 0.05, 0.51))
# Invert the y-axis to show bars with values less than 0.5 to the left of the center
ax.invert_yaxis()
ax.set_yticks([])
# Add a title for the disease
if data['overall_probability'] >= 0.5:
ax.set_title(f"{disease} : score of {data['overall_probability']:.2f}")
else:
ax.set_title(f"No {disease} : score of {data['overall_probability']:.2f}")
# make title larger and bold
ax.title.set_fontsize(15)
ax.title.set_fontweight(600)
# Save the plot
plt.tight_layout() # Adjust subplot parameters to give specified padding
file_path = 'plot.png'
plt.savefig(file_path)
plt.close(fig)
return file_path
def classify_image(inference_model, image_path, diseases_to_predict):
descriptors_with_indication = [d + " indicating " + disease for disease, descriptors in diseases_to_predict.items() for d in descriptors]
probs, negative_probs = inference_model.get_descriptor_probs(image_path=Path(image_path), descriptors=descriptors_with_indication,
do_negative_prompting=True, demo=True)
disease_probs, negative_disease_probs = inference_model.get_diseases_probs(diseases_to_predict, pos_probs=probs, negative_probs=negative_probs)
model_output = {}
for idx, disease in enumerate(diseases_to_predict.keys()):
model_output[disease] = {
'overall_probability': disease_probs[disease],
'descriptor_probabilities': {descriptor: probs[f'{descriptor} indicating {disease}'].item() for descriptor in
diseases_to_predict[disease]}
}
file_path = plot_bars(model_output)
return file_path
# Define the function you want to wrap
def process_input(image_path, prompt_names: list, disease_name: str, descriptors: str):
diseases_to_predict = {}
for prompt in prompt_names:
if prompt == 'Custom':
diseases_to_predict[disease_name] = descriptors.split('\n')
else:
if prompt in disease_descriptors_chexpert:
diseases_to_predict[prompt] = disease_descriptors_chexpert[prompt]
else: # only chestxray14
diseases_to_predict[prompt] = disease_descriptors_chestxray14[prompt]
# classify
model = InferenceModel()
output = classify_image(model, image_path, diseases_to_predict)
return output
with open("article.md", "r") as f:
article = f.read()
with open("description.md", "r") as f:
description = f.read()
# Define the Gradio interface
iface = gr.Interface(
fn=process_input,
examples = [['examples/enlarged_cardiomediastinum.jpg', ['Enlarged Cardiomediastinum'], '', ''],['examples/edema.jpg', ['Edema'], '', ''],
['examples/support_devices.jpg', ['Custom'], 'Pacemaker', 'metalic object\nimplant on the left side of the chest\nimplanted cardiac device']],
inputs=[gr.inputs.Image(type="filepath"), gr.inputs.CheckboxGroup(
choices=['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices',
'Infiltration', 'Mass', 'Nodule', 'Emphysema', 'Fibrosis', 'Pleural Thickening', 'Hernia',
'Custom'],
default=['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices'],
label='Select to use predefined disease descriptors. Select "Custom" to define your own observations.'),
gr.inputs.Textbox(lines=2, placeholder="Name of pathology for which you want to define custom observations", label='Pathology:'),
gr.inputs.Textbox(lines=2, placeholder="Add your custom (positive) observations separated by a new line"
"\n Note: Each descriptor will automatically be embedded into our prompt format: There is/are (no) <observation> indicating <pathology>"
"\n Example:\n\n Opacity\nPleural Effusion\nConsolidation"
, label='Custom Observations:')],
article=article,
description=description,
outputs=gr.outputs.Image(type="filepath")
)
# Launch the interface
iface.launch()
|