testingoriginal / app.py
sayeed99's picture
Update app.py
49c812e verified
raw
history blame
3.82 kB
import gradio as gr
import os
import json
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "sayeed99/meta-llama3-8b-xtherapy-bnb-4bit", # YOUR MODEL YOU USED FOR TRAINING
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
# alpaca_prompt = You MUST copy from above!
formatted_string = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>You are Anna, a helpful AI assistant for mental therapy assistance developed by a team of developers at xnetics. If you do not know the user's name, start by asking the name. If you do not know details about user, ask them."
# Function to format the string
def format_chat_data(data):
formatted_output = []
if data["role"] == "assistant":
value = data["content"]
formatted_output.append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>" + value)
else:
formatted_output.append("<|eot_id|><|start_header_id|>user<|end_header_id|>" + data["content"])
return "".join(formatted_output)
def formatting_prompts_funcV2(examples):
conversations = examples
text = formatted_string
for conversation in conversations:
# Must add EOS_TOKEN, otherwise your generation will go on forever!
text = text + format_chat_data(conversation)
return text
def get_last_assistant_message(text):
# Split the text by 'assistant' to isolate assistant's messages
parts = text.split('<|start_header_id|>assistant<|end_header_id|>')
# The last part is the last assistant message
# Remove leading/trailing whitespace and return
last_message = parts[-1].strip()
last_message = cleanup(last_message)
return last_message
def cleanup(text):
# Check if the string ends with 'eot_id'
if text.endswith('<|eot_id|>'):
# Remove the last 10 characters
return text[:-10]
else:
return text
# Define a function to handle the conversation and update the session
def handle_conversation(user_input):
historyPrompt = formatting_prompts_funcV2(user_input)
historyPrompt = historyPrompt + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
inputs = tokenizer(
[
historyPrompt
], return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=512, use_cache=True)
decoded_outputs = tokenizer.batch_decode(outputs)[0]
# decoded_outputs = "Hello Welcome"
last_message = get_last_assistant_message(decoded_outputs)
# Return the AI response
return last_message
def complete(messages):
ai_response = handle_conversation(messages)
return ai_response
def predict(message, history):
history_openai_format = []
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human })
history_openai_format.append({"role": "assistant", "content":assistant})
history_openai_format.append({"role": "user", "content": message})
response = complete(history_openai_format)
print(response)
partial_message = ""
for chunk in response:
if chunk is not None:
partial_message = partial_message + chunk
yield partial_message
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
predict
)
if __name__ == "__main__":
demo.launch()