Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -31,10 +31,10 @@ class BertClassifier(nn.Module):
|
|
31 |
# Load the tokenizer and model
|
32 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
33 |
bert_model = BertModel.from_pretrained('bert-base-uncased')
|
34 |
-
model = BertClassifier(bert_model, num_classes=
|
35 |
|
36 |
# Load the model weights from the .pkl file
|
37 |
-
model.load_state_dict(torch.load('
|
38 |
model.eval()
|
39 |
|
40 |
# Define prediction function
|
@@ -50,38 +50,33 @@ def predict(text):
|
|
50 |
confidence_scores = outputs[0].tolist()
|
51 |
|
52 |
# Convert to a dictionary mapping labels to confidence scores
|
53 |
-
labels = ['
|
54 |
prediction = {label: score for label, score in zip(labels, confidence_scores)}
|
55 |
|
56 |
return prediction
|
57 |
|
58 |
# Add example texts
|
59 |
examples = [
|
60 |
-
["
|
61 |
-
["
|
62 |
-
["
|
63 |
-
["
|
64 |
-
["
|
65 |
-
["
|
66 |
-
["
|
67 |
-
["
|
68 |
-
["
|
69 |
-
["
|
70 |
-
["What a motherfucking piece of crap those fuckheads for blocking us!"],
|
71 |
-
["Get a life you animal fucker. Shut up you fucking nigger. Fuck off and shit your pants full of all the shit you can fill them with. 144.131.176.126"],
|
72 |
-
["HOPE YOUR HEAD GETS CUT OFF AND SOMEONE WIPS THERE ASS WITH IT AND THEN STABS YOU IN YOUR HEART"],
|
73 |
-
["you people are pretty overzealous with this whole free thing. get a fucking life, you fucking niggers !!!23 16!!!"],
|
74 |
-
["Stupid peace of shit stop deleting my stuff asshole go die and fall in a hole go to hell!"],
|
75 |
-
["Bye! Don't look, come or think of comming back! Tosser."]
|
76 |
]
|
|
|
77 |
# Create Gradio interface
|
78 |
iface = gr.Interface(
|
79 |
fn=predict,
|
80 |
-
inputs=gr.Textbox(lines=10, placeholder="Enter
|
81 |
-
outputs=gr.Label(num_top_classes=
|
82 |
examples=examples,
|
83 |
-
title="
|
84 |
-
description="
|
85 |
)
|
86 |
|
87 |
iface.launch()
|
|
|
31 |
# Load the tokenizer and model
|
32 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
33 |
bert_model = BertModel.from_pretrained('bert-base-uncased')
|
34 |
+
model = BertClassifier(bert_model, num_classes=10)
|
35 |
|
36 |
# Load the model weights from the .pkl file
|
37 |
+
model.load_state_dict(torch.load('mimic4_icd_classifier.pkl', map_location=torch.device('cpu')))
|
38 |
model.eval()
|
39 |
|
40 |
# Define prediction function
|
|
|
50 |
confidence_scores = outputs[0].tolist()
|
51 |
|
52 |
# Convert to a dictionary mapping labels to confidence scores
|
53 |
+
labels = ['Cardiovascular', 'Respiratory', 'Neurological', 'Infectious', 'Endocrine', 'Musculoskeletal', 'Gastrointestinal', 'Renal', 'Psychiatric', 'Other']
|
54 |
prediction = {label: score for label, score in zip(labels, confidence_scores)}
|
55 |
|
56 |
return prediction
|
57 |
|
58 |
# Add example texts
|
59 |
examples = [
|
60 |
+
["Patient admitted with chest pain, shortness of breath, and abnormal ECG findings."],
|
61 |
+
["Elderly patient presented with symptoms of confusion, fever, and elevated white blood cell count."],
|
62 |
+
["Patient complains of persistent cough, wheezing, and history of asthma."],
|
63 |
+
["Admitted with severe abdominal pain, nausea, and vomiting. Suspected appendicitis."],
|
64 |
+
["Patient has a history of diabetes mellitus and presented with high blood glucose levels and dehydration."],
|
65 |
+
["Patient admitted following a fall, showing signs of fracture in the left femur."],
|
66 |
+
["Patient experiencing severe headaches, dizziness, and a history of epilepsy."],
|
67 |
+
["Acute kidney injury suspected due to elevated creatinine and reduced urine output."],
|
68 |
+
["Patient diagnosed with major depressive disorder, experiencing prolonged sadness and loss of interest in activities."],
|
69 |
+
["Presented with a bacterial skin infection requiring intravenous antibiotics."]
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
]
|
71 |
+
|
72 |
# Create Gradio interface
|
73 |
iface = gr.Interface(
|
74 |
fn=predict,
|
75 |
+
inputs=gr.Textbox(lines=10, placeholder="Enter clinical notes here..."),
|
76 |
+
outputs=gr.Label(num_top_classes=10),
|
77 |
examples=examples,
|
78 |
+
title="MIMIC-IV ICD Code Prediction",
|
79 |
+
description="Predict ICD code categories based on clinical text using a BERT-based model. The model outputs confidence scores for ten common ICD code categories.",
|
80 |
)
|
81 |
|
82 |
iface.launch()
|