Spaces:
Runtime error
Runtime error
Update add_embeddings.py
Browse files- add_embeddings.py +74 -63
add_embeddings.py
CHANGED
@@ -13,9 +13,11 @@ class LegalDocumentProcessor:
|
|
13 |
print("Initializing Legal Document Processor...")
|
14 |
self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
15 |
self.model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
|
|
|
|
16 |
|
17 |
-
# Initialize ChromaDB
|
18 |
-
self.pdf_dir = "/home/user/app"
|
19 |
db_dir = os.path.join(self.pdf_dir, "chroma_db")
|
20 |
os.makedirs(db_dir, exist_ok=True)
|
21 |
|
@@ -31,19 +33,45 @@ class LegalDocumentProcessor:
|
|
31 |
name="indian_legal_docs",
|
32 |
metadata={"description": "Indian Criminal Law Documents"}
|
33 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
with torch.no_grad():
|
43 |
-
model_output = self.model(**inputs)
|
44 |
-
sentence_embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
|
45 |
-
return sentence_embeddings[0].tolist()
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def process_pdf(self, pdf_path: str) -> List[str]:
|
48 |
"""Extract text from PDF and split into chunks"""
|
49 |
print(f"Processing PDF: {pdf_path}")
|
@@ -51,8 +79,8 @@ class LegalDocumentProcessor:
|
|
51 |
reader = PdfReader(pdf_path)
|
52 |
text = ""
|
53 |
for page in reader.pages:
|
54 |
-
text += page.extract_text()
|
55 |
-
|
56 |
chunks = self._split_into_chunks(text)
|
57 |
print(f"Created {len(chunks)} chunks from {pdf_path}")
|
58 |
return chunks
|
@@ -60,45 +88,31 @@ class LegalDocumentProcessor:
|
|
60 |
print(f"Error processing PDF {pdf_path}: {str(e)}")
|
61 |
return []
|
62 |
|
63 |
-
def
|
64 |
-
"""
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
current_chunk = ""
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
if current_chunk:
|
79 |
-
chunks.append(current_chunk.strip())
|
80 |
-
|
81 |
-
return chunks
|
82 |
|
83 |
def process_and_store_documents(self):
|
84 |
"""Process all legal documents and store in ChromaDB"""
|
85 |
print("Starting document processing...")
|
86 |
-
print(f"Looking for PDFs in: {self.pdf_dir}")
|
87 |
-
print(f"Directory contents: {os.listdir(self.pdf_dir)}")
|
88 |
|
89 |
-
# Define the expected PDF paths
|
90 |
pdf_files = {
|
91 |
'BNS': os.path.join(self.pdf_dir, 'BNS.pdf'),
|
92 |
'BNSS': os.path.join(self.pdf_dir, 'BNSS.pdf'),
|
93 |
'BSA': os.path.join(self.pdf_dir, 'BSA.pdf')
|
94 |
}
|
95 |
|
96 |
-
# Verify files exist
|
97 |
-
for law_code, pdf_path in pdf_files.items():
|
98 |
-
if not os.path.exists(pdf_path):
|
99 |
-
print(f"Warning: {pdf_path} not found")
|
100 |
-
|
101 |
-
# Process each PDF
|
102 |
for law_code, pdf_path in pdf_files.items():
|
103 |
if os.path.exists(pdf_path):
|
104 |
print(f"Processing {law_code} from {pdf_path}")
|
@@ -124,13 +138,7 @@ class LegalDocumentProcessor:
|
|
124 |
)
|
125 |
except Exception as e:
|
126 |
print(f"Error processing chunk {i} from {law_code}: {str(e)}")
|
127 |
-
|
128 |
-
print(f"Completed processing {law_code}")
|
129 |
-
else:
|
130 |
-
print(f"Skipping {law_code} - PDF not found")
|
131 |
-
|
132 |
-
print("Document processing completed")
|
133 |
-
|
134 |
def search_documents(self, query: str, n_results: int = 3) -> Dict:
|
135 |
"""Search for relevant legal information"""
|
136 |
try:
|
@@ -140,25 +148,28 @@ class LegalDocumentProcessor:
|
|
140 |
n_results=n_results
|
141 |
)
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
return {
|
144 |
-
"documents":
|
145 |
-
"metadatas":
|
146 |
}
|
147 |
except Exception as e:
|
148 |
print(f"Error during search: {str(e)}")
|
149 |
return {
|
150 |
"documents": ["Sorry, I couldn't search the documents effectively."],
|
151 |
"metadatas": [{"law_code": "ERROR", "source": "error"}]
|
152 |
-
}
|
153 |
-
|
154 |
-
if __name__ == "__main__":
|
155 |
-
processor = LegalDocumentProcessor()
|
156 |
-
processor.process_and_store_documents()
|
157 |
-
|
158 |
-
test_query = "What are the provisions for digital evidence?"
|
159 |
-
results = processor.search_documents(test_query)
|
160 |
-
print(f"Query: {test_query}")
|
161 |
-
print("\nResults:")
|
162 |
-
for doc, metadata in zip(results["documents"], results["metadatas"]):
|
163 |
-
print(f"\nFrom {metadata['source']}:")
|
164 |
-
print(doc[:200] + "...")
|
|
|
13 |
print("Initializing Legal Document Processor...")
|
14 |
self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
15 |
self.model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
16 |
+
self.max_chunk_size = 500 # Reduced chunk size
|
17 |
+
self.max_context_length = 4000 # Maximum context length for response
|
18 |
|
19 |
+
# Initialize ChromaDB
|
20 |
+
self.pdf_dir = "/home/user/app"
|
21 |
db_dir = os.path.join(self.pdf_dir, "chroma_db")
|
22 |
os.makedirs(db_dir, exist_ok=True)
|
23 |
|
|
|
33 |
name="indian_legal_docs",
|
34 |
metadata={"description": "Indian Criminal Law Documents"}
|
35 |
)
|
36 |
+
|
37 |
+
def _split_into_chunks(self, text: str) -> List[str]:
|
38 |
+
"""Split text into smaller chunks while preserving context"""
|
39 |
+
# Split on meaningful boundaries
|
40 |
+
patterns = [
|
41 |
+
r'(?=Chapter \d+)',
|
42 |
+
r'(?=Section \d+)',
|
43 |
+
r'(?=\n\d+\.\s)', # Numbered paragraphs
|
44 |
+
r'\n\n'
|
45 |
+
]
|
46 |
|
47 |
+
# Combine patterns
|
48 |
+
split_pattern = '|'.join(patterns)
|
49 |
+
sections = re.split(split_pattern, text)
|
|
|
50 |
|
51 |
+
chunks = []
|
52 |
+
current_chunk = ""
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
for section in sections:
|
55 |
+
section = section.strip()
|
56 |
+
if not section:
|
57 |
+
continue
|
58 |
+
|
59 |
+
# If section is small enough, add to current chunk
|
60 |
+
if len(current_chunk) + len(section) < self.max_chunk_size:
|
61 |
+
current_chunk += " " + section
|
62 |
+
else:
|
63 |
+
# If current chunk is not empty, add it to chunks
|
64 |
+
if current_chunk:
|
65 |
+
chunks.append(current_chunk.strip())
|
66 |
+
# Start new chunk with current section
|
67 |
+
current_chunk = section
|
68 |
+
|
69 |
+
# Add the last chunk if not empty
|
70 |
+
if current_chunk:
|
71 |
+
chunks.append(current_chunk.strip())
|
72 |
+
|
73 |
+
return chunks
|
74 |
+
|
75 |
def process_pdf(self, pdf_path: str) -> List[str]:
|
76 |
"""Extract text from PDF and split into chunks"""
|
77 |
print(f"Processing PDF: {pdf_path}")
|
|
|
79 |
reader = PdfReader(pdf_path)
|
80 |
text = ""
|
81 |
for page in reader.pages:
|
82 |
+
text += page.extract_text() + "\n\n"
|
83 |
+
|
84 |
chunks = self._split_into_chunks(text)
|
85 |
print(f"Created {len(chunks)} chunks from {pdf_path}")
|
86 |
return chunks
|
|
|
88 |
print(f"Error processing PDF {pdf_path}: {str(e)}")
|
89 |
return []
|
90 |
|
91 |
+
def get_embedding(self, text: str) -> List[float]:
|
92 |
+
"""Generate embedding for text"""
|
93 |
+
inputs = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
|
94 |
+
with torch.no_grad():
|
95 |
+
model_output = self.model(**inputs)
|
|
|
96 |
|
97 |
+
# Mean pooling
|
98 |
+
token_embeddings = model_output[0]
|
99 |
+
attention_mask = inputs['attention_mask']
|
100 |
+
mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
101 |
+
sum_embeddings = torch.sum(token_embeddings * mask, 1)
|
102 |
+
sum_mask = torch.clamp(mask.sum(1), min=1e-9)
|
103 |
+
return (sum_embeddings / sum_mask).squeeze().tolist()
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
def process_and_store_documents(self):
|
106 |
"""Process all legal documents and store in ChromaDB"""
|
107 |
print("Starting document processing...")
|
|
|
|
|
108 |
|
109 |
+
# Define the expected PDF paths
|
110 |
pdf_files = {
|
111 |
'BNS': os.path.join(self.pdf_dir, 'BNS.pdf'),
|
112 |
'BNSS': os.path.join(self.pdf_dir, 'BNSS.pdf'),
|
113 |
'BSA': os.path.join(self.pdf_dir, 'BSA.pdf')
|
114 |
}
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
for law_code, pdf_path in pdf_files.items():
|
117 |
if os.path.exists(pdf_path):
|
118 |
print(f"Processing {law_code} from {pdf_path}")
|
|
|
138 |
)
|
139 |
except Exception as e:
|
140 |
print(f"Error processing chunk {i} from {law_code}: {str(e)}")
|
141 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
def search_documents(self, query: str, n_results: int = 3) -> Dict:
|
143 |
"""Search for relevant legal information"""
|
144 |
try:
|
|
|
148 |
n_results=n_results
|
149 |
)
|
150 |
|
151 |
+
# Limit context size
|
152 |
+
documents = results["documents"][0]
|
153 |
+
total_length = 0
|
154 |
+
filtered_documents = []
|
155 |
+
filtered_metadatas = []
|
156 |
+
|
157 |
+
for doc, metadata in zip(documents, results["metadatas"][0]):
|
158 |
+
doc_length = len(doc)
|
159 |
+
if total_length + doc_length <= self.max_context_length:
|
160 |
+
filtered_documents.append(doc)
|
161 |
+
filtered_metadatas.append(metadata)
|
162 |
+
total_length += doc_length
|
163 |
+
else:
|
164 |
+
break
|
165 |
+
|
166 |
return {
|
167 |
+
"documents": filtered_documents,
|
168 |
+
"metadatas": filtered_metadatas
|
169 |
}
|
170 |
except Exception as e:
|
171 |
print(f"Error during search: {str(e)}")
|
172 |
return {
|
173 |
"documents": ["Sorry, I couldn't search the documents effectively."],
|
174 |
"metadatas": [{"law_code": "ERROR", "source": "error"}]
|
175 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|