import gradio as gr import torch from torch import nn from transformers import BertTokenizer, BertModel # Define the BertClassifier class class BertClassifier(nn.Module): def __init__(self, bert: BertModel, num_classes: int): super().__init__() self.bert = bert self.classifier = nn.Linear(bert.config.hidden_size, num_classes) self.criterion = nn.BCELoss() def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None): outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask ) cls_output = outputs.pooler_output cls_output = self.classifier(cls_output) cls_output = torch.sigmoid(cls_output) loss = 0 if labels is not None: loss = self.criterion(cls_output, labels) return loss, cls_output # Load the tokenizer and model tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') bert_model = BertModel.from_pretrained('bert-base-uncased') model = BertClassifier(bert_model, num_classes=7) # Load the model weights from the .pkl file model.load_state_dict(torch.load('bert_classifier_icd.pkl', map_location=torch.device('cpu'))) model.eval() # Define prediction function def predict(text): tokens = tokenizer.encode(text, add_special_tokens=True, max_length=512, truncation=True) input_ids = torch.tensor([tokens]) mask = (input_ids != tokenizer.pad_token_id).float() with torch.no_grad(): _, outputs = model(input_ids, attention_mask=mask) # Assuming outputs[0] contains the probability scores for each class confidence_scores = outputs[0].tolist() # Convert to a dictionary mapping labels to confidence scores labels = ['Cardiovascular', 'Respiratory', 'Neurological', 'Infectious', 'Endocrine', 'Musculoskeletal', 'Gastrointestinal'] prediction = {label: score for label, score in zip(labels, confidence_scores)} return prediction # Add example texts examples = [ ["Patient admitted with chest pain, shortness of breath, and abnormal ECG findings."], ["Elderly patient presented with symptoms of confusion, fever, and elevated white blood cell count."], ["Patient complains of persistent cough, wheezing, and history of asthma."], ["Admitted with severe abdominal pain, nausea, and vomiting. Suspected appendicitis."], ["Patient has a history of diabetes mellitus and presented with high blood glucose levels and dehydration."], ["Patient admitted following a fall, showing signs of fracture in the left femur."], ["Patient experiencing severe headaches, dizziness, and a history of epilepsy."], ["Acute kidney injury suspected due to elevated creatinine and reduced urine output."], ["Patient diagnosed with major depressive disorder, experiencing prolonged sadness and loss of interest in activities."], ["Presented with a bacterial skin infection requiring intravenous antibiotics."] ] # Create Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Textbox(lines=10, placeholder="Enter clinical notes here..."), outputs=gr.Label(num_top_classes=7), examples=examples, title="MIMIC-IV ICD Code Classification", description="Predict ICD code categories based on clinical text using a BERT-based model. The model outputs confidence scores for seven common ICD code categories.", ) iface.launch()