File size: 4,434 Bytes
ae63ae0
0b2c8a7
 
 
 
ae63ae0
0d7efa8
6ea3825
0b1754f
0b2c8a7
4748e0d
0b2c8a7
1d5f072
843fa75
0b2c8a7
 
 
1d5f072
0b2c8a7
155dcec
 
0b2c8a7
 
 
 
 
 
 
 
4ee8972
155dcec
0b2c8a7
4ee8972
0b2c8a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c267ccd
0b2c8a7
c267ccd
0b2c8a7
 
 
c267ccd
0b2c8a7
 
155dcec
0b2c8a7
 
155dcec
0b2c8a7
e10ac2b
0b2c8a7
 
 
 
 
e10ac2b
0b2c8a7
 
e10ac2b
0b2c8a7
 
 
 
 
 
 
 
 
c267ccd
0b2c8a7
 
 
 
c267ccd
0b2c8a7
 
 
 
 
 
 
 
 
 
 
 
e10ac2b
0b2c8a7
 
 
 
 
c267ccd
0b2c8a7
 
 
 
 
 
 
c267ccd
0b2c8a7
 
0c3cf4a
 
5ba5aa7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
from transformers import pipeline, AutoTokenizer, AutoModel
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
import gradio as gr
import PyPDF2
import os
from huggingface_hub import login
from typing import List, Tuple

# Configuration
SPACE_DIR = os.environ.get("HF_HOME", os.getcwd())
PDF_PATH = os.path.join(SPACE_DIR, "train.pdf")
EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5"
MODEL_NAME = "google/gemma-2-2b-jpn-it"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Authentification HuggingFace
def init_huggingface_auth():
    token = os.getenv("HUGGINGFACE_TOKEN")
    if token:
        try:
            login(token=token, add_to_git_credential=False)
            print("Authentification HF réussie")
            return True
        except Exception as e:
            print(f"Erreur d'authentification: {e}")
    return False

if not init_huggingface_auth():
    print("Avertissement: Authentification échouée")

# Chargement et traitement du PDF
def load_and_process_pdf() -> List[str]:
    with open(PDF_PATH, 'rb') as file:
        pdf_reader = PyPDF2.PdfReader(file)
        text = "\n".join([page.extract_text() for page in pdf_reader.pages])
    
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=512,
        chunk_overlap=128,
        length_function=len,
        separators=["\n\n", "\n", ".", "!", "?", ";", ",", " "]
    )
    return text_splitter.split_text(text)

# Initialisation des modèles
def initialize_models():
    embeddings = HuggingFaceEmbeddings(
        model_name=EMBEDDING_MODEL,
        model_kwargs={'device': DEVICE},
        encode_kwargs={'normalize_embeddings': True}
    )
    
    chunks = load_and_process_pdf()
    vector_store = FAISS.from_texts(chunks, embeddings)
    
    generator = pipeline(
        "text-generation",
        model=MODEL_NAME,
        tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME),
        model_kwargs={"torch_dtype": torch.bfloat16},
        device=DEVICE
    )
    
    return vector_store, generator

vector_store, generator = initialize_models()

# Prompt engineering
SYSTEM_PROMPT = """Vous êtes Foton, assistant virtuel expert en programmation Lugha Tausi. 
Répondez en swahili sauf demande contraire. Basez-vous strictement sur la documentation fournie.

Documentation:
{context}

Question: {question}
Réponse:"""

WELCOME_MESSAGE = "**Karibu Lugha Tausi!** Mimi ni Foton, msaidizi wako wa kibinafsi. Niko hapa kukusaidia kwa masuala yoyote ya programu. **Ninaweza kukusaidiaje leo?**"

# Fonction de génération améliorée
def rag_response(query: str, history: List[Tuple[str, str]] = []) -> str:
    # Recherche contextuelle
    docs = vector_store.similarity_search(query, k=3)
    context = "\n".join([d.page_content for d in docs])
    
    # Construction du prompt
    messages = [{"role": "user", "content": SYSTEM_PROMPT.format(context=context, question=query)}]
    
    # Génération avec contrôle de qualité
    response = generator(
        messages,
        max_new_tokens=512,
        temperature=0.3,
        top_p=0.95,
        repetition_penalty=1.1,
        do_sample=True,
        num_return_sequences=1
    )
    
    # Post-traitement
    answer = response[0]['generated_text'].split("Réponse:")[-1].strip()
    return answer

# Interface Gradio améliorée
with gr.Blocks(theme=gr.themes.Soft(), css=gr.themes.Soft()._get_theme_css()) as demo:
    gr.Markdown("# Foton - Msaidizi wa Lugha Tausi")
    
    with gr.Row():
        with gr.Column(scale=2):
            gr.Image("foton.webp", label="Foton", width=200)
        with gr.Column(scale=8):
            chatbot = gr.Chatbot(
                value=[(None, WELCOME_MESSAGE)],
                bubble_full_width=False,
                height=600
            )
            
    msg = gr.Textbox(
        placeholder="Andika ujumbe wako hapa...",
        label="Pitia swali lako",
        container=False
    )
    
    clear = gr.Button("Safisha Mazungumzo")
    
    def respond(message, chat_history):
        response = rag_response(message)
        chat_history.append((message, response))
        return "", chat_history

    msg.submit(respond, [msg, chatbot], [msg, chatbot])
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.launch(share=True, ssr_mode=False)