kenken999's picture
dfa
e1aa577
import random
from sentence_transformers import SentenceTransformer
import faiss
import pandas as pd
class Dedup:
def __init__(self, config=None):
self.index = None
self.xb = None
self.clusters = None
self.th = (config or {}).get("dedup_threshold", 0.5)
self.model_name = (config or {}).get("embeddings_model", 'all-MiniLM-L6-v2')
def copy(self):
return Dedup(
{"dedup_threshold": self.th,
"embeddings_model": self.model_name}
)
def generate_embeddings(self, texts):
"""
Generate embeddings for the given texts using the SentenceTransformer model.
"""
model = SentenceTransformer(self.model_name)
embeddings = model.encode(texts, show_progress_bar=True)
return embeddings
def build_index(self, records):
"""
Build the FAISS index for the given dataset.
input: records - a pandas dataframe with a 'text' column
output: index - the FAISS index
embeddings - the embeddings of the dataset
"""
# Generate embeddings for the dataset
embeddings = self.generate_embeddings(records['text'].tolist())
# Build the FAISS index
embeddings_dim = embeddings.shape[1]
index = faiss.IndexFlatL2(embeddings_dim)
index.add(embeddings)
return index, embeddings
def cluster_data(self, records):
"""
Cluster the given dataset.
input: records - a pandas dataframe with a 'text' column
output: clusters - a list of clusters, where each cluster is a set of indices
"""
if self.index is None:
self.index, self.xb = self.build_index(records)
distances, indices = self.index.search(self.xb, 30) #TODO: dereive it from the batch size
clusters = []
visited = set()
for i in range(len(self.xb)):
if i in visited:
continue
# Find neighbors and create a new cluster
neighbors = [idx for idx, distance in zip(indices[i], distances[i]) if distance <= self.th]
new_cluster = {i}
# Add all neighbors to the new cluster
for neighbor in neighbors:
if neighbor not in visited:
visited.add(neighbor)
new_cluster.add(neighbor)
clusters.append(new_cluster)
return clusters
def sample(self, records: pd.DataFrame, operation_function=random.choice):
"""
Sample the given dataset.
input: records - a pandas dataframe with a 'text' column
operation_function - a function that receives a cluster and returns an index
output: a pandas dataframe with the sampled records
"""
if not callable(operation_function):
raise ValueError("The 'operation_function' must be a callable function.")
if self.clusters is None:
self.clusters = self.cluster_data(records)
samples = [operation_function(list(cluster)) for cluster in self.clusters]
return records.iloc[sorted(samples)]