Intae commited on
Commit
17b7b4b
·
1 Parent(s): 6b01608

Fix app.py

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app.py +43 -10
  3. best.data-00000-of-00001 +2 -2
  4. best.index +1 -1
  5. requirements.txt +0 -1
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea
app.py CHANGED
@@ -2,12 +2,10 @@ import os
2
  import numpy as np
3
  import tensorflow as tf
4
  import pandas as pd
 
5
  import time
6
 
7
  from recommenders.models.sasrec.model import SASREC
8
- from tabulate import tabulate
9
-
10
- import streamlit as st
11
 
12
 
13
  class SASREC_Vessl(SASREC):
@@ -56,6 +54,16 @@ class SASREC_Vessl(SASREC):
56
  return predictions
57
 
58
 
 
 
 
 
 
 
 
 
 
 
59
  def load_model():
60
  model_config = {
61
  "MAXLEN": 50,
@@ -81,19 +89,44 @@ def load_model():
81
  num_neg_test=model_config.get("NUM_NEG_TEST"),
82
  )
83
 
84
- if os.path.isfile('best.index') and os.path.isfile('best.data-00000-of-00001'):
 
85
  model.load_weights('best').expect_partial()
86
 
87
  return model
88
 
89
 
90
- def main():
91
- st.title('Self-Attentive Sequential Recommendation(SASRec)')
92
- model = load_model()
93
- st.write(model)
94
 
95
- numbers = st.text_input
96
- st.write(numbers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
 
99
  if __name__ == '__main__':
 
2
  import numpy as np
3
  import tensorflow as tf
4
  import pandas as pd
5
+ import streamlit as st
6
  import time
7
 
8
  from recommenders.models.sasrec.model import SASREC
 
 
 
9
 
10
 
11
  class SASREC_Vessl(SASREC):
 
54
  return predictions
55
 
56
 
57
+ def elapsed_time(fn, *args):
58
+ start = time.time()
59
+ output = fn(*args)
60
+ end = time.time()
61
+
62
+ elapsed = f'{end - start:.2f}'
63
+
64
+ return elapsed, output
65
+
66
+
67
  def load_model():
68
  model_config = {
69
  "MAXLEN": 50,
 
89
  num_neg_test=model_config.get("NUM_NEG_TEST"),
90
  )
91
 
92
+ if os.path.isfile('best.index') and os.path.isfile(
93
+ 'best.data-00000-of-00001'):
94
  model.load_weights('best').expect_partial()
95
 
96
  return model
97
 
98
 
99
+ def postprocess_data(data):
100
+ predictions = -1 * data
101
+ rec_items = predictions.argsort()[:5]
 
102
 
103
+ dic_result = {
104
+ "Rank": [i for i in range(1, 6)],
105
+ "ItemID": list(rec_items + 1),
106
+ "Similarity Score": -1 * predictions[rec_items]
107
+ }
108
+ result = pd.DataFrame(dic_result)
109
+
110
+ time.sleep(0.5)
111
+
112
+ best_item = rec_items[0] + 1
113
+
114
+ return result, best_item
115
+
116
+
117
+ def main():
118
+ st.title("Self-Attentive Sequential Recommendation(SASRec)")
119
+ elapsed, model = elapsed_time(load_model)
120
+ st.write(f"Model is loaded in {elapsed} seconds!")
121
+
122
+ numbers = st.text_input(
123
+ label="Please write input items separated by comma. (e.g. 80, 70, 100, 1)")
124
+ if numbers:
125
+ integer_numbers = np.array(list(map(int, numbers.split(","))))
126
+ result = model.predict_next(integer_numbers)
127
+ table, best_item = postprocess_data(result)
128
+ st.table(table)
129
+ st.write(f"Best item is {best_item}")
130
 
131
 
132
  if __name__ == '__main__':
best.data-00000-of-00001 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:214d4120e8eb750023ccec9ef22d615cd0d9ae6028b605c622a2d452eec491b7
3
- size 5389888
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:faac3d239691b156d73bf8d6ae5fcfc7fdd9c1fe7952c43501c30109bc9a36d8
3
+ size 5390136
best.index CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:90232741029421cb9a2bf31e3eff972467f8613017d1af46ac0f36225bb34f9c
3
  size 2055
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d507d7b032877b335818718a0381ef2028512e01a90570c47ab3f678a75e5076
3
  size 2055
requirements.txt CHANGED
@@ -1,3 +1,2 @@
1
  tensorflow
2
- tabulate
3
  recommenders
 
1
  tensorflow
 
2
  recommenders