phi-3-RAG / datastore.py
dmedhi's picture
chat appication
6917098
raw
history blame contribute delete
No virus
3.07 kB
__import__("pysqlite3")
import sys
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
import uuid
from collections import defaultdict
from typing import Any, List
import chromadb
import numpy as np
from chromadb import Collection
from embeddings import Embedding
class ChromaStore:
def __init__(
self,
collection_name: str,
storage_path: str = "./chroma",
database: str = "database",
metadata: dict = {"hnsw:space": "cosine"},
) -> None:
"""Initiate Chromadb
- collection_name(str): name of the collection
- metadata(dict): available options for 'hnsw:space' are 'l2', 'ip' or 'cosine'.
"""
self.collection_name = collection_name
self.metadata = metadata
self.storage_path = storage_path
self.database = database
self.client = chromadb.PersistentClient(path=self.storage_path)
def _health_check(self) -> bool:
return isinstance(self.client.heartbeat(), int)
def create(self):
collection = self.client.get_or_create_collection(
name=self.collection_name,
)
return collection
def add(
self,
collection: Collection,
embeddings: List[float],
documents: List[str],
ids: List[str],
):
"""Add embeddings, documents to index or collection.
Args:
- collection: created collection.
- embeddings: list of embeddings
- documents: text documents
- ids: list of ids"""
try:
collection.add(
embeddings=embeddings,
ids=ids,
documents=documents,
)
except Exception as e:
raise Exception(f"Failed to add documents to Chroma store. {e}")
def query(
self,
collection: Collection,
query_embedding: List[float],
top_k: int = 3,
) -> list:
"""Retrieve relevant images from chroma database.
Args:
- collection: created collection.
- query_embedding: query image embedding.
- top_k (int): top k images to retrieve.
Returns:
- list of images along with their score.
"""
result = collection.query(query_embeddings=query_embedding, n_results=top_k)
relevant_chunks = [chunk for chunk in result["documents"][0]]
return relevant_chunks
# scores = [round(score, 3) for score in result["distances"][0]]
# return list(zip(relevant_chunks, scores))
def delete(self, collection_name: str):
try:
self.client.delete_collection(collection_name)
return True
except Exception as e:
raise Exception("Failed to delete collection", e)
def list_collections(self):
return self.client.list_collections()
@staticmethod
def collection_info(collection: Collection):
info = defaultdict(str)
info["count"] = collection.count()
info["top_10_items"] = collection.peek()
return info