yuntian-deng commited on
Commit
4daf261
·
verified ·
1 Parent(s): 31dc2a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -13,16 +13,16 @@ if torch.cuda.is_available():
13
  model.cuda()
14
 
15
  validation_results = json.load(open('validation_results.json'))
16
- scores, thresholds, precisions = validation_results['scores'], validation_results['thresholds'], validation_results['precisions']
17
 
18
  def get_threshold_precision(score_):
19
- for score, threshold, precision in zip(scores, thresholds, precisions):
20
  if score_ < score:
21
  break
22
- prev_score, prev_threshold, prev_precision = score, threshold, precision
23
  if prev_threshold == prev_score:
24
  prev_threshold = score_
25
- return prev_threshold, prev_precision
26
 
27
  def normalize_spaces(text):
28
  return re.sub(r'\s+', ' ', text).strip()
@@ -57,9 +57,9 @@ def predict(title, authors, abstract):
57
  #selected = [d for d in validation_data if d['score'] >= score]
58
  #true_positives = sum(1 for d in selected if d['label'] == 1)
59
  #precision = true_positives / len(selected) if selected else 0
60
- threshold, precision = get_threshold_precision(score)
61
 
62
- result = f"Your score: {score:.2f}.\nFor papers with score >= {threshold:.2f}, {precision * 100:.2f}% are selected by AK."
63
 
64
  return score, result
65
 
 
13
  model.cuda()
14
 
15
  validation_results = json.load(open('validation_results.json'))
16
+ scores, thresholds, precisions, recalls = validation_results['scores'], validation_results['thresholds'], validation_results['precisions'], validation_results['recalls']
17
 
18
  def get_threshold_precision(score_):
19
+ for score, threshold, precision, recall in zip(scores, thresholds, precisions, recalls):
20
  if score_ < score:
21
  break
22
+ prev_score, prev_threshold, prev_precision, prev_recall = score, threshold, precision, recall
23
  if prev_threshold == prev_score:
24
  prev_threshold = score_
25
+ return prev_threshold, prev_precision, prev_recall
26
 
27
  def normalize_spaces(text):
28
  return re.sub(r'\s+', ' ', text).strip()
 
57
  #selected = [d for d in validation_data if d['score'] >= score]
58
  #true_positives = sum(1 for d in selected if d['label'] == 1)
59
  #precision = true_positives / len(selected) if selected else 0
60
+ threshold, precision, recall = get_threshold_precision(score)
61
 
62
+ result = f"Your score: {score:.2f}.\nFor papers with score>={threshold:.2f}, {precision * 100:.2f}% are selected by AK.\nFor papers selected by AK, {recall * 100:.2f}% have score>={threshold:.2f}"
63
 
64
  return score, result
65