veerukhannan commited on
Commit
7540753
·
verified ·
1 Parent(s): 6e6a2b1

Update add_embeddings.py

Browse files
Files changed (1) hide show
  1. add_embeddings.py +74 -63
add_embeddings.py CHANGED
@@ -13,9 +13,11 @@ class LegalDocumentProcessor:
13
  print("Initializing Legal Document Processor...")
14
  self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
15
  self.model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
 
 
16
 
17
- # Initialize ChromaDB with persistent storage for Hugging Face Spaces
18
- self.pdf_dir = "/home/user/app" # Default path in HF Spaces
19
  db_dir = os.path.join(self.pdf_dir, "chroma_db")
20
  os.makedirs(db_dir, exist_ok=True)
21
 
@@ -31,19 +33,45 @@ class LegalDocumentProcessor:
31
  name="indian_legal_docs",
32
  metadata={"description": "Indian Criminal Law Documents"}
33
  )
 
 
 
 
 
 
 
 
 
 
34
 
35
- def mean_pooling(self, model_output, attention_mask):
36
- token_embeddings = model_output[0]
37
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
38
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
39
 
40
- def get_embedding(self, text: str) -> List[float]:
41
- inputs = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
42
- with torch.no_grad():
43
- model_output = self.model(**inputs)
44
- sentence_embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
45
- return sentence_embeddings[0].tolist()
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def process_pdf(self, pdf_path: str) -> List[str]:
48
  """Extract text from PDF and split into chunks"""
49
  print(f"Processing PDF: {pdf_path}")
@@ -51,8 +79,8 @@ class LegalDocumentProcessor:
51
  reader = PdfReader(pdf_path)
52
  text = ""
53
  for page in reader.pages:
54
- text += page.extract_text()
55
-
56
  chunks = self._split_into_chunks(text)
57
  print(f"Created {len(chunks)} chunks from {pdf_path}")
58
  return chunks
@@ -60,45 +88,31 @@ class LegalDocumentProcessor:
60
  print(f"Error processing PDF {pdf_path}: {str(e)}")
61
  return []
62
 
63
- def _split_into_chunks(self, text: str, max_chunk_size: int = 1000) -> List[str]:
64
- """Split text into smaller chunks while preserving context"""
65
- sections = re.split(r'(Chapter \d+|Section \d+|\n\n)', text)
66
-
67
- chunks = []
68
- current_chunk = ""
69
 
70
- for section in sections:
71
- if len(current_chunk) + len(section) < max_chunk_size:
72
- current_chunk += section
73
- else:
74
- if current_chunk:
75
- chunks.append(current_chunk.strip())
76
- current_chunk = section
77
-
78
- if current_chunk:
79
- chunks.append(current_chunk.strip())
80
-
81
- return chunks
82
 
83
  def process_and_store_documents(self):
84
  """Process all legal documents and store in ChromaDB"""
85
  print("Starting document processing...")
86
- print(f"Looking for PDFs in: {self.pdf_dir}")
87
- print(f"Directory contents: {os.listdir(self.pdf_dir)}")
88
 
89
- # Define the expected PDF paths for Hugging Face Spaces
90
  pdf_files = {
91
  'BNS': os.path.join(self.pdf_dir, 'BNS.pdf'),
92
  'BNSS': os.path.join(self.pdf_dir, 'BNSS.pdf'),
93
  'BSA': os.path.join(self.pdf_dir, 'BSA.pdf')
94
  }
95
 
96
- # Verify files exist
97
- for law_code, pdf_path in pdf_files.items():
98
- if not os.path.exists(pdf_path):
99
- print(f"Warning: {pdf_path} not found")
100
-
101
- # Process each PDF
102
  for law_code, pdf_path in pdf_files.items():
103
  if os.path.exists(pdf_path):
104
  print(f"Processing {law_code} from {pdf_path}")
@@ -124,13 +138,7 @@ class LegalDocumentProcessor:
124
  )
125
  except Exception as e:
126
  print(f"Error processing chunk {i} from {law_code}: {str(e)}")
127
-
128
- print(f"Completed processing {law_code}")
129
- else:
130
- print(f"Skipping {law_code} - PDF not found")
131
-
132
- print("Document processing completed")
133
-
134
  def search_documents(self, query: str, n_results: int = 3) -> Dict:
135
  """Search for relevant legal information"""
136
  try:
@@ -140,25 +148,28 @@ class LegalDocumentProcessor:
140
  n_results=n_results
141
  )
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  return {
144
- "documents": results["documents"][0],
145
- "metadatas": results["metadatas"][0]
146
  }
147
  except Exception as e:
148
  print(f"Error during search: {str(e)}")
149
  return {
150
  "documents": ["Sorry, I couldn't search the documents effectively."],
151
  "metadatas": [{"law_code": "ERROR", "source": "error"}]
152
- }
153
-
154
- if __name__ == "__main__":
155
- processor = LegalDocumentProcessor()
156
- processor.process_and_store_documents()
157
-
158
- test_query = "What are the provisions for digital evidence?"
159
- results = processor.search_documents(test_query)
160
- print(f"Query: {test_query}")
161
- print("\nResults:")
162
- for doc, metadata in zip(results["documents"], results["metadatas"]):
163
- print(f"\nFrom {metadata['source']}:")
164
- print(doc[:200] + "...")
 
13
  print("Initializing Legal Document Processor...")
14
  self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
15
  self.model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
16
+ self.max_chunk_size = 500 # Reduced chunk size
17
+ self.max_context_length = 4000 # Maximum context length for response
18
 
19
+ # Initialize ChromaDB
20
+ self.pdf_dir = "/home/user/app"
21
  db_dir = os.path.join(self.pdf_dir, "chroma_db")
22
  os.makedirs(db_dir, exist_ok=True)
23
 
 
33
  name="indian_legal_docs",
34
  metadata={"description": "Indian Criminal Law Documents"}
35
  )
36
+
37
+ def _split_into_chunks(self, text: str) -> List[str]:
38
+ """Split text into smaller chunks while preserving context"""
39
+ # Split on meaningful boundaries
40
+ patterns = [
41
+ r'(?=Chapter \d+)',
42
+ r'(?=Section \d+)',
43
+ r'(?=\n\d+\.\s)', # Numbered paragraphs
44
+ r'\n\n'
45
+ ]
46
 
47
+ # Combine patterns
48
+ split_pattern = '|'.join(patterns)
49
+ sections = re.split(split_pattern, text)
 
50
 
51
+ chunks = []
52
+ current_chunk = ""
 
 
 
 
53
 
54
+ for section in sections:
55
+ section = section.strip()
56
+ if not section:
57
+ continue
58
+
59
+ # If section is small enough, add to current chunk
60
+ if len(current_chunk) + len(section) < self.max_chunk_size:
61
+ current_chunk += " " + section
62
+ else:
63
+ # If current chunk is not empty, add it to chunks
64
+ if current_chunk:
65
+ chunks.append(current_chunk.strip())
66
+ # Start new chunk with current section
67
+ current_chunk = section
68
+
69
+ # Add the last chunk if not empty
70
+ if current_chunk:
71
+ chunks.append(current_chunk.strip())
72
+
73
+ return chunks
74
+
75
  def process_pdf(self, pdf_path: str) -> List[str]:
76
  """Extract text from PDF and split into chunks"""
77
  print(f"Processing PDF: {pdf_path}")
 
79
  reader = PdfReader(pdf_path)
80
  text = ""
81
  for page in reader.pages:
82
+ text += page.extract_text() + "\n\n"
83
+
84
  chunks = self._split_into_chunks(text)
85
  print(f"Created {len(chunks)} chunks from {pdf_path}")
86
  return chunks
 
88
  print(f"Error processing PDF {pdf_path}: {str(e)}")
89
  return []
90
 
91
+ def get_embedding(self, text: str) -> List[float]:
92
+ """Generate embedding for text"""
93
+ inputs = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
94
+ with torch.no_grad():
95
+ model_output = self.model(**inputs)
 
96
 
97
+ # Mean pooling
98
+ token_embeddings = model_output[0]
99
+ attention_mask = inputs['attention_mask']
100
+ mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
101
+ sum_embeddings = torch.sum(token_embeddings * mask, 1)
102
+ sum_mask = torch.clamp(mask.sum(1), min=1e-9)
103
+ return (sum_embeddings / sum_mask).squeeze().tolist()
 
 
 
 
 
104
 
105
  def process_and_store_documents(self):
106
  """Process all legal documents and store in ChromaDB"""
107
  print("Starting document processing...")
 
 
108
 
109
+ # Define the expected PDF paths
110
  pdf_files = {
111
  'BNS': os.path.join(self.pdf_dir, 'BNS.pdf'),
112
  'BNSS': os.path.join(self.pdf_dir, 'BNSS.pdf'),
113
  'BSA': os.path.join(self.pdf_dir, 'BSA.pdf')
114
  }
115
 
 
 
 
 
 
 
116
  for law_code, pdf_path in pdf_files.items():
117
  if os.path.exists(pdf_path):
118
  print(f"Processing {law_code} from {pdf_path}")
 
138
  )
139
  except Exception as e:
140
  print(f"Error processing chunk {i} from {law_code}: {str(e)}")
141
+
 
 
 
 
 
 
142
  def search_documents(self, query: str, n_results: int = 3) -> Dict:
143
  """Search for relevant legal information"""
144
  try:
 
148
  n_results=n_results
149
  )
150
 
151
+ # Limit context size
152
+ documents = results["documents"][0]
153
+ total_length = 0
154
+ filtered_documents = []
155
+ filtered_metadatas = []
156
+
157
+ for doc, metadata in zip(documents, results["metadatas"][0]):
158
+ doc_length = len(doc)
159
+ if total_length + doc_length <= self.max_context_length:
160
+ filtered_documents.append(doc)
161
+ filtered_metadatas.append(metadata)
162
+ total_length += doc_length
163
+ else:
164
+ break
165
+
166
  return {
167
+ "documents": filtered_documents,
168
+ "metadatas": filtered_metadatas
169
  }
170
  except Exception as e:
171
  print(f"Error during search: {str(e)}")
172
  return {
173
  "documents": ["Sorry, I couldn't search the documents effectively."],
174
  "metadatas": [{"law_code": "ERROR", "source": "error"}]
175
+ }