Spaces:
Running
Running
File size: 5,813 Bytes
d3906fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import os
import gradio as gr
from huggingface_hub import hf_hub_download
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import LlamaCpp
REPO_ID = "ryota39/gemma-2-2b-jpn-it-q8"
FILENAME = "gemma-2-2b-jpn-it-Q8_0.gguf"
def get_model_path():
return hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
repo_type="model",
)
GGUF_MODEL_PATH = get_model_path()
VECTOR_DB_PATH = "./vectorstore/ruri-large"
EMBEDDING_MODEL = "cl-nagoya/ruri-large"
class RAGSystem:
def __init__(self):
self.vectorstore = None
self.qa_chain = None
self.setup_models()
def setup_models(self):
self.embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL,
model_kwargs={"device": "cpu"},
)
try:
self.load_vectorstore()
except Exception as e:
print(f"ベクトルDBの読み込みに失敗しました: {str(e)}")
try:
self.llm = LlamaCpp(
model_path=GGUF_MODEL_PATH,
temperature=0.7,
max_tokens=512,
n_ctx=2048, # コンテキスト長
n_threads=8, # 使用するCPUスレッド数
n_gpu_layers=-1, # 可能であればGPUレイヤーを全て使用
verbose=False,
streaming=True,
model_kwargs={"f16_kv": True},
)
if self.vectorstore:
self.setup_qa_chain()
except Exception as e:
print(f"LLMの読み込みに失敗しました: {str(e)}")
def load_vectorstore(self):
if os.path.exists(VECTOR_DB_PATH):
self.vectorstore = FAISS.load_local(
VECTOR_DB_PATH,
self.embeddings,
allow_dangerous_deserialization=True,
)
if self.llm:
self.setup_qa_chain()
return True
return False
def setup_qa_chain(self):
if self.vectorstore and self.llm:
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.vectorstore.as_retriever(search_kwargs={"k": 3}),
)
return True
return False
def answer_question_stream(self, question):
if not self.qa_chain:
if not self.vectorstore:
yield "ベクトルDBが読み込まれていません。"
return
if not self.llm:
yield "LLMモデルが読み込まれていません。"
return
yield "QAチェーンの初期化に失敗しました。"
return
try:
docs = self.vectorstore.similarity_search(question, k=3)
context = "\n\n".join([doc.page_content for doc in docs])
prompt = f"""与えられた文書を用いて、質問に対する適切な応答を書きなさい。
文書: {context}
質問: {question}
応答: """
response = ""
for chunk in self.llm._stream(prompt):
if isinstance(chunk, str):
response += chunk
else:
response += chunk.text
yield response
except Exception as e:
yield f"回答生成中にエラーが発生しました: {str(e)}"
def get_system_status(self):
status = list()
if os.path.exists(GGUF_MODEL_PATH):
model_size = os.path.getsize(GGUF_MODEL_PATH) / (1024 * 1024 * 1024)
status.append(
f"✅ LLMモデル: {os.path.basename(GGUF_MODEL_PATH)} ({model_size:.2f} GB)"
)
else:
status.append(f"❌ LLMモデル: {GGUF_MODEL_PATH} が見つかりません")
if os.path.exists(VECTOR_DB_PATH):
status.append(f"✅ ベクトルDB: {VECTOR_DB_PATH}")
else:
status.append(f"❌ ベクトルDB: {VECTOR_DB_PATH} が見つかりません")
status.append(f"✅ 埋め込みモデル: {EMBEDDING_MODEL}")
if self.qa_chain:
status.append("✅ RAGシステム: 準備完了")
else:
status.append("❌ RAGシステム: 初期化されていません")
return "\n".join(status)
rag_system = RAGSystem()
with gr.Blocks(title="RAGデモアプリ") as demo:
gr.Markdown("# 🎇 Sake RAG デモアプリ")
gr.Markdown("醸造協会誌5年分のデータをベクトルDBとして保持した2B級の小型モデルです")
with gr.Row():
with gr.Column(scale=1):
refresh_button = gr.Button("システム状態を更新", variant="secondary")
status_output = gr.Textbox(
label="システム状態",
value=rag_system.get_system_status(),
interactive=False,
lines=5,
)
with gr.Column(scale=2):
question_input = gr.Textbox(
label="質問を入力してください",
placeholder="質問を入力してください",
lines=2,
)
submit_button = gr.Button("質問する", variant="primary")
answer_output = gr.Textbox(label="回答", interactive=False, lines=10)
refresh_button.click(
fn=rag_system.get_system_status,
inputs=[],
outputs=[status_output],
)
submit_button.click(
fn=rag_system.answer_question_stream,
inputs=[question_input],
outputs=[answer_output],
)
if __name__ == "__main__":
demo.launch()
|