from typing import List from sentence_transformers import SentenceTransformer from kmeans_pytorch import kmeans import torch from sklearn.cluster import KMeans from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,Text2TextGenerationPipeline class Template: def __init__(self): self.PLM = { 'sentence-transformer-mini': '''sentence-transformers/all-MiniLM-L6-v2''', 'sentence-t5-xxl': '''sentence-transformers/sentence-t5-xxl''', 'all-mpnet-base-v2':'''sentence-transformers/all-mpnet-base-v2''' } self.dimension_reduction = { 'pca': None, 'vae': None, 'cnn': None } self.clustering = { 'kmeans-cosine': kmeans, 'kmeans-euclidean': KMeans, 'gmm': None } self.keywords_extraction = { 'keyphrase-transformer': '''snrspeaks/KeyPhraseTransformer''', 'KeyBartAdapter': '''Adapting/KeyBartAdapter''', 'KeyBart': '''bloomberg/KeyBART''' } template = Template() def __create_model__(model_ckpt): ''' :param model_ckpt: keys in Template class :return: model/function: callable ''' if model_ckpt == '''sentence-transformer-mini''': return SentenceTransformer(template.PLM[model_ckpt]) elif model_ckpt == '''sentence-t5-xxl''': return SentenceTransformer(template.PLM[model_ckpt]) elif model_ckpt == '''all-mpnet-base-v2''': return SentenceTransformer(template.PLM[model_ckpt]) elif model_ckpt == 'none': return None elif model_ckpt == 'kmeans-cosine': def ret(x,k): tmp = template.clustering[model_ckpt]( X=torch.from_numpy(x), num_clusters=k, distance='cosine', device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ) return tmp[0].cpu().detach().numpy(), tmp[1].cpu().detach().numpy() return ret elif model_ckpt =='kmeans-euclidean': def ret(x,k): tmp = KMeans(n_clusters=k,random_state=50).fit(x) return tmp.labels_, tmp.cluster_centers_ return ret elif model_ckpt == 'keyphrase-transformer': tokenizer = AutoTokenizer.from_pretrained(template.keywords_extraction[model_ckpt]) model = AutoModelForSeq2SeqLM.from_pretrained(template.keywords_extraction[model_ckpt]) pipe = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer) def ret(texts: List[str]): tmp = pipe(texts) results = [ set( map(str.strip, x['generated_text'].split('|') #[str...] ) ) for x in tmp] # [{str...}...] return results return ret elif model_ckpt == 'KeyBartAdapter': model_ckpt = template.keywords_extraction[model_ckpt] tokenizer = AutoTokenizer.from_pretrained(model_ckpt) model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt) pipe = Text2TextGenerationPipeline(model=model,tokenizer=tokenizer) def ret(texts: List[str]): tmp = pipe(texts) results = [ set( map(str.strip, x['generated_text'].split(';') # [str...] ) ) for x in tmp] # [{str...}...] return results return ret elif model_ckpt == 'KeyBart': model_ckpt = template.keywords_extraction[model_ckpt] tokenizer = AutoTokenizer.from_pretrained(model_ckpt) model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt) pipe = Text2TextGenerationPipeline(model=model,tokenizer=tokenizer) def ret(texts: List[str]): tmp = pipe(texts) results = [ set( map(str.strip, x['generated_text'].split(';') # [str...] ) ) for x in tmp] # [{str...}...] return results return ret else: raise RuntimeError(f'The model {model_ckpt} is not supported. Please open an issue on the GitHub about the model.')