Spaces:
Paused
Paused
File size: 2,178 Bytes
6ab28e5 9bc4a6c 6ab28e5 227586c 6ab28e5 9bc4a6c 6ab28e5 227586c 6ab28e5 227586c 6ab28e5 9bc4a6c 6ab28e5 9bc4a6c 6ab28e5 9bc4a6c 6ab28e5 227586c 6ab28e5 227586c 6ab28e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import gradio as gr
from langchain.chains import RetrievalQA
from langchain.embeddings import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.vectorstores import Qdrant
from openai.error import InvalidRequestError
from qdrant_client import QdrantClient
from config import DB_CONFIG
PERSIST_DIR_NAME = "nvdajp-book"
def get_retrieval_qa(temperature: int, option: str) -> RetrievalQA:
embeddings = OpenAIEmbeddings()
db_url, db_api_key, db_collection_name = DB_CONFIG
client = QdrantClient(url=db_url, api_key=db_api_key)
db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
if option is None or option == "All":
retriever = db.as_retriever()
else:
retriever = db.as_retriever(
search_kwargs={
"filter": {"category": option},
}
)
return RetrievalQA.from_chain_type(
llm=OpenAI(temperature=temperature), chain_type="stuff", retriever=retriever, return_source_documents=True,
)
def get_related_url(metadata):
urls = set()
for m in metadata:
# p = m['source']
url = m["url"]
if url in urls:
continue
urls.add(url)
category = m["category"]
# print(m)
yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
def main(query: str, option: str, temperature: int):
qa = get_retrieval_qa(temperature, option)
try:
result = qa(query)
except InvalidRequestError as e:
return "回答が見つかりませんでした。別な質問をしてみてください", str(e)
else:
metadata = [s.metadata for s in result["source_documents"]]
html = "<div>" + "\n".join(get_related_url(metadata)) + "</div>"
return result["result"], html
nvdajp_book_qa = gr.Interface(
fn=main,
inputs=[
gr.Textbox(label="query"),
gr.Radio(["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], label="絞り込み", info="ドキュメント制限する?"),
gr.Slider(0, 2)
],
outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
)
nvdajp_book_qa.launch()
|