Spaces:
Running
Running
import torch | |
from transformers import pipeline, AutoTokenizer, AutoModel | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
import gradio as gr | |
import PyPDF2 | |
import os | |
from huggingface_hub import login | |
from typing import List, Tuple | |
# Configuration | |
SPACE_DIR = os.environ.get("HF_HOME", os.getcwd()) | |
PDF_PATH = os.path.join(SPACE_DIR, "train.pdf") | |
EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5" | |
MODEL_NAME = "google/gemma-2-2b-jpn-it" | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Authentification HuggingFace | |
def init_huggingface_auth(): | |
token = os.getenv("HUGGINGFACE_TOKEN") | |
if token: | |
try: | |
login(token=token, add_to_git_credential=False) | |
print("Authentification HF réussie") | |
return True | |
except Exception as e: | |
print(f"Erreur d'authentification: {e}") | |
return False | |
if not init_huggingface_auth(): | |
print("Avertissement: Authentification échouée") | |
# Chargement et traitement du PDF | |
def load_and_process_pdf() -> List[str]: | |
with open(PDF_PATH, 'rb') as file: | |
pdf_reader = PyPDF2.PdfReader(file) | |
text = "\n".join([page.extract_text() for page in pdf_reader.pages]) | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=512, | |
chunk_overlap=128, | |
length_function=len, | |
separators=["\n\n", "\n", ".", "!", "?", ";", ",", " "] | |
) | |
return text_splitter.split_text(text) | |
# Initialisation des modèles | |
def initialize_models(): | |
embeddings = HuggingFaceEmbeddings( | |
model_name=EMBEDDING_MODEL, | |
model_kwargs={'device': DEVICE}, | |
encode_kwargs={'normalize_embeddings': True} | |
) | |
chunks = load_and_process_pdf() | |
vector_store = FAISS.from_texts(chunks, embeddings) | |
generator = pipeline( | |
"text-generation", | |
model=MODEL_NAME, | |
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME), | |
model_kwargs={"torch_dtype": torch.bfloat16}, | |
device=DEVICE | |
) | |
return vector_store, generator | |
vector_store, generator = initialize_models() | |
# Prompt engineering | |
SYSTEM_PROMPT = """Vous êtes Foton, assistant virtuel expert en programmation Lugha Tausi. | |
Répondez en swahili sauf demande contraire. Basez-vous strictement sur la documentation fournie. | |
Documentation: | |
{context} | |
Question: {question} | |
Réponse:""" | |
WELCOME_MESSAGE = "**Karibu Lugha Tausi!** Mimi ni Foton, msaidizi wako wa kibinafsi. Niko hapa kukusaidia kwa masuala yoyote ya programu. **Ninaweza kukusaidiaje leo?**" | |
# Fonction de génération améliorée | |
def rag_response(query: str, history: List[Tuple[str, str]] = []) -> str: | |
# Recherche contextuelle | |
docs = vector_store.similarity_search(query, k=3) | |
context = "\n".join([d.page_content for d in docs]) | |
# Construction du prompt | |
messages = [{"role": "user", "content": SYSTEM_PROMPT.format(context=context, question=query)}] | |
# Génération avec contrôle de qualité | |
response = generator( | |
messages, | |
max_new_tokens=512, | |
temperature=0.3, | |
top_p=0.95, | |
repetition_penalty=1.1, | |
do_sample=True, | |
num_return_sequences=1 | |
) | |
# Post-traitement | |
answer = response[0]['generated_text'].split("Réponse:")[-1].strip() | |
return answer | |
# Interface Gradio améliorée | |
with gr.Blocks(theme=gr.themes.Soft(), css=gr.themes.Soft()._get_theme_css()) as demo: | |
gr.Markdown("# Foton - Msaidizi wa Lugha Tausi") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gr.Image("foton.webp", label="Foton", width=200) | |
with gr.Column(scale=8): | |
chatbot = gr.Chatbot( | |
value=[(None, WELCOME_MESSAGE)], | |
bubble_full_width=False, | |
height=600 | |
) | |
msg = gr.Textbox( | |
placeholder="Andika ujumbe wako hapa...", | |
label="Pitia swali lako", | |
container=False | |
) | |
clear = gr.Button("Safisha Mazungumzo") | |
def respond(message, chat_history): | |
response = rag_response(message) | |
chat_history.append((message, response)) | |
return "", chat_history | |
msg.submit(respond, [msg, chatbot], [msg, chatbot]) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
if __name__ == "__main__": | |
demo.launch(share=True, ssr_mode=False) |