Spaces:
Paused
Paused
import fitz | |
import re | |
import chromadb | |
from chromadb.utils import embedding_functions | |
import uuid | |
import torch | |
from langchain.text_splitter import SentenceTransformersTokenTextSplitter | |
from sentence_transformers import CrossEncoder | |
emb_model_name = "sentence-transformers/all-mpnet-base-v2" | |
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-mpnet-base-v2") | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
client = chromadb.PersistentClient(path='.vectorstore') | |
collection = client.get_or_create_collection(name='huerto',embedding_function=sentence_transformer_ef,metadata={"hnsw:space": "cosine"}) | |
def parse_pdf(file) : | |
'''transforma un pdf en una lista''' | |
pdf = fitz.open(file) | |
output = [] | |
for page_num in range(pdf.page_count): | |
page = pdf[page_num] | |
text = page.get_text() | |
# Merge hyphenated words | |
text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text) | |
# Fix newlines in the middle of sentences | |
text = re.sub(r"(?<!\n\s)\n(?!\s\n)", " ", text.strip()) | |
# Remove multiple newlines | |
text = re.sub(r"\n\s*\n", "\n\n", text) | |
output.append(text) | |
return output | |
def file_to_splits(file,tokens_per_chunk,chunk_overlap ): | |
'''Transforma un txt o pdf en una en una lista que contiene piezas con metadata''' | |
text_splitter = SentenceTransformersTokenTextSplitter( | |
model_name=emb_model_name, | |
tokens_per_chunk=tokens_per_chunk, | |
chunk_overlap=chunk_overlap, | |
) | |
text = parse_pdf(file) | |
doc_chunks = [] | |
for i in range(len(text)): | |
chunks = text_splitter.split_text(text[i]) | |
for j in range(len(chunks)): | |
doc = [chunks[j], {"source": file.split('/')[-1] ,"page": i+1, "chunk": j+1}, str(uuid.uuid4())] | |
doc_chunks.append(doc) | |
return doc_chunks | |
def file_to_vs(file,tokens_per_chunk, chunk_overlap): | |
try: | |
splits=[] | |
splits.extend(file_to_splits(file, | |
tokens_per_chunk, | |
chunk_overlap)) | |
splits = list(zip(*splits)) | |
collection.add(documents=list(splits[0]), metadatas=list(splits[1]), ids= list(splits[2])) | |
return 'Files uploaded successfully' | |
except Exception as e: | |
return str(e) | |
def similarity_search(query,k): | |
sources = {} | |
ss_out= collection.query(query_texts=[query],n_results=20) | |
for _ in range(len(ss_out['ids'][0])): | |
score = float(cross_encoder.predict([query,ss_out['documents'][0][_]],activation_fct=torch.nn.Sigmoid())) | |
sources[str(_)]={"page_content":ss_out['documents'][0][_],"metadata":ss_out['metadatas'][0][_],"similarity":round(score*100,2)} | |
sorted_sources = sorted(sources.items(), key=lambda x: x[1]['similarity'], reverse=True) | |
sources = {} | |
for _ in range(k): | |
sources[str(_)] = sorted_sources[_][1] | |
return sources | |