import os import re import threading import bm25s import gradio as gr import gradio_iframe import spaces from bm25s.hf import BM25HF from citations import inseq_citation, inseq_xai_citation, lxt_citation, mirage_citation, pecore_citation from examples import examples from lxt.functional import add2, mul2, softmax from lxt.models.llama import LlamaForCausalLM, attnlrp from rerankers import Reranker from style import custom_css from tqdm import tqdm from transformers import AutoTokenizer from inseq import load_model, register_step_function from inseq.attr import StepFunctionArgs from inseq.commands.attribute_context import visualize_attribute_context from inseq.commands.attribute_context.attribute_context import ( AttributeContextArgs, attribute_context_with_model, ) from inseq.utils.contrast_utils import _setup_contrast_args model = None model_id = "HuggingFaceTB/SmolLM-360M-Instruct" ranker = Reranker("answerdotai/answerai-colbert-small-v1", model_type="colbert") retriever = BM25HF.load_from_hub("xhluca/bm25s-nq-index", load_corpus=True, mmap=True) # Model registry to store loaded models model_registry = {} def get_model(model_size): model_id = f"HuggingFaceTB/SmolLM-{model_size}-Instruct" if model_id not in model_registry: hf_model = LlamaForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) attnlrp.register(hf_model) model = load_model(hf_model, "saliency", tokenizer=tokenizer) model.bos_token = "<|endoftext|>" model.bos_token_id = 0 model_registry[model_id] = model return model_registry[model_id] def lxt_probability_fn(args: StepFunctionArgs): logits = args.attribution_model.output2logits(args.forward_output) target_ids = args.target_ids.reshape(logits.shape[0], 1).to(logits.device) logits = softmax(logits, dim=-1) return logits.gather(-1, target_ids).squeeze(-1) def lxt_contrast_prob_fn( args: StepFunctionArgs, contrast_sources=None, contrast_targets=None, contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, contrast_force_inputs: bool = False, skip_special_tokens: bool = False, ): c_args = _setup_contrast_args( args, contrast_sources=contrast_sources, contrast_targets=contrast_targets, contrast_targets_alignments=contrast_targets_alignments, contrast_force_inputs=contrast_force_inputs, skip_special_tokens=skip_special_tokens, ) return lxt_probability_fn(c_args) def lxt_contrast_prob_diff_fn( args: StepFunctionArgs, contrast_sources=None, contrast_targets=None, contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, contrast_force_inputs: bool = False, skip_special_tokens: bool = False, ): model_probs = lxt_probability_fn(args) contrast_probs = lxt_contrast_prob_fn( args=args, contrast_sources=contrast_sources, contrast_targets=contrast_targets, contrast_targets_alignments=contrast_targets_alignments, contrast_force_inputs=contrast_force_inputs, skip_special_tokens=skip_special_tokens, ).to(model_probs.device) return add2(model_probs, mul2(contrast_probs, -1)) def set_interactive_settings(rag_setting, retrieve_k, top_k, custom_context): if rag_setting in ("Retrieve with BM25", "Rerank with ColBERT"): return ( gr.Slider(interactive=True), gr.Slider(interactive=True), gr.Textbox( placeholder="Context will be retrieved automatically. Change mode to 'Use Custom Context' to specify your own.", interactive=False, ), ) elif rag_setting == "Use Custom Context": return ( gr.Slider(interactive=False), gr.Slider(interactive=False), gr.Textbox(placeholder="Insert a custom context...", interactive=True), ) @spaces.GPU() def generate( query, max_new_tokens, top_p, temperature, retrieve_k, top_k, rag_setting, custom_context, model_size, progress=gr.Progress(track_tqdm=True), ): global model, model_id if rag_setting == "Use Custom Context": docs = custom_context.split("\n\n") progress(0.1, desc="Using custom context...") else: if not query: raise gr.Error("Please enter a query.") progress(0, desc="Retrieving with BM25...") q = bm25s.tokenize(query) results = retriever.retrieve(q, k=retrieve_k) if rag_setting == "Rerank with ColBERT": progress(0.1, desc="Reranking with ColBERT...") docs = [x["text"] for x in results.documents[0]] out = ranker.rank(query=query, docs=docs) docs = [out.results[i].document.text for i in range(top_k)] else: docs = [results.documents[0][i]["text"] for i in range(top_k)] docs = [re.sub(r"\[\d+\]", "", doc) for doc in docs] curr_model_id = f"HuggingFaceTB/SmolLM-{model_size}-Instruct" if model is None or model.model_name != curr_model_id: progress(0.2, desc="Loading model...") model = get_model(model_size) estimated_time = 20 tstep = 1 lm_rag_prompting_example = AttributeContextArgs( model_name_or_path=model_id, input_context_text="\n\n".join(docs), input_current_text=query, output_template="{current}", attributed_fn="lxt_contrast_prob_diff", input_template="<|im_start|>user\n### Context\n{context}\n\n### Query\n{current}<|im_end|>\n<|im_start|>assistant\n", contextless_input_current_text="<|im_start|>user\n### Query\n{current}<|im_end|>\n<|im_start|>assistant\n", attribution_method="saliency", show_viz=False, show_intermediate_outputs=False, context_sensitivity_std_threshold=1, decoder_input_output_separator=" ", special_tokens_to_keep=["<|im_start|>", "<|endoftext|>"], generation_kwargs={ "max_new_tokens": max_new_tokens, "top_p": top_p, "temperature": temperature, }, attribution_aggregators=["sum"], rescale_attributions=True, save_path=os.path.join(os.path.dirname(__file__), "outputs/output.json"), viz_path=os.path.join(os.path.dirname(__file__), "outputs/output.html"), ) ret = [None] def run_attribute_context(): ret[0] = attribute_context_with_model(lm_rag_prompting_example, model) thread = threading.Thread(target=run_attribute_context) pbar = tqdm(total=estimated_time, desc="Attributing with LXT...") thread.start() while thread.is_alive(): thread.join(timeout=tstep) pbar.update(tstep) pbar.close() out = ret[0] html = visualize_attribute_context(out, show_viz=False, return_html=True) return [ gradio_iframe.iFrame(html, height=500, visible=True), gr.DownloadButton( label="📂 Download output", value=os.path.join(os.path.dirname(__file__), "outputs/output.json"), visible=True, ), gr.DownloadButton( label="🔍 Download HTML", value=os.path.join(os.path.dirname(__file__), "outputs/output.html"), visible=True, ), ] register_step_function( lxt_contrast_prob_diff_fn, "lxt_contrast_prob_diff", overwrite=True ) with gr.Blocks(css=custom_css) as demo: with gr.Row(): with gr.Column(min_width=500): gr.HTML( '
' ) text = gr.Markdown( "This demo showcases an end-to-end usage of model internals for RAG answer attribution with the PECoRe framework, as described in our MIRAGE paper.inseq attribute-context
, including this demo) can be cited with:'
)
gr.Code(
inseq_citation,
interactive=False,
label="Inseq (Sarti et al., 2023)",
)
gr.Code(
inseq_xai_citation,
interactive=False,
label="Inseq v0.6 (Sarti et al., 2024)"
)
gr.Markdown(
"The AttnLRP attribution method used in this demo via the LXT library can be cited with:"
)
gr.Code(
lxt_citation,
interactive=False,
label="AttnLRP (Achtibat et al., 2024)",
)
with gr.Column():
attribute_context_out = gradio_iframe.iFrame(height=400, visible=True)
with gr.Row(equal_height=True):
download_output_file_button = gr.DownloadButton(
"📂 Download output",
visible=False,
)
download_output_html_button = gr.DownloadButton(
"🔍 Download HTML",
visible=False,
value=os.path.join(
os.path.dirname(__file__), "outputs/output.html"
),
)
with gr.Row(elem_classes="footer-container"):
with gr.Column():
gr.Markdown(
""" """
)
with gr.Column():
with gr.Row(elem_classes="footer-custom-block"):
with gr.Column(scale=0.30, min_width=150):
gr.Markdown(
"""Built by Gabriele Sarti