Spaces:
Sleeping
Sleeping
import gradio as gr | |
from haystack_integrations.document_stores.chroma import ChromaDocumentStore | |
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder | |
from haystack_integrations.components.retrievers.chroma import ChromaQueryTextRetriever | |
from haystack.components.converters import PyPDFToDocument | |
from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter | |
from haystack.components.writers import DocumentWriter | |
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator | |
from haystack.components.generators import HuggingFaceAPIGenerator | |
from haystack.components.builders import PromptBuilder,ChatPromptBuilder | |
from haystack.components.converters import OutputAdapter | |
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore | |
from haystack_experimental.components.retrievers import ChatMessageRetriever | |
from haystack_experimental.components.writers import ChatMessageWriter | |
from haystack.dataclasses import ChatMessage | |
from itertools import chain | |
from typing import Any, List | |
from haystack import component | |
from haystack.core.component.types import Variadic | |
from haystack import Pipeline | |
from haystack import Document | |
from pathlib import Path | |
import os | |
# Chargement des documents | |
def load_and_convert_pdfs(data_dir): | |
"""Charge et convertit tous les fichiers PDF dans un répertoire donné. | |
Args: | |
data_dir: Le chemin d'accès au répertoire contenant les fichiers PDF. | |
Returns: | |
Une liste de documents convertis. | |
""" | |
pdf_converter = PyPDFToDocument() | |
all_docs = [] | |
for filepath in Path(data_dir).glob("*.pdf"): | |
try: | |
docs = pdf_converter.run(sources=[filepath]) | |
all_docs.extend(docs["documents"]) | |
except Exception as e: | |
print(f"Erreur lors du traitement du fichier {filepath}: {e}") | |
return all_docs | |
data_dir = "/data/" | |
all_documents = load_and_convert_pdfs(data_dir) | |
# Pré-Traitement des documents | |
document_cleaner = DocumentCleaner() | |
document_splitter = DocumentSplitter(split_by="period",split_length=1000, split_overlap=200) | |
# Chargement de l'embeddings | |
document_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2") | |
# Stockage de l'embedding des docs dans une base de donnée vectorielle (chromaDb ici) | |
document_store = ChromaDocumentStore() | |
document_writer = DocumentWriter(document_store=document_store) | |
# Pipeline pour aller plus vite et combiner/connecter tous les composants | |
prepocessing_pipeline = Pipeline() | |
prepocessing_pipeline.add_component(name="cleaner", instance=DocumentCleaner()) # document_cleaner() | |
prepocessing_pipeline.add_component(name="splitter", instance=DocumentSplitter(split_by="period", split_length=1000, split_overlap=200)) # document_splitter | |
prepocessing_pipeline.add_component(name="embedder", instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")) # document_embedder | |
prepocessing_pipeline.add_component(name="writer", instance=DocumentWriter(document_store=document_store)) # document_writer | |
prepocessing_pipeline.connect("cleaner.documents", "splitter.documents") | |
prepocessing_pipeline.connect("splitter.documents", "embedder.documents") | |
prepocessing_pipeline.connect("embedder.documents", "writer.documents") | |
# execute le pipeline | |
# pré-traitement et stockage des embeddings dans 'document_store' | |
prepocessing_pipeline.run({"cleaner": {"documents": all_documents}}) | |
############################################### | |
# Aprés l'étape de pré-traitement | |
# passe à la récupération et à la génération. | |
############################################### | |
# On va incorporer ce composant dans le pipeline afin | |
# de gérer les messages provenant de l'utilisateur et du llm | |
class ListJoiner: | |
def __init__(self, _type: Any): | |
component.set_output_types(self, values=_type) | |
def run(self, values: Variadic[Any]): | |
result = list(chain(*values)) | |
return {"values": result} | |
# composants pour gerer l'historique | |
memory_store = InMemoryChatMessageStore() | |
memory_retriever = ChatMessageRetriever(memory_store) | |
memory_writer = ChatMessageWriter(memory_store) | |
query_rephrase_template = """ | |
Réécris la question pour la recherche en conservant son sens et ses termes clés intacts. | |
Si l'historique de conversation est vide, NE MODIFIE PAS la requête. | |
Utilise l'historique de conversation uniquement si nécessaire, et évite d'enrichir la requête avec tes propres connaissances. | |
Si aucune modification n'est nécessaire, renvoie la question actuelle telle quelle. | |
Historique de conversation : | |
{% for memory in memories %} | |
{{ memory.text }} | |
{% endfor %} | |
Requête utilisateur : {{query}} | |
Requête reformulée : | |
""" | |
system_message = ChatMessage.from_system( | |
"""T'es une intelligence artificielle en comptabilité.""" | |
"""Tu utilises les documents de support fournis et l'historique de conversation pour aider les utilisateurs sur des questions de comptabilité.""" | |
) | |
# Template de message utilisateur en français | |
user_message_template = """ | |
Étant donné l'historique de conversation et les documents de support fournis, donne une réponse concise à la question. | |
Note que les documents de support ne font pas partie de la conversation. Si la question ne peut pas être répondue à partir des documents de support, indique le. | |
Historique de conversation : | |
{% for memory in memories %} | |
{{ memory.text }} | |
{% endfor %} | |
Documents de support : | |
{% for doc in documents %} | |
{{ doc.content }} | |
{% endfor %} | |
\nQuestion : {{ query }} | |
\nRéponse : | |
""" | |
user_message = ChatMessage.from_user(user_message_template) | |
# pipeline pour le RAG : combinaison/connection | |
# Pour le modéle, j'utilises ici Mistral 7B | |
conversational_rag = Pipeline() | |
# composants pour reformuler la requête (si nécessaire) | |
conversational_rag.add_component("query_rephrase_prompt_builder", PromptBuilder(query_rephrase_template)) | |
conversational_rag.add_component("query_rephrase_llm", HuggingFaceAPIGenerator(api_type="serverless_inference_api", | |
api_params={"model": "mistralai/Mistral-7B-Instruct-v0.3"}, | |
generation_kwargs={"temperature" : 0.1})) | |
conversational_rag.add_component("list_to_str_adapter", OutputAdapter(template="{{ replies[0] }}", output_type=str)) | |
# composants pour le RAG | |
conversational_rag.add_component("retriever", instance=ChromaQueryTextRetriever(document_store=document_store,top_k=3)) | |
conversational_rag.add_component("prompt_builder", ChatPromptBuilder(variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"])) | |
conversational_rag.add_component("llm", HuggingFaceAPIChatGenerator(api_type="serverless_inference_api", | |
api_params={"model": "mistralai/Mistral-7B-Instruct-v0.3"}, | |
generation_kwargs={"temperature" : 0.1, "max_tokens":300})) | |
# composants pour l'historique | |
conversational_rag.add_component("memory_retriever", ChatMessageRetriever(memory_store)) | |
conversational_rag.add_component("memory_writer", ChatMessageWriter(memory_store)) | |
conversational_rag.add_component("memory_joiner", ListJoiner(List[ChatMessage])) | |
conversational_rag.connect("memory_retriever", "query_rephrase_prompt_builder.memories") | |
conversational_rag.connect("query_rephrase_prompt_builder.prompt", "query_rephrase_llm") | |
conversational_rag.connect("query_rephrase_llm.replies", "list_to_str_adapter") | |
conversational_rag.connect("list_to_str_adapter", "retriever.query") | |
conversational_rag.connect("retriever.documents", "prompt_builder.documents") | |
conversational_rag.connect("prompt_builder.prompt", "llm.messages") | |
conversational_rag.connect("llm.replies", "memory_joiner") | |
conversational_rag.connect("memory_joiner", "memory_writer") | |
conversational_rag.connect("memory_retriever", "prompt_builder.memories") | |
# Interface utilisateur avec Gradio | |
def rag_conversation(question,history): | |
messages = [system_message, user_message] | |
res = conversational_rag.run( | |
data = {'query_rephrase_prompt_builder' : {'query': question}, | |
'prompt_builder': {'template': messages, 'query': question}, | |
'memory_joiner': {'values': [ChatMessage.from_user(question)]}}, | |
include_outputs_from=['llm','query_rephrase_llm']) | |
assistant_resp = res['llm']['replies'][0].text | |
return assistant_resp | |
demo = gr.ChatInterface( | |
rag_conversation, | |
type="messages", | |
title="🤖 IA de comptabilité", | |
textbox=gr.Textbox(placeholder="Pose ta question", container=False, scale=7), | |
description="**Ce chatbot est une IA avancée en comptabilité des associations.**" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |