SrujayReddy31's picture
Update app.py
6759637 verified
import gradio as gr
import torch
from torch import nn
from transformers import BertTokenizer, BertModel
# Define the BertClassifier class
class BertClassifier(nn.Module):
def __init__(self, bert: BertModel, num_classes: int):
super().__init__()
self.bert = bert
self.classifier = nn.Linear(bert.config.hidden_size, num_classes)
self.criterion = nn.BCELoss()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask
)
cls_output = outputs.pooler_output
cls_output = self.classifier(cls_output)
cls_output = torch.sigmoid(cls_output)
loss = 0
if labels is not None:
loss = self.criterion(cls_output, labels)
return loss, cls_output
# Load the tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
model = BertClassifier(bert_model, num_classes=7)
# Load the model weights from the .pkl file
model.load_state_dict(torch.load('bert_classifier_icd.pkl', map_location=torch.device('cpu')))
model.eval()
# Define prediction function
def predict(text):
tokens = tokenizer.encode(text, add_special_tokens=True, max_length=512, truncation=True)
input_ids = torch.tensor([tokens])
mask = (input_ids != tokenizer.pad_token_id).float()
with torch.no_grad():
_, outputs = model(input_ids, attention_mask=mask)
# Assuming outputs[0] contains the probability scores for each class
confidence_scores = outputs[0].tolist()
# Convert to a dictionary mapping labels to confidence scores
labels = ['Cardiovascular', 'Respiratory', 'Neurological', 'Infectious', 'Endocrine', 'Musculoskeletal', 'Gastrointestinal']
prediction = {label: score for label, score in zip(labels, confidence_scores)}
return prediction
# Add example texts
examples = [
["Patient admitted with chest pain, shortness of breath, and abnormal ECG findings."],
["Elderly patient presented with symptoms of confusion, fever, and elevated white blood cell count."],
["Patient complains of persistent cough, wheezing, and history of asthma."],
["Admitted with severe abdominal pain, nausea, and vomiting. Suspected appendicitis."],
["Patient has a history of diabetes mellitus and presented with high blood glucose levels and dehydration."],
["Patient admitted following a fall, showing signs of fracture in the left femur."],
["Patient experiencing severe headaches, dizziness, and a history of epilepsy."],
["Acute kidney injury suspected due to elevated creatinine and reduced urine output."],
["Patient diagnosed with major depressive disorder, experiencing prolonged sadness and loss of interest in activities."],
["Presented with a bacterial skin infection requiring intravenous antibiotics."]
]
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=10, placeholder="Enter clinical notes here..."),
outputs=gr.Label(num_top_classes=7),
examples=examples,
title="MIMIC-IV ICD Code Classification",
description="Predict ICD code categories based on clinical text using a BERT-based model. The model outputs confidence scores for seven common ICD code categories.",
)
iface.launch()