|
import os |
|
import logging |
|
from typing import List |
|
from pydantic import BaseModel |
|
from fastapi import FastAPI, HTTPException |
|
import rdflib |
|
from rdflib import RDF, RDFS, OWL, URIRef |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import json |
|
import numpy as np |
|
from dotenv import load_dotenv |
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
handlers=[logging.FileHandler("app.log"), logging.StreamHandler()] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
HF_API_KEY = os.getenv("HF_API_KEY") |
|
if not HF_API_KEY: |
|
logger.error("HF_API_KEY non impostata.") |
|
raise EnvironmentError("HF_API_KEY non impostata.") |
|
|
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
RDF_FILE = os.path.join(BASE_DIR, "Ontologia.rdf") |
|
HF_MODEL = "Qwen/Qwen2.5-72B-Instruct" |
|
|
|
|
|
DOCUMENTS_FILE = os.path.join(BASE_DIR, "data", "documents.json") |
|
FAISS_INDEX_FILE = os.path.join(BASE_DIR, "data", "faiss.index") |
|
|
|
|
|
try: |
|
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
logger.info("Modello SentenceTransformer caricato con successo.") |
|
except Exception as e: |
|
logger.error(f"Errore nel caricamento del modello SentenceTransformer: {e}") |
|
raise e |
|
|
|
|
|
try: |
|
client = InferenceClient(model=HF_MODEL, token=HF_API_KEY) |
|
logger.info("InferenceClient inizializzato correttamente.") |
|
except Exception as e: |
|
logger.error(f"Errore nell'inizializzazione di InferenceClient: {e}") |
|
raise e |
|
|
|
def create_data_directory(): |
|
"""Crea la directory 'data/' se non esiste.""" |
|
os.makedirs(os.path.join(BASE_DIR, "data"), exist_ok=True) |
|
logger.info("Directory 'data/' creata o già esistente.") |
|
|
|
def extract_lines(rdf_file: str, output_file: str): |
|
""" |
|
Estrae ogni riga dell'ontologia RDF e la salva in un file JSON. |
|
Questo permette di indicizzare ogni riga singolarmente. |
|
""" |
|
logger.info(f"Inizio estrazione delle linee dall'ontologia da {rdf_file}.") |
|
try: |
|
with open(rdf_file, "r", encoding="utf-8") as f: |
|
lines = f.readlines() |
|
|
|
lines = [line.strip() for line in lines if line.strip()] |
|
|
|
with open(output_file, "w", encoding="utf-8") as f: |
|
json.dump({"lines": lines}, f, ensure_ascii=False, indent=2) |
|
logger.info(f"Linee estratte e salvate in {output_file}") |
|
except Exception as e: |
|
logger.error(f"Errore nell'estrazione delle linee: {e}") |
|
raise e |
|
|
|
def create_faiss_index(documents_file: str, index_file: str, embedding_model_instance: SentenceTransformer): |
|
""" |
|
Crea un indice FAISS a partire dalle linee estratte. |
|
""" |
|
logger.info(f"Inizio creazione dell'indice FAISS da {documents_file}.") |
|
try: |
|
|
|
with open(documents_file, "r", encoding="utf-8") as f: |
|
document = json.load(f) |
|
lines = document['lines'] |
|
logger.info(f"{len(lines)} linee caricate da {documents_file}.") |
|
|
|
|
|
embeddings = embedding_model_instance.encode(lines, convert_to_numpy=True, show_progress_bar=True) |
|
logger.info("Embedding generati con SentenceTransformer.") |
|
|
|
|
|
dimension = embeddings.shape[1] |
|
index = faiss.IndexFlatL2(dimension) |
|
index.add(embeddings) |
|
logger.info(f"Indice FAISS creato con dimensione: {dimension}.") |
|
|
|
|
|
faiss.write_index(index, index_file) |
|
logger.info(f"Indice FAISS salvato in {index_file}.") |
|
except Exception as e: |
|
logger.error(f"Errore nella creazione dell'indice FAISS: {e}") |
|
raise e |
|
|
|
def prepare_retrieval(embedding_model_instance: SentenceTransformer): |
|
"""Prepara i file necessari per l'approccio RAG.""" |
|
logger.info("Inizio preparazione per il retrieval.") |
|
create_data_directory() |
|
|
|
|
|
if not os.path.exists(RDF_FILE): |
|
logger.error(f"File RDF non trovato: {RDF_FILE}") |
|
raise FileNotFoundError(f"File RDF non trovato: {RDF_FILE}") |
|
else: |
|
logger.info(f"File RDF trovato: {RDF_FILE}") |
|
|
|
|
|
if not os.path.exists(DOCUMENTS_FILE): |
|
logger.info(f"File {DOCUMENTS_FILE} non trovato. Estrazione delle linee dell'ontologia.") |
|
try: |
|
extract_lines(RDF_FILE, DOCUMENTS_FILE) |
|
except Exception as e: |
|
logger.error(f"Errore nell'estrazione delle linee: {e}") |
|
raise e |
|
else: |
|
logger.info(f"File {DOCUMENTS_FILE} trovato.") |
|
|
|
|
|
if not os.path.exists(FAISS_INDEX_FILE): |
|
logger.info(f"File {FAISS_INDEX_FILE} non trovato. Creazione dell'indice FAISS.") |
|
try: |
|
create_faiss_index(DOCUMENTS_FILE, FAISS_INDEX_FILE, embedding_model_instance) |
|
except Exception as e: |
|
logger.error(f"Errore nella creazione dell'indice FAISS: {e}") |
|
raise e |
|
else: |
|
logger.info(f"File {FAISS_INDEX_FILE} trovato.") |
|
|
|
def retrieve_relevant_lines(query: str, top_k: int = 5, embedding_model_instance: SentenceTransformer = None): |
|
"""Recupera le linee rilevanti usando FAISS.""" |
|
logger.info(f"Recupero delle linee rilevanti per la query: {query}") |
|
try: |
|
|
|
with open(DOCUMENTS_FILE, "r", encoding="utf-8") as f: |
|
document = json.load(f) |
|
lines = document['lines'] |
|
logger.info(f"{len(lines)} linee caricate da {DOCUMENTS_FILE}.") |
|
|
|
|
|
index = faiss.read_index(FAISS_INDEX_FILE) |
|
logger.info(f"Indice FAISS caricato da {FAISS_INDEX_FILE}.") |
|
|
|
|
|
if embedding_model_instance is None: |
|
embedding_model_instance = SentenceTransformer('all-MiniLM-L6-v2') |
|
logger.info("Modello SentenceTransformer caricato per l'embedding della query.") |
|
|
|
query_embedding = embedding_model_instance.encode([query], convert_to_numpy=True) |
|
logger.info("Embedding della query generati.") |
|
|
|
|
|
distances, indices = index.search(query_embedding, top_k) |
|
logger.info(f"Ricerca FAISS completata. Risultati ottenuti: {len(indices[0])}") |
|
|
|
|
|
relevant_texts = [lines[idx] for idx in indices[0] if idx < len(lines)] |
|
retrieved_docs = "\n".join(relevant_texts) |
|
logger.info(f"Linee rilevanti recuperate: {len(relevant_texts)}") |
|
return retrieved_docs |
|
except Exception as e: |
|
logger.error(f"Errore nel recupero delle linee rilevanti: {e}") |
|
raise e |
|
|
|
def create_system_message(retrieved_docs: str) -> str: |
|
""" |
|
Prompt di sistema robusto, con regole su query in una riga e |
|
informazioni recuperate tramite RAG. |
|
""" |
|
return f"""### Istruzioni ### |
|
Sei un assistente museale esperto in ontologie RDF. Utilizza le informazioni fornite per generare query SPARQL precise e pertinenti. |
|
|
|
### Ontologia ### |
|
{retrieved_docs} |
|
### FINE Ontologia ### |
|
|
|
### Regole Stringenti ### |
|
1) Se l'utente chiede informazioni su questa ontologia, genera SEMPRE una query SPARQL in UNA SOLA RIGA, con prefix: |
|
PREFIX base: <http://www.semanticweb.org/lucreziamosca/ontologies/progettoMuseo#> |
|
2) La query SPARQL deve essere precisa e cercare esattamente le entità specificate dall'utente. Ad esempio, se l'utente chiede "Chi ha creato l'opera 'Amore e Psiche'?", la query dovrebbe cercere l'opera esattamente con quel nome. |
|
3) Se la query produce 0 risultati o fallisce, ritenta con un secondo tentativo. |
|
4) Se la domanda è generica (tipo 'Ciao, come stai?'), rispondi brevemente. |
|
5) Se trovi risultati, la risposta finale deve essere la query SPARQL (una sola riga). |
|
6) Se non trovi nulla, rispondi con 'Nessuna info.' |
|
7) Non multiline. Esempio: PREFIX base: <...> SELECT ?x WHERE {{ ... }}. |
|
Esempio: |
|
Utente: Chi ha creato l'opera 'Amore e Psiche'? |
|
Risposta: PREFIX base: <http://www.semanticweb.org/lucreziamosca/ontologies/progettoMuseo#> SELECT ?creatore WHERE {{ ?opera base:hasName "Amore e Psiche" . ?opera base:creatoDa ?creatore . }} |
|
### FINE REGOLE ### |
|
|
|
### Conversazione ### |
|
""" |
|
|
|
def create_explanation_prompt(results_str: str) -> str: |
|
"""Prompt per generare una spiegazione museale dei risultati SPARQL.""" |
|
return f"""Ho ottenuto questi risultati SPARQL: |
|
{results_str} |
|
Ora fornisci una breve spiegazione museale (massimo ~10 righe), senza inventare oltre i risultati. |
|
""" |
|
|
|
async def call_hf_model(prompt: str, temperature: float = 0.5, max_tokens: int = 150, stream: bool = False) -> str: |
|
"""Chiama il modello Hugging Face tramite InferenceClient e gestisce la risposta.""" |
|
logger.debug("Chiamo HF con il seguente prompt:") |
|
content_preview = (prompt[:300] + '...') if len(prompt) > 300 else prompt |
|
logger.debug(f"PROMPT => {content_preview}") |
|
|
|
try: |
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": prompt} |
|
] |
|
|
|
|
|
response = client.chat_completion( |
|
messages=messages, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=0.7, |
|
stream=stream |
|
) |
|
|
|
logger.debug(f"Risposta completa dal modello: {response}") |
|
|
|
if stream: |
|
|
|
generated_text = "" |
|
async for token in response: |
|
if token.choices and token.choices[0].delta.get("content"): |
|
generated_text += token.choices[0].delta["content"] |
|
print(token.choices[0].delta["content"], end="") |
|
return generated_text.strip() |
|
else: |
|
|
|
|
|
logger.debug(f"Risposta completa: {response}") |
|
|
|
|
|
if isinstance(response, dict): |
|
if 'choices' in response and len(response['choices']) > 0: |
|
generated_text = response['choices'][0].get('message', {}).get('content', '') |
|
else: |
|
raise ValueError("Risposta non contiene 'choices' o 'message'.") |
|
elif isinstance(response, list) and len(response) > 0: |
|
generated_text = response[0].get('message', {}).get('content', '') |
|
else: |
|
raise ValueError("Struttura della risposta non riconosciuta.") |
|
|
|
|
|
single_line = " ".join(generated_text.splitlines()) |
|
logger.debug(f"Risposta HF single-line: {single_line}") |
|
return single_line.strip() |
|
except Exception as e: |
|
logger.error(f"Errore nella chiamata all'API Hugging Face tramite InferenceClient: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
entity_labels: List[str] = [] |
|
|
|
def load_entity_labels(rdf_file: str): |
|
"""Carica le etichette delle entità dall'ontologia RDF.""" |
|
global entity_labels |
|
try: |
|
g = rdflib.Graph() |
|
g.parse(rdf_file, format="xml") |
|
entities = set() |
|
for s in g.subjects(RDF.type, OWL.NamedIndividual): |
|
label = g.value(s, RDFS.label, default=str(s)) |
|
if isinstance(label, URIRef): |
|
label = label.split('#')[-1].replace('_', ' ') |
|
else: |
|
label = str(label) |
|
entities.add(label.lower()) |
|
entity_labels = list(entities) |
|
logger.info(f"Elenco delle etichette delle entità caricato: {entity_labels}") |
|
except Exception as e: |
|
logger.error(f"Errore nel caricamento delle etichette delle entità: {e}") |
|
entity_labels = [] |
|
|
|
def is_ontology_related(query: str) -> bool: |
|
"""Determina se la domanda è pertinente all'ontologia.""" |
|
query_lower = query.lower() |
|
keywords = ["opera", "museo", "stanza", "tour", "visitatore", "biglietto", "guida", "evento", "agente"] |
|
if any(keyword in query_lower for keyword in keywords): |
|
return True |
|
if any(entity in query_lower for entity in entity_labels): |
|
return True |
|
return False |
|
|
|
app = FastAPI() |
|
|
|
class QueryRequest(BaseModel): |
|
message: str |
|
max_tokens: int = 512 |
|
temperature: float = 0.5 |
|
|
|
@app.post("/generate-response/") |
|
async def generate_response(req: QueryRequest): |
|
user_input = req.message |
|
logger.info(f"Utente dice: {user_input}") |
|
|
|
if not is_ontology_related(user_input): |
|
|
|
generic_prompt = f"{user_input}" |
|
try: |
|
response = await call_hf_model(generic_prompt, req.temperature, req.max_tokens, stream=False) |
|
return { |
|
"type": "NATURAL", |
|
"response": response.strip() |
|
} |
|
except Exception as e: |
|
logger.error(f"Errore nella chiamata al modello Hugging Face per domanda generica: {e}") |
|
return { |
|
"type": "ERROR", |
|
"response": f"Errore nella generazione della risposta per domanda generica: {e}" |
|
} |
|
|
|
try: |
|
|
|
retrieved_docs = retrieve_relevant_lines(user_input, top_k=5, embedding_model_instance=embedding_model) |
|
except Exception as e: |
|
logger.error(f"Errore nel recupero delle linee rilevanti: {e}") |
|
return {"type": "ERROR", "response": f"Errore nel recupero delle linee: {e}"} |
|
|
|
sys_msg = create_system_message(retrieved_docs) |
|
prompt = sys_msg + f"\nUtente: {user_input}\nAssistente:" |
|
|
|
|
|
try: |
|
r1 = await call_hf_model(prompt, req.temperature, req.max_tokens, stream=False) |
|
logger.info(f"PRIMA RISPOSTA:\n{r1}") |
|
except Exception as e: |
|
logger.error(f"Errore nella chiamata al modello Hugging Face: {e}") |
|
return {"type": "ERROR", "response": f"Errore nella generazione della risposta: {e}"} |
|
|
|
|
|
if not r1.startswith("PREFIX base:"): |
|
sc = f"Non hai risposto con query SPARQL su una sola riga. Riprova. Domanda: {user_input}" |
|
fallback_prompt = sys_msg + f"\nAssistente: {r1}\nUtente: {sc}\nAssistente:" |
|
try: |
|
r2 = await call_hf_model(fallback_prompt, req.temperature, req.max_tokens, stream=False) |
|
logger.info(f"SECONDA RISPOSTA:\n{r2}") |
|
if r2.startswith("PREFIX base:"): |
|
sparql_query = r2 |
|
else: |
|
return {"type": "NATURAL", "response": r2} |
|
except Exception as e: |
|
logger.error(f"Errore nella seconda chiamata al modello Hugging Face: {e}") |
|
return {"type": "ERROR", "response": f"Errore nella generazione della seconda risposta: {e}"} |
|
else: |
|
sparql_query = r1 |
|
|
|
|
|
g = rdflib.Graph() |
|
try: |
|
g.parse(RDF_FILE, format="xml") |
|
logger.info(f"Parsing RDF di {RDF_FILE} riuscito per l'esecuzione della query.") |
|
except Exception as e: |
|
logger.error(f"Parsing RDF error: {e}") |
|
return {"type": "ERROR", "response": f"Parsing RDF error: {e}"} |
|
|
|
try: |
|
results = g.query(sparql_query) |
|
logger.info(f"Query SPARQL eseguita con successo. Risultati: {len(results)}") |
|
except Exception as e: |
|
fallback = f"La query SPARQL ha fallito. Riprova. Domanda: {user_input}" |
|
fallback_prompt = sys_msg + f"\nAssistente: {sparql_query}\nUtente: {fallback}\nAssistente:" |
|
try: |
|
r3 = await call_hf_model(fallback_prompt, req.temperature, req.max_tokens, stream=False) |
|
logger.info(f"TERZA RISPOSTA (fallback):\n{r3}") |
|
if r3.startswith("PREFIX base:"): |
|
sparql_query = r3 |
|
try: |
|
results = g.query(sparql_query) |
|
logger.info(f"Seconda query SPARQL eseguita con successo. Risultati: {len(results)}") |
|
except Exception as e2: |
|
logger.error(f"Seconda Query fallita: {e2}") |
|
return {"type": "ERROR", "response": f"Query fallita di nuovo: {e2}"} |
|
else: |
|
return {"type": "NATURAL", "response": r3} |
|
except Exception as e: |
|
logger.error(f"Errore nella chiamata al modello Hugging Face durante il fallback: {e}") |
|
return {"type": "ERROR", "response": f"Errore durante il fallback della risposta: {e}"} |
|
|
|
if len(results) == 0: |
|
return {"type": "NATURAL", "sparql_query": sparql_query, "response": "Nessuna info."} |
|
|
|
|
|
row_list = [] |
|
for row in results: |
|
|
|
row_dict = dict(row) |
|
row_str = ", ".join([f"{k}: {v}" for k, v in row_dict.items()]) |
|
row_list.append(row_str) |
|
results_str = "\n".join(row_list) |
|
|
|
|
|
exp_prompt = create_explanation_prompt(results_str) |
|
try: |
|
explanation = await call_hf_model(exp_prompt, req.temperature, req.max_tokens, stream=False) |
|
except Exception as e: |
|
logger.error(f"Errore nella generazione della spiegazione: {e}") |
|
return {"type": "ERROR", "response": f"Errore nella generazione della spiegazione: {e}"} |
|
|
|
return { |
|
"type": "NATURAL", |
|
"sparql_query": sparql_query, |
|
"sparql_results": row_list, |
|
"explanation": explanation |
|
} |
|
|
|
@app.post("/prova") |
|
async def prova(req: QueryRequest): |
|
return { |
|
"type": "NATURAL", |
|
"response": "Questa è una prova di richiesta" |
|
} |
|
|
|
@app.get("/") |
|
def home(): |
|
return {"message": "Assistente Museale con supporto SPARQL."} |
|
|
|
|
|
try: |
|
create_data_directory() |
|
prepare_retrieval(embedding_model) |
|
load_entity_labels(RDF_FILE) |
|
logger.info("Applicazione avviata e pronta per ricevere richieste.") |
|
except Exception as e: |
|
logger.error(f"Errore durante la preparazione dell'applicazione: {e}") |
|
raise e |
|
|