Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import logging | |
from contextlib import asynccontextmanager | |
from typing import List, Optional | |
import chromadb | |
from cashews import cache | |
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction | |
from fastapi import FastAPI, HTTPException, Query | |
from httpx import AsyncClient | |
from huggingface_hub import DatasetCard | |
from pydantic import BaseModel | |
from starlette.responses import RedirectResponse | |
from starlette.status import ( | |
HTTP_403_FORBIDDEN, | |
HTTP_404_NOT_FOUND, | |
HTTP_500_INTERNAL_SERVER_ERROR, | |
) | |
from load_card_data import card_embedding_function, refresh_card_data | |
from load_viewer_data import refresh_viewer_data | |
from utils import get_save_path, get_collection, get_chroma_client | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
# Set up caching | |
cache.setup("mem://?check_interval=10&size=1000") | |
# Initialize Chroma client | |
client = get_chroma_client() | |
async_client = AsyncClient( | |
follow_redirects=True, | |
) | |
async def lifespan(app: FastAPI): | |
# Startup: refresh data and initialize collection | |
logger.info("Starting up the application") | |
try: | |
# Refresh data | |
logger.info("Starting refresh of card data") | |
refresh_card_data() | |
logger.info("Card data refresh completed") | |
logger.info("Starting refresh of viewer data") | |
await refresh_viewer_data() | |
logger.info("Viewer data refresh completed") | |
logger.info("Data refresh completed successfully") | |
except Exception as e: | |
logger.error(f"Error during startup: {str(e)}") | |
logger.warning("Application starting with potential data issues") | |
yield | |
# Shutdown: perform any cleanup | |
logger.info("Shutting down the application") | |
# Add any cleanup code here if needed | |
app = FastAPI(lifespan=lifespan) | |
def root(): | |
return RedirectResponse(url="/docs") | |
async def try_get_card(hub_id: str) -> Optional[str]: | |
try: | |
response = await async_client.get( | |
f"https://huggingface.co/datasets/{hub_id}/raw/main/README.md" | |
) | |
if response.status_code == 200: | |
card = DatasetCard(response.text) | |
return card.text | |
except Exception as e: | |
logger.error(f"Error fetching card for hub_id {hub_id}: {str(e)}") | |
return None | |
class QueryResult(BaseModel): | |
dataset_id: str | |
similarity: float | |
class QueryResponse(BaseModel): | |
results: List[QueryResult] | |
class DatasetCardNotFoundError(HTTPException): | |
def __init__(self, dataset_id: str): | |
super().__init__( | |
status_code=HTTP_404_NOT_FOUND, | |
detail=f"No dataset card available for dataset: {dataset_id}", | |
) | |
class DatasetNotForAllAudiencesError(HTTPException): | |
def __init__(self, dataset_id: str): | |
super().__init__( | |
status_code=HTTP_403_FORBIDDEN, | |
detail=f"Dataset {dataset_id} is not for all audiences and not supported in this service.", | |
) | |
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)): | |
embedding_function = card_embedding_function() | |
collection = get_collection(client, embedding_function, "dataset_cards") | |
try: | |
logger.info(f"Querying dataset: {dataset_id}") | |
# Get the embedding for the given dataset_id | |
result = collection.get(ids=[dataset_id], include=["embeddings"]) | |
if not result.get("embeddings"): | |
logger.info(f"Dataset not found: {dataset_id}") | |
try: | |
card = await try_get_card(dataset_id) | |
if card is None: | |
raise DatasetCardNotFoundError(dataset_id) | |
embeddings = embedding_function(card) | |
collection.upsert(ids=[dataset_id], embeddings=embeddings[0]) | |
logger.info(f"Dataset {dataset_id} added to collection") | |
result = collection.get(ids=[dataset_id], include=["embeddings"]) | |
if result.get("not-for-all-audiences"): | |
raise DatasetNotForAllAudiencesError(dataset_id) | |
except (DatasetCardNotFoundError, DatasetNotForAllAudiencesError): | |
raise | |
except Exception as e: | |
logger.error( | |
f"Error adding dataset {dataset_id} to collection: {str(e)}" | |
) | |
raise DatasetCardNotFoundError(dataset_id) from e | |
embedding = result["embeddings"][0] | |
# Query the collection for similar datasets | |
query_result = collection.query( | |
query_embeddings=[embedding], n_results=n, include=["distances"] | |
) | |
if not query_result["ids"]: | |
logger.info(f"No similar datasets found for: {dataset_id}") | |
raise HTTPException( | |
status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found." | |
) | |
# Prepare the response | |
results = [ | |
QueryResult(dataset_id=id, similarity=1 - distance) | |
for id, distance in zip( | |
query_result["ids"][0], query_result["distances"][0] | |
) | |
] | |
logger.info(f"Found {len(results)} similar datasets for: {dataset_id}") | |
return QueryResponse(results=results) | |
except (HTTPException, DatasetCardNotFoundError): | |
raise | |
except Exception as e: | |
logger.error(f"Error querying dataset {dataset_id}: {str(e)}") | |
raise HTTPException( | |
status_code=HTTP_500_INTERNAL_SERVER_ERROR, | |
detail="An unexpected error occurred.", | |
) from e | |
async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)): | |
try: | |
logger.info(f"Querying datasets by text: {query}") | |
collection = client.get_collection( | |
name="dataset_cards", embedding_function=card_embedding_function() | |
) | |
print(query) | |
query_result = collection.query( | |
query_texts=query, n_results=n, include=["distances"] | |
) | |
print(query_result) | |
if not query_result["ids"]: | |
logger.info(f"No similar datasets found for query: {query}") | |
raise HTTPException( | |
status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found." | |
) | |
# Prepare the response | |
results = [ | |
QueryResult(dataset_id=str(id), similarity=float(1 - distance)) | |
for id, distance in zip( | |
query_result["ids"][0], query_result["distances"][0] | |
) | |
] | |
logger.info(f"Found {len(results)} similar datasets for query: {query}") | |
return QueryResponse(results=results) | |
except Exception as e: | |
logger.error(f"Error querying datasets by text {query}: {str(e)}") | |
raise HTTPException( | |
status_code=HTTP_500_INTERNAL_SERVER_ERROR, | |
detail="An unexpected error occurred.", | |
) from e | |
async def api_search_viewer(query: str, n: int = Query(default=10, ge=1, le=100)): | |
try: | |
embedding_function = SentenceTransformerEmbeddingFunction( | |
model_name="davanstrien/query-to-dataset-viewer-descriptions", | |
trust_remote_code=True, | |
) | |
collection = client.get_collection( | |
name="dataset-viewer-descriptions", | |
embedding_function=embedding_function, | |
) | |
query = f"USER_QUERY: {query}" | |
query_result = collection.query( | |
query_texts=query, n_results=n, include=["distances"] | |
) | |
print(query_result) | |
if not query_result["ids"]: | |
logger.info(f"No similar datasets found for query: {query}") | |
raise HTTPException( | |
status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found." | |
) | |
# Prepare the response | |
results = [ | |
QueryResult(dataset_id=str(id), similarity=float(1 - distance)) | |
for id, distance in zip( | |
query_result["ids"][0], query_result["distances"][0] | |
) | |
] | |
logger.info(f"Found {len(results)} similar datasets for query: {query}") | |
return QueryResponse(results=results) | |
except Exception as e: | |
logger.error(f"Error querying datasets by text {query}: {str(e)}") | |
raise HTTPException( | |
status_code=HTTP_500_INTERNAL_SERVER_ERROR, | |
detail="An unexpected error occurred.", | |
) from e | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |