Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -12,14 +12,21 @@ model.eval()
|
|
12 |
if torch.cuda.is_available():
|
13 |
model.cuda()
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def normalize_spaces(text):
|
17 |
return re.sub(r'\s+', ' ', text).strip()
|
18 |
|
19 |
-
# Load your validation set
|
20 |
-
#with open('validation_data.json', 'r') as file:
|
21 |
-
# validation_data = json.load(file)
|
22 |
-
|
23 |
def fill_template(title, authors, abstract):
|
24 |
title = normalize_spaces(title.replace('\n', ' '))
|
25 |
authors = ', '.join([author.strip() for author in authors.split(',')])
|
@@ -50,9 +57,9 @@ def predict(title, authors, abstract):
|
|
50 |
#selected = [d for d in validation_data if d['score'] >= score]
|
51 |
#true_positives = sum(1 for d in selected if d['label'] == 1)
|
52 |
#precision = true_positives / len(selected) if selected else 0
|
53 |
-
precision =
|
54 |
|
55 |
-
result = f"Your score: {score:.2f}.\nFor papers with score >= {
|
56 |
|
57 |
return score, result
|
58 |
|
|
|
12 |
if torch.cuda.is_available():
|
13 |
model.cuda()
|
14 |
|
15 |
+
validation_results = json.load(open('validation_results.json'))
|
16 |
+
scores, thresholds, precisions = validation_results['scores'], validation_results['thresholds'], validation_results['precisions']
|
17 |
+
|
18 |
+
def get_threshold_precision(score_):
|
19 |
+
for score, threshold, precision in zip(all_scores, thresholds, precisions):
|
20 |
+
if score_ < score:
|
21 |
+
break
|
22 |
+
prev_score, prev_threshold, prev_precision = score, threshold, precision
|
23 |
+
if prev_threshold == prev_score:
|
24 |
+
prev_threshold = score_
|
25 |
+
return prev_threshold, prev_precision
|
26 |
|
27 |
def normalize_spaces(text):
|
28 |
return re.sub(r'\s+', ' ', text).strip()
|
29 |
|
|
|
|
|
|
|
|
|
30 |
def fill_template(title, authors, abstract):
|
31 |
title = normalize_spaces(title.replace('\n', ' '))
|
32 |
authors = ', '.join([author.strip() for author in authors.split(',')])
|
|
|
57 |
#selected = [d for d in validation_data if d['score'] >= score]
|
58 |
#true_positives = sum(1 for d in selected if d['label'] == 1)
|
59 |
#precision = true_positives / len(selected) if selected else 0
|
60 |
+
threshold, precision = get_threshold_precision(score)
|
61 |
|
62 |
+
result = f"Your score: {score:.2f}.\nFor papers with score >= {threshold:.2f}, {precision * 100:.2f}% are selected by AK."
|
63 |
|
64 |
return score, result
|
65 |
|