import datetime import gradio as gr import torch from cache_system import CacheHandler from header import article, header from newspaper import Article from prompts import summarize_clickbait_short_prompt from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig, LogitsProcessorList, TextStreamer, ) from utils import StopAfterTokenIsGenerated total_runs = 0 # Cargar el tokenizador tokenizer = AutoTokenizer.from_pretrained("somosnlp/NoticIA-7B") # Cargamos el modelo en 4 bits para usar menos VRAM # Usamos bitsandbytes por que es lo más sencillo de implementar para la demo aunque no es ni lo más rápido ni lo más eficiente quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( "somosnlp/NoticIA-7B", torch_dtype=torch.bfloat16, device_map="auto", quantization_config=quantization_config, ) print(f"Model loaded in {model.device}") # Parámetros de generación. generation_config = GenerationConfig( max_new_tokens=128, # Los resúmenes son cortos, no necesitamos más tokens min_new_tokens=1, # No queremos resúmenes vacíos do_sample=True, # Un poquito mejor que greedy sampling num_beams=1, use_cache=True, # Eficiencia top_k=40, top_p=0.1, repetition_penalty=1.1, # Ayuda a evitar que el modelo entre en bucles encoder_repetition_penalty=1.1, # Favorecemos que el modelo cite el texto original resumenerature=0.15, # resumeneratura baja para evitar que el modelo genere texto muy creativo. ) # Stop words, para evitar que el modelo genere tokens que no queremos. stop_words = [ "", "", "\\n", "[/INST]", "[INST]", "### User:", "### Assistant:", "###", "", "", "\\n", "", ] # Creamos un logits processor para detener la generación cuando el modelo genere un stop word stop_criteria = LogitsProcessorList( [ StopAfterTokenIsGenerated( stops=[ torch.tensor(tokenizer.encode(stop_word, add_special_tokens=False)) for stop_word in stop_words.copy() ], eos_token_id=tokenizer.eos_token_id, ) ] ) def generate_text(url: str) -> (str, str): """ Dada una URL de una noticia, genera un resumen de una sola frase que revela la verdad detrás del titular. Args: url (str): URL de la noticia. Returns: str: Titular de la noticia. str: Resumen de la noticia. """ global cache_handler global total_runs total_runs += 1 print(f"Total runs: {total_runs}. Last run: {datetime.datetime.now()}") url = url.strip() if url.startswith("https://twitter.com/") or url.startswith("https://x.com/"): yield ( "🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.", "❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌", "Error", ) return ( "🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.", "❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌", "Error", ) # 1) Download the article # progress(0, desc="🤖 Accediendo a la noticia") # First, check if the URL is in the cache headline, text, resumen = cache_handler.get_from_cache(url, 0) if headline is not None and text is not None and resumen is not None: yield headline, resumen return headline, resumen else: try: article = Article(url) article.download() article.parse() headline = article.title text = article.text except Exception as e: print(e) headline = None text = None if headline is None or text is None: yield ( "🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.", "❌❌❌ Inténtalo de nuevo ❌❌❌", "Error", ) return ( "🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.", "❌❌❌ Inténtalo de nuevo ❌❌❌", "Error", ) # progress(0.5, desc="🤖 Leyendo noticia") try: prompt = summarize_clickbait_short_prompt(headline=headline, body=text) formatted_prompt = tokenizer.apply_chat_template( [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True, ) model_inputs = tokenizer( [formatted_prompt], return_tensors="pt", add_special_tokens=False ) streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True) model_output = model.generate( **model_inputs.to(model.device), streamer=streamer, generation_config=generation_config, logits_processor=stop_criteria, ) yield headline, streamer resumen = tokenizer.batch_decode( model_output, skip_special_tokens=True, clean_up_tokenization_spaces=True, )[0].replace("<|end_of_turn|>", "") resumen = resumen.split("GPT4 Correct Assistant:")[-1] except Exception as e: print(e) yield ( "🤖 Error en la generación.", "❌❌❌ Inténtalo de nuevo más tarde ❌❌❌", "Error", ) return ( "🤖 Error en la generación.", "❌❌❌ Inténtalo de nuevo más tarde ❌❌❌", "Error", ) cache_handler.add_to_cache( url=url, title=headline, text=text, summary_type=0, summary=resumen ) yield headline, resumen hits, misses, cache_len = cache_handler.get_cache_stats() print( f"Hits: {hits}, misses: {misses}, cache length: {cache_len}. Percent hits: {round(hits/(hits+misses)*100,2)}%." ) return headline, resumen # Usamos una cache para guardar las últimas URL procesadas # Los usuarios seguramente introducirán en un mismo día la misma URL varias veces, por que # diferentes personas querrán ver el resumen de la misma noticia. # La cache se encarga de guardar los resúmenes de las noticias para que no tengamos que volver a generarlos. # La cache tiene un tamaño máximo de 1000 elementos, cuando se llena, se elimina el elemento más antiguo. cache_handler = CacheHandler(max_cache_size=1000) demo = gr.Interface( generate_text, inputs=[ gr.Textbox( label="🌐 URL de la noticia", info="Introduce la URL de la noticia que deseas resumir.", value="https://somosnlp.org/", interactive=True, ) ], outputs=[ gr.Textbox( label="📰 Titular de la noticia", interactive=False, placeholder="Aquí aparecerá el título de la noticia", ), gr.Textbox( label="🗒️ Resumen", interactive=False, placeholder="Aquí aparecerá el resumen de la noticia.", ), ], # headline="⚔️ Clickbait Fighter! ⚔️", thumbnail="https://huggingface.co/datasets/Iker/NoticIA/resolve/main/assets/logo.png", theme="JohnSmith9982/small_and_pretty", description=header, article=article, cache_examples=False, concurrency_limit=1, examples=[ "https://www.huffingtonpost.es/virales/le-compra-abrigo-abuela-97nos-reaccion-fantasia.html", "https://emisorasunidas.com/2023/12/29/que-pasara-el-15-de-enero-de-2024/", "https://www.huffingtonpost.es/virales/llega-espana-le-llama-atencion-nombres-propios-persona.html", "https://www.infobae.com/que-puedo-ver/2023/11/19/la-comedia-familiar-y-navidena-que-ya-esta-en-netflix-y-puedes-ver-en-estas-fiestas/", "https://www.cope.es/n/1610984", ], submit_btn="Generar resumen", stop_btn="Detener generación", clear_btn="Limpiar", allow_flagging=False, ) demo.queue(max_size=None) demo.launch(share=False)