Spaces:
Runtime error
Runtime error
Thomas De Decker
commited on
Commit
·
f2f4fc6
1
Parent(s):
099b1c5
Add max input length
Browse files
app.py
CHANGED
@@ -12,9 +12,9 @@ from pipelines.keyphrase_generation_pipeline import KeyphraseGenerationPipeline
|
|
12 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
13 |
def load_pipeline(chosen_model):
|
14 |
if "keyphrase-extraction" in chosen_model:
|
15 |
-
return KeyphraseExtractionPipeline(chosen_model)
|
16 |
elif "keyphrase-generation" in chosen_model:
|
17 |
-
return KeyphraseGenerationPipeline(chosen_model)
|
18 |
|
19 |
|
20 |
def extract_keyphrases():
|
@@ -159,7 +159,12 @@ with st.form("keyphrase-extraction-form"):
|
|
159 |
)
|
160 |
|
161 |
st.session_state.input_text = (
|
162 |
-
st.text_area(
|
|
|
|
|
|
|
|
|
|
|
163 |
.replace("\n", " ")
|
164 |
.strip()
|
165 |
)
|
|
|
12 |
@st.cache(allow_output_mutation=True, show_spinner=False)
|
13 |
def load_pipeline(chosen_model):
|
14 |
if "keyphrase-extraction" in chosen_model:
|
15 |
+
return KeyphraseExtractionPipeline(chosen_model, truncation=True)
|
16 |
elif "keyphrase-generation" in chosen_model:
|
17 |
+
return KeyphraseGenerationPipeline(chosen_model, truncation=True)
|
18 |
|
19 |
|
20 |
def extract_keyphrases():
|
|
|
159 |
)
|
160 |
|
161 |
st.session_state.input_text = (
|
162 |
+
st.text_area(
|
163 |
+
"✍ Input",
|
164 |
+
st.session_state.config.get("example_text"),
|
165 |
+
height=250,
|
166 |
+
max_chars=2500,
|
167 |
+
)
|
168 |
.replace("\n", " ")
|
169 |
.strip()
|
170 |
)
|
pipelines/keyphrase_extraction_pipeline.py
CHANGED
@@ -11,9 +11,7 @@ class KeyphraseExtractionPipeline(TokenClassificationPipeline):
|
|
11 |
def __init__(self, model, *args, **kwargs):
|
12 |
super().__init__(
|
13 |
model=AutoModelForTokenClassification.from_pretrained(model),
|
14 |
-
tokenizer=AutoTokenizer.from_pretrained(
|
15 |
-
model, truncate=True
|
16 |
-
),
|
17 |
*args,
|
18 |
**kwargs
|
19 |
)
|
|
|
11 |
def __init__(self, model, *args, **kwargs):
|
12 |
super().__init__(
|
13 |
model=AutoModelForTokenClassification.from_pretrained(model),
|
14 |
+
tokenizer=AutoTokenizer.from_pretrained(model),
|
|
|
|
|
15 |
*args,
|
16 |
**kwargs
|
17 |
)
|
pipelines/keyphrase_generation_pipeline.py
CHANGED
@@ -1,14 +1,17 @@
|
|
1 |
import string
|
2 |
|
3 |
-
from transformers import (
|
4 |
-
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
class KeyphraseGenerationPipeline(Text2TextGenerationPipeline):
|
8 |
def __init__(self, model, keyphrase_sep_token=";", *args, **kwargs):
|
9 |
super().__init__(
|
10 |
model=AutoModelForSeq2SeqLM.from_pretrained(model),
|
11 |
-
tokenizer=AutoTokenizer.from_pretrained(model
|
12 |
*args,
|
13 |
**kwargs
|
14 |
)
|
|
|
1 |
import string
|
2 |
|
3 |
+
from transformers import (
|
4 |
+
AutoModelForSeq2SeqLM,
|
5 |
+
AutoTokenizer,
|
6 |
+
Text2TextGenerationPipeline,
|
7 |
+
)
|
8 |
|
9 |
|
10 |
class KeyphraseGenerationPipeline(Text2TextGenerationPipeline):
|
11 |
def __init__(self, model, keyphrase_sep_token=";", *args, **kwargs):
|
12 |
super().__init__(
|
13 |
model=AutoModelForSeq2SeqLM.from_pretrained(model),
|
14 |
+
tokenizer=AutoTokenizer.from_pretrained(model),
|
15 |
*args,
|
16 |
**kwargs
|
17 |
)
|