chagu-dev / rag_sec /rag_chagu_demo.py
talexm
update RAG query improvements
73321dd
raw
history blame
3.76 kB
from transformers import pipeline
from difflib import get_close_matches
from pathlib import Path
import os
class BadQueryDetector:
def __init__(self):
self.detector = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
def is_bad_query(self, query):
result = self.detector(query)[0]
label = result["label"]
score = result["score"]
# Mark queries as malicious or bad if negative sentiment with high confidence
if label == "NEGATIVE" and score > 0.8:
print(f"Detected malicious query with high confidence ({score:.4f}): {query}")
return True
return False
class QueryTransformer:
def transform_query(self, query):
# Simple transformation example: rephrasing and clarifying
# In practice, this could involve more sophisticated models like T5
if "DROP TABLE" in query or "SELECT *" in query:
return "Your query appears to contain SQL injection elements. Please rephrase."
# Add more sophisticated handling here
return query
class DocumentRetriever:
def __init__(self):
self.documents = []
def load_documents(self, source_dir):
data_dir = Path(source_dir)
if not data_dir.exists():
print(f"Source directory not found: {source_dir}")
return
for file in data_dir.glob("*.txt"):
with open(file, "r", encoding="utf-8") as f:
self.documents.append(f.read())
print(f"Loaded {len(self.documents)} documents.")
def retrieve(self, query):
matches = get_close_matches(query, self.documents, n=5, cutoff=0.3)
return matches if matches else ["No matching documents found."]
class SemanticResponseGenerator:
def __init__(self):
self.generator = pipeline("text-generation", model="gpt2")
def generate_response(self, retrieved_docs):
# Generate a semantic response using retrieved documents
combined_docs = " ".join(retrieved_docs[:2]) # Use top 2 matches for response
response = self.generator(f"Based on the following information: {combined_docs}", max_length=100)
return response[0]["generated_text"]
class DocumentSearchSystem:
def __init__(self):
self.detector = BadQueryDetector()
self.transformer = QueryTransformer()
self.retriever = DocumentRetriever()
self.response_generator = SemanticResponseGenerator()
def process_query(self, query):
if self.detector.is_bad_query(query):
return {"status": "rejected", "message": "Query blocked due to detected malicious intent."}
transformed_query = self.transformer.transform_query(query)
retrieved_docs = self.retriever.retrieve(transformed_query)
if "No matching documents found." in retrieved_docs:
return {"status": "no_results", "message": "No relevant documents found for your query."}
response = self.response_generator.generate_response(retrieved_docs)
return {"status": "success", "response": response}
# Test the enhanced system
def test_system():
system = DocumentSearchSystem()
system.retriever.load_documents("/path/to/documents")
# Test with a normal query
normal_query = "Tell me about great acting performances."
normal_result = system.process_query(normal_query)
print("\nNormal Query Result:")
print(normal_result)
# Test with a malicious query
malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;"
malicious_result = system.process_query(malicious_query)
print("\nMalicious Query Result:")
print(malicious_result)
if __name__ == "__main__":
test_system()