from pathlib import Path import gradio as gr from jinja2 import Environment from tokenizers.pre_tokenizers import Whitespace from transformers import pipeline from recognizers import DiffAlign, DiffDel def load_pipeline(model_name_or_path: str = "ZurichNLP/unsup-simcse-xlm-roberta-base"): return pipeline("feature-extraction", model=model_name_or_path) def generate_diff(text_a: str, text_b: str, method: str): global my_pipeline if my_pipeline is None: my_pipeline = load_pipeline() if method == "DiffAlign": diff = DiffAlign(pipeline=my_pipeline) min_value = 0.3758048415184021 - 0.1 max_value = 1.045647144317627 - 0.1 elif method == "DiffDel": diff = DiffDel(pipeline=my_pipeline) min_value = 0.4864141941070556 max_value = 0.5012983083724976 + 0.025 else: raise ValueError(f"Unknown method: {method}") encoding_a = tokenizer.pre_tokenize_str(text_a) encoding_b = tokenizer.pre_tokenize_str(text_b) result = diff.predict( a=" ".join([token[0] for token in encoding_a]), b=" ".join([token[0] for token in encoding_b]), ) result.add_whitespace(encoding_a, encoding_b) # Normalize labels based on empirical min/max values result.labels_a = tuple([(label - min_value) / (max_value - min_value) for label in result.labels_a]) result.labels_b = tuple([(label - min_value) / (max_value - min_value) for label in result.labels_b]) # Round labels to range 0, 2, ... 10 result.labels_a = tuple([round(min(10, label * 10)) for label in result.labels_a]) result.labels_b = tuple([round(min(10, label * 10)) for label in result.labels_b]) template_path = Path(__file__).parent / "result_template.html" template = Environment().from_string(template_path.read_text()) html_dir = Path(__file__).parent / "html_out" html_dir.mkdir(exist_ok=True) html_a = template.render(token_labels=result.token_labels_a) html_b = template.render(token_labels=result.token_labels_b) return str(html_a), str(html_b) my_pipeline = None tokenizer = Whitespace() with gr.Blocks() as demo: preamble = (Path(__file__).parent / "preamble.md").read_text() gr.Markdown(preamble) with gr.Row(): text_a = gr.Textbox(label="Text A", value="We'll meet Steve on Wednesday.", lines=2) text_b = gr.Textbox(label="Text B", value="We are going to see Mary on Friday.", lines=2) with gr.Row(): method = gr.Dropdown(choices=["DiffAlign", "DiffDel"], label="Comparison Method", value="DiffAlign") with gr.Row(): with gr.Column(variant="panel"): output_a = gr.HTML(label="Result for text A", show_label=True) with gr.Column(variant="panel"): output_b = gr.HTML(label="Result for text B", show_label=True) with gr.Row(): submit_btn = gr.Button(label="Generate Diff") submit_btn.click( fn=generate_diff, inputs=[text_a, text_b, method], outputs=[output_a, output_b], ) description = (Path(__file__).parent / "description.md").read_text() gr.Markdown(description) if my_pipeline is None: my_pipeline = load_pipeline() demo.launch()