Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import subprocess | |
import os | |
import shutil | |
import string | |
import random | |
from pypdf import PdfReader | |
import ocrmypdf | |
from sentence_transformers import SentenceTransformer | |
model = SentenceTransformer("Snowflake/snowflake-arctic-embed-m") | |
model.to(device="cuda") | |
def embed(queries, chunks) -> dict[str, list[tuple[str, float]]]: | |
query_embeddings = model.encode(queries, prompt_name="query") | |
document_embeddings = model.encode(chunks) | |
scores = query_embeddings @ document_embeddings.T | |
results = {} | |
for query, query_scores in zip(queries, scores): | |
chunk_idxs = [i for i in range(len(chunks))] | |
# Get a structure like {query: [(chunk_idx, score), (chunk_idx, score), ...]} | |
results[query] = list(zip(chunk_idxs, query_scores)) | |
return results | |
def random_word(length): | |
letters = string.ascii_lowercase | |
return "".join(random.choice(letters) for _ in range(length)) | |
def convert_pdf(input_file) -> str: | |
reader = PdfReader(input_file) | |
text = extract_text_from_pdf(reader) | |
# Check if there are any images | |
image_count = 0 | |
for page in reader.pages: | |
image_count += len(page.images) | |
# If there are images and not much content, perform OCR on the document | |
if image_count > 0 and len(text) < 1000: | |
out_pdf_file = input_file.replace(".pdf", "_ocr.pdf") | |
ocrmypdf.ocr(input_file, out_pdf_file, force_ocr=True) | |
# Re-extract text | |
text = extract_text_from_pdf(PdfReader(input_file)) | |
# Delete the OCR file | |
os.remove(out_pdf_file) | |
return text | |
def extract_text_from_pdf(reader): | |
full_text = "" | |
for idx, page in enumerate(reader.pages): | |
text = page.extract_text() | |
if len(text) > 0: | |
full_text += f"---- Page {idx} ----\n" + page.extract_text() + "\n\n" | |
return full_text.strip() | |
def convert_pandoc(input_file, filename) -> str: | |
# Temporarily copy the file | |
shutil.copyfile(input_file, filename) | |
# Convert the file to markdown with pandoc | |
output_file = f"{random_word(16)}.md" | |
result = subprocess.call(["pandoc", filename, "-t", "markdown", "-o", output_file]) | |
if result != 0: | |
raise ValueError("Error converting file to markdown with pandoc") | |
# Read the file and delete temporary files | |
with open(output_file, "r") as f: | |
markdown = f.read() | |
os.remove(output_file) | |
os.remove(filename) | |
return markdown | |
def convert(input_file, filename) -> str: | |
plain_text_filetypes = [ | |
".txt", | |
".csv", | |
".tsv", | |
".md", | |
".yaml", | |
".toml", | |
".json", | |
".json5", | |
".jsonc", | |
] | |
# Already a plain text file that wouldn't benefit from pandoc so return the content | |
if any(filename.endswith(ft) for ft in plain_text_filetypes): | |
with open(input_file, "r") as f: | |
return f.read() | |
if filename.endswith(".pdf"): | |
return convert_pdf(input_file) | |
return convert_pandoc(input_file, filename) | |
def chunk_to_length(text, max_length=512): | |
chunks = [] | |
while len(text) > max_length: | |
chunks.append(text[:max_length]) | |
text = text[max_length:] | |
chunks.append(text) | |
return chunks | |
def predict(queries, documents, document_filenames, max_characters) -> list[list[str]]: | |
queries = queries.split("\n") | |
document_filenames = document_filenames.split("\n") | |
# Convert the documents to text | |
converted_docs = [ | |
convert(doc, filename) for doc, filename in zip(documents, document_filenames) | |
] | |
# Return if the total length is less than the max characters | |
total_doc_lengths = sum([len(doc) for doc in converted_docs]) | |
if total_doc_lengths < max_characters: | |
return [[doc] for doc, _ in converted_docs] | |
# Embed the documents in 512 character chunks | |
chunked_docs = [chunk_to_length(doc, 512) for doc in converted_docs] | |
embedded_docs = [embed(queries, chunks) for chunks in chunked_docs] | |
# Get a structure like {query: [(doc_idx, chunk_idx, score), (doc_idx, chunk_idx, score), ...]} | |
query_embeddings = {} | |
for doc_idx, embedded_doc in enumerate(embedded_docs): | |
for query, doc_scores in embedded_doc.items(): | |
doc_scores_with_doc = [ | |
(doc_idx, chunk_idx, score) for (chunk_idx, score) in doc_scores | |
] | |
if query not in query_embeddings: | |
query_embeddings[query] = [] | |
query_embeddings[query] = query_embeddings[query] + doc_scores_with_doc | |
# Sort the embeddings by score | |
for query, doc_scores in query_embeddings.items(): | |
query_embeddings[query] = sorted(doc_scores, key=lambda x: x[2], reverse=True) | |
# Choose the top embedding from each query until we reach the max characters | |
# Getting a structure like [[chunk, ...]] | |
document_embeddings = [[] for _ in range(len(documents))] | |
total_chars = 0 | |
while ( | |
total_chars < max_characters | |
and sum([len(x) for x in query_embeddings.values()]) > 0 | |
): | |
for query, doc_scores in query_embeddings.items(): | |
if len(doc_scores) == 0: | |
continue | |
# Grab the top score for the query | |
doc_idx, chunk_idx, _ = doc_scores.pop(0) | |
# Ensure we have space | |
chunk = chunked_docs[doc_idx][chunk_idx] | |
if total_chars + len(chunk) > max_characters: | |
continue | |
# Ensure we haven't already added this chunk from this document | |
if chunk_idx in document_embeddings[doc_idx]: | |
continue | |
# Add the chunk | |
document_embeddings[doc_idx].append(chunk_idx) | |
total_chars += len(chunk) | |
# Get the actual text for the chunks | |
document_embeddings = [ | |
[chunked_docs[doc_idx][chunk_idx] for chunk_idx in chunks] | |
for doc_idx, chunks in enumerate(document_embeddings) | |
] | |
return document_embeddings | |
# We accept a filename because the gradio JS interface removes this information | |
# and it's critical for choosing the correct processing pipeline | |
gr.Interface( | |
predict, | |
inputs=[ | |
gr.Textbox(label="Queries separated by newline"), | |
gr.File(label="Upload File", file_count="multiple"), | |
gr.Textbox(label="Filenames separated by newline"), | |
gr.Number(label="Max output characters", value=16384), | |
], | |
outputs=[gr.JSON(label="Embedded documents")], | |
).launch() | |