Spaces:
Sleeping
Sleeping
from transformers import pipeline | |
from difflib import get_close_matches | |
from pathlib import Path | |
import os | |
class BadQueryDetector: | |
def __init__(self): | |
self.detector = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english") | |
def is_bad_query(self, query): | |
result = self.detector(query)[0] | |
label = result["label"] | |
score = result["score"] | |
# Mark queries as malicious or bad if negative sentiment with high confidence | |
if label == "NEGATIVE" and score > 0.8: | |
print(f"Detected malicious query with high confidence ({score:.4f}): {query}") | |
return True | |
return False | |
class QueryTransformer: | |
def transform_query(self, query): | |
# Simple transformation example: rephrasing and clarifying | |
# In practice, this could involve more sophisticated models like T5 | |
if "DROP TABLE" in query or "SELECT *" in query: | |
return "Your query appears to contain SQL injection elements. Please rephrase." | |
# Add more sophisticated handling here | |
return query | |
class DocumentRetriever: | |
def __init__(self): | |
self.documents = [] | |
def load_documents(self, source_dir): | |
data_dir = Path(source_dir) | |
if not data_dir.exists(): | |
print(f"Source directory not found: {source_dir}") | |
return | |
for file in data_dir.glob("*.txt"): | |
with open(file, "r", encoding="utf-8") as f: | |
self.documents.append(f.read()) | |
print(f"Loaded {len(self.documents)} documents.") | |
def retrieve(self, query): | |
matches = get_close_matches(query, self.documents, n=5, cutoff=0.3) | |
return matches if matches else ["No matching documents found."] | |
class SemanticResponseGenerator: | |
def __init__(self): | |
self.generator = pipeline("text-generation", model="gpt2") | |
def generate_response(self, retrieved_docs): | |
# Generate a semantic response using retrieved documents | |
combined_docs = " ".join(retrieved_docs[:2]) # Use top 2 matches for response | |
response = self.generator(f"Based on the following information: {combined_docs}", max_length=100) | |
return response[0]["generated_text"] | |
class DocumentSearchSystem: | |
def __init__(self): | |
self.detector = BadQueryDetector() | |
self.transformer = QueryTransformer() | |
self.retriever = DocumentRetriever() | |
self.response_generator = SemanticResponseGenerator() | |
def process_query(self, query): | |
if self.detector.is_bad_query(query): | |
return {"status": "rejected", "message": "Query blocked due to detected malicious intent."} | |
transformed_query = self.transformer.transform_query(query) | |
retrieved_docs = self.retriever.retrieve(transformed_query) | |
if "No matching documents found." in retrieved_docs: | |
return {"status": "no_results", "message": "No relevant documents found for your query."} | |
response = self.response_generator.generate_response(retrieved_docs) | |
return {"status": "success", "response": response} | |
# Test the enhanced system | |
def test_system(): | |
system = DocumentSearchSystem() | |
system.retriever.load_documents("/path/to/documents") | |
# Test with a normal query | |
normal_query = "Tell me about great acting performances." | |
normal_result = system.process_query(normal_query) | |
print("\nNormal Query Result:") | |
print(normal_result) | |
# Test with a malicious query | |
malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;" | |
malicious_result = system.process_query(malicious_query) | |
print("\nMalicious Query Result:") | |
print(malicious_result) | |
if __name__ == "__main__": | |
test_system() | |