cluster-summ / utils /clustering.py
jaisidhsingh's picture
add code
f98f59d
import torch
from sklearn.cluster import AgglomerativeClustering, KMeans
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
class ClusterEmbeddings():
def __init__(
self,
cluster_estimate,
cluster_fn,
embeddings,
sentences,
words
):
self.cluster_estimate = cluster_estimate
self.embeddings = embeddings
self.sentences = sentences
self.words = words
self.cluster_fn = cluster_fn
if self.cluster_fn == "agglo":
self.clustering_algo = AgglomerativeClustering(n_clusters=self.cluster_estimate)
self.num_clusters = cluster_estimate
elif self.cluster_fn == "kmeans":
self.clustering_algo = KMeans(n_clusters=self.cluster_estimate)
self.num_clusters = cluster_estimate
self.cluster = self.clustering_algo.fit(embeddings)
self.labels = self.cluster.labels_
def get_sentence_clusters(self):
sent_clusters = []
chunk = ""
for lbl in range(self.num_clusters):
single_cluster = self.sentences[self.labels == lbl]
for sent in single_cluster:
chunk += sent + " "
sent_clusters.append(chunk)
chunk = ""
return np.array(sent_clusters)
def make_plot(self):
projector = TSNE(
n_components=2,
learning_rate="auto",
init="random"
)
proj_embeddings = np.array(
projector.fit_transform(self.embeddings)
)
for lbl in range(self.num_clusters):
xs = proj_embeddings[self.labels == lbl]
plt.scatter(xs[:, 0], xs[:, 1], label=f"Cluster {lbl}")
plt.legend()
plt.xlabel("x1")
plt.ylabel("x2")
plt.show()