import streamlit as st import os from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.prompts import ChatPromptTemplate from langchain_community.document_loaders import TextLoader from langchain_huggingface import HuggingFaceEmbeddings from langchain.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain.vectorstores import Chroma from chromadb.config import Settings from langchain_text_splitters import RecursiveCharacterTextSplitter # Tiêu đề ứng dụng page = st.title("Chat with AskUSTH") # Khởi tạo trạng thái phiên if "gemini_api" not in st.session_state: st.session_state.gemini_api = None if "rag" not in st.session_state: st.session_state.rag = None if "llm" not in st.session_state: st.session_state.llm = None if "embd" not in st.session_state: st.session_state.embd = None if "model" not in st.session_state: st.session_state.model = None if "save_dir" not in st.session_state: st.session_state.save_dir = None if "uploaded_files" not in st.session_state: st.session_state.uploaded_files = set() if "chat_history" not in st.session_state: st.session_state.chat_history = [] # Hàm tải và xử lý file văn bản def load_txt(file_path): loader = TextLoader(file_path=file_path, encoding="utf-8") doc = loader.load() return doc # Hàm định dạng văn bản def format_docs(docs): """Định dạng các tài liệu thành chuỗi văn bản.""" return "\n\n".join(doc.page_content for doc in docs) # Hàm thiết lập mô hình Google Gemini @st.cache_resource def get_chat_google_model(api_key): os.environ["GOOGLE_API_KEY"] = api_key return ChatGoogleGenerativeAI( model="gemini-1.5-pro", temperature=0, max_tokens=None, timeout=None, max_retries=2, ) # Hàm thiết lập mô hình embedding @st.cache_resource def get_embedding_model(): model_name = "bkai-foundation-models/vietnamese-bi-encoder" model_kwargs = {'device': 'cpu'} encode_kwargs = {'normalize_embeddings': False} model = HuggingFaceEmbeddings( model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs ) return model # Hàm tạo RAG Chain @st.cache_resource def compute_rag_chain(_model, _embd, docs_texts): if not docs_texts: raise ValueError("Không có tài liệu nào để xử lý. Vui lòng tải lên các tệp hợp lệ.") combined_text = "\n\n".join(docs_texts) text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) texts = text_splitter.split_text(combined_text) if len(texts) > 5000: raise ValueError("Tài liệu tạo ra quá nhiều đoạn. Vui lòng sử dụng tài liệu nhỏ hơn.") # Tạo thư mục lưu trữ persist_dir = "./chromadb_store" if not os.path.exists(persist_dir): os.makedirs(persist_dir) # Khởi tạo Chroma với cấu hình lưu trữ settings = Settings(persist_directory=persist_dir) # Khởi tạo Chroma và lưu dữ liệu vectorstore = Chroma.from_texts(texts=texts, embedding=_embd, client_settings=settings) retriever = vectorstore.as_retriever() # Template cho prompt template = """ Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời. Dưới đây là thông tin liên quan mà bạn cần sử dụng tới: {context} hãy trả lời: {question} """ prompt = PromptTemplate(template=template, input_variables=["context", "question"]) rag_chain = ( {"context": retriever | format_docs, "question": RunnablePassthrough()} | prompt | _model | StrOutputParser() ) return rag_chain # Dialog cài đặt Google Gemini @st.dialog("Setup Gemini") def setup_gemini(): st.markdown( """ Để sử dụng Google Gemini, bạn cần cung cấp API key. Tạo key của bạn [tại đây](https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python&hl=vi) và dán vào bên dưới. """ ) key = st.text_input("Key:", "") if st.button("Save") and key != "": st.session_state.gemini_api = key st.rerun() if st.session_state.gemini_api is None: setup_gemini() if st.session_state.gemini_api and st.session_state.model is None: st.session_state.model = get_chat_google_model(st.session_state.gemini_api) if st.session_state.embd is None: st.session_state.embd = get_embedding_model() if st.session_state.save_dir is None: save_dir = "./Documents" if not os.path.exists(save_dir): os.makedirs(save_dir) st.session_state.save_dir = save_dir # Cập nhật xử lý Sidebar with st.sidebar: uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"]) max_file_size_mb = 5 if uploaded_files: documents = [] for uploaded_file in uploaded_files: if uploaded_file.size > max_file_size_mb * 1024 * 1024: st.warning(f"Tệp {uploaded_file.name} vượt quá giới hạn {max_file_size_mb}MB.") continue file_path = os.path.join(st.session_state.save_dir, uploaded_file.name) with open(file_path, mode='wb') as w: w.write(uploaded_file.getvalue()) doc = load_txt(file_path) documents.extend([*doc]) if documents: docs_texts = [d.page_content for d in documents] st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts) # Giao diện chat for message in st.session_state.chat_history: with st.chat_message(message["role"]): st.write(message["content"]) prompt = st.chat_input("Bạn muốn hỏi gì?") if st.session_state.model is not None: if prompt: st.session_state.chat_history.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.write(prompt) with st.chat_message("assistant"): if st.session_state.rag is not None: response = st.session_state.rag.invoke(prompt) st.write(response) else: ans = st.session_state.llm.invoke(prompt) response = ans.content st.write(response) st.session_state.chat_history.append({"role": "assistant", "content": response})