import os import re import threading import bm25s import gradio as gr import gradio_iframe import spaces from bm25s.hf import BM25HF from citations import inseq_citation, inseq_xai_citation, lxt_citation, mirage_citation, pecore_citation from examples import examples from lxt.functional import add2, mul2, softmax from lxt.models.llama import LlamaForCausalLM, attnlrp from rerankers import Reranker from style import custom_css from tqdm import tqdm from transformers import AutoTokenizer from inseq import load_model, register_step_function from inseq.attr import StepFunctionArgs from inseq.commands.attribute_context import visualize_attribute_context from inseq.commands.attribute_context.attribute_context import ( AttributeContextArgs, attribute_context_with_model, ) from inseq.utils.contrast_utils import _setup_contrast_args model = None model_id = "HuggingFaceTB/SmolLM-360M-Instruct" ranker = Reranker("answerdotai/answerai-colbert-small-v1", model_type="colbert") retriever = BM25HF.load_from_hub("xhluca/bm25s-nq-index", load_corpus=True, mmap=True) # Model registry to store loaded models model_registry = {} def get_model(model_size): model_id = f"HuggingFaceTB/SmolLM-{model_size}-Instruct" if model_id not in model_registry: hf_model = LlamaForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) attnlrp.register(hf_model) model = load_model(hf_model, "saliency", tokenizer=tokenizer) model.bos_token = "<|endoftext|>" model.bos_token_id = 0 model_registry[model_id] = model return model_registry[model_id] def lxt_probability_fn(args: StepFunctionArgs): logits = args.attribution_model.output2logits(args.forward_output) target_ids = args.target_ids.reshape(logits.shape[0], 1).to(logits.device) logits = softmax(logits, dim=-1) return logits.gather(-1, target_ids).squeeze(-1) def lxt_contrast_prob_fn( args: StepFunctionArgs, contrast_sources=None, contrast_targets=None, contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, contrast_force_inputs: bool = False, skip_special_tokens: bool = False, ): c_args = _setup_contrast_args( args, contrast_sources=contrast_sources, contrast_targets=contrast_targets, contrast_targets_alignments=contrast_targets_alignments, contrast_force_inputs=contrast_force_inputs, skip_special_tokens=skip_special_tokens, ) return lxt_probability_fn(c_args) def lxt_contrast_prob_diff_fn( args: StepFunctionArgs, contrast_sources=None, contrast_targets=None, contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, contrast_force_inputs: bool = False, skip_special_tokens: bool = False, ): model_probs = lxt_probability_fn(args) contrast_probs = lxt_contrast_prob_fn( args=args, contrast_sources=contrast_sources, contrast_targets=contrast_targets, contrast_targets_alignments=contrast_targets_alignments, contrast_force_inputs=contrast_force_inputs, skip_special_tokens=skip_special_tokens, ).to(model_probs.device) return add2(model_probs, mul2(contrast_probs, -1)) def set_interactive_settings(rag_setting, retrieve_k, top_k, custom_context): if rag_setting in ("Retrieve with BM25", "Rerank with ColBERT"): return ( gr.Slider(interactive=True), gr.Slider(interactive=True), gr.Textbox( placeholder="Context will be retrieved automatically. Change mode to 'Use Custom Context' to specify your own.", interactive=False, ), ) elif rag_setting == "Use Custom Context": return ( gr.Slider(interactive=False), gr.Slider(interactive=False), gr.Textbox(placeholder="Insert a custom context...", interactive=True), ) @spaces.GPU() def generate( query, max_new_tokens, top_p, temperature, retrieve_k, top_k, rag_setting, custom_context, model_size, progress=gr.Progress(track_tqdm=True), ): global model, model_id if rag_setting == "Use Custom Context": docs = custom_context.split("\n\n") progress(0.1, desc="Using custom context...") else: if not query: raise gr.Error("Please enter a query.") progress(0, desc="Retrieving with BM25...") q = bm25s.tokenize(query) results = retriever.retrieve(q, k=retrieve_k) if rag_setting == "Rerank with ColBERT": progress(0.1, desc="Reranking with ColBERT...") docs = [x["text"] for x in results.documents[0]] out = ranker.rank(query=query, docs=docs) docs = [out.results[i].document.text for i in range(top_k)] else: docs = [results.documents[0][i]["text"] for i in range(top_k)] docs = [re.sub(r"\[\d+\]", "", doc) for doc in docs] curr_model_id = f"HuggingFaceTB/SmolLM-{model_size}-Instruct" if model is None or model.model_name != curr_model_id: progress(0.2, desc="Loading model...") model = get_model(model_size) estimated_time = 20 tstep = 1 lm_rag_prompting_example = AttributeContextArgs( model_name_or_path=model_id, input_context_text="\n\n".join(docs), input_current_text=query, output_template="{current}", attributed_fn="lxt_contrast_prob_diff", input_template="<|im_start|>user\n### Context\n{context}\n\n### Query\n{current}<|im_end|>\n<|im_start|>assistant\n", contextless_input_current_text="<|im_start|>user\n### Query\n{current}<|im_end|>\n<|im_start|>assistant\n", attribution_method="saliency", show_viz=False, show_intermediate_outputs=False, context_sensitivity_std_threshold=1, decoder_input_output_separator=" ", special_tokens_to_keep=["<|im_start|>", "<|endoftext|>"], generation_kwargs={ "max_new_tokens": max_new_tokens, "top_p": top_p, "temperature": temperature, }, attribution_aggregators=["sum"], rescale_attributions=True, save_path=os.path.join(os.path.dirname(__file__), "outputs/output.json"), viz_path=os.path.join(os.path.dirname(__file__), "outputs/output.html"), ) ret = [None] def run_attribute_context(): ret[0] = attribute_context_with_model(lm_rag_prompting_example, model) thread = threading.Thread(target=run_attribute_context) pbar = tqdm(total=estimated_time, desc="Attributing with LXT...") thread.start() while thread.is_alive(): thread.join(timeout=tstep) pbar.update(tstep) pbar.close() out = ret[0] html = visualize_attribute_context(out, show_viz=False, return_html=True) return [ gradio_iframe.iFrame(html, height=500, visible=True), gr.DownloadButton( label="📂 Download output", value=os.path.join(os.path.dirname(__file__), "outputs/output.json"), visible=True, ), gr.DownloadButton( label="🔍 Download HTML", value=os.path.join(os.path.dirname(__file__), "outputs/output.html"), visible=True, ), ] register_step_function( lxt_contrast_prob_diff_fn, "lxt_contrast_prob_diff", overwrite=True ) with gr.Blocks(css=custom_css) as demo: with gr.Row(): with gr.Column(min_width=500): gr.HTML( '

' ) text = gr.Markdown( "This demo showcases an end-to-end usage of model internals for RAG answer attribution with the PECoRe framework, as described in our MIRAGE paper.
" "Insert a query to retrieve relevant contexts, generate an answer and attribute its context-sensitive components. An interactive Treescope visualization will appear in the green square.
" "📋 Retrieval is performed on Natural Questions using BM25S, with optional reranking via ColBERT." " SmolLM models are used for generation, while Inseq and LXT are used for attribution.
" "➡️ For more details, see also our PECoRe Demo", ) with gr.Row(): with gr.Column(): query = gr.Textbox( placeholder="Insert a query for the language model...", label="Model query", interactive=True, lines=2, ) btn = gr.Button("Submit", variant="primary") attribute_input_examples = gr.Examples( examples, inputs=[query], examples_per_page=2, ) with gr.Accordion("⚙️ Parameters", open=False): with gr.Row(): model_size = gr.Radio( ["135M", "360M", "1.7B"], value="360M", label="Model size", interactive=True, ) with gr.Row(): rag_setting = gr.Radio( [ "Retrieve with BM25", "Rerank with ColBERT", "Use Custom Context", ], value="Rerank with ColBERT", label="Mode", interactive=True, ) with gr.Row(): retrieve_k = gr.Slider( 1, 500, value=100, step=1, label="# Docs to Retrieve", interactive=True, ) top_k = gr.Slider( 1, 10, value=3, step=1, label="# Docs in Context", interactive=True, ) custom_context = gr.Textbox( placeholder="Context will be retrieved automatically. Change mode to 'Use Custom Context' to specify your own.", label="Custom context", interactive=False, lines=4, ) with gr.Row(): max_new_tokens = gr.Slider( 0, 500, value=50, step=5.0, label="Max new tokens", interactive=True, ) top_p = gr.Slider( 0, 1, value=1, step=0.01, label="Top P", interactive=True ) temperature = gr.Slider( 0, 1, value=0, step=0.01, label="Temperature", interactive=True ) with gr.Accordion("📝 Citation", open=False): gr.Markdown( "Using PECoRe for model internals-based RAG answer attribution is discussed in:" ) gr.Code( mirage_citation, interactive=False, label="MIRAGE (Qi, Sarti et al., 2024)", ) gr.Markdown("To refer to the original PECoRe paper, cite:") gr.Code( pecore_citation, interactive=False, label="PECoRe (Sarti et al., 2024)", ) gr.Markdown( 'The Inseq implementation used in this work (inseq attribute-context, including this demo) can be cited with:' ) gr.Code( inseq_citation, interactive=False, label="Inseq (Sarti et al., 2023)", ) gr.Code( inseq_xai_citation, interactive=False, label="Inseq v0.6 (Sarti et al., 2024)" ) gr.Markdown( "The AttnLRP attribution method used in this demo via the LXT library can be cited with:" ) gr.Code( lxt_citation, interactive=False, label="AttnLRP (Achtibat et al., 2024)", ) with gr.Column(): attribute_context_out = gradio_iframe.iFrame(height=400, visible=True) with gr.Row(equal_height=True): download_output_file_button = gr.DownloadButton( "📂 Download output", visible=False, ) download_output_html_button = gr.DownloadButton( "🔍 Download HTML", visible=False, value=os.path.join( os.path.dirname(__file__), "outputs/output.html" ), ) with gr.Row(elem_classes="footer-container"): with gr.Column(): gr.Markdown( """""" ) with gr.Column(): with gr.Row(elem_classes="footer-custom-block"): with gr.Column(scale=0.30, min_width=150): gr.Markdown( """Built by Gabriele Sarti
with the support of
""" ) with gr.Column(scale=0.30, min_width=120): gr.Markdown( """""" ) with gr.Column(scale=0.30, min_width=120): gr.Markdown( """""" ) rag_setting.change( fn=set_interactive_settings, inputs=[rag_setting, retrieve_k, top_k, custom_context], outputs=[retrieve_k, top_k, custom_context], ) btn.click( fn=generate, inputs=[ query, max_new_tokens, top_p, temperature, retrieve_k, top_k, rag_setting, custom_context, model_size, ], outputs=[ attribute_context_out, download_output_file_button, download_output_html_button, ], ) demo.queue(api_open=False, max_size=20).launch( allowed_paths=["img/", "outputs/"], show_api=False )