talexm commited on
Commit
c3c1187
·
1 Parent(s): e9a8c67
rag_sec/document_search_system.py CHANGED
@@ -1,42 +1,79 @@
1
- from bad_query_detector import BadQueryDetector
2
- from query_transformer import QueryTransformer
3
- from document_retriver import DocumentRetriever
4
- from senamtic_response_generator import SemanticResponseGenerator
 
 
 
 
5
 
6
  class DocumentSearchSystem:
7
  def __init__(self):
 
 
 
 
 
 
 
8
  self.detector = BadQueryDetector()
9
  self.transformer = QueryTransformer()
10
  self.retriever = DocumentRetriever()
11
  self.response_generator = SemanticResponseGenerator()
12
 
13
  def process_query(self, query):
 
 
 
 
 
 
 
 
 
 
14
  if self.detector.is_bad_query(query):
15
  return {"status": "rejected", "message": "Query blocked due to detected malicious intent."}
16
 
 
17
  transformed_query = self.transformer.transform_query(query)
18
- retrieved_docs = self.retriever.retrieve(transformed_query)
19
 
 
 
20
  if not retrieved_docs:
21
  return {"status": "no_results", "message": "No relevant documents found for your query."}
22
 
 
23
  response = self.response_generator.generate_response(retrieved_docs)
24
  return {"status": "success", "response": response}
25
 
26
 
27
  def test_system():
 
 
 
 
 
 
 
 
 
 
 
28
  system = DocumentSearchSystem()
29
- system.retriever.load_documents("/path/to/documents")
30
 
31
- # Normal query
32
  normal_query = "Tell me about great acting performances."
33
  print("\nNormal Query Result:")
34
  print(system.process_query(normal_query))
35
 
36
- # Malicious query
37
  malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;"
38
  print("\nMalicious Query Result:")
39
  print(system.process_query(malicious_query))
40
 
 
41
  if __name__ == "__main__":
42
  test_system()
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from .bad_query_detector import BadQueryDetector
5
+ from .query_transformer import QueryTransformer
6
+ from .document_retriver import DocumentRetriever
7
+ from .senamtic_response_generator import SemanticResponseGenerator
8
+
9
 
10
  class DocumentSearchSystem:
11
  def __init__(self):
12
+ """
13
+ Initializes the DocumentSearchSystem with:
14
+ - BadQueryDetector for identifying malicious or inappropriate queries.
15
+ - QueryTransformer for improving or rephrasing queries.
16
+ - DocumentRetriever for semantic document retrieval.
17
+ - SemanticResponseGenerator for generating context-aware responses.
18
+ """
19
  self.detector = BadQueryDetector()
20
  self.transformer = QueryTransformer()
21
  self.retriever = DocumentRetriever()
22
  self.response_generator = SemanticResponseGenerator()
23
 
24
  def process_query(self, query):
25
+ """
26
+ Processes a user query through the following steps:
27
+ 1. Detect if the query is malicious.
28
+ 2. Transform the query if needed.
29
+ 3. Retrieve relevant documents based on the query.
30
+ 4. Generate a response using the retrieved documents.
31
+
32
+ :param query: The user query as a string.
33
+ :return: A dictionary with the status and response or error message.
34
+ """
35
  if self.detector.is_bad_query(query):
36
  return {"status": "rejected", "message": "Query blocked due to detected malicious intent."}
37
 
38
+ # Transform the query
39
  transformed_query = self.transformer.transform_query(query)
40
+ print(f"Transformed Query: {transformed_query}")
41
 
42
+ # Retrieve relevant documents
43
+ retrieved_docs = self.retriever.retrieve(transformed_query)
44
  if not retrieved_docs:
45
  return {"status": "no_results", "message": "No relevant documents found for your query."}
46
 
47
+ # Generate a response based on the retrieved documents
48
  response = self.response_generator.generate_response(retrieved_docs)
49
  return {"status": "success", "response": response}
50
 
51
 
52
  def test_system():
53
+ """
54
+ Test the DocumentSearchSystem with normal and malicious queries.
55
+ - Load documents from a dataset directory.
56
+ - Perform a normal query and display results.
57
+ - Perform a malicious query to ensure proper blocking.
58
+ """
59
+ # Define the path to the dataset directory
60
+ home_dir = Path(os.getenv("HOME", "/"))
61
+ data_dir = home_dir / "data-sets/aclImdb/train"
62
+
63
+ # Initialize the system
64
  system = DocumentSearchSystem()
65
+ system.retriever.load_documents(data_dir)
66
 
67
+ # Perform a normal query
68
  normal_query = "Tell me about great acting performances."
69
  print("\nNormal Query Result:")
70
  print(system.process_query(normal_query))
71
 
72
+ # Perform a malicious query
73
  malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;"
74
  print("\nMalicious Query Result:")
75
  print(system.process_query(malicious_query))
76
 
77
+
78
  if __name__ == "__main__":
79
  test_system()
rag_sec/senamtic_response_generator.py CHANGED
@@ -1,10 +1,21 @@
1
- from transformers import pipeline
 
2
 
3
  class SemanticResponseGenerator:
4
- def __init__(self):
5
- self.generator = pipeline("text-generation", model="gpt2")
 
 
 
6
 
7
  def generate_response(self, retrieved_docs):
8
- combined_docs = " ".join(retrieved_docs[:2]) # Use top 2 matches
9
- response = self.generator(f"Based on the following information: {combined_docs}", max_length=100)
10
- return response[0]["generated_text"]
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
+
3
 
4
  class SemanticResponseGenerator:
5
+ def __init__(self, model_name="google/flan-t5-small", max_input_length=512, max_new_tokens=50):
6
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
+ self.max_input_length = max_input_length
9
+ self.max_new_tokens = max_new_tokens
10
 
11
  def generate_response(self, retrieved_docs):
12
+ combined_docs = " ".join(retrieved_docs[:2])
13
+ truncated_docs = combined_docs[:self.max_input_length - 50]
14
+ input_text = f"Based on the following information: {truncated_docs}"
15
+ inputs = self.tokenizer(input_text, return_tensors="pt", truncation=True, max_length=self.max_input_length)
16
+ outputs = self.model.generate(
17
+ **inputs,
18
+ max_new_tokens=self.max_new_tokens,
19
+ pad_token_id=self.tokenizer.eos_token_id
20
+ )
21
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)