talexm commited on
Commit
f861dee
·
1 Parent(s): 73321dd
anomaly_detection_tool/__init__.py DELETED
File without changes
rag_sec/__pycache__/rag_chagu_demo.cpython-38-pytest-8.3.2.pyc CHANGED
Binary files a/rag_sec/__pycache__/rag_chagu_demo.cpython-38-pytest-8.3.2.pyc and b/rag_sec/__pycache__/rag_chagu_demo.cpython-38-pytest-8.3.2.pyc differ
 
rag_sec/bad_query_detector.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+
3
+ class BadQueryDetector:
4
+ def __init__(self):
5
+ self.detector = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
6
+
7
+ def is_bad_query(self, query):
8
+ result = self.detector(query)[0]
9
+ label = result["label"]
10
+ score = result["score"]
11
+ if label == "NEGATIVE" and score > 0.8:
12
+ print(f"Detected malicious query with high confidence ({score:.4f}): {query}")
13
+ return True
14
+ return False
rag_sec/document_retriver.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ from sklearn.feature_extraction.text import TfidfVectorizer
3
+ import numpy as np
4
+
5
+ class DocumentRetriever:
6
+ def __init__(self):
7
+ self.documents = []
8
+ self.vectorizer = TfidfVectorizer()
9
+ self.index = None
10
+
11
+ def load_documents(self, source_dir):
12
+ from pathlib import Path
13
+
14
+ data_dir = Path(source_dir)
15
+ if not data_dir.exists():
16
+ print(f"Source directory not found: {source_dir}")
17
+ return
18
+
19
+ for file in data_dir.glob("*.txt"):
20
+ with open(file, "r", encoding="utf-8") as f:
21
+ self.documents.append(f.read())
22
+
23
+ print(f"Loaded {len(self.documents)} documents.")
24
+
25
+ # Create the FAISS index
26
+ self._build_index()
27
+
28
+ def _build_index(self):
29
+ # Generate TF-IDF vectors for documents
30
+ doc_vectors = self.vectorizer.fit_transform(self.documents).toarray()
31
+
32
+ # Create FAISS index
33
+ self.index = faiss.IndexFlatL2(doc_vectors.shape[1])
34
+ self.index.add(doc_vectors.astype(np.float32))
35
+
36
+ def retrieve(self, query, top_k=5):
37
+ if not self.index:
38
+ return ["Document retrieval is not initialized."]
39
+
40
+ # Vectorize the query
41
+ query_vector = self.vectorizer.transform([query]).toarray().astype(np.float32)
42
+
43
+ # Perform FAISS search
44
+ distances, indices = self.index.search(query_vector, top_k)
45
+
46
+ # Return matching documents
47
+ return [self.documents[i] for i in indices[0] if i < len(self.documents)]
rag_sec/document_search_system.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bad_query_detector import BadQueryDetector
2
+ from query_transformer import QueryTransformer
3
+ from document_retriver import DocumentRetriever
4
+ from senamtic_response_generator import SemanticResponseGenerator
5
+
6
+ class DocumentSearchSystem:
7
+ def __init__(self):
8
+ self.detector = BadQueryDetector()
9
+ self.transformer = QueryTransformer()
10
+ self.retriever = DocumentRetriever()
11
+ self.response_generator = SemanticResponseGenerator()
12
+
13
+ def process_query(self, query):
14
+ if self.detector.is_bad_query(query):
15
+ return {"status": "rejected", "message": "Query blocked due to detected malicious intent."}
16
+
17
+ transformed_query = self.transformer.transform_query(query)
18
+ retrieved_docs = self.retriever.retrieve(transformed_query)
19
+
20
+ if not retrieved_docs:
21
+ return {"status": "no_results", "message": "No relevant documents found for your query."}
22
+
23
+ response = self.response_generator.generate_response(retrieved_docs)
24
+ return {"status": "success", "response": response}
25
+
26
+
27
+ def test_system():
28
+ system = DocumentSearchSystem()
29
+ system.retriever.load_documents("/path/to/documents")
30
+
31
+ # Normal query
32
+ normal_query = "Tell me about great acting performances."
33
+ print("\nNormal Query Result:")
34
+ print(system.process_query(normal_query))
35
+
36
+ # Malicious query
37
+ malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;"
38
+ print("\nMalicious Query Result:")
39
+ print(system.process_query(malicious_query))
40
+
41
+ if __name__ == "__main__":
42
+ test_system()
rag_sec/query_transformer.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ class QueryTransformer:
2
+ def transform_query(self, query):
3
+ if "DROP TABLE" in query or "SELECT *" in query:
4
+ return "Your query appears to contain SQL injection elements. Please rephrase."
5
+ return query
rag_sec/senamtic_response_generator.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+
3
+ class SemanticResponseGenerator:
4
+ def __init__(self):
5
+ self.generator = pipeline("text-generation", model="gpt2")
6
+
7
+ def generate_response(self, retrieved_docs):
8
+ combined_docs = " ".join(retrieved_docs[:2]) # Use top 2 matches
9
+ response = self.generator(f"Based on the following information: {combined_docs}", max_length=100)
10
+ return response[0]["generated_text"]