Iker commited on
Commit
12ea223
·
verified ·
1 Parent(s): a3284de

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +13 -6
  2. app.py +255 -0
  3. cache_system.py +61 -0
  4. header.py +57 -0
  5. prompts.py +31 -0
  6. requirements.txt +6 -0
  7. utils.py +54 -0
README.md CHANGED
@@ -1,13 +1,20 @@
1
  ---
2
- title: NoticIA Demo
3
- emoji: 🐢
4
- colorFrom: pink
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.24.0
8
- app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: NotiCIA
3
+ emoji: 📰
4
+ colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
 
 
7
  pinned: false
8
  license: apache-2.0
9
+ suggested_hardware: t4-small
10
+ suggested_storage: small
11
+ app_file: app.py
12
+ fullWidth: true
13
+ models:
14
+ - somosnlp/NoticIA-7B
15
+ tags:
16
+ - summarization
17
+ - clickbait
18
  ---
19
 
20
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from cache_system import CacheHandler
6
+ from header import article, header
7
+ from newspaper import Article
8
+ from prompts import summarize_clickbait_short_prompt
9
+ from transformers import (
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ BitsAndBytesConfig,
13
+ GenerationConfig,
14
+ LogitsProcessorList,
15
+ TextStreamer,
16
+ )
17
+ from utils import StopAfterTokenIsGenerated
18
+
19
+ total_runs = 0
20
+
21
+ # Cargar el tokenizador
22
+ tokenizer = AutoTokenizer.from_pretrained("somosnlp/NoticIA-7B")
23
+
24
+ # Cargamos el modelo en 4 bits para usar menos VRAM
25
+ # Usamos bitsandbytes por que es lo más sencillo de implementar para la demo aunque no es ni lo más rápido ni lo más eficiente
26
+ quantization_config = BitsAndBytesConfig(
27
+ load_in_4bit=True,
28
+ bnb_4bit_compute_dtype=torch.bfloat16,
29
+ bnb_4bit_use_double_quant=True,
30
+ )
31
+
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ "somosnlp/NoticIA-7B",
34
+ torch_dtype=torch.bfloat16,
35
+ device_map="auto",
36
+ quantization_config=quantization_config,
37
+ )
38
+
39
+ # Parámetros de generación.
40
+ generation_config = GenerationConfig(
41
+ max_new_tokens=128, # Los resúmenes son cortos, no necesitamos más tokens
42
+ min_new_tokens=1, # No queremos resúmenes vacíos
43
+ do_sample=True, # Un poquito mejor que greedy sampling
44
+ num_beams=1,
45
+ use_cache=True, # Eficiencia
46
+ top_k=40,
47
+ top_p=0.1,
48
+ repetition_penalty=1.1, # Ayuda a evitar que el modelo entre en bucles
49
+ encoder_repetition_penalty=1.1, # Favorecemos que el modelo cite el texto original
50
+ resumenerature=0.15, # resumeneratura baja para evitar que el modelo genere texto muy creativo.
51
+ )
52
+
53
+ # Stop words, para evitar que el modelo genere tokens que no queremos.
54
+ stop_words = [
55
+ "<s>",
56
+ "</s>",
57
+ "\\n",
58
+ "[/INST]",
59
+ "[INST]",
60
+ "### User:",
61
+ "### Assistant:",
62
+ "###",
63
+ "<start_of_turn>" "<end_of_turn>" "<end_of_turn>\n" "<end_of_turn>\\n",
64
+ "<eos>",
65
+ ]
66
+
67
+ # Creamos un logits processor para detener la generación cuando el modelo genere un stop word
68
+ stop_criteria = LogitsProcessorList(
69
+ [
70
+ StopAfterTokenIsGenerated(
71
+ stops=[
72
+ torch.tensor(tokenizer.encode(stop_word, add_special_tokens=False))
73
+ for stop_word in stop_words.copy()
74
+ ],
75
+ eos_token_id=tokenizer.eos_token_id,
76
+ )
77
+ ]
78
+ )
79
+
80
+
81
+ def generate_text(url: str) -> (str, str):
82
+ """
83
+ Dada una URL de una noticia, genera un resumen de una sola frase que revela la verdad detrás del titular.
84
+
85
+ Args:
86
+ url (str): URL de la noticia.
87
+
88
+ Returns:
89
+ str: Titular de la noticia.
90
+ str: Resumen de la noticia.
91
+ """
92
+ global cache_handler
93
+ global total_runs
94
+
95
+ total_runs += 1
96
+ print(f"Total runs: {total_runs}. Last run: {datetime.datetime.now()}")
97
+
98
+ url = url.strip()
99
+
100
+ if url.startswith("https://twitter.com/") or url.startswith("https://x.com/"):
101
+ yield (
102
+ "🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.",
103
+ "❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌",
104
+ "Error",
105
+ )
106
+ return (
107
+ "🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.",
108
+ "❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌",
109
+ "Error",
110
+ )
111
+
112
+ # 1) Download the article
113
+
114
+ # progress(0, desc="🤖 Accediendo a la noticia")
115
+
116
+ # First, check if the URL is in the cache
117
+ headline, text, resumen = cache_handler.get_from_cache(url, 0)
118
+ if headline is not None and text is not None and resumen is not None:
119
+ yield headline, resumen
120
+ return headline, resumen
121
+ else:
122
+ try:
123
+ article = Article(url)
124
+ article.download()
125
+ article.parse()
126
+ headline = article.title
127
+ text = article.text
128
+ except Exception as e:
129
+ print(e)
130
+ headline = None
131
+ text = None
132
+
133
+ if headline is None or text is None:
134
+ yield (
135
+ "🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.",
136
+ "❌❌❌ Inténtalo de nuevo ❌❌❌",
137
+ "Error",
138
+ )
139
+ return (
140
+ "🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.",
141
+ "❌❌❌ Inténtalo de nuevo ❌❌❌",
142
+ "Error",
143
+ )
144
+
145
+ # progress(0.5, desc="🤖 Leyendo noticia")
146
+
147
+ try:
148
+ prompt = summarize_clickbait_short_prompt(headline=headline, body=text)
149
+
150
+ formatted_prompt = tokenizer.apply_chat_template(
151
+ [{"role": "user", "content": prompt}],
152
+ tokenize=False,
153
+ add_generation_prompt=True,
154
+ )
155
+
156
+ model_inputs = tokenizer(
157
+ [formatted_prompt], return_tensors="pt", add_special_tokens=False
158
+ )
159
+
160
+ streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True)
161
+
162
+ model_output = model.generate(
163
+ **model_inputs.to(model.device),
164
+ streamer=streamer,
165
+ generation_config=generation_config,
166
+ logits_processor=stop_criteria,
167
+ )
168
+
169
+ yield headline, streamer
170
+
171
+ resumen = tokenizer.batch_decode(
172
+ model_output,
173
+ skip_special_tokens=True,
174
+ clean_up_tokenization_spaces=True,
175
+ )[0].replace("<|end_of_turn|>", "")
176
+
177
+ resumen = resumen.split("GPT4 Correct Assistant:")[-1]
178
+
179
+ except Exception as e:
180
+ print(e)
181
+ yield (
182
+ "🤖 Error en la generación.",
183
+ "❌❌❌ Inténtalo de nuevo más tarde ❌❌❌",
184
+ "Error",
185
+ )
186
+ return (
187
+ "🤖 Error en la generación.",
188
+ "❌❌❌ Inténtalo de nuevo más tarde ❌❌❌",
189
+ "Error",
190
+ )
191
+
192
+ cache_handler.add_to_cache(
193
+ url=url, title=headline, text=text, summary_type=0, summary=resumen
194
+ )
195
+ yield headline, resumen
196
+
197
+ hits, misses, cache_len = cache_handler.get_cache_stats()
198
+ print(
199
+ f"Hits: {hits}, misses: {misses}, cache length: {cache_len}. Percent hits: {round(hits/(hits+misses)*100,2)}%."
200
+ )
201
+ return headline, resumen
202
+
203
+
204
+ # Usamos una cache para guardar las últimas URL procesadas
205
+ # Los usuarios seguramente introducirán en un mismo día la misma URL varias veces, por que
206
+ # diferentes personas querrán ver el resumen de la misma noticia.
207
+ # La cache se encarga de guardar los resúmenes de las noticias para que no tengamos que volver a generarlos.
208
+ # La cache tiene un tamaño máximo de 1000 elementos, cuando se llena, se elimina el elemento más antiguo.
209
+ cache_handler = CacheHandler(max_cache_size=1000)
210
+
211
+
212
+ demo = gr.Interface(
213
+ generate_text,
214
+ inputs=[
215
+ gr.Textbox(
216
+ label="🌐 URL de la noticia",
217
+ info="Introduce la URL de la noticia que deseas resumir.",
218
+ value="https://somosnlp.org/",
219
+ interactive=True,
220
+ )
221
+ ],
222
+ outputs=[
223
+ gr.Textbox(
224
+ label="📰 Titular de la noticia",
225
+ interactive=False,
226
+ placeholder="Aquí aparecerá el título de la noticia",
227
+ ),
228
+ gr.Textbox(
229
+ label="🗒️ Resumen",
230
+ interactive=False,
231
+ placeholder="Aquí aparecerá el resumen de la noticia.",
232
+ ),
233
+ ],
234
+ # headline="⚔️ Clickbait Fighter! ⚔️",
235
+ thumbnail="https://huggingface.co/datasets/Iker/NoticIA/resolve/main/assets/logo.png",
236
+ theme="JohnSmith9982/small_and_pretty",
237
+ description=header,
238
+ article=article,
239
+ cache_examples=False,
240
+ concurrency_limit=1,
241
+ examples=[
242
+ "https://www.huffingtonpost.es/virales/le-compra-abrigo-abuela-97nos-reaccion-fantasia.html",
243
+ "https://emisorasunidas.com/2023/12/29/que-pasara-el-15-de-enero-de-2024/",
244
+ "https://www.huffingtonpost.es/virales/llega-espana-le-llama-atencion-nombres-propios-persona.html",
245
+ "https://www.infobae.com/que-puedo-ver/2023/11/19/la-comedia-familiar-y-navidena-que-ya-esta-en-netflix-y-puedes-ver-en-estas-fiestas/",
246
+ "https://www.cope.es/n/1610984",
247
+ ],
248
+ submit_btn="Generar resumen",
249
+ stop_btn="Detener generación",
250
+ clear_btn="Limpiar",
251
+ allow_flagging=False,
252
+ )
253
+
254
+ demo.queue(max_size=None)
255
+ demo.launch(share=False)
cache_system.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from datetime import datetime
3
+ from typing import Optional
4
+
5
+
6
+ class CacheHandler:
7
+ def __init__(self, max_cache_size: int = 1000):
8
+ # Using OrderedDict to maintain the order of insertion for efficient removal of oldest items
9
+ self.cache = OrderedDict()
10
+ self.max_cache_size = max_cache_size
11
+ self.misses = 0
12
+ self.hits = 0
13
+
14
+ def add_to_cache(
15
+ self, url: str, title: str, text: str, summary_type: int, summary: str
16
+ ):
17
+ # If URL already exists, update it and move it to the end to mark it as the most recently used
18
+ if url in self.cache:
19
+ self.cache.move_to_end(url)
20
+ self.cache[url][f"summary_{summary_type}"] = summary
21
+ self.cache[url]["date"] = datetime.now()
22
+ else:
23
+ # Add new entry to the cache
24
+ self.cache[url] = {
25
+ "title": title,
26
+ "text": text,
27
+ "date": datetime.now(),
28
+ "summary_0": summary if summary_type == 0 else None,
29
+ "summary_50": summary if summary_type == 50 else None,
30
+ "summary_100": summary if summary_type == 100 else None,
31
+ }
32
+ # Remove the oldest item if cache exceeds max size
33
+ if len(self.cache) > self.max_cache_size:
34
+ self.cache.move_to_end(
35
+ "https://ikergarcia1996.github.io/Iker-Garcia-Ferrero/"
36
+ ) # This is the default value in the demo, so we don't want to remove it
37
+ self.cache.popitem(last=False) # pop the oldest item
38
+
39
+ def get_from_cache(
40
+ self, url: str, summary_type: int, second_try: bool = False
41
+ ) -> Optional[tuple]:
42
+ if url in self.cache and self.cache[url][f"summary_{summary_type}"] is not None:
43
+ # Move the accessed item to the end to mark it as recently used
44
+ self.cache.move_to_end(url)
45
+ self.hits += 1
46
+ if second_try:
47
+ # In the first try we didn't get the cache hit, probably because it was a shortened URL
48
+ # So me decrease the number of misses, because we got the cache hit in the end
49
+ self.misses -= 1
50
+ return (
51
+ self.cache[url]["title"],
52
+ self.cache[url]["text"],
53
+ self.cache[url][f"summary_{summary_type}"],
54
+ )
55
+ else:
56
+ if not second_try:
57
+ self.misses += 1
58
+ return None, None, None
59
+
60
+ def get_cache_stats(self):
61
+ return self.hits, self.misses, len(self.cache)
header.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ header = """
2
+
3
+
4
+ <p align="center">
5
+ <img src="https://huggingface.co/datasets/Iker/NoticIA/resolve/main/assets/logo.png" style="width: 50%;">
6
+ </p>
7
+
8
+ <p align="justify">
9
+ Los artículos 🖱️Clickbait buscan captar la atención de los lectores mediante la curiosidad, utilizando titulares que plantean preguntas o afirmaciones incompletas, sensacionalistas, exageradas o directamente engañosas. Estos titulares a menudo esconden la respuesta al clickbait hasta el final del artículo, obligando al lector a desplazarse a través de un sinfín de contenido irrelevante. El verdadero objetivo es atraer visitas a la página para exponer al usuario a una cantidad máxima de publicidad, sacrificando la calidad y el valor informativo en el proceso.
10
+ </p>
11
+
12
+ ### ¿Por qué representa un problema?
13
+ <p align="justify">
14
+ La práctica del 🖱️Clickbait erosiona la confianza del público en las fuentes de noticias digitales y perjudica los ingresos publicitarios de los productores de contenido legítimo, que pueden experimentar una disminución en su tráfico web como resultado.
15
+ </p>
16
+
17
+ ### ¿Qué acciones hemos tomado para abordar este desafío?
18
+
19
+ - 📰 Hemos desarrollado NoticIA, una colección que incluye 850 artículos de noticias en español caracterizados por titulares clickbait. Cada artículo está acompañado de un resumen generativo de alta calidad y concisión, redactado por expertos humanos. Explora [🤗NoticIA-it](https://huggingface.co/datasets/somosnlp/NoticIA-it).
20
+ - 📈 Evaluamos decenas de modelos de inteligencia artificial en este conjunto de datos. Los resultados se pueden consultar aquí: [NoticIA Benchmark](https://huggingface.co/somosnlp/Resumen_Noticias_Clickbait/resolve/main/Results_finetune.png).
21
+ - 🤖 Entrenamos un avanzado modelo de lenguaje con 7 billones de parámetros específicamente con nuestro dataset, [🤗NoticIA-7B](https://huggingface.co/somosnlp/NoticIA-7B).
22
+
23
+ <p align="justify">
24
+ NoticIA ofrece un escenario ideal para probar la habilidad de los modelos de lenguaje en la comprensión de textos en español. Esta tarea es compleja que discernir la pregunta oculta en un titular clickbait o identificar la información que realmente busca el usuario. Este reto implica filtrar grandes volúmenes de contenido superfluo para hallar y resumir de manera precisa y sucinta la información relevante.
25
+ </p>
26
+
27
+ ## ¿Cómo funciona esta demo?
28
+ <p align="justify">
29
+ Solo introduce la URL de un artículo clickbait en el campo de texto y haz clic en el botón "Generar resumen" para probarla.
30
+ </p>
31
+
32
+ ## Mirando hacia el futuro
33
+ - 📚 Planeamos expandir NoticIA con aún más artículos clickbait.
34
+ - 🔮 Introduciremos etiquetas adicionales al conjunto de datos, incluyendo métricas que cuantifiquen el grado de clickbait de los artículos.
35
+ - 📔 Estamos preparando un artículo para profundizar en los hallazgos y metodologías de nuestro proyecto.
36
+
37
+ """.strip()
38
+
39
+
40
+ article = """
41
+
42
+ Esta demo ha sido creada por [Iker García-Ferrero](https://ikergarcia1996.github.io/Iker-Garcia-Ferrero/) y [Begoña Altuna](https://www.linkedin.com/in/bego%C3%B1a-altuna-78014139). Somos investigadores en PLN en la Universidad del País Vasco, dentro del grupo de investigación [IXA](https://www.ixa.eus/) y formamos parte de [HiTZ, el Centro Vasco de Tecnología de la Lengua](https://www.hitz.eus/es).
43
+
44
+
45
+ <div style="display: flex; justify-content: space-around; width: 100%;">
46
+ <div style="width: 50%;" align="left">
47
+ <a href="http://ixa.si.ehu.es/">
48
+ <img src="https://raw.githubusercontent.com/ikergarcia1996/Iker-Garcia-Ferrero/master/icons/ixa.png" width="50" height="50" alt="Ixa NLP Group">
49
+ </a>
50
+ </div>
51
+ <div style="width: 50%;" align="right">
52
+ <a href="http://www.hitz.eus/">
53
+ <img src="https://raw.githubusercontent.com/ikergarcia1996/Iker-Garcia-Ferrero/master/icons/Hitz.png" width="300" height="50" alt="HiTZ Basque Center for Language Technologies">
54
+ </a>
55
+ </div>
56
+ </div>
57
+ """.strip()
prompts.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def summarize_clickbait_short_prompt(
2
+ headline: str,
3
+ body: str,
4
+ ) -> str:
5
+ """
6
+ Generate the prompt for the model.
7
+
8
+ Args:
9
+ headline (`str`):
10
+ The headline of the article.
11
+ body (`str`):
12
+ The body of the article.
13
+ Returns:
14
+ `str`: The formatted prompt.
15
+ """
16
+
17
+ return (
18
+ f"Ahora eres una Inteligencia Artificial experta en desmontar titulares sensacionalistas o clickbait. "
19
+ f"Tu tarea consiste en analizar noticias con titulares sensacionalistas y "
20
+ f"generar un resumen de una sola frase que revele la verdad detrás del titular.\n"
21
+ f"Este es el titular de la noticia: {headline}\n"
22
+ f"El titular plantea una pregunta o proporciona información incompleta. "
23
+ f"Debes buscar en el cuerpo de la noticia una frase que responda lo que se sugiere en el título. "
24
+ f"Siempre que puedas cita el texto original, especialmente si se trata de una frase que alguien ha dicho. "
25
+ f"Si citas una frase que alguien ha dicho, usa comillas para indicar que es una cita. "
26
+ f"Usa siempre las mínimas palabras posibles. No es necesario que la respuesta sea una oración completa. "
27
+ f"Puede ser sólo el foco de la pregunta. "
28
+ f"Recuerda responder siempre en Español.\n"
29
+ f"Este es el cuerpo de la noticia:\n"
30
+ f"{body}"
31
+ )
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ setuptools
2
+ gradio
3
+ transformers
4
+ numpy
5
+ bitsandbytes
6
+ newspaper3k
utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List
3
+
4
+ import torch
5
+ from transformers import (
6
+ LogitsProcessor,
7
+ )
8
+
9
+
10
+ class StopAfterTokenIsGenerated(LogitsProcessor):
11
+ def __init__(self, stops: List[torch.tensor], eos_token_id: int):
12
+ super().__init__()
13
+
14
+ self.stops = stops
15
+ self.eos_token_id = eos_token_id
16
+ logging.info(f"Stopping criteria words ids: {self.stops}")
17
+ self.first_batch = True
18
+
19
+ def __call__(
20
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
21
+ ) -> torch.FloatTensor:
22
+ """
23
+ Args:
24
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
25
+ Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
26
+ scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
27
+ Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
28
+ search or log softmax for each vocabulary token when using beam search
29
+
30
+ Return:
31
+ `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
32
+
33
+ """
34
+ if self.first_batch:
35
+ self.first_batch = False
36
+ return scores
37
+
38
+ for seq_no, seq in enumerate(input_ids):
39
+ # logging.info(seq_no)
40
+ for stop in self.stops:
41
+ stop = stop.to(device=seq.device, dtype=seq.dtype)
42
+ if (
43
+ len(seq) >= len(stop)
44
+ and torch.all((stop == seq[-len(stop) :])).item()
45
+ ):
46
+ scores[seq_no, :] = -float("inf")
47
+ scores[seq_no, self.eos_token_id] = 0
48
+ # logging.info(f"Stopping criteria found: {stop}")
49
+ break
50
+
51
+ return scores
52
+
53
+ def reset(self):
54
+ self.first_batch = True