Spaces:
Sleeping
Sleeping
File size: 5,552 Bytes
449bb7f d4b4694 04b42d6 59f49c7 04b42d6 685e53c 9bb7380 59f49c7 04b42d6 9bb7380 59f49c7 9bb7380 59f49c7 685e53c 59f49c7 2ced9a6 9bb7380 59f49c7 685e53c bf18560 9bb7380 bf18560 59f49c7 04b42d6 59f49c7 aae1639 9bb7380 59f49c7 04b42d6 9bb7380 59f49c7 9bb7380 59f49c7 aae1639 9bb7380 59f49c7 9bb7380 59f49c7 9bb7380 04b42d6 9bb7380 04b42d6 9bb7380 04b42d6 2ced9a6 9bb7380 d4b4694 9bb7380 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import streamlit as st
st.set_page_config(page_title="RAG Book Analyzer", layout="wide") # Must be the first Streamlit command
import torch
import numpy as np
import faiss
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
import fitz # PyMuPDF for PDF extraction
import docx2txt # For DOCX extraction
from langchain_text_splitters import RecursiveCharacterTextSplitter
# ------------------------
# Configuration
# ------------------------
MODEL_NAME = "ibm-granite/granite-3.1-1b-a400m-instruct"
EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
CHUNK_SIZE = 512
CHUNK_OVERLAP = 64
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ------------------------
# Model Loading with Caching
# ------------------------
@st.cache_resource
def load_models():
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
revision="main"
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
revision="main",
device_map="auto" if DEVICE == "cuda" else None,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
low_cpu_mem_usage=True
).eval()
embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
return tokenizer, model, embedder
except Exception as e:
st.error(f"Model loading failed: {str(e)}")
st.stop()
tokenizer, model, embedder = load_models()
# ------------------------
# Text Processing Functions
# ------------------------
def split_text(text):
splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
length_function=len
)
return splitter.split_text(text)
def extract_text(file):
file_type = file.type
if file_type == "application/pdf":
try:
doc = fitz.open(stream=file.read(), filetype="pdf")
return "\n".join([page.get_text() for page in doc])
except Exception as e:
st.error("Error processing PDF: " + str(e))
return ""
elif file_type == "text/plain":
return file.read().decode("utf-8")
elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
try:
return docx2txt.process(file)
except Exception as e:
st.error("Error processing DOCX: " + str(e))
return ""
else:
st.error("Unsupported file type: " + file_type)
return ""
def build_index(chunks):
embeddings = embedder.encode(chunks, show_progress_bar=True)
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)
faiss.normalize_L2(embeddings)
index.add(embeddings)
return index
# ------------------------
# Summarization and Q&A Functions
# ------------------------
def generate_summary(text):
# Limit input text to avoid long sequences
prompt = f"<|user|>\nSummarize the following book in a concise and informative paragraph:\n\n{text[:4000]}\n<|assistant|>\n"
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
outputs = model.generate(**inputs, max_new_tokens=300, temperature=0.5)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
return summary.split("<|assistant|>")[-1].strip() if "<|assistant|>" in summary else summary.strip()
def generate_answer(query, context):
prompt = f"<|user|>\nUsing the context below, answer the following question precisely. If unsure, say 'I don't know'.\n\nContext: {context}\n\nQuestion: {query}\n<|assistant|>\n"
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(DEVICE)
outputs = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.4,
top_p=0.9,
repetition_penalty=1.2,
do_sample=True
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer.split("<|assistant|>")[-1].strip() if "<|assistant|>" in answer else answer.strip()
# ------------------------
# Streamlit UI
# ------------------------
st.title("RAG-Based Book Analyzer")
st.write("Upload a book (PDF, TXT, DOCX) to get a summary and ask questions about its content.")
uploaded_file = st.file_uploader("Upload File", type=["pdf", "txt", "docx"])
if uploaded_file:
text = extract_text(uploaded_file)
if text:
st.success("File successfully processed!")
st.write("Generating summary...")
summary = generate_summary(text)
st.markdown("### Book Summary")
st.write(summary)
# Process text into chunks and build FAISS index
chunks = split_text(text)
index = build_index(chunks)
st.session_state.chunks = chunks
st.session_state.index = index
st.markdown("### Ask a Question about the Book:")
query = st.text_input("Your Question:")
if query:
# Retrieve top 3 relevant chunks as context
query_embedding = embedder.encode([query])
faiss.normalize_L2(query_embedding)
distances, indices = st.session_state.index.search(query_embedding, k=3)
retrieved_chunks = [chunks[i] for i in indices[0] if i < len(chunks)]
context = "\n".join(retrieved_chunks)
answer = generate_answer(query, context)
st.markdown("### Answer")
st.write(answer)
|