davanstrien's picture
davanstrien HF staff
refactor
b5f94b5
raw
history blame
8.87 kB
import logging
from contextlib import asynccontextmanager
from typing import List, Optional
import chromadb
from cashews import cache
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from fastapi import FastAPI, HTTPException, Query
from httpx import AsyncClient
from huggingface_hub import DatasetCard
from pydantic import BaseModel
from starlette.responses import RedirectResponse
from starlette.status import (
HTTP_403_FORBIDDEN,
HTTP_404_NOT_FOUND,
HTTP_500_INTERNAL_SERVER_ERROR,
)
from load_card_data import card_embedding_function, refresh_card_data
from load_viewer_data import refresh_viewer_data
from utils import get_save_path, get_collection
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Set up caching
cache.setup("mem://?check_interval=10&size=1000")
# Initialize Chroma client
SAVE_PATH = get_save_path()
client = chromadb.PersistentClient(path=SAVE_PATH)
async_client = AsyncClient(
follow_redirects=True,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: refresh data and initialize collection
logger.info("Starting up the application")
try:
# Refresh data
logger.info("Starting refresh of card data")
refresh_card_data()
logger.info("Card data refresh completed")
logger.info("Starting refresh of viewer data")
await refresh_viewer_data()
logger.info("Viewer data refresh completed")
logger.info("Data refresh completed successfully")
except Exception as e:
logger.error(f"Error during startup: {str(e)}")
logger.warning("Application starting with potential data issues")
yield
# Shutdown: perform any cleanup
logger.info("Shutting down the application")
# Add any cleanup code here if needed
app = FastAPI(lifespan=lifespan)
@app.get("/", include_in_schema=False)
def root():
return RedirectResponse(url="/docs")
async def try_get_card(hub_id: str) -> Optional[str]:
try:
response = await async_client.get(
f"https://huggingface.co/datasets/{hub_id}/raw/main/README.md"
)
if response.status_code == 200:
card = DatasetCard(response.text)
return card.text
except Exception as e:
logger.error(f"Error fetching card for hub_id {hub_id}: {str(e)}")
return None
class QueryResult(BaseModel):
dataset_id: str
similarity: float
class QueryResponse(BaseModel):
results: List[QueryResult]
class DatasetCardNotFoundError(HTTPException):
def __init__(self, dataset_id: str):
super().__init__(
status_code=HTTP_404_NOT_FOUND,
detail=f"No dataset card available for dataset: {dataset_id}",
)
class DatasetNotForAllAudiencesError(HTTPException):
def __init__(self, dataset_id: str):
super().__init__(
status_code=HTTP_403_FORBIDDEN,
detail=f"Dataset {dataset_id} is not for all audiences and not supported in this service.",
)
@app.get("/similar", response_model=QueryResponse)
@cache(ttl="1h")
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
embedding_function = card_embedding_function()
collection = get_collection(client, embedding_function, "dataset_cards")
try:
logger.info(f"Querying dataset: {dataset_id}")
# Get the embedding for the given dataset_id
result = collection.get(ids=[dataset_id], include=["embeddings"])
if not result.get("embeddings"):
logger.info(f"Dataset not found: {dataset_id}")
try:
card = await try_get_card(dataset_id)
if card is None:
raise DatasetCardNotFoundError(dataset_id)
embeddings = embedding_function(card)
collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
logger.info(f"Dataset {dataset_id} added to collection")
result = collection.get(ids=[dataset_id], include=["embeddings"])
if result.get("not-for-all-audiences"):
raise DatasetNotForAllAudiencesError(dataset_id)
except (DatasetCardNotFoundError, DatasetNotForAllAudiencesError):
raise
except Exception as e:
logger.error(
f"Error adding dataset {dataset_id} to collection: {str(e)}"
)
raise DatasetCardNotFoundError(dataset_id) from e
embedding = result["embeddings"][0]
# Query the collection for similar datasets
query_result = collection.query(
query_embeddings=[embedding], n_results=n, include=["distances"]
)
if not query_result["ids"]:
logger.info(f"No similar datasets found for: {dataset_id}")
raise HTTPException(
status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found."
)
# Prepare the response
results = [
QueryResult(dataset_id=id, similarity=1 - distance)
for id, distance in zip(
query_result["ids"][0], query_result["distances"][0]
)
]
logger.info(f"Found {len(results)} similar datasets for: {dataset_id}")
return QueryResponse(results=results)
except (HTTPException, DatasetCardNotFoundError):
raise
except Exception as e:
logger.error(f"Error querying dataset {dataset_id}: {str(e)}")
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred.",
) from e
@app.get("/similar-text", response_model=QueryResponse)
@cache(ttl="1h")
async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)):
try:
logger.info(f"Querying datasets by text: {query}")
collection = client.get_collection(
name="dataset_cards", embedding_function=card_embedding_function()
)
print(query)
query_result = collection.query(
query_texts=query, n_results=n, include=["distances"]
)
print(query_result)
if not query_result["ids"]:
logger.info(f"No similar datasets found for query: {query}")
raise HTTPException(
status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found."
)
# Prepare the response
results = [
QueryResult(dataset_id=str(id), similarity=float(1 - distance))
for id, distance in zip(
query_result["ids"][0], query_result["distances"][0]
)
]
logger.info(f"Found {len(results)} similar datasets for query: {query}")
return QueryResponse(results=results)
except Exception as e:
logger.error(f"Error querying datasets by text {query}: {str(e)}")
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred.",
) from e
@app.get("/search-viewer", response_model=QueryResponse)
@cache(ttl="1h")
async def api_search_viewer(query: str, n: int = Query(default=10, ge=1, le=100)):
try:
embedding_function = SentenceTransformerEmbeddingFunction(
model_name="davanstrien/dataset-viewer-descriptions-processed-st",
trust_remote_code=True,
)
collection = client.get_collection(
name="dataset-viewer-descriptions",
embedding_function=embedding_function,
)
query = f"USER_QUERY: {query}"
query_result = collection.query(
query_texts=query, n_results=n, include=["distances"]
)
print(query_result)
if not query_result["ids"]:
logger.info(f"No similar datasets found for query: {query}")
raise HTTPException(
status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found."
)
# Prepare the response
results = [
QueryResult(dataset_id=str(id), similarity=float(1 - distance))
for id, distance in zip(
query_result["ids"][0], query_result["distances"][0]
)
]
logger.info(f"Found {len(results)} similar datasets for query: {query}")
return QueryResponse(results=results)
except Exception as e:
logger.error(f"Error querying datasets by text {query}: {str(e)}")
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred.",
) from e
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)