import torch from transformers import pipeline, AutoTokenizer, AutoModel from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import FAISS from langchain.embeddings import HuggingFaceEmbeddings import gradio as gr import PyPDF2 import os from huggingface_hub import login from typing import List, Tuple # Configuration SPACE_DIR = os.environ.get("HF_HOME", os.getcwd()) PDF_PATH = os.path.join(SPACE_DIR, "train.pdf") EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5" MODEL_NAME = "google/gemma-2-2b-jpn-it" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Authentification HuggingFace def init_huggingface_auth(): token = os.getenv("HUGGINGFACE_TOKEN") if token: try: login(token=token, add_to_git_credential=False) print("Authentification HF réussie") return True except Exception as e: print(f"Erreur d'authentification: {e}") return False if not init_huggingface_auth(): print("Avertissement: Authentification échouée") # Chargement et traitement du PDF def load_and_process_pdf() -> List[str]: with open(PDF_PATH, 'rb') as file: pdf_reader = PyPDF2.PdfReader(file) text = "\n".join([page.extract_text() for page in pdf_reader.pages]) text_splitter = RecursiveCharacterTextSplitter( chunk_size=512, chunk_overlap=128, length_function=len, separators=["\n\n", "\n", ".", "!", "?", ";", ",", " "] ) return text_splitter.split_text(text) # Initialisation des modèles def initialize_models(): embeddings = HuggingFaceEmbeddings( model_name=EMBEDDING_MODEL, model_kwargs={'device': DEVICE}, encode_kwargs={'normalize_embeddings': True} ) chunks = load_and_process_pdf() vector_store = FAISS.from_texts(chunks, embeddings) generator = pipeline( "text-generation", model=MODEL_NAME, tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME), model_kwargs={"torch_dtype": torch.bfloat16}, device=DEVICE ) return vector_store, generator vector_store, generator = initialize_models() # Prompt engineering SYSTEM_PROMPT = """Vous êtes Foton, assistant virtuel expert en programmation Lugha Tausi. Répondez en swahili sauf demande contraire. Basez-vous strictement sur la documentation fournie. Documentation: {context} Question: {question} Réponse:""" WELCOME_MESSAGE = "**Karibu Lugha Tausi!** Mimi ni Foton, msaidizi wako wa kibinafsi. Niko hapa kukusaidia kwa masuala yoyote ya programu. **Ninaweza kukusaidiaje leo?**" # Fonction de génération améliorée def rag_response(query: str, history: List[Tuple[str, str]] = []) -> str: # Recherche contextuelle docs = vector_store.similarity_search(query, k=3) context = "\n".join([d.page_content for d in docs]) # Construction du prompt messages = [{"role": "user", "content": SYSTEM_PROMPT.format(context=context, question=query)}] # Génération avec contrôle de qualité response = generator( messages, max_new_tokens=512, temperature=0.3, top_p=0.95, repetition_penalty=1.1, do_sample=True, num_return_sequences=1 ) # Post-traitement answer = response[0]['generated_text'].split("Réponse:")[-1].strip() return answer # Interface Gradio améliorée with gr.Blocks(theme=gr.themes.Soft(), css=gr.themes.Soft()._get_theme_css()) as demo: gr.Markdown("# Foton - Msaidizi wa Lugha Tausi") with gr.Row(): with gr.Column(scale=2): gr.Image("foton.webp", label="Foton", width=200) with gr.Column(scale=8): chatbot = gr.Chatbot( value=[(None, WELCOME_MESSAGE)], bubble_full_width=False, height=600 ) msg = gr.Textbox( placeholder="Andika ujumbe wako hapa...", label="Pitia swali lako", container=False ) clear = gr.Button("Safisha Mazungumzo") def respond(message, chat_history): response = rag_response(message) chat_history.append((message, response)) return "", chat_history msg.submit(respond, [msg, chatbot], [msg, chatbot]) clear.click(lambda: None, None, chatbot, queue=False) if __name__ == "__main__": demo.launch(share=True, ssr_mode=False)