Tonic commited on
Commit
15ec37f
·
unverified ·
1 Parent(s): b768cbf

add verifier

Browse files
Files changed (5) hide show
  1. app.py +3 -4
  2. requirements.txt +1 -0
  3. templates/oneclick.html +10 -3
  4. utils/oneclick.py +28 -18
  5. utils/verifier.py +70 -0
app.py CHANGED
@@ -1,5 +1,4 @@
1
  # app.py
2
-
3
  from flask import Flask, render_template, request, send_file, redirect, url_for
4
  import os
5
  import logging
@@ -7,7 +6,6 @@ from utils.meldrx import MeldRxAPI
7
  from utils.oneclick import generate_discharge_paper_one_click
8
  from huggingface_hub import InferenceClient
9
 
10
-
11
  logging.basicConfig(level=logging.DEBUG)
12
  logger = logging.getLogger(__name__)
13
 
@@ -30,7 +28,6 @@ HF_TOKEN = os.getenv("HF_TOKEN")
30
  if not HF_TOKEN:
31
  raise ValueError("HF_TOKEN environment variable not set.")
32
  client = InferenceClient(api_key=HF_TOKEN)
33
- MODEL_NAME = "meta-llama/Llama-3.3-70B-Instruct"
34
 
35
  @app.route('/')
36
  def index():
@@ -77,7 +74,7 @@ def one_click():
77
 
78
  logger.info(f"One-click request - ID: {patient_id}, First: {first_name}, Last: {last_name}, Action: {action}")
79
 
80
- pdf_path, status, basic_summary, ai_summary = generate_discharge_paper_one_click(
81
  meldrx_api, client, patient_id, first_name, last_name
82
  )
83
 
@@ -86,6 +83,7 @@ def one_click():
86
  status=status,
87
  basic_summary=basic_summary.replace('\n', '<br>') if basic_summary else None,
88
  ai_summary=ai_summary.replace('\n', '<br>') if ai_summary else None,
 
89
  patient_id=patient_id,
90
  first_name=first_name,
91
  last_name=last_name)
@@ -97,6 +95,7 @@ def one_click():
97
  status=status,
98
  basic_summary=basic_summary.replace('\n', '<br>') if basic_summary else None,
99
  ai_summary=ai_summary.replace('\n', '<br>') if ai_summary else None,
 
100
  patient_id=patient_id,
101
  first_name=first_name,
102
  last_name=last_name)
 
1
  # app.py
 
2
  from flask import Flask, render_template, request, send_file, redirect, url_for
3
  import os
4
  import logging
 
6
  from utils.oneclick import generate_discharge_paper_one_click
7
  from huggingface_hub import InferenceClient
8
 
 
9
  logging.basicConfig(level=logging.DEBUG)
10
  logger = logging.getLogger(__name__)
11
 
 
28
  if not HF_TOKEN:
29
  raise ValueError("HF_TOKEN environment variable not set.")
30
  client = InferenceClient(api_key=HF_TOKEN)
 
31
 
32
  @app.route('/')
33
  def index():
 
74
 
75
  logger.info(f"One-click request - ID: {patient_id}, First: {first_name}, Last: {last_name}, Action: {action}")
76
 
77
+ pdf_path, status, basic_summary, ai_summary, verified_summary = generate_discharge_paper_one_click(
78
  meldrx_api, client, patient_id, first_name, last_name
79
  )
80
 
 
83
  status=status,
84
  basic_summary=basic_summary.replace('\n', '<br>') if basic_summary else None,
85
  ai_summary=ai_summary.replace('\n', '<br>') if ai_summary else None,
86
+ verified_summary=verified_summary if verified_summary else None,
87
  patient_id=patient_id,
88
  first_name=first_name,
89
  last_name=last_name)
 
95
  status=status,
96
  basic_summary=basic_summary.replace('\n', '<br>') if basic_summary else None,
97
  ai_summary=ai_summary.replace('\n', '<br>') if ai_summary else None,
98
+ verified_summary=verified_summary if verified_summary else None,
99
  patient_id=patient_id,
100
  first_name=first_name,
101
  last_name=last_name)
requirements.txt CHANGED
@@ -13,3 +13,4 @@ gradio
13
  huggingface_hub
14
  lxml
15
  reportlab
 
 
13
  huggingface_hub
14
  lxml
15
  reportlab
16
+ lettucedetect
templates/oneclick.html CHANGED
@@ -2,9 +2,9 @@
2
  {% block content %}
3
  <h2>One-Click Discharge Summary</h2>
4
  <form method="POST">
5
- <input type="text" name="patient_id" placeholder="Patient ID (Optional)">
6
- <input type="text" name="first_name" placeholder="First Name (Optional)">
7
- <input type="text" name="last_name" placeholder="Last Name (Optional)"><br><br>
8
  <input type="submit" name="action" value="Display Summary" class="cyberpunk-button">
9
  <input type="submit" name="action" value="Generate PDF" class="cyberpunk-button">
10
  </form>
@@ -27,6 +27,13 @@
27
  </div>
28
  {% endif %}
29
 
 
 
 
 
 
 
 
30
  <style>
31
  .status-message {
32
  margin: 20px 0;
 
2
  {% block content %}
3
  <h2>One-Click Discharge Summary</h2>
4
  <form method="POST">
5
+ <input type="text" name="patient_id" placeholder="Patient ID (Optional)" value="{{ patient_id or '' }}">
6
+ <input type="text" name="first_name" placeholder="First Name (Optional)" value="{{ first_name or '' }}">
7
+ <input type="text" name="last_name" placeholder="Last Name (Optional)" value="{{ last_name or '' }}"><br><br>
8
  <input type="submit" name="action" value="Display Summary" class="cyberpunk-button">
9
  <input type="submit" name="action" value="Generate PDF" class="cyberpunk-button">
10
  </form>
 
27
  </div>
28
  {% endif %}
29
 
30
+ {% if verified_summary %}
31
+ <div class="summary-container">
32
+ <h3>Verified AI Discharge Summary (Hallucinations Highlighted)</h3>
33
+ <div class="summary-content">{{ verified_summary | safe }}</div>
34
+ </div>
35
+ {% endif %}
36
+
37
  <style>
38
  .status-message {
39
  margin: 20px 0;
utils/oneclick.py CHANGED
@@ -3,10 +3,11 @@ from typing import Tuple, Optional, Dict
3
  from .meldrx import MeldRxAPI
4
  from .responseparser import PatientDataExtractor
5
  from .pdfutils import PDFGenerator
 
6
  import logging
7
  import json
8
  from huggingface_hub import InferenceClient
9
- import os
10
 
11
  logger = logging.getLogger(__name__)
12
 
@@ -15,9 +16,10 @@ if not HF_TOKEN:
15
  raise ValueError("HF_TOKEN environment variable not set.")
16
  client = InferenceClient(api_key=HF_TOKEN)
17
  MODEL_NAME = "meta-llama/Llama-3.3-70B-Instruct"
 
18
 
19
- def generate_ai_discharge_summary(patient_dict: Dict[str, str], client) -> Optional[str]:
20
- """Generate a discharge summary using AI based on extracted patient data."""
21
  try:
22
  formatted_summary = format_discharge_summary(patient_dict)
23
 
@@ -49,12 +51,22 @@ def generate_ai_discharge_summary(patient_dict: Dict[str, str], client) -> Optio
49
  if content:
50
  discharge_summary += content
51
 
 
52
  logger.info("AI discharge summary generated successfully")
53
- return discharge_summary.strip()
 
 
 
 
 
 
 
 
 
54
 
55
  except Exception as e:
56
  logger.error(f"Error generating AI discharge summary: {str(e)}", exc_info=True)
57
- return None
58
 
59
  def generate_discharge_paper_one_click(
60
  api: MeldRxAPI,
@@ -62,12 +74,12 @@ def generate_discharge_paper_one_click(
62
  patient_id: str = "",
63
  first_name: str = "",
64
  last_name: str = ""
65
- ) -> Tuple[Optional[str], str, Optional[str], Optional[str]]:
66
  try:
67
  patients_data = api.get_patients()
68
  if not patients_data or "entry" not in patients_data:
69
  logger.error("No patient data received from MeldRx API")
70
- return None, "Failed to fetch patient data from MeldRx API", None, None
71
 
72
  logger.debug(f"Raw patient data from API: {patients_data}")
73
 
@@ -75,7 +87,7 @@ def generate_discharge_paper_one_click(
75
 
76
  if not extractor.patients:
77
  logger.error("No patients found in the parsed data")
78
- return None, "No patients found in the data", None, None
79
 
80
  logger.info(f"Found {len(extractor.patients)} patients in the data")
81
 
@@ -102,10 +114,8 @@ def generate_discharge_paper_one_click(
102
  logger.debug(f"Comparing - Input: ID={patient_id_input}, First={first_name_input}, Last={last_name_input}")
103
 
104
  matches = True
105
- # Only enforce ID match if both input and data have non-empty IDs
106
  if patient_id_input and patient_id_from_data and patient_id_input != patient_id_from_data:
107
  matches = False
108
- # Use exact match for names if provided, ignoring case
109
  if first_name_input and first_name_input != first_name_from_data:
110
  matches = False
111
  if last_name_input and last_name_input != last_name_from_data:
@@ -123,28 +133,28 @@ def generate_discharge_paper_one_click(
123
  logger.info(f"Available patient names: {all_patient_names}")
124
  return None, (f"No patients found matching criteria: {search_criteria}\n"
125
  f"Available IDs: {', '.join(all_patient_ids)}\n"
126
- f"Available Names: {', '.join(all_patient_names)}"), None, None
127
- logger.debug(f"Raw patient data from API: {json.dumps(patients_data, indent=2)}")
128
  patient_data = matching_patients[0]
129
  logger.info(f"Selected patient data: {patient_data}")
130
 
131
  basic_summary = format_discharge_summary(patient_data)
132
- ai_summary = generate_ai_discharge_summary(patient_data, client)
133
 
134
- if not ai_summary:
135
- return None, "Failed to generate AI summary", basic_summary, None
136
 
137
  pdf_gen = PDFGenerator()
138
  filename = f"discharge_{patient_data.get('id', 'unknown')}_{patient_data.get('last_name', 'patient')}.pdf"
139
  pdf_path = pdf_gen.generate_pdf_from_text(ai_summary, filename)
140
 
141
  if pdf_path:
142
- return pdf_path, "Discharge summary generated successfully", basic_summary, ai_summary
143
- return None, "Failed to generate PDF file", basic_summary, ai_summary
144
 
145
  except Exception as e:
146
  logger.error(f"Error in one-click discharge generation: {str(e)}", exc_info=True)
147
- return None, f"Error generating discharge summary: {str(e)}", None, None
148
 
149
  def format_discharge_summary(patient_data: dict) -> str:
150
  """Format patient data into a discharge summary text."""
 
3
  from .meldrx import MeldRxAPI
4
  from .responseparser import PatientDataExtractor
5
  from .pdfutils import PDFGenerator
6
+ from .verifier import DischargeVerifier # Import the verifier
7
  import logging
8
  import json
9
  from huggingface_hub import InferenceClient
10
+ import os
11
 
12
  logger = logging.getLogger(__name__)
13
 
 
16
  raise ValueError("HF_TOKEN environment variable not set.")
17
  client = InferenceClient(api_key=HF_TOKEN)
18
  MODEL_NAME = "meta-llama/Llama-3.3-70B-Instruct"
19
+ verifier = DischargeVerifier() # Initialize the verifier
20
 
21
+ def generate_ai_discharge_summary(patient_dict: Dict[str, str], client) -> Tuple[Optional[str], Optional[str]]:
22
+ """Generate a discharge summary using AI and verify it for hallucinations."""
23
  try:
24
  formatted_summary = format_discharge_summary(patient_dict)
25
 
 
51
  if content:
52
  discharge_summary += content
53
 
54
+ discharge_summary = discharge_summary.strip()
55
  logger.info("AI discharge summary generated successfully")
56
+
57
+ # Verify the summary for hallucinations
58
+ question = "Provide a complete discharge summary based on the patient information."
59
+ verified_summary = verifier.verify_discharge_summary(
60
+ context=formatted_summary,
61
+ question=question,
62
+ answer=discharge_summary
63
+ )
64
+
65
+ return discharge_summary, verified_summary
66
 
67
  except Exception as e:
68
  logger.error(f"Error generating AI discharge summary: {str(e)}", exc_info=True)
69
+ return None, None
70
 
71
  def generate_discharge_paper_one_click(
72
  api: MeldRxAPI,
 
74
  patient_id: str = "",
75
  first_name: str = "",
76
  last_name: str = ""
77
+ ) -> Tuple[Optional[str], str, Optional[str], Optional[str], Optional[str]]:
78
  try:
79
  patients_data = api.get_patients()
80
  if not patients_data or "entry" not in patients_data:
81
  logger.error("No patient data received from MeldRx API")
82
+ return None, "Failed to fetch patient data from MeldRx API", None, None, None
83
 
84
  logger.debug(f"Raw patient data from API: {patients_data}")
85
 
 
87
 
88
  if not extractor.patients:
89
  logger.error("No patients found in the parsed data")
90
+ return None, "No patients found in the data", None, None, None
91
 
92
  logger.info(f"Found {len(extractor.patients)} patients in the data")
93
 
 
114
  logger.debug(f"Comparing - Input: ID={patient_id_input}, First={first_name_input}, Last={last_name_input}")
115
 
116
  matches = True
 
117
  if patient_id_input and patient_id_from_data and patient_id_input != patient_id_from_data:
118
  matches = False
 
119
  if first_name_input and first_name_input != first_name_from_data:
120
  matches = False
121
  if last_name_input and last_name_input != last_name_from_data:
 
133
  logger.info(f"Available patient names: {all_patient_names}")
134
  return None, (f"No patients found matching criteria: {search_criteria}\n"
135
  f"Available IDs: {', '.join(all_patient_ids)}\n"
136
+ f"Available Names: {', '.join(all_patient_names)}"), None, None, None
137
+
138
  patient_data = matching_patients[0]
139
  logger.info(f"Selected patient data: {patient_data}")
140
 
141
  basic_summary = format_discharge_summary(patient_data)
142
+ ai_summary, verified_summary = generate_ai_discharge_summary(patient_data, client)
143
 
144
+ if not ai_summary or not verified_summary:
145
+ return None, "Failed to generate or verify AI summary", basic_summary, None, None
146
 
147
  pdf_gen = PDFGenerator()
148
  filename = f"discharge_{patient_data.get('id', 'unknown')}_{patient_data.get('last_name', 'patient')}.pdf"
149
  pdf_path = pdf_gen.generate_pdf_from_text(ai_summary, filename)
150
 
151
  if pdf_path:
152
+ return pdf_path, "Discharge summary generated and verified successfully", basic_summary, ai_summary, verified_summary
153
+ return None, "Failed to generate PDF file", basic_summary, ai_summary, verified_summary
154
 
155
  except Exception as e:
156
  logger.error(f"Error in one-click discharge generation: {str(e)}", exc_info=True)
157
+ return None, f"Error generating discharge summary: {str(e)}", None, None, None
158
 
159
  def format_discharge_summary(patient_data: dict) -> str:
160
  """Format patient data into a discharge summary text."""
utils/verifier.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/verifier.py
2
+ from lettucedetect.models.inference import HallucinationDetector
3
+ import logging
4
+ from typing import List, Dict, Optional
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class DischargeVerifier:
9
+ def __init__(self):
10
+ """Initialize the hallucination detector."""
11
+ try:
12
+ self.detector = HallucinationDetector(
13
+ method="transformer",
14
+ model_path="KRLabsOrg/lettucedect-base-modernbert-en-v1",
15
+ )
16
+ logger.info("Hallucination detector initialized successfully")
17
+ except Exception as e:
18
+ logger.error(f"Failed to initialize hallucination detector: {str(e)}")
19
+ raise
20
+
21
+ def create_interactive_text(self, text: str, spans: List[Dict[str, int | float]]) -> str:
22
+ """Create interactive HTML with highlighting and hover effects."""
23
+ html_text = text
24
+
25
+ for span in sorted(spans, key=lambda x: x["start"], reverse=True):
26
+ span_text = text[span["start"]:span["end"]]
27
+ highlighted_span = (
28
+ f'<span class="hallucination" title="Confidence: {span["confidence"]:.3f}">{span_text}</span>'
29
+ )
30
+ html_text = (
31
+ html_text[:span["start"]] + highlighted_span + html_text[span["end"]:]
32
+ )
33
+
34
+ return f"""
35
+ <style>
36
+ .container {{
37
+ font-family: Arial, sans-serif;
38
+ font-size: 16px;
39
+ line-height: 1.6;
40
+ padding: 20px;
41
+ }}
42
+ .hallucination {{
43
+ background-color: rgba(255, 99, 71, 0.3);
44
+ padding: 2px;
45
+ border-radius: 3px;
46
+ cursor: help;
47
+ }}
48
+ .hallucination:hover {{
49
+ background-color: rgba(255, 99, 71, 0.5);
50
+ }}
51
+ </style>
52
+ <div class="container">{html_text}</div>
53
+ """
54
+
55
+ def verify_discharge_summary(
56
+ self, context: str, question: str, answer: str
57
+ ) -> Optional[str]:
58
+ """Verify the discharge summary for hallucinations and return highlighted HTML."""
59
+ try:
60
+ predictions = self.detector.predict(
61
+ context=[context],
62
+ question=question,
63
+ answer=answer,
64
+ output_format="spans"
65
+ )
66
+ logger.debug(f"Hallucination predictions: {predictions}")
67
+ return self.create_interactive_text(answer, predictions)
68
+ except Exception as e:
69
+ logger.error(f"Error verifying discharge summary: {str(e)}")
70
+ return None