File size: 6,902 Bytes
246d201 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import json
from openhands.core.config import AgentConfig, LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.event import Event
from openhands.events.serialization.event import event_to_memory
from openhands.events.stream import EventStream
from openhands.utils.embeddings import (
LLAMA_INDEX_AVAILABLE,
EmbeddingsLoader,
check_llama_index,
)
# Conditional imports based on llama_index availability
if LLAMA_INDEX_AVAILABLE:
import chromadb
from llama_index.core import Document
from llama_index.core.indices.vector_store.base import VectorStoreIndex
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from llama_index.core.schema import TextNode
from llama_index.vector_stores.chroma import ChromaVectorStore
class LongTermMemory:
"""Handles storing information for the agent to access later, using chromadb."""
event_stream: EventStream
def __init__(
self,
llm_config: LLMConfig,
agent_config: AgentConfig,
event_stream: EventStream,
):
"""Initialize the chromadb and set up ChromaVectorStore for later use."""
check_llama_index()
# initialize the chromadb client
db = chromadb.PersistentClient(
path=f'./cache/sessions/{event_stream.sid}/memory',
# FIXME anonymized_telemetry=False,
)
self.collection = db.get_or_create_collection(name='memories')
vector_store = ChromaVectorStore(chroma_collection=self.collection)
# embedding model
embedding_strategy = llm_config.embedding_model
self.embed_model = EmbeddingsLoader.get_embedding_model(
embedding_strategy, llm_config
)
logger.debug(f'Using embedding model: {self.embed_model}')
# instantiate the index
self.index = VectorStoreIndex.from_vector_store(vector_store, self.embed_model)
self.thought_idx = 0
# initialize the event stream
self.event_stream = event_stream
# max of threads to run the pipeline
self.memory_max_threads = agent_config.memory_max_threads
def add_event(self, event: Event):
"""Adds a new event to the long term memory with a unique id.
Parameters:
- event: The new event to be added to memory
"""
try:
# convert the event to a memory-friendly format, and don't truncate
event_data = event_to_memory(event, -1)
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f'Failed to process event: {e}')
return
# determine the event type and ID
event_type = ''
event_id = ''
if 'action' in event_data:
event_type = 'action'
event_id = event_data['action']
elif 'observation' in event_data:
event_type = 'observation'
event_id = event_data['observation']
# create a Document instance for the event
doc = Document(
text=json.dumps(event_data),
doc_id=str(self.thought_idx),
extra_info={
'type': event_type,
'id': event_id,
'idx': self.thought_idx,
},
)
self.thought_idx += 1
logger.debug('Adding %s event to memory: %d', event_type, self.thought_idx)
self._add_document(document=doc)
def _add_document(self, document: 'Document'):
"""Inserts a single document into the index."""
self.index.insert_nodes([self._create_node(document)])
def _create_node(self, document: 'Document') -> 'TextNode':
"""Create a TextNode from a Document instance."""
return TextNode(
text=document.text,
doc_id=document.doc_id,
extra_info=document.extra_info,
)
def search(self, query: str, k: int = 10) -> list[str]:
"""Searches through the current memory using VectorIndexRetriever.
Parameters:
- query (str): A query to match search results to
- k (int): Number of top results to return
Returns:
- list[str]: List of top k results found in current memory
"""
retriever = VectorIndexRetriever(
index=self.index,
similarity_top_k=k,
)
results = retriever.retrieve(query)
for result in results:
logger.debug(
f'Doc ID: {result.doc_id}:\n Text: {result.get_text()}\n Score: {result.score}'
)
return [r.get_text() for r in results]
def _events_to_docs(self) -> list['Document']:
"""Convert all events from the EventStream to documents for batch insert into the index."""
try:
events = self.event_stream.get_events()
except Exception as e:
logger.debug(f'No events found for session {self.event_stream.sid}: {e}')
return []
documents: list[Document] = []
for event in events:
try:
# convert the event to a memory-friendly format, and don't truncate
event_data = event_to_memory(event, -1)
# determine the event type and ID
event_type = ''
event_id = ''
if 'action' in event_data:
event_type = 'action'
event_id = event_data['action']
elif 'observation' in event_data:
event_type = 'observation'
event_id = event_data['observation']
# create a Document instance for the event
doc = Document(
text=json.dumps(event_data),
doc_id=str(self.thought_idx),
extra_info={
'type': event_type,
'id': event_id,
'idx': self.thought_idx,
},
)
documents.append(doc)
self.thought_idx += 1
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f'Failed to process event: {e}')
continue
if documents:
logger.debug(f'Batch inserting {len(documents)} documents into the index.')
else:
logger.debug('No valid documents found to insert into the index.')
return documents
def create_nodes(self, documents: list['Document']) -> list['TextNode']:
"""Create nodes from a list of documents."""
return [self._create_node(doc) for doc in documents]
|