|
import gradio as gr |
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
import pandas as pd |
|
import os |
|
|
|
os.environ['CURL_CA_BUNDLE'] = '' |
|
|
|
|
|
issues_dataset = load_dataset("gvozdev/subspace-info-v2", split="train") |
|
|
|
|
|
model_ckpt = "sentence-transformers/all-MiniLM-L12-v1" |
|
tokenizer = AutoTokenizer.from_pretrained(model_ckpt) |
|
model = AutoModel.from_pretrained(model_ckpt, trust_remote_code=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
issues_dataset = issues_dataset.map() |
|
|
|
|
|
device = torch.device("cpu") |
|
model.to(device) |
|
|
|
|
|
|
|
def cls_pooling(model_output): |
|
return model_output.last_hidden_state[:, 0] |
|
|
|
|
|
|
|
|
|
def get_embeddings(text_list): |
|
encoded_input = tokenizer( |
|
text_list, padding=True, truncation=True, return_tensors="pt" |
|
) |
|
encoded_input = {k: v.to(device) for k, v in encoded_input.items()} |
|
model_output = model(**encoded_input) |
|
return cls_pooling(model_output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embeddings_dataset = issues_dataset.map( |
|
lambda x: {"embeddings": get_embeddings(x["subject"]).detach().cpu().numpy()[0]} |
|
) |
|
|
|
|
|
embeddings_dataset.add_faiss_index(column="embeddings") |
|
|
|
|
|
|
|
def answer_question(question): |
|
|
|
question_embedding = get_embeddings([question]).cpu().detach().numpy() |
|
|
|
|
|
scores, samples = embeddings_dataset.get_nearest_examples( |
|
"embeddings", question_embedding, k=1 |
|
) |
|
|
|
samples_df = pd.DataFrame.from_dict(samples) |
|
|
|
|
|
|
|
|
|
|
|
return samples_df["details"].values[0] |
|
|
|
|
|
|
|
title = "Subspace Docs bot" |
|
description = '<p style="text-align: center;">This is a bot trained on Subspace Network documentation ' \ |
|
'to answer the most common questions about the project</p>' |
|
|
|
|
|
def chat(message, history): |
|
history = history or [] |
|
response = answer_question(message) |
|
history.append((message, response)) |
|
return history, history |
|
|
|
|
|
iface = gr.Interface( |
|
chat, |
|
["text", "state"], |
|
["chatbot", "state"], |
|
allow_flagging="never", |
|
title=title, |
|
description=description, |
|
theme="Monochrome", |
|
examples=["What is Subspace Network?", "Do you have a token?", "System requirements"] |
|
) |
|
|
|
iface.launch(share=False) |
|
|