Learningsw / app.py
FESG1234's picture
Update app.py
5ba5aa7 verified
raw
history blame
4.43 kB
import torch
from transformers import pipeline, AutoTokenizer, AutoModel
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
import gradio as gr
import PyPDF2
import os
from huggingface_hub import login
from typing import List, Tuple
# Configuration
SPACE_DIR = os.environ.get("HF_HOME", os.getcwd())
PDF_PATH = os.path.join(SPACE_DIR, "train.pdf")
EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5"
MODEL_NAME = "google/gemma-2-2b-jpn-it"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Authentification HuggingFace
def init_huggingface_auth():
token = os.getenv("HUGGINGFACE_TOKEN")
if token:
try:
login(token=token, add_to_git_credential=False)
print("Authentification HF réussie")
return True
except Exception as e:
print(f"Erreur d'authentification: {e}")
return False
if not init_huggingface_auth():
print("Avertissement: Authentification échouée")
# Chargement et traitement du PDF
def load_and_process_pdf() -> List[str]:
with open(PDF_PATH, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
text = "\n".join([page.extract_text() for page in pdf_reader.pages])
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=512,
chunk_overlap=128,
length_function=len,
separators=["\n\n", "\n", ".", "!", "?", ";", ",", " "]
)
return text_splitter.split_text(text)
# Initialisation des modèles
def initialize_models():
embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL,
model_kwargs={'device': DEVICE},
encode_kwargs={'normalize_embeddings': True}
)
chunks = load_and_process_pdf()
vector_store = FAISS.from_texts(chunks, embeddings)
generator = pipeline(
"text-generation",
model=MODEL_NAME,
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME),
model_kwargs={"torch_dtype": torch.bfloat16},
device=DEVICE
)
return vector_store, generator
vector_store, generator = initialize_models()
# Prompt engineering
SYSTEM_PROMPT = """Vous êtes Foton, assistant virtuel expert en programmation Lugha Tausi.
Répondez en swahili sauf demande contraire. Basez-vous strictement sur la documentation fournie.
Documentation:
{context}
Question: {question}
Réponse:"""
WELCOME_MESSAGE = "**Karibu Lugha Tausi!** Mimi ni Foton, msaidizi wako wa kibinafsi. Niko hapa kukusaidia kwa masuala yoyote ya programu. **Ninaweza kukusaidiaje leo?**"
# Fonction de génération améliorée
def rag_response(query: str, history: List[Tuple[str, str]] = []) -> str:
# Recherche contextuelle
docs = vector_store.similarity_search(query, k=3)
context = "\n".join([d.page_content for d in docs])
# Construction du prompt
messages = [{"role": "user", "content": SYSTEM_PROMPT.format(context=context, question=query)}]
# Génération avec contrôle de qualité
response = generator(
messages,
max_new_tokens=512,
temperature=0.3,
top_p=0.95,
repetition_penalty=1.1,
do_sample=True,
num_return_sequences=1
)
# Post-traitement
answer = response[0]['generated_text'].split("Réponse:")[-1].strip()
return answer
# Interface Gradio améliorée
with gr.Blocks(theme=gr.themes.Soft(), css=gr.themes.Soft()._get_theme_css()) as demo:
gr.Markdown("# Foton - Msaidizi wa Lugha Tausi")
with gr.Row():
with gr.Column(scale=2):
gr.Image("foton.webp", label="Foton", width=200)
with gr.Column(scale=8):
chatbot = gr.Chatbot(
value=[(None, WELCOME_MESSAGE)],
bubble_full_width=False,
height=600
)
msg = gr.Textbox(
placeholder="Andika ujumbe wako hapa...",
label="Pitia swali lako",
container=False
)
clear = gr.Button("Safisha Mazungumzo")
def respond(message, chat_history):
response = rag_response(message)
chat_history.append((message, response))
return "", chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch(share=True, ssr_mode=False)