Spaces:
Runtime error
Runtime error
import torch | |
from sklearn.cluster import AgglomerativeClustering, KMeans | |
from sklearn.manifold import TSNE | |
import numpy as np | |
import matplotlib.pyplot as plt | |
class ClusterEmbeddings(): | |
def __init__( | |
self, | |
cluster_estimate, | |
cluster_fn, | |
embeddings, | |
sentences, | |
words | |
): | |
self.cluster_estimate = cluster_estimate | |
self.embeddings = embeddings | |
self.sentences = sentences | |
self.words = words | |
self.cluster_fn = cluster_fn | |
if self.cluster_fn == "agglo": | |
self.clustering_algo = AgglomerativeClustering(n_clusters=self.cluster_estimate) | |
self.num_clusters = cluster_estimate | |
elif self.cluster_fn == "kmeans": | |
self.clustering_algo = KMeans(n_clusters=self.cluster_estimate) | |
self.num_clusters = cluster_estimate | |
self.cluster = self.clustering_algo.fit(embeddings) | |
self.labels = self.cluster.labels_ | |
def get_sentence_clusters(self): | |
sent_clusters = [] | |
chunk = "" | |
for lbl in range(self.num_clusters): | |
single_cluster = self.sentences[self.labels == lbl] | |
for sent in single_cluster: | |
chunk += sent + " " | |
sent_clusters.append(chunk) | |
chunk = "" | |
return np.array(sent_clusters) | |
def make_plot(self): | |
projector = TSNE( | |
n_components=2, | |
learning_rate="auto", | |
init="random" | |
) | |
proj_embeddings = np.array( | |
projector.fit_transform(self.embeddings) | |
) | |
for lbl in range(self.num_clusters): | |
xs = proj_embeddings[self.labels == lbl] | |
plt.scatter(xs[:, 0], xs[:, 1], label=f"Cluster {lbl}") | |
plt.legend() | |
plt.xlabel("x1") | |
plt.ylabel("x2") | |
plt.show() |