Pringled's picture
Switched default dataset
c8fad0f unverified
import gradio as gr
from datasets import load_dataset
from difflib import ndiff
from semhash import SemHash
from semhash.datamodels import DeduplicationResult
from model2vec import StaticModel
# Default parameters
default_dataset_name = "SetFit/amazon_massive_scenario_en-US"
default_dataset1_split = "train"
default_dataset2_split = "test"
default_text_column = "text"
default_threshold = 0.9
# Load the model to use
model = StaticModel.from_pretrained("minishlab/potion-base-8M")
def display_word_differences(x: str, y: str) -> str:
"""
Display the word-level differences between two texts, formatted to avoid
misinterpretation of Markdown syntax.
"""
diff = ndiff(x.split(), y.split())
formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-")))
return f"```\n{formatted_diff}\n```"
def load_dataset_texts(dataset_name: str, dataset_split: str, text_column: str) -> list[str]:
"""Load texts from a specified dataset split."""
ds = load_dataset(dataset_name, split=dataset_split)
return [example[text_column] for example in ds]
def deduplicate_single_dataset(texts: list[str], threshold: float) -> DeduplicationResult:
"""Deduplicate within a single dataset using SemHash, treating each text as a raw string record."""
# Build a SemHash index from the raw texts
semhash = SemHash.from_records(records=texts, model=model)
# Deduplicate the entire dataset
return semhash.self_deduplicate(threshold=threshold)
def deduplicate_two_datasets(texts1: list[str], texts2: list[str], threshold: float) -> DeduplicationResult:
"""Deduplicate dataset2 against dataset1, both as raw strings, using SemHash."""
# Build SemHash index on dataset1
semhash = SemHash.from_records(records=texts1, model=model)
# Deduplicate texts2 against dataset1
return semhash.deduplicate(records=texts2, threshold=threshold)
def perform_deduplication(
deduplication_type: str,
dataset1_name: str,
dataset1_split: str,
dataset1_text_column: str,
dataset2_name: str = "",
dataset2_split: str = "",
dataset2_text_column: str = "",
threshold: float = default_threshold,
progress: gr.Progress = gr.Progress(track_tqdm=True)
):
"""
Perform deduplication on one or two datasets using SemHash. This function
streams status updates to Gradio for user feedback.
"""
try:
threshold = float(threshold)
# Load Dataset 1
yield "Loading Dataset 1...", ""
texts1 = load_dataset_texts(dataset1_name, dataset1_split, dataset1_text_column)
if deduplication_type == "Single dataset":
# Single-dataset deduplication
yield "Deduplicating within Dataset 1 (SemHash)...", ""
result = deduplicate_single_dataset(texts1, threshold=threshold)
# Sort all duplicates in descending order of their highest score
for duprec in result.duplicates:
duprec.duplicates.sort(key=lambda x: x[1], reverse=True)
# Summarize results
num_duplicates = len(result.duplicates)
deduplicated_count = len(result.deduplicated)
total_docs = len(texts1)
result_text = (
f"**Total documents (Dataset 1):** {total_docs}\n\n"
f"**Duplicates found:** {num_duplicates}\n\n"
f"**Unique documents after deduplication:** {deduplicated_count}\n\n"
+ "-" * 50 + "\n\n"
)
# Show example duplicates
if num_duplicates > 0:
result_text += "**Example duplicates:**\n\n"
# Only show duplicates that actually have near-duplicate records
duplicates_with_data = [duprec for duprec in result.duplicates if duprec.duplicates]
if duplicates_with_data:
for duprec in duplicates_with_data[:5]:
dup_text = duprec.record
orig_text, score = duprec.duplicates[0]
differences = display_word_differences(orig_text, dup_text)
result_text += (
f"**Original:**\n{orig_text}\n\n"
f"**Duplicate:**\n{dup_text}\n\n"
f"**Similarity Score:** {score:.4f}\n"
f"**Differences:**\n{differences}\n"
+ "-" * 50 + "\n\n"
)
else:
result_text += "No near-duplicate details available.\n\n"
else:
result_text += "No duplicates found."
yield "Deduplication completed.", result_text
else:
# Cross-dataset deduplication
yield "Loading Dataset 2...", ""
texts2 = load_dataset_texts(dataset2_name, dataset2_split, dataset2_text_column)
yield "Deduplicating Dataset 2 against Dataset 1 (SemHash)...", ""
result = deduplicate_two_datasets(texts1, texts2, threshold=threshold)
# Sort duplicates in descending order of their highest score
for duprec in result.duplicates:
duprec.duplicates.sort(key=lambda x: x[1], reverse=True)
num_duplicates = len(result.duplicates)
total_docs2 = len(texts2)
deduplicated_count = len(result.deduplicated)
result_text = (
f"**Total documents in {dataset2_name}/{dataset2_split}:** {total_docs2}\n\n"
f"**Duplicates found in Dataset 2:** {num_duplicates}\n\n"
f"**Unique documents after deduplication:** {deduplicated_count}\n\n"
+ "-" * 50 + "\n\n"
)
if num_duplicates > 0:
result_text += "**Example duplicates from Dataset 2:**\n\n"
# Again, only show duplicates that actually have near-duplicate records
duplicates_with_data = [duprec for duprec in result.duplicates if duprec.duplicates]
if duplicates_with_data:
for duprec in duplicates_with_data[:5]:
dup_text = duprec.record # The "duplicate" text from dataset2
orig_text, score = duprec.duplicates[0]
differences = display_word_differences(orig_text, dup_text)
result_text += (
f"**Original (Dataset 1):**\n{orig_text}\n\n"
f"**Duplicate (Dataset 2):**\n{dup_text}\n\n"
f"**Similarity Score:** {score:.4f}\n"
f"**Differences:**\n{differences}\n"
+ "-" * 50 + "\n\n"
)
else:
result_text += "No near-duplicate details available.\n\n"
else:
result_text += "No duplicates found."
yield "Deduplication completed.", result_text
except Exception as e:
yield f"An error occurred: {e}", ""
raise e
# --- Gradio App ---
with gr.Blocks(theme=gr.themes.Ocean(), css="#status_output { height: 50px; overflow: auto; }") as demo:
gr.Markdown("# Semantic Text Deduplication Using SemHash")
gr.Markdown("""
This demo showcases **semantic deduplication** using [SemHash](https://github.com/MinishLab/semhash) for HuggingFace datasets, using a [Model2Vec](https://github.com/MinishLab/model2vec) encoder.
It can be used to identify duplicate texts within a **single dataset** or across **two datasets**.
You can adjust the similarity threshold to control the strictness of the deduplication.
**NOTE**: This demo runs on a free CPU backend, so it may be slow for large datasets.
For faster results, please run the code locally.
""")
deduplication_type = gr.Radio(
choices=["Cross-dataset", "Single dataset"],
label="Deduplication Type",
value="Cross-dataset", # default
)
with gr.Row():
dataset1_name = gr.Textbox(value=default_dataset_name, label="Dataset 1 Name")
dataset1_split = gr.Textbox(value=default_dataset1_split, label="Dataset 1 Split")
dataset1_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
dataset2_inputs = gr.Column(visible=True)
with dataset2_inputs:
with gr.Row():
dataset2_name = gr.Textbox(value=default_dataset_name, label="Dataset 2 Name")
dataset2_split = gr.Textbox(value=default_dataset2_split, label="Dataset 2 Split")
dataset2_text_column = gr.Textbox(value=default_text_column, label="Text Column Name")
threshold = gr.Slider(0.0, 1.0, value=default_threshold, label="Similarity Threshold")
with gr.Row():
compute_button = gr.Button("Deduplicate")
status_output = gr.Markdown(elem_id="status_output")
result_output = gr.Markdown()
def update_visibility(choice: str):
return gr.update(visible=(choice == "Cross-dataset"))
deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=dataset2_inputs)
compute_button.click(
fn=perform_deduplication,
inputs=[
deduplication_type,
dataset1_name,
dataset1_split,
dataset1_text_column,
dataset2_name,
dataset2_split,
dataset2_text_column,
threshold,
],
outputs=[status_output, result_output],
)
demo.launch()