sileod commited on
Commit
4800973
·
verified ·
1 Parent(s): df02321

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -113,6 +113,7 @@ def create_analysis_html(sentence_results, global_label, global_confidence):
113
  html += "</table>"
114
  return html
115
 
 
116
  def process_input(text_input, labels_or_premise, mode):
117
  if mode == "Zero-Shot Classification":
118
  labels = [label.strip() for label in labels_or_premise.split(',')]
@@ -126,8 +127,9 @@ def process_input(text_input, labels_or_premise, mode):
126
  else: # Long Context NLI
127
  # Global prediction
128
  global_pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
129
- global_results = {pred['label']: pred['score'] for pred in global_pred}
130
- global_label, global_confidence = max(global_results.items(), key=lambda x: x[1])
 
131
 
132
  # Sentence-level analysis
133
  sentences = sent_tokenize(text_input)
@@ -135,15 +137,18 @@ def process_input(text_input, labels_or_premise, mode):
135
 
136
  for sentence in sentences:
137
  sent_pred = nli_classifier([{"text": sentence, "text_pair": labels_or_premise}], return_all_scores=True)[0]
138
- sent_scores = {pred['label']: pred['score'] for pred in sent_pred}
139
- max_label = max(sent_scores.items(), key=lambda x: x[1])[0]
 
 
 
140
  sentence_results.append({
141
  'sentence': sentence,
142
  'prediction': max_label,
143
- 'scores': sent_scores
144
  })
145
 
146
- analysis_html = create_analysis_html(sentence_results, global_label,global_confidence)
147
  return global_results, analysis_html
148
 
149
  def update_interface(mode):
 
113
  html += "</table>"
114
  return html
115
 
116
+
117
  def process_input(text_input, labels_or_premise, mode):
118
  if mode == "Zero-Shot Classification":
119
  labels = [label.strip() for label in labels_or_premise.split(',')]
 
127
  else: # Long Context NLI
128
  # Global prediction
129
  global_pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
130
+ global_results = {p['label']: p['score'] for p in global_pred}
131
+ global_label = max(global_results.items(), key=lambda x: x[1])[0]
132
+ global_confidence = max(global_results.values())
133
 
134
  # Sentence-level analysis
135
  sentences = sent_tokenize(text_input)
 
137
 
138
  for sentence in sentences:
139
  sent_pred = nli_classifier([{"text": sentence, "text_pair": labels_or_premise}], return_all_scores=True)[0]
140
+ # Get the prediction and confidence for the sentence
141
+ pred_scores = [(p['label'], p['score']) for p in sent_pred]
142
+ max_pred = max(pred_scores, key=lambda x: x[1])
143
+ max_label, confidence = max_pred
144
+
145
  sentence_results.append({
146
  'sentence': sentence,
147
  'prediction': max_label,
148
+ 'confidence': confidence
149
  })
150
 
151
+ analysis_html = create_analysis_html(sentence_results, global_label, global_confidence)
152
  return global_results, analysis_html
153
 
154
  def update_interface(mode):