Split sentences for model that requires it. Default beams to 1
Browse files- app.py +3 -3
- generator.py +9 -2
app.py
CHANGED
@@ -66,9 +66,9 @@ It was a quite young girl, unknown to me, with a hood over her head, and with la
|
|
66 |
st.session_state["text"] = st.text_area(
|
67 |
"Enter text", st.session_state.prompt_box, height=300
|
68 |
)
|
69 |
-
num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=
|
70 |
num_beam_groups = st.sidebar.number_input(
|
71 |
-
"Num beam groups", min_value=1, max_value=10, value=
|
72 |
)
|
73 |
length_penalty = st.sidebar.number_input(
|
74 |
"Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
|
@@ -97,7 +97,7 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
|
|
97 |
time_end = time.time()
|
98 |
time_diff = time_end - time_start
|
99 |
|
100 |
-
st.write(result
|
101 |
text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()])
|
102 |
st.markdown(f" π *generated in {time_diff:.2f}s, `{text_line}`*")
|
103 |
|
|
|
66 |
st.session_state["text"] = st.text_area(
|
67 |
"Enter text", st.session_state.prompt_box, height=300
|
68 |
)
|
69 |
+
num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1)
|
70 |
num_beam_groups = st.sidebar.number_input(
|
71 |
+
"Num beam groups", min_value=1, max_value=10, value=1
|
72 |
)
|
73 |
length_penalty = st.sidebar.number_input(
|
74 |
"Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
|
|
|
97 |
time_end = time.time()
|
98 |
time_diff = time_end - time_start
|
99 |
|
100 |
+
st.write(result.replace("\n", " \n"))
|
101 |
text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()])
|
102 |
st.markdown(f" π *generated in {time_diff:.2f}s, `{text_line}`*")
|
103 |
|
generator.py
CHANGED
@@ -43,6 +43,7 @@ class Generator:
|
|
43 |
self.model_name = model_name
|
44 |
self.task = task
|
45 |
self.desc = desc
|
|
|
46 |
self.tokenizer = None
|
47 |
self.model = None
|
48 |
self.prefix = ""
|
@@ -92,8 +93,14 @@ class Generator:
|
|
92 |
def generate(self, text: str, **generate_kwargs) -> (str, dict):
|
93 |
# Replace two or more newlines with a single newline in text
|
94 |
text = re.sub(r"\n{2,}", "\n", text)
|
95 |
-
|
96 |
generate_kwargs = {**self.gen_kwargs, **generate_kwargs}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
batch_encoded = self.tokenizer(
|
98 |
self.prefix + text,
|
99 |
max_length=generate_kwargs["max_length"],
|
@@ -115,7 +122,7 @@ class Generator:
|
|
115 |
pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
|
116 |
for pred in decoded_preds
|
117 |
]
|
118 |
-
return decoded_preds, generate_kwargs
|
119 |
|
120 |
def __str__(self):
|
121 |
return self.desc
|
|
|
43 |
self.model_name = model_name
|
44 |
self.task = task
|
45 |
self.desc = desc
|
46 |
+
self.split_sentences = split_sentences
|
47 |
self.tokenizer = None
|
48 |
self.model = None
|
49 |
self.prefix = ""
|
|
|
93 |
def generate(self, text: str, **generate_kwargs) -> (str, dict):
|
94 |
# Replace two or more newlines with a single newline in text
|
95 |
text = re.sub(r"\n{2,}", "\n", text)
|
|
|
96 |
generate_kwargs = {**self.gen_kwargs, **generate_kwargs}
|
97 |
+
|
98 |
+
# if there are newlines in the text, and the model needs line-splitting, split the text
|
99 |
+
if re.search(r"\n", text) and self.split_sentences:
|
100 |
+
lines = text.splitlines()
|
101 |
+
translated = [self.generate(line, **generate_kwargs)[0] for line in lines]
|
102 |
+
return "\n".join(translated), generate_kwargs
|
103 |
+
|
104 |
batch_encoded = self.tokenizer(
|
105 |
self.prefix + text,
|
106 |
max_length=generate_kwargs["max_length"],
|
|
|
122 |
pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
|
123 |
for pred in decoded_preds
|
124 |
]
|
125 |
+
return decoded_preds[0], generate_kwargs
|
126 |
|
127 |
def __str__(self):
|
128 |
return self.desc
|