|
from pathlib import Path |
|
|
|
import gradio as gr |
|
from jinja2 import Environment |
|
from tokenizers.pre_tokenizers import Whitespace |
|
from transformers import pipeline |
|
|
|
from recognizers import DiffAlign, DiffDel |
|
|
|
|
|
def load_pipeline(model_name_or_path: str = "ZurichNLP/unsup-simcse-xlm-roberta-base"): |
|
return pipeline("feature-extraction", model=model_name_or_path) |
|
|
|
|
|
def generate_diff(text_a: str, text_b: str, method: str): |
|
global my_pipeline |
|
if my_pipeline is None: |
|
my_pipeline = load_pipeline() |
|
|
|
if method == "DiffAlign": |
|
diff = DiffAlign(pipeline=my_pipeline) |
|
min_value = 0.3758048415184021 - 0.1 |
|
max_value = 1.045647144317627 - 0.1 |
|
elif method == "DiffDel": |
|
diff = DiffDel(pipeline=my_pipeline) |
|
min_value = 0.4864141941070556 |
|
max_value = 0.5012983083724976 + 0.025 |
|
else: |
|
raise ValueError(f"Unknown method: {method}") |
|
|
|
encoding_a = tokenizer.pre_tokenize_str(text_a) |
|
encoding_b = tokenizer.pre_tokenize_str(text_b) |
|
|
|
result = diff.predict( |
|
a=" ".join([token[0] for token in encoding_a]), |
|
b=" ".join([token[0] for token in encoding_b]), |
|
) |
|
|
|
result.add_whitespace(encoding_a, encoding_b) |
|
|
|
|
|
result.labels_a = tuple([(label - min_value) / (max_value - min_value) for label in result.labels_a]) |
|
result.labels_b = tuple([(label - min_value) / (max_value - min_value) for label in result.labels_b]) |
|
|
|
|
|
result.labels_a = tuple([round(min(10, label * 10)) for label in result.labels_a]) |
|
result.labels_b = tuple([round(min(10, label * 10)) for label in result.labels_b]) |
|
|
|
template_path = Path(__file__).parent / "result_template.html" |
|
template = Environment().from_string(template_path.read_text()) |
|
html_dir = Path(__file__).parent / "html_out" |
|
html_dir.mkdir(exist_ok=True) |
|
|
|
html_a = template.render(token_labels=result.token_labels_a) |
|
html_b = template.render(token_labels=result.token_labels_b) |
|
return str(html_a), str(html_b) |
|
|
|
|
|
my_pipeline = None |
|
tokenizer = Whitespace() |
|
|
|
|
|
with gr.Blocks() as demo: |
|
preamble = (Path(__file__).parent / "preamble.md").read_text() |
|
gr.Markdown(preamble) |
|
with gr.Row(): |
|
text_a = gr.Textbox(label="Text A", value="We'll meet Steve on Wednesday.", lines=2) |
|
text_b = gr.Textbox(label="Text B", value="We are going to see Mary on Friday.", lines=2) |
|
with gr.Row(): |
|
method = gr.Dropdown(choices=["DiffAlign", "DiffDel"], label="Comparison Method", value="DiffAlign") |
|
with gr.Row(): |
|
with gr.Column(variant="panel"): |
|
output_a = gr.HTML(label="Result for text A", show_label=True) |
|
with gr.Column(variant="panel"): |
|
output_b = gr.HTML(label="Result for text B", show_label=True) |
|
with gr.Row(): |
|
submit_btn = gr.Button(label="Generate Diff") |
|
submit_btn.click( |
|
fn=generate_diff, |
|
inputs=[text_a, text_b, method], |
|
outputs=[output_a, output_b], |
|
) |
|
description = (Path(__file__).parent / "description.md").read_text() |
|
gr.Markdown(description) |
|
|
|
|
|
if my_pipeline is None: |
|
my_pipeline = load_pipeline() |
|
demo.launch() |
|
|