Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torch import nn | |
from transformers import BertTokenizer, BertModel | |
# Define the BertClassifier class | |
class BertClassifier(nn.Module): | |
def __init__(self, bert: BertModel, num_classes: int): | |
super().__init__() | |
self.bert = bert | |
self.classifier = nn.Linear(bert.config.hidden_size, num_classes) | |
self.criterion = nn.BCELoss() | |
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None): | |
outputs = self.bert( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask | |
) | |
cls_output = outputs.pooler_output | |
cls_output = self.classifier(cls_output) | |
cls_output = torch.sigmoid(cls_output) | |
loss = 0 | |
if labels is not None: | |
loss = self.criterion(cls_output, labels) | |
return loss, cls_output | |
# Load the tokenizer and model | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
bert_model = BertModel.from_pretrained('bert-base-uncased') | |
model = BertClassifier(bert_model, num_classes=7) | |
# Load the model weights from the .pkl file | |
model.load_state_dict(torch.load('bert_classifier_icd.pkl', map_location=torch.device('cpu'))) | |
model.eval() | |
# Define prediction function | |
def predict(text): | |
tokens = tokenizer.encode(text, add_special_tokens=True, max_length=512, truncation=True) | |
input_ids = torch.tensor([tokens]) | |
mask = (input_ids != tokenizer.pad_token_id).float() | |
with torch.no_grad(): | |
_, outputs = model(input_ids, attention_mask=mask) | |
# Assuming outputs[0] contains the probability scores for each class | |
confidence_scores = outputs[0].tolist() | |
# Convert to a dictionary mapping labels to confidence scores | |
labels = ['Cardiovascular', 'Respiratory', 'Neurological', 'Infectious', 'Endocrine', 'Musculoskeletal', 'Gastrointestinal'] | |
prediction = {label: score for label, score in zip(labels, confidence_scores)} | |
return prediction | |
# Add example texts | |
examples = [ | |
["Patient admitted with chest pain, shortness of breath, and abnormal ECG findings."], | |
["Elderly patient presented with symptoms of confusion, fever, and elevated white blood cell count."], | |
["Patient complains of persistent cough, wheezing, and history of asthma."], | |
["Admitted with severe abdominal pain, nausea, and vomiting. Suspected appendicitis."], | |
["Patient has a history of diabetes mellitus and presented with high blood glucose levels and dehydration."], | |
["Patient admitted following a fall, showing signs of fracture in the left femur."], | |
["Patient experiencing severe headaches, dizziness, and a history of epilepsy."], | |
["Acute kidney injury suspected due to elevated creatinine and reduced urine output."], | |
["Patient diagnosed with major depressive disorder, experiencing prolonged sadness and loss of interest in activities."], | |
["Presented with a bacterial skin infection requiring intravenous antibiotics."] | |
] | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(lines=10, placeholder="Enter clinical notes here..."), | |
outputs=gr.Label(num_top_classes=7), | |
examples=examples, | |
title="MIMIC-IV ICD Code Classification", | |
description="Predict ICD code categories based on clinical text using a BERT-based model. The model outputs confidence scores for seven common ICD code categories.", | |
) | |
iface.launch() | |