File size: 2,546 Bytes
fb6f6d9 fdfcf53 4020981 a06a315 7c119cb 5b5abf5 7c119cb fb6f6d9 7c119cb 5b5abf5 fb6f6d9 7c119cb 5b5abf5 b2369fc 5b5abf5 a06a315 fdfcf53 a06a315 5b5abf5 fb6f6d9 5b5abf5 a49ad35 5b5abf5 309abbd b9221b8 5b5abf5 4bf350f 9ddf764 b9221b8 fb6f6d9 7c119cb a49ad35 fb6f6d9 fdfcf53 a49ad35 b6628bb fb6f6d9 5b5abf5 fb6f6d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import streamlit as st
from langchain_community.llms import HuggingFaceHub
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.vectorstores import FAISS
# 1. 准备知识库数据 (示例)
knowledge_base = [
"Gemma 是 Google 开发的大型语言模型。",
"Gemma 具有强大的自然语言处理能力。",
"Gemma 可以用于问答、对话、文本生成等任务。",
"Gemma 基于 Transformer 架构。",
"Gemma 支持多种语言。"
]
# 2. 构建向量数据库 (如果需要,仅构建一次)
try:
embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
db = FAISS.from_texts(knowledge_base, embeddings)
except Exception as e:
st.error(f"向量数据库构建失败:{e}")
st.stop()
# 3. 问答函数
def answer_question(repo_id, temperature, max_length, question):
# 4. 初始化 Gemma 模型
try:
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
except Exception as e:
st.error(f"Gemma 模型加载失败:{e}")
st.stop()
# 5. 获取答案
try:
question_embedding = embeddings.embed_query(question)
question_embedding_str = " ".join(map(str, question_embedding))
# print('question_embedding: ' + question_embedding_str)
docs_and_scores = db.similarity_search_with_score(question_embedding_str)
context = "\n".join([doc.page_content for doc, _ in docs_and_scores])
print('context: ' + context)
prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
print('prompt: ' + prompt)
answer = llm.invoke(prompt)
return answer
except Exception as e:
st.error(f"问答过程出错:{e}")
return "An error occurred during the answering process."
# 6. Streamlit 界面
st.title("Gemma 知识库问答系统")
gemma = st.selectbox("repo-id", ("google/gemma-2-9b-it", "google/gemma-2-2b-it", "google/recurrentgemma-2b-it"), 2)
temperature = st.number_input("temperature", value=1.0)
max_length = st.number_input("max_length", value=1024)
question = st.text_area("请输入问题", "Gemma 有哪些特点?")
if st.button("提交"):
if not question:
st.warning("请输入问题!")
else:
with st.spinner("正在查询..."):
answer = answer_question(gemma, float(temperature), int(max_length), question)
st.write("答案:")
st.write(answer) |