Update app.py
Browse files
app.py
CHANGED
@@ -113,6 +113,7 @@ def create_analysis_html(sentence_results, global_label, global_confidence):
|
|
113 |
html += "</table>"
|
114 |
return html
|
115 |
|
|
|
116 |
def process_input(text_input, labels_or_premise, mode):
|
117 |
if mode == "Zero-Shot Classification":
|
118 |
labels = [label.strip() for label in labels_or_premise.split(',')]
|
@@ -126,8 +127,9 @@ def process_input(text_input, labels_or_premise, mode):
|
|
126 |
else: # Long Context NLI
|
127 |
# Global prediction
|
128 |
global_pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
|
129 |
-
global_results = {
|
130 |
-
global_label
|
|
|
131 |
|
132 |
# Sentence-level analysis
|
133 |
sentences = sent_tokenize(text_input)
|
@@ -135,15 +137,18 @@ def process_input(text_input, labels_or_premise, mode):
|
|
135 |
|
136 |
for sentence in sentences:
|
137 |
sent_pred = nli_classifier([{"text": sentence, "text_pair": labels_or_premise}], return_all_scores=True)[0]
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
140 |
sentence_results.append({
|
141 |
'sentence': sentence,
|
142 |
'prediction': max_label,
|
143 |
-
'
|
144 |
})
|
145 |
|
146 |
-
analysis_html = create_analysis_html(sentence_results, global_label,global_confidence)
|
147 |
return global_results, analysis_html
|
148 |
|
149 |
def update_interface(mode):
|
|
|
113 |
html += "</table>"
|
114 |
return html
|
115 |
|
116 |
+
|
117 |
def process_input(text_input, labels_or_premise, mode):
|
118 |
if mode == "Zero-Shot Classification":
|
119 |
labels = [label.strip() for label in labels_or_premise.split(',')]
|
|
|
127 |
else: # Long Context NLI
|
128 |
# Global prediction
|
129 |
global_pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
|
130 |
+
global_results = {p['label']: p['score'] for p in global_pred}
|
131 |
+
global_label = max(global_results.items(), key=lambda x: x[1])[0]
|
132 |
+
global_confidence = max(global_results.values())
|
133 |
|
134 |
# Sentence-level analysis
|
135 |
sentences = sent_tokenize(text_input)
|
|
|
137 |
|
138 |
for sentence in sentences:
|
139 |
sent_pred = nli_classifier([{"text": sentence, "text_pair": labels_or_premise}], return_all_scores=True)[0]
|
140 |
+
# Get the prediction and confidence for the sentence
|
141 |
+
pred_scores = [(p['label'], p['score']) for p in sent_pred]
|
142 |
+
max_pred = max(pred_scores, key=lambda x: x[1])
|
143 |
+
max_label, confidence = max_pred
|
144 |
+
|
145 |
sentence_results.append({
|
146 |
'sentence': sentence,
|
147 |
'prediction': max_label,
|
148 |
+
'confidence': confidence
|
149 |
})
|
150 |
|
151 |
+
analysis_html = create_analysis_html(sentence_results, global_label, global_confidence)
|
152 |
return global_results, analysis_html
|
153 |
|
154 |
def update_interface(mode):
|