NoticIA-demo / app.py
Iker's picture
Update app.py
b23da3d verified
import datetime
import gradio as gr
import torch
from cache_system import CacheHandler
from header import article, header
from newspaper import Article
from prompts import summarize_clickbait_short_prompt
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
LogitsProcessorList,
TextStreamer,
)
from utils import StopAfterTokenIsGenerated
total_runs = 0
# Cargar el tokenizador
tokenizer = AutoTokenizer.from_pretrained("somosnlp/NoticIA-7B")
# Cargamos el modelo en 4 bits para usar menos VRAM
# Usamos bitsandbytes por que es lo más sencillo de implementar para la demo aunque no es ni lo más rápido ni lo más eficiente
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
"somosnlp/NoticIA-7B",
torch_dtype=torch.bfloat16,
device_map="auto",
#quantization_config=quantization_config,
)
print(f"Model loaded in {model.device}")
# Parámetros de generación.
generation_config = GenerationConfig(
max_new_tokens=128, # Los resúmenes son cortos, no necesitamos más tokens
min_new_tokens=1, # No queremos resúmenes vacíos
do_sample=True, # Un poquito mejor que greedy sampling
num_beams=1,
use_cache=True, # Eficiencia
top_k=40,
top_p=0.1,
repetition_penalty=1.1, # Ayuda a evitar que el modelo entre en bucles
encoder_repetition_penalty=1.1, # Favorecemos que el modelo cite el texto original
temperature=0.15, # temperature baja para evitar que el modelo genere texto muy creativo.
)
# Stop words, para evitar que el modelo genere tokens que no queremos.
stop_words = [
"<s>",
"</s>",
"\\n",
"[/INST]",
"[INST]",
"### User:",
"### Assistant:",
"###",
"<start_of_turn>",
"<end_of_turn>",
"<end_of_turn>\\n",
"<eos>",
]
# Creamos un logits processor para detener la generación cuando el modelo genere un stop word
stop_criteria = LogitsProcessorList(
[
StopAfterTokenIsGenerated(
stops=[
torch.tensor(tokenizer.encode(stop_word, add_special_tokens=False))
for stop_word in stop_words.copy()
],
eos_token_id=tokenizer.eos_token_id,
)
]
)
def generate_text(url: str) -> (str, str):
"""
Dada una URL de una noticia, genera un resumen de una sola frase que revela la verdad detrás del titular.
Args:
url (str): URL de la noticia.
Returns:
str: Titular de la noticia.
str: Resumen de la noticia.
"""
global cache_handler
global total_runs
total_runs += 1
print(f"Total runs: {total_runs}. Last run: {datetime.datetime.now()}")
url = url.strip()
if url.startswith("https://twitter.com/") or url.startswith("https://x.com/"):
yield (
"🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.",
"❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌",
"Error",
)
return (
"🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.",
"❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌",
"Error",
)
# 1) Download the article
# progress(0, desc="🤖 Accediendo a la noticia")
# First, check if the URL is in the cache
headline, text, resumen = cache_handler.get_from_cache(url, 0)
if headline is not None and text is not None and resumen is not None:
yield headline, resumen
return headline, resumen
else:
try:
article = Article(url)
article.download()
article.parse()
headline = article.title
text = article.text
except Exception as e:
print(e)
headline = None
text = None
if headline is None or text is None:
yield (
"🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.",
"❌❌❌ Inténtalo de nuevo ❌❌❌",
"Error",
)
return (
"🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.",
"❌❌❌ Inténtalo de nuevo ❌❌❌",
"Error",
)
# progress(0.5, desc="🤖 Leyendo noticia")
try:
prompt = summarize_clickbait_short_prompt(headline=headline, body=text)
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer(
[formatted_prompt], return_tensors="pt", add_special_tokens=False
)
streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True)
model_output = model.generate(
**model_inputs.to(model.device),
streamer=streamer,
generation_config=generation_config,
logits_processor=stop_criteria,
)
yield headline, streamer
resumen = tokenizer.batch_decode(
model_output,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)[0].replace("<|end_of_turn|>", "")
resumen = resumen.split("GPT4 Correct Assistant:")[-1]
except Exception as e:
print(e)
yield (
"🤖 Error en la generación.",
"❌❌❌ Inténtalo de nuevo más tarde ❌❌❌",
"Error",
)
return (
"🤖 Error en la generación.",
"❌❌❌ Inténtalo de nuevo más tarde ❌❌❌",
"Error",
)
cache_handler.add_to_cache(
url=url, title=headline, text=text, summary_type=0, summary=resumen
)
yield headline, resumen
hits, misses, cache_len = cache_handler.get_cache_stats()
print(
f"Hits: {hits}, misses: {misses}, cache length: {cache_len}. Percent hits: {round(hits/(hits+misses)*100,2)}%."
)
return headline, resumen
# Usamos una cache para guardar las últimas URL procesadas
# Los usuarios seguramente introducirán en un mismo día la misma URL varias veces, por que
# diferentes personas querrán ver el resumen de la misma noticia.
# La cache se encarga de guardar los resúmenes de las noticias para que no tengamos que volver a generarlos.
# La cache tiene un tamaño máximo de 1000 elementos, cuando se llena, se elimina el elemento más antiguo.
cache_handler = CacheHandler(max_cache_size=1000)
demo = gr.Interface(
generate_text,
inputs=[
gr.Textbox(
label="🌐 URL de la noticia",
info="Introduce la URL de la noticia que deseas resumir.",
value="https://somosnlp.org/",
interactive=True,
)
],
outputs=[
gr.Textbox(
label="📰 Titular de la noticia",
interactive=False,
placeholder="Aquí aparecerá el título de la noticia",
),
gr.Textbox(
label="🗒️ Resumen",
interactive=False,
placeholder="Aquí aparecerá el resumen de la noticia.",
),
],
# headline="⚔️ Clickbait Fighter! ⚔️",
thumbnail="https://huggingface.co/datasets/Iker/NoticIA/resolve/main/assets/logo.png",
theme="JohnSmith9982/small_and_pretty",
description=header,
article=article,
cache_examples=False,
concurrency_limit=1,
examples=[
"https://www.huffingtonpost.es/virales/le-compra-abrigo-abuela-97nos-reaccion-fantasia.html",
"https://emisorasunidas.com/2023/12/29/que-pasara-el-15-de-enero-de-2024/",
"https://www.huffingtonpost.es/virales/llega-espana-le-llama-atencion-nombres-propios-persona.html",
"https://www.infobae.com/que-puedo-ver/2023/11/19/la-comedia-familiar-y-navidena-que-ya-esta-en-netflix-y-puedes-ver-en-estas-fiestas/",
"https://www.cope.es/n/1610984",
],
submit_btn="Generar resumen",
stop_btn="Detener generación",
clear_btn="Limpiar",
allow_flagging=False,
)
demo.queue(max_size=None)
demo.launch(share=False)