File size: 2,546 Bytes
fb6f6d9
fdfcf53
4020981
a06a315
7c119cb
5b5abf5
7c119cb
 
 
fb6f6d9
 
 
7c119cb
 
5b5abf5
fb6f6d9
 
 
 
 
 
7c119cb
5b5abf5
b2369fc
5b5abf5
a06a315
fdfcf53
a06a315
 
 
5b5abf5
 
fb6f6d9
 
5b5abf5
a49ad35
5b5abf5
 
 
 
309abbd
b9221b8
5b5abf5
4bf350f
9ddf764
b9221b8
fb6f6d9
 
 
7c119cb
a49ad35
fb6f6d9
 
fdfcf53
a49ad35
 
b6628bb
fb6f6d9
 
 
 
 
 
5b5abf5
fb6f6d9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import streamlit as st
from langchain_community.llms import HuggingFaceHub
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.vectorstores import FAISS

# 1. 准备知识库数据 (示例)
knowledge_base = [
    "Gemma 是 Google 开发的大型语言模型。",
    "Gemma 具有强大的自然语言处理能力。",
    "Gemma 可以用于问答、对话、文本生成等任务。",
    "Gemma 基于 Transformer 架构。",
    "Gemma 支持多种语言。"
]

# 2. 构建向量数据库 (如果需要,仅构建一次)
try:
    embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
    db = FAISS.from_texts(knowledge_base, embeddings)
except Exception as e:
    st.error(f"向量数据库构建失败:{e}")
    st.stop()

# 3. 问答函数
def answer_question(repo_id, temperature, max_length, question):
    # 4. 初始化 Gemma 模型
    try:
        llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
    except Exception as e:
        st.error(f"Gemma 模型加载失败:{e}")
        st.stop()

    # 5. 获取答案
    try:
        question_embedding = embeddings.embed_query(question)
        question_embedding_str = " ".join(map(str, question_embedding))
        # print('question_embedding: ' + question_embedding_str)
        docs_and_scores = db.similarity_search_with_score(question_embedding_str)

        context = "\n".join([doc.page_content for doc, _ in docs_and_scores])
        print('context: ' + context)

        prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
        print('prompt: ' + prompt)

        answer = llm.invoke(prompt)
        return answer
    except Exception as e:
        st.error(f"问答过程出错:{e}")
        return "An error occurred during the answering process."

# 6. Streamlit 界面
st.title("Gemma 知识库问答系统")

gemma = st.selectbox("repo-id", ("google/gemma-2-9b-it", "google/gemma-2-2b-it", "google/recurrentgemma-2b-it"), 2)
temperature = st.number_input("temperature", value=1.0)
max_length = st.number_input("max_length", value=1024)
question = st.text_area("请输入问题", "Gemma 有哪些特点?")

if st.button("提交"):
    if not question:
        st.warning("请输入问题!")
    else:
        with st.spinner("正在查询..."):
            answer = answer_question(gemma, float(temperature), int(max_length), question)
            st.write("答案:")
            st.write(answer)