|
|
|
|
|
|
|
from typing import TypedDict |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_huggingface import HuggingFacePipeline |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from transformers import pipeline |
|
|
|
import torch |
|
|
|
import gradio as gr |
|
|
|
from asyncio import sleep |
|
|
|
from vector_store import get_document_database |
|
|
|
|
|
class ChatMessage(TypedDict): |
|
role: str |
|
metadata: dict |
|
content: str |
|
|
|
|
|
|
|
|
|
MODEL_NAME = "google/gemma-2-2b-it" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
|
|
|
|
torch_dtype=torch.bfloat16 |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
|
text_generation_pipeline = pipeline( |
|
model=model, |
|
tokenizer=tokenizer, |
|
task="text-generation", |
|
temperature=0.2, |
|
do_sample=True, |
|
repetition_penalty=1.1, |
|
return_full_text=True, |
|
max_new_tokens=400, |
|
) |
|
|
|
llm = HuggingFacePipeline(pipeline=text_generation_pipeline) |
|
|
|
|
|
print("creating the document database") |
|
db = get_document_database("learning_material/*/*/*") |
|
print("Document database is ready") |
|
|
|
|
|
def generate_prompt(message_history: list[ChatMessage], max_history=5): |
|
|
|
|
|
|
|
prompt_template = ChatPromptTemplate([ |
|
("system", """You are 'thesizer', a HAMK thesis assistant. |
|
You will help the user with technicalities on writing a thesis |
|
for hamk. If you can't find the answer from the context given to you, |
|
you will tell the user that you cannot assist with the specific topic. |
|
You speak both Finnish and English by following the user's language. |
|
Continue the conversation with a single response from the AI."""), |
|
("system", "{context}"), |
|
]) |
|
|
|
|
|
if len(message_history) < 4: |
|
prompt_template.append( |
|
("assistant", "Hei! Kuinka voin auttaa opinnäytetyösi kanssa?"), |
|
) |
|
prompt_template.append( |
|
("assistant", "Hello! How can I help you with authoring your thesis?"), |
|
) |
|
|
|
|
|
for message in message_history[-max_history:]: |
|
prompt_template.append( |
|
(message["role"], message["content"]) |
|
) |
|
|
|
|
|
|
|
prompt_template.append( |
|
("assistant", "<RESPONSE>:") |
|
) |
|
|
|
return prompt_template |
|
|
|
|
|
async def generate_answer(message_history: list[ChatMessage]): |
|
|
|
|
|
n_of_best_results = 4 |
|
retriever = db.as_retriever( |
|
search_type="similarity", search_kwargs={"k": n_of_best_results}) |
|
|
|
print("generating prompt") |
|
prompt = generate_prompt(message_history, max_history=5) |
|
print("prompt is ready") |
|
|
|
|
|
|
|
retrieval_chain = ( |
|
{"context": retriever, "user_input": RunnablePassthrough()} |
|
| prompt |
|
| llm |
|
| StrOutputParser() |
|
) |
|
|
|
|
|
user_input = message_history[-1]["content"] |
|
print("invoking") |
|
response = retrieval_chain.invoke(user_input) |
|
print("response recieved from invoke") |
|
|
|
|
|
print("=====raw response=====") |
|
print(response) |
|
|
|
|
|
|
|
parsed_answer = response.split( |
|
str(user_input) |
|
).pop().split("<RESPONSE>:", 1).pop().strip() |
|
|
|
print(repr(parsed_answer)) |
|
|
|
|
|
|
|
return parsed_answer.replace("\n\n", "<br>") |
|
|
|
|
|
def update_chat(user_message: str, history: list): |
|
return "", history + [{"role": "user", "content": user_message}] |
|
|
|
|
|
async def handle_conversation( |
|
history: list[ChatMessage], |
|
characters_per_second=80 |
|
): |
|
bot_message = await generate_answer(history) |
|
new_message: ChatMessage = {"role": "assistant", |
|
"metadata": {"title": None}, |
|
"content": ""} |
|
history.append(new_message) |
|
for character in bot_message: |
|
history[-1]['content'] += character |
|
yield history |
|
await sleep(1 / characters_per_second) |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks() as interface: |
|
gr.Markdown("# 📄 Thesizer: HAMK Thesis Assistant") |
|
gr.Markdown("Ask for help with authoring the HAMK thesis!") |
|
|
|
gr.Markdown("## Chat interface") |
|
|
|
with gr.Column(): |
|
chatbot = gr.Chatbot(type="messages") |
|
|
|
with gr.Row(): |
|
user_input = gr.Textbox( |
|
label="You:", |
|
placeholder="Type your message here...", |
|
show_label=False |
|
) |
|
send_button = gr.Button("Send") |
|
|
|
|
|
send_button.click( |
|
fn=update_chat, |
|
inputs=[user_input, chatbot], |
|
outputs=[user_input, chatbot], |
|
queue=False |
|
).then( |
|
fn=handle_conversation, |
|
inputs=chatbot, |
|
outputs=chatbot |
|
) |
|
|
|
|
|
user_input.submit( |
|
fn=update_chat, |
|
inputs=[user_input, chatbot], |
|
outputs=[user_input, chatbot], |
|
queue=False |
|
).then( |
|
fn=handle_conversation, |
|
inputs=chatbot, |
|
outputs=chatbot |
|
) |
|
|
|
return interface |
|
|
|
|
|
if __name__ == "__main__": |
|
create_interface().launch() |
|
|