import asyncio import logging import chromadb import httpx import requests import stamina from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction from huggingface_hub import InferenceClient from tqdm.auto import tqdm from tqdm.contrib.concurrent import thread_map from prep_viewer_data import prep_data # Set up logging logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def initialize_clients(): logger.info("Initializing clients") chroma_client = chromadb.PersistentClient() inference_client = InferenceClient( "https://bm143rfir2on1bkw.us-east-1.aws.endpoints.huggingface.cloud" ) return chroma_client, inference_client def create_collection(chroma_client): logger.info("Creating or getting collection") embedding_function = SentenceTransformerEmbeddingFunction( model_name="davanstrien/dataset-viewer-descriptions-processed-st", trust_remote_code=True, ) return chroma_client.create_collection( name="dataset-viewer-descriptions", get_or_create=True, embedding_function=embedding_function, metadata={"hnsw:space": "cosine"}, ) @stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10) def embed_card(text, client): text = text[:8192] return client.feature_extraction(text) def embed_and_upsert_datasets( dataset_rows_and_ids, collection, inference_client, batch_size=10 ): logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets") for i in tqdm(range(0, len(dataset_rows_and_ids), batch_size)): batch = dataset_rows_and_ids[i : i + batch_size] ids = [] documents = [] for item in batch: ids.append(item["dataset_id"]) documents.append(f"HUB_DATASET_PREVIEW: {item['formatted_prompt']}") results = thread_map( lambda doc: embed_card(doc, inference_client), documents, leave=False ) collection.upsert( ids=ids, embeddings=[embedding.tolist()[0] for embedding in results], ) logger.debug(f"Processed batch {i//batch_size + 1}") async def refresh_viewer_data(sample_size=100_000, min_likes=2): logger.info( f"Refreshing viewer data with sample_size={sample_size} and min_likes={min_likes}" ) chroma_client, inference_client = initialize_clients() collection = create_collection(chroma_client) logger.info("Preparing data") df = await prep_data(sample_size=sample_size, min_likes=min_likes) dataset_rows_and_ids = df.to_dicts() logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets") embed_and_upsert_datasets(dataset_rows_and_ids, collection, inference_client) logger.info("Refresh completed successfully") if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.run(refresh_viewer_data())