Spaces:
Runtime error
Runtime error
- app.py +2 -11
- lrt/utils/functions.py +22 -1
- requirements.txt +2 -1
app.py
CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
|
|
2 |
from widgets import *
|
3 |
|
4 |
|
5 |
-
|
6 |
|
7 |
# sidebar content
|
8 |
platforms, number_papers, start_year, end_year, hyperparams = render_sidebar()
|
@@ -29,16 +29,7 @@ if submitted:
|
|
29 |
show_preview, start_year, end_year,
|
30 |
hyperparams,
|
31 |
hyperparams['standardization'])
|
32 |
-
|
33 |
-
# bar = (
|
34 |
-
# Bar()
|
35 |
-
# .add_xaxis(["Cluster 1", "Cluster 2", "Cluster 3", 'Cluster 4', 'Cluster 5'])
|
36 |
-
# .add_yaxis("numbers", [23, 16, 13, 12, 5])
|
37 |
-
# .set_global_opts(title_opts=opts.TitleOpts(title="Fake Data"))
|
38 |
-
# )
|
39 |
-
#
|
40 |
-
# components.html(generate_html_pyecharts(bar, 'tmp.html'), height=500, width=1000)
|
41 |
-
# '''
|
42 |
|
43 |
|
44 |
|
|
|
2 |
from widgets import *
|
3 |
|
4 |
|
5 |
+
|
6 |
|
7 |
# sidebar content
|
8 |
platforms, number_papers, start_year, end_year, hyperparams = render_sidebar()
|
|
|
29 |
show_preview, start_year, end_year,
|
30 |
hyperparams,
|
31 |
hyperparams['standardization'])
|
32 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
|
lrt/utils/functions.py
CHANGED
@@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,Text2TextGeneratio
|
|
7 |
from inference_hf import InferenceHF
|
8 |
from .dimension_reduction import PCA
|
9 |
from unsupervised_learning.clustering import GaussianMixture
|
|
|
10 |
|
11 |
class Template:
|
12 |
def __init__(self):
|
@@ -114,7 +115,7 @@ def __create_model__(model_ckpt):
|
|
114 |
|
115 |
return ret
|
116 |
|
117 |
-
elif model_ckpt == '
|
118 |
model_ckpt = template.keywords_extraction[model_ckpt]
|
119 |
def ret(texts: List[str]):
|
120 |
# first try inference API
|
@@ -154,6 +155,26 @@ def __create_model__(model_ckpt):
|
|
154 |
return results
|
155 |
|
156 |
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
else:
|
158 |
raise RuntimeError(f'The model {model_ckpt} is not supported. Please open an issue on the GitHub about the model.')
|
159 |
|
|
|
7 |
from inference_hf import InferenceHF
|
8 |
from .dimension_reduction import PCA
|
9 |
from unsupervised_learning.clustering import GaussianMixture
|
10 |
+
from models import KeyBartAdapter
|
11 |
|
12 |
class Template:
|
13 |
def __init__(self):
|
|
|
115 |
|
116 |
return ret
|
117 |
|
118 |
+
elif model_ckpt == 'KeyBart':
|
119 |
model_ckpt = template.keywords_extraction[model_ckpt]
|
120 |
def ret(texts: List[str]):
|
121 |
# first try inference API
|
|
|
155 |
return results
|
156 |
|
157 |
return ret
|
158 |
+
|
159 |
+
elif model_ckpt == 'KeyBartAdapter':
|
160 |
+
def ret(texts: List[str]):
|
161 |
+
model = KeyBartAdapter.from_pretrained('Adapting/KeyBartAdapter',revision='3aee5ecf1703b9955ab0cd1b23208cc54eb17fce', adapter_hid_dim=32)
|
162 |
+
tokenizer = AutoTokenizer.from_pretrained("bloomberg/KeyBART")
|
163 |
+
pipe = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer)
|
164 |
+
|
165 |
+
tmp = pipe(texts)
|
166 |
+
results = [
|
167 |
+
set(
|
168 |
+
map(str.strip,
|
169 |
+
x['generated_text'].split(';') # [str...]
|
170 |
+
)
|
171 |
+
)
|
172 |
+
for x in tmp] # [{str...}...]
|
173 |
+
|
174 |
+
return results
|
175 |
+
return ret
|
176 |
+
|
177 |
+
|
178 |
else:
|
179 |
raise RuntimeError(f'The model {model_ckpt} is not supported. Please open an issue on the GitHub about the model.')
|
180 |
|
requirements.txt
CHANGED
@@ -11,4 +11,5 @@ transformers==4.22.1
|
|
11 |
textdistance==4.5.0
|
12 |
datasets==2.5.2
|
13 |
bokeh==2.4.1
|
14 |
-
ml-leoxiang66
|
|
|
|
11 |
textdistance==4.5.0
|
12 |
datasets==2.5.2
|
13 |
bokeh==2.4.1
|
14 |
+
ml-leoxiang66
|
15 |
+
KeyBartAdapter
|