yuntian-deng commited on
Commit
40eb9ab
·
verified ·
1 Parent(s): 40b904c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -12,14 +12,21 @@ model.eval()
12
  if torch.cuda.is_available():
13
  model.cuda()
14
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def normalize_spaces(text):
17
  return re.sub(r'\s+', ' ', text).strip()
18
 
19
- # Load your validation set
20
- #with open('validation_data.json', 'r') as file:
21
- # validation_data = json.load(file)
22
-
23
  def fill_template(title, authors, abstract):
24
  title = normalize_spaces(title.replace('\n', ' '))
25
  authors = ', '.join([author.strip() for author in authors.split(',')])
@@ -50,9 +57,9 @@ def predict(title, authors, abstract):
50
  #selected = [d for d in validation_data if d['score'] >= score]
51
  #true_positives = sum(1 for d in selected if d['label'] == 1)
52
  #precision = true_positives / len(selected) if selected else 0
53
- precision = 0.2
54
 
55
- result = f"Your score: {score:.2f}.\nFor papers with score >= {score:.2f}, {precision * 100:.2f}% are selected by AK."
56
 
57
  return score, result
58
 
 
12
  if torch.cuda.is_available():
13
  model.cuda()
14
 
15
+ validation_results = json.load(open('validation_results.json'))
16
+ scores, thresholds, precisions = validation_results['scores'], validation_results['thresholds'], validation_results['precisions']
17
+
18
+ def get_threshold_precision(score_):
19
+ for score, threshold, precision in zip(all_scores, thresholds, precisions):
20
+ if score_ < score:
21
+ break
22
+ prev_score, prev_threshold, prev_precision = score, threshold, precision
23
+ if prev_threshold == prev_score:
24
+ prev_threshold = score_
25
+ return prev_threshold, prev_precision
26
 
27
  def normalize_spaces(text):
28
  return re.sub(r'\s+', ' ', text).strip()
29
 
 
 
 
 
30
  def fill_template(title, authors, abstract):
31
  title = normalize_spaces(title.replace('\n', ' '))
32
  authors = ', '.join([author.strip() for author in authors.split(',')])
 
57
  #selected = [d for d in validation_data if d['score'] >= score]
58
  #true_positives = sum(1 for d in selected if d['label'] == 1)
59
  #precision = true_positives / len(selected) if selected else 0
60
+ threshold, precision = get_threshold_precision(score)
61
 
62
+ result = f"Your score: {score:.2f}.\nFor papers with score >= {threshold:.2f}, {precision * 100:.2f}% are selected by AK."
63
 
64
  return score, result
65