gunjanjoshi's picture
Added accelerate
c3f1f3a
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, PeftModel
import gradio as gr
# Loading PEFT model
PEFT_MODEL = "gunjanjoshi/llama2-7b-sharded-bf16-finetuned-mental-health-conversational"
config = PeftConfig.from_pretrained(PEFT_MODEL)
peft_base_model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
device_map="cpu", # Ensure this is set to CPU
trust_remote_code=True,
)
peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL)
peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
peft_tokenizer.pad_token = peft_tokenizer.eos_token
system_message = """You are a helpful and and truthful psychology and psychotherapy assistant. Your primary role is to provide empathetic, understanding, and non-judgmental responses to users seeking emotional and psychological support.
Always respond with empathy and demonstrate active listening; try to focus on the user. Your responses should reflect that you understand the user's feelings and concerns. If a user expresses thoughts of self-harm, suicide, or harm to others, prioritize their safety.
Encourage them to seek immediate professional help and provide emergency contact numbers when appropriate. You are not a licensed medical professional. Do not diagnose or prescribe treatments.
Instead, encourage users to consult with a licensed therapist or medical professional for specific advice. Avoid taking sides or expressing personal opinions. Your role is to provide a safe space for users to share and reflect.
Remember, your goal is to provide a supportive and understanding environment for users to share their feelings and concerns. Always prioritize their well-being and safety."""
def generate_response(user_input):
formatted = f"<s>[INST] <<SYS>>{system_message}<</SYS>>{user_input} [/INST]"
input_ids = peft_tokenizer(formatted, return_tensors="pt", truncation=True, max_length=1024).input_ids
outputs = peft_model.generate(input_ids=input_ids, do_sample=True, top_p=0.9, temperature=0.95, max_length=2048)
translated_output = peft_tokenizer.batch_decode(outputs.detach().numpy(), skip_special_tokens=True)[0][len(formatted)-1:]
return translated_output
with gr.Blocks() as demo:
gr.Markdown("# Mental Health Chatbot")
with gr.Row():
user_input = gr.Textbox(lines=5, label="Input:")
response_output = gr.HTML(label="Assistant")
submit_button = gr.Button("Submit")
submit_button.click(fn=generate_response, inputs=user_input, outputs=response_output)
demo.launch()