|
import streamlit as st |
|
import pandas as pd |
|
import vec2text |
|
from transformers import AutoModel, AutoTokenizer |
|
from sklearn.decomposition import PCA |
|
from utils import file_cache |
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_corrector(): |
|
return vec2text.load_pretrained_corrector("gtr-base") |
|
|
|
|
|
@st.cache_data |
|
def load_data(): |
|
return pd.read_csv("https://huggingface.co/datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv") |
|
|
|
|
|
@st.cache_resource |
|
def vector_compressor_from_config(): |
|
|
|
|
|
return PCA(n_components=2) |
|
|
|
|
|
@st.cache_data |
|
@file_cache(".cache/reducer_embeddings.pickle") |
|
def reduce_embeddings(embeddings): |
|
reducer = vector_compressor_from_config() |
|
return reducer.fit_transform(embeddings), reducer |
|
|
|
|
|
@st.cache_resource |
|
def load_model_and_tokenizer(device="cpu"): |
|
encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to(device) |
|
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base") |
|
return encoder, tokenizer |