IA_Compta / app.py
amadoujr's picture
Update app.py
6c6e81a verified
import gradio as gr
from haystack_integrations.document_stores.chroma import ChromaDocumentStore
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
from haystack_integrations.components.retrievers.chroma import ChromaQueryTextRetriever
from haystack.components.converters import PyPDFToDocument
from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter
from haystack.components.writers import DocumentWriter
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
from haystack.components.generators import HuggingFaceAPIGenerator
from haystack.components.builders import PromptBuilder,ChatPromptBuilder
from haystack.components.converters import OutputAdapter
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
from haystack_experimental.components.retrievers import ChatMessageRetriever
from haystack_experimental.components.writers import ChatMessageWriter
from haystack.dataclasses import ChatMessage
from itertools import chain
from typing import Any, List
from haystack import component
from haystack.core.component.types import Variadic
from haystack import Pipeline
from haystack import Document
from pathlib import Path
import os
# Chargement des documents
def load_and_convert_pdfs(data_dir):
"""Charge et convertit tous les fichiers PDF dans un répertoire donné.
Args:
data_dir: Le chemin d'accès au répertoire contenant les fichiers PDF.
Returns:
Une liste de documents convertis.
"""
pdf_converter = PyPDFToDocument()
all_docs = []
for filepath in Path(data_dir).glob("*.pdf"):
try:
docs = pdf_converter.run(sources=[filepath])
all_docs.extend(docs["documents"])
except Exception as e:
print(f"Erreur lors du traitement du fichier {filepath}: {e}")
return all_docs
data_dir = "/data/"
all_documents = load_and_convert_pdfs(data_dir)
# Pré-Traitement des documents
document_cleaner = DocumentCleaner()
document_splitter = DocumentSplitter(split_by="period",split_length=1000, split_overlap=200)
# Chargement de l'embeddings
document_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
# Stockage de l'embedding des docs dans une base de donnée vectorielle (chromaDb ici)
document_store = ChromaDocumentStore()
document_writer = DocumentWriter(document_store=document_store)
# Pipeline pour aller plus vite et combiner/connecter tous les composants
prepocessing_pipeline = Pipeline()
prepocessing_pipeline.add_component(name="cleaner", instance=DocumentCleaner()) # document_cleaner()
prepocessing_pipeline.add_component(name="splitter", instance=DocumentSplitter(split_by="period", split_length=1000, split_overlap=200)) # document_splitter
prepocessing_pipeline.add_component(name="embedder", instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")) # document_embedder
prepocessing_pipeline.add_component(name="writer", instance=DocumentWriter(document_store=document_store)) # document_writer
prepocessing_pipeline.connect("cleaner.documents", "splitter.documents")
prepocessing_pipeline.connect("splitter.documents", "embedder.documents")
prepocessing_pipeline.connect("embedder.documents", "writer.documents")
# execute le pipeline
# pré-traitement et stockage des embeddings dans 'document_store'
prepocessing_pipeline.run({"cleaner": {"documents": all_documents}})
###############################################
# Aprés l'étape de pré-traitement
# passe à la récupération et à la génération.
###############################################
# On va incorporer ce composant dans le pipeline afin
# de gérer les messages provenant de l'utilisateur et du llm
@component
class ListJoiner:
def __init__(self, _type: Any):
component.set_output_types(self, values=_type)
def run(self, values: Variadic[Any]):
result = list(chain(*values))
return {"values": result}
# composants pour gerer l'historique
memory_store = InMemoryChatMessageStore()
memory_retriever = ChatMessageRetriever(memory_store)
memory_writer = ChatMessageWriter(memory_store)
query_rephrase_template = """
Réécris la question pour la recherche en conservant son sens et ses termes clés intacts.
Si l'historique de conversation est vide, NE MODIFIE PAS la requête.
Utilise l'historique de conversation uniquement si nécessaire, et évite d'enrichir la requête avec tes propres connaissances.
Si aucune modification n'est nécessaire, renvoie la question actuelle telle quelle.
Historique de conversation :
{% for memory in memories %}
{{ memory.text }}
{% endfor %}
Requête utilisateur : {{query}}
Requête reformulée :
"""
system_message = ChatMessage.from_system(
"""T'es une intelligence artificielle en comptabilité."""
"""Tu utilises les documents de support fournis et l'historique de conversation pour aider les utilisateurs sur des questions de comptabilité."""
)
# Template de message utilisateur en français
user_message_template = """
Étant donné l'historique de conversation et les documents de support fournis, donne une réponse concise à la question.
Note que les documents de support ne font pas partie de la conversation. Si la question ne peut pas être répondue à partir des documents de support, indique le.
Historique de conversation :
{% for memory in memories %}
{{ memory.text }}
{% endfor %}
Documents de support :
{% for doc in documents %}
{{ doc.content }}
{% endfor %}
\nQuestion : {{ query }}
\nRéponse :
"""
user_message = ChatMessage.from_user(user_message_template)
# pipeline pour le RAG : combinaison/connection
# Pour le modéle, j'utilises ici Mistral 7B
conversational_rag = Pipeline()
# composants pour reformuler la requête (si nécessaire)
conversational_rag.add_component("query_rephrase_prompt_builder", PromptBuilder(query_rephrase_template))
conversational_rag.add_component("query_rephrase_llm", HuggingFaceAPIGenerator(api_type="serverless_inference_api",
api_params={"model": "mistralai/Mistral-7B-Instruct-v0.3"},
generation_kwargs={"temperature" : 0.1}))
conversational_rag.add_component("list_to_str_adapter", OutputAdapter(template="{{ replies[0] }}", output_type=str))
# composants pour le RAG
conversational_rag.add_component("retriever", instance=ChromaQueryTextRetriever(document_store=document_store,top_k=3))
conversational_rag.add_component("prompt_builder", ChatPromptBuilder(variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"]))
conversational_rag.add_component("llm", HuggingFaceAPIChatGenerator(api_type="serverless_inference_api",
api_params={"model": "mistralai/Mistral-7B-Instruct-v0.3"},
generation_kwargs={"temperature" : 0.1, "max_tokens":300}))
# composants pour l'historique
conversational_rag.add_component("memory_retriever", ChatMessageRetriever(memory_store))
conversational_rag.add_component("memory_writer", ChatMessageWriter(memory_store))
conversational_rag.add_component("memory_joiner", ListJoiner(List[ChatMessage]))
conversational_rag.connect("memory_retriever", "query_rephrase_prompt_builder.memories")
conversational_rag.connect("query_rephrase_prompt_builder.prompt", "query_rephrase_llm")
conversational_rag.connect("query_rephrase_llm.replies", "list_to_str_adapter")
conversational_rag.connect("list_to_str_adapter", "retriever.query")
conversational_rag.connect("retriever.documents", "prompt_builder.documents")
conversational_rag.connect("prompt_builder.prompt", "llm.messages")
conversational_rag.connect("llm.replies", "memory_joiner")
conversational_rag.connect("memory_joiner", "memory_writer")
conversational_rag.connect("memory_retriever", "prompt_builder.memories")
# Interface utilisateur avec Gradio
def rag_conversation(question,history):
messages = [system_message, user_message]
res = conversational_rag.run(
data = {'query_rephrase_prompt_builder' : {'query': question},
'prompt_builder': {'template': messages, 'query': question},
'memory_joiner': {'values': [ChatMessage.from_user(question)]}},
include_outputs_from=['llm','query_rephrase_llm'])
assistant_resp = res['llm']['replies'][0].text
return assistant_resp
demo = gr.ChatInterface(
rag_conversation,
type="messages",
title="🤖 IA de comptabilité",
textbox=gr.Textbox(placeholder="Pose ta question", container=False, scale=7),
description="**Ce chatbot est une IA avancée en comptabilité des associations.**"
)
if __name__ == "__main__":
demo.launch()