yhavinga commited on
Commit
a19a543
β€’
1 Parent(s): 3f553b1

Split sentences for model that requires it. Default beams to 1

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. generator.py +9 -2
app.py CHANGED
@@ -66,9 +66,9 @@ It was a quite young girl, unknown to me, with a hood over her head, and with la
66
  st.session_state["text"] = st.text_area(
67
  "Enter text", st.session_state.prompt_box, height=300
68
  )
69
- num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=6)
70
  num_beam_groups = st.sidebar.number_input(
71
- "Num beam groups", min_value=1, max_value=10, value=3
72
  )
73
  length_penalty = st.sidebar.number_input(
74
  "Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
@@ -97,7 +97,7 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
97
  time_end = time.time()
98
  time_diff = time_end - time_start
99
 
100
- st.write(result[0].replace("\n", " \n"))
101
  text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()])
102
  st.markdown(f" πŸ•™ *generated in {time_diff:.2f}s, `{text_line}`*")
103
 
 
66
  st.session_state["text"] = st.text_area(
67
  "Enter text", st.session_state.prompt_box, height=300
68
  )
69
+ num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1)
70
  num_beam_groups = st.sidebar.number_input(
71
+ "Num beam groups", min_value=1, max_value=10, value=1
72
  )
73
  length_penalty = st.sidebar.number_input(
74
  "Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
 
97
  time_end = time.time()
98
  time_diff = time_end - time_start
99
 
100
+ st.write(result.replace("\n", " \n"))
101
  text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()])
102
  st.markdown(f" πŸ•™ *generated in {time_diff:.2f}s, `{text_line}`*")
103
 
generator.py CHANGED
@@ -43,6 +43,7 @@ class Generator:
43
  self.model_name = model_name
44
  self.task = task
45
  self.desc = desc
 
46
  self.tokenizer = None
47
  self.model = None
48
  self.prefix = ""
@@ -92,8 +93,14 @@ class Generator:
92
  def generate(self, text: str, **generate_kwargs) -> (str, dict):
93
  # Replace two or more newlines with a single newline in text
94
  text = re.sub(r"\n{2,}", "\n", text)
95
-
96
  generate_kwargs = {**self.gen_kwargs, **generate_kwargs}
 
 
 
 
 
 
 
97
  batch_encoded = self.tokenizer(
98
  self.prefix + text,
99
  max_length=generate_kwargs["max_length"],
@@ -115,7 +122,7 @@ class Generator:
115
  pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
116
  for pred in decoded_preds
117
  ]
118
- return decoded_preds, generate_kwargs
119
 
120
  def __str__(self):
121
  return self.desc
 
43
  self.model_name = model_name
44
  self.task = task
45
  self.desc = desc
46
+ self.split_sentences = split_sentences
47
  self.tokenizer = None
48
  self.model = None
49
  self.prefix = ""
 
93
  def generate(self, text: str, **generate_kwargs) -> (str, dict):
94
  # Replace two or more newlines with a single newline in text
95
  text = re.sub(r"\n{2,}", "\n", text)
 
96
  generate_kwargs = {**self.gen_kwargs, **generate_kwargs}
97
+
98
+ # if there are newlines in the text, and the model needs line-splitting, split the text
99
+ if re.search(r"\n", text) and self.split_sentences:
100
+ lines = text.splitlines()
101
+ translated = [self.generate(line, **generate_kwargs)[0] for line in lines]
102
+ return "\n".join(translated), generate_kwargs
103
+
104
  batch_encoded = self.tokenizer(
105
  self.prefix + text,
106
  max_length=generate_kwargs["max_length"],
 
122
  pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
123
  for pred in decoded_preds
124
  ]
125
+ return decoded_preds[0], generate_kwargs
126
 
127
  def __str__(self):
128
  return self.desc