Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -13,16 +13,16 @@ if torch.cuda.is_available():
|
|
13 |
model.cuda()
|
14 |
|
15 |
validation_results = json.load(open('validation_results.json'))
|
16 |
-
scores, thresholds, precisions = validation_results['scores'], validation_results['thresholds'], validation_results['precisions']
|
17 |
|
18 |
def get_threshold_precision(score_):
|
19 |
-
for score, threshold, precision in zip(scores, thresholds, precisions):
|
20 |
if score_ < score:
|
21 |
break
|
22 |
-
prev_score, prev_threshold, prev_precision = score, threshold, precision
|
23 |
if prev_threshold == prev_score:
|
24 |
prev_threshold = score_
|
25 |
-
return prev_threshold, prev_precision
|
26 |
|
27 |
def normalize_spaces(text):
|
28 |
return re.sub(r'\s+', ' ', text).strip()
|
@@ -57,9 +57,9 @@ def predict(title, authors, abstract):
|
|
57 |
#selected = [d for d in validation_data if d['score'] >= score]
|
58 |
#true_positives = sum(1 for d in selected if d['label'] == 1)
|
59 |
#precision = true_positives / len(selected) if selected else 0
|
60 |
-
threshold, precision = get_threshold_precision(score)
|
61 |
|
62 |
-
result = f"Your score: {score:.2f}.\nFor papers with score
|
63 |
|
64 |
return score, result
|
65 |
|
|
|
13 |
model.cuda()
|
14 |
|
15 |
validation_results = json.load(open('validation_results.json'))
|
16 |
+
scores, thresholds, precisions, recalls = validation_results['scores'], validation_results['thresholds'], validation_results['precisions'], validation_results['recalls']
|
17 |
|
18 |
def get_threshold_precision(score_):
|
19 |
+
for score, threshold, precision, recall in zip(scores, thresholds, precisions, recalls):
|
20 |
if score_ < score:
|
21 |
break
|
22 |
+
prev_score, prev_threshold, prev_precision, prev_recall = score, threshold, precision, recall
|
23 |
if prev_threshold == prev_score:
|
24 |
prev_threshold = score_
|
25 |
+
return prev_threshold, prev_precision, prev_recall
|
26 |
|
27 |
def normalize_spaces(text):
|
28 |
return re.sub(r'\s+', ' ', text).strip()
|
|
|
57 |
#selected = [d for d in validation_data if d['score'] >= score]
|
58 |
#true_positives = sum(1 for d in selected if d['label'] == 1)
|
59 |
#precision = true_positives / len(selected) if selected else 0
|
60 |
+
threshold, precision, recall = get_threshold_precision(score)
|
61 |
|
62 |
+
result = f"Your score: {score:.2f}.\nFor papers with score>={threshold:.2f}, {precision * 100:.2f}% are selected by AK.\nFor papers selected by AK, {recall * 100:.2f}% have score>={threshold:.2f}"
|
63 |
|
64 |
return score, result
|
65 |
|