Spaces:
Runtime error
Runtime error
import os | |
import uuid | |
from pathlib import Path | |
from pinecone.grpc import PineconeGRPC as Pinecone | |
from pinecone import ServerlessSpec | |
from langchain_community.vectorstores import Chroma | |
from langchain_openai import OpenAIEmbeddings | |
current_dir = Path(__file__).resolve().parent | |
os.environ['PINECONE_API_KEY'] = "988da8ab-3725-4047-b622-cc42d07ecb6c" | |
os.environ['OPENAI_API_KEY'] = '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' | |
class DataIndexer: | |
source_file = os.path.join(current_dir, 'sources.txt') | |
def __init__(self, index_name='langchain-repo') -> None: | |
# self.embedding_client = InferenceClient( | |
# "dunzhang/stella_en_1.5B_v5", | |
# ) | |
self.embedding_client = OpenAIEmbeddings() | |
self.index_name = index_name | |
self.pinecone_client = Pinecone(api_key=os.environ.get('PINECONE_API_KEY')) | |
if index_name not in self.pinecone_client.list_indexes().names(): | |
self.pinecone_client.create_index( | |
name=index_name, | |
dimension=1536, | |
metric='cosine', | |
spec=ServerlessSpec( | |
cloud='aws', | |
region='us-east-1' | |
) | |
) | |
self.index = self.pinecone_client.Index(self.index_name) | |
self.source_index = self.get_source_index() | |
# self.source_index = None | |
def get_source_index(self): | |
if not os.path.isfile(self.source_file): | |
print('No source file') | |
return None | |
print('create source index') | |
with open(self.source_file, 'r') as file: | |
sources = file.readlines() | |
sources = [s.rstrip('\n') for s in sources] | |
vectorstore = Chroma.from_texts( | |
sources, embedding=self.embedding_client | |
) | |
return vectorstore | |
def index_data(self, docs, batch_size=32): | |
with open(self.source_file, 'a') as file: | |
for doc in docs: | |
file.writelines(doc.metadata['source'] + '\n') | |
for i in range(0, len(docs), batch_size): | |
batch = docs[i: i + batch_size] | |
values = self.embedding_client.embed_documents([ | |
doc.page_content for doc in batch | |
]) | |
# values = self.embedding_client.feature_extraction([ | |
# doc.page_content for doc in batch | |
# ]) | |
vector_ids = [str(uuid.uuid4()) for _ in batch] | |
metadatas = [{ | |
'text': doc.page_content, | |
**doc.metadata | |
} for doc in batch] | |
vectors = [{ | |
'id': vector_id, | |
'values': value, | |
'metadata': metadata | |
} for vector_id, value, metadata in zip(vector_ids, values, metadatas)] | |
try: | |
upsert_response = self.index.upsert(vectors=vectors) | |
print(upsert_response) | |
except Exception as e: | |
print(e) | |
def search(self, text_query, top_k=5, hybrid_search=False): | |
print('text query:', text_query) | |
filter = None | |
if hybrid_search and self.source_index: | |
source_docs = self.source_index.similarity_search(text_query, 50) | |
print("source_docs", source_docs) | |
filter = {"source": {"$in":[doc.page_content for doc in source_docs]}} | |
# vector = self.embedding_client.feature_extraction(text_query) | |
vector = self.embedding_client.embed_query(text_query) | |
result = self.index.query( | |
vector=vector, | |
top_k=top_k, | |
include_metadata=True, | |
filter=filter | |
) | |
docs = [] | |
for res in result["matches"]: | |
metadata = res["metadata"] | |
if 'text' in metadata: | |
text = metadata.pop('text') | |
docs.append(text) | |
return docs | |
if __name__ == '__main__': | |
from langchain_community.document_loaders import GitLoader | |
from langchain_text_splitters import ( | |
Language, | |
RecursiveCharacterTextSplitter, | |
) | |
loader = GitLoader( | |
clone_url="https://github.com/langchain-ai/langchain", | |
repo_path="./code_data/langchain_repo/", | |
branch="master", | |
) | |
python_splitter = RecursiveCharacterTextSplitter.from_language( | |
language=Language.PYTHON, chunk_size=10000, chunk_overlap=100 | |
) | |
docs = loader.load() | |
docs = [doc for doc in docs if doc.metadata['file_type'] in ['.py', '.md']] | |
docs = [doc for doc in docs if len(doc.page_content) < 50000] | |
docs = python_splitter.split_documents(docs) | |
for doc in docs: | |
doc.page_content = '# {}\n\n'.format(doc.metadata['source']) + doc.page_content | |
indexer = DataIndexer() | |
with open('/app/sources.txt', 'a') as file: | |
for doc in docs: | |
file.writelines(doc.metadata['source'] + '\n') | |
print('DONE') | |
indexer.index_data(docs) | |