import asyncio import logging import chromadb import requests import stamina from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction from huggingface_hub import InferenceClient from tqdm.auto import tqdm from tqdm.contrib.concurrent import thread_map from prep_viewer_data import prep_data from utils import get_chroma_client # Set up logging logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) EMBEDDING_MODEL_NAME = "davanstrien/query-to-dataset-viewer-descriptions" EMBEDDING_MODEL_REVISION = "07c71d97861a73695f0c53cd6b4b32980007d908" INFERENCE_MODEL_URL = ( "https://ecg0by60w2vo9j8h.us-east-1.aws.endpoints.huggingface.cloud" ) def initialize_clients(): logger.info("Initializing clients") chroma_client = get_chroma_client() inference_client = InferenceClient( INFERENCE_MODEL_URL, ) return chroma_client, inference_client def create_collection(chroma_client): logger.info("Creating or getting collection") embedding_function = SentenceTransformerEmbeddingFunction( model_name=EMBEDDING_MODEL_NAME, trust_remote_code=True, revision=EMBEDDING_MODEL_REVISION, ) logger.info(f"Embedding function: {embedding_function}") logger.info(f"Embedding model name: {EMBEDDING_MODEL_NAME}") logger.info(f"Embedding model revision: {EMBEDDING_MODEL_REVISION}") return chroma_client.create_collection( name="dataset-viewer-descriptions", get_or_create=True, embedding_function=embedding_function, metadata={"hnsw:space": "cosine"}, ) @stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10) def embed_card(text, client): text = text[:8192] return client.feature_extraction(text) def embed_and_upsert_datasets( dataset_rows_and_ids: list[dict[str, str]], collection: chromadb.Collection, inference_client: InferenceClient, batch_size: int = 100, ): logger.info( f"Embedding and upserting {len(dataset_rows_and_ids)} datasets for viewer data" ) for i in tqdm(range(0, len(dataset_rows_and_ids), batch_size)): batch = dataset_rows_and_ids[i : i + batch_size] ids = [] documents = [] for item in batch: ids.append(item["dataset_id"]) documents.append(f"HUB_DATASET_PREVIEW: {item['formatted_prompt']}") results = thread_map( lambda doc: embed_card(doc, inference_client), documents, leave=False ) logger.info(f"Results: {len(results)}") collection.upsert( ids=ids, embeddings=[embedding.tolist()[0] for embedding in results], ) logger.debug(f"Processed batch {i//batch_size + 1}") async def refresh_viewer_data(sample_size=200_000, min_likes=2): logger.info( f"Refreshing viewer data with sample_size={sample_size} and min_likes={min_likes}" ) chroma_client, inference_client = initialize_clients() collection = create_collection(chroma_client) logger.info("Collection created successfully") logger.info("Preparing data") df = await prep_data(sample_size=sample_size, min_likes=min_likes) df.write_parquet("viewer_data.parquet") if df is not None: logger.info("Data prepared successfully") logger.info(f"Data: {df}") dataset_rows_and_ids = df.to_dicts() logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets") embed_and_upsert_datasets(dataset_rows_and_ids, collection, inference_client) logger.info("Refresh completed successfully") if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.run(refresh_viewer_data())