File size: 3,071 Bytes
6917098
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
__import__("pysqlite3")
import sys

sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")


import uuid
from collections import defaultdict
from typing import Any, List

import chromadb
import numpy as np
from chromadb import Collection

from embeddings import Embedding


class ChromaStore:
    def __init__(
        self,
        collection_name: str,
        storage_path: str = "./chroma",
        database: str = "database",
        metadata: dict = {"hnsw:space": "cosine"},
    ) -> None:
        """Initiate Chromadb
        - collection_name(str): name of the collection
        - metadata(dict): available options for 'hnsw:space' are 'l2', 'ip' or 'cosine'.
        """

        self.collection_name = collection_name
        self.metadata = metadata
        self.storage_path = storage_path
        self.database = database

        self.client = chromadb.PersistentClient(path=self.storage_path)

    def _health_check(self) -> bool:
        return isinstance(self.client.heartbeat(), int)

    def create(self):
        collection = self.client.get_or_create_collection(
            name=self.collection_name,
        )
        return collection

    def add(
        self,
        collection: Collection,
        embeddings: List[float],
        documents: List[str],
        ids: List[str],
    ):
        """Add embeddings, documents to index or collection.

        Args:
        - collection: created collection.
        - embeddings: list of embeddings
        - documents: text documents
        - ids: list of ids"""
        try:
            collection.add(
                embeddings=embeddings,
                ids=ids,
                documents=documents,
            )
        except Exception as e:
            raise Exception(f"Failed to add documents to Chroma store. {e}")

    def query(
        self,
        collection: Collection,
        query_embedding: List[float],
        top_k: int = 3,
    ) -> list:
        """Retrieve relevant images from chroma database.

        Args:
        - collection: created collection.
        - query_embedding: query image embedding.
        - top_k (int): top k images to retrieve.

        Returns:
        - list of images along with their score.
        """
        result = collection.query(query_embeddings=query_embedding, n_results=top_k)
        relevant_chunks = [chunk for chunk in result["documents"][0]]
        return relevant_chunks
        # scores = [round(score, 3) for score in result["distances"][0]]
        # return list(zip(relevant_chunks, scores))

    def delete(self, collection_name: str):
        try:
            self.client.delete_collection(collection_name)
            return True
        except Exception as e:
            raise Exception("Failed to delete collection", e)

    def list_collections(self):
        return self.client.list_collections()

    @staticmethod
    def collection_info(collection: Collection):
        info = defaultdict(str)
        info["count"] = collection.count()
        info["top_10_items"] = collection.peek()
        return info