metadata
base_model: meta-llama/Llama-3.1-8B-Instruct
library_name: peft
license: openrail
datasets:
- navidmadani/extended_esc
language:
- en
Overview
The fine-tuned model for Steering Conversational Large Language Models for Long Emotional Support Conversations paper. Code for training and inference can be found in our github repository.
Running the model
You can find the scripts to chat with the model in the github repo. The following code shows a sample inference using the model. You need to import the list of strategies and their description from our github as follows:
from prompting.llama_prompt import modified_extes_support_strategies
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def get_sys_msg_with_strategy(strategy):
if strategy is None:
return "You are a helpful, precise, and accurate emotional support expert."
description = modified_extes_support_strategies.get(strategy, "No description available")
return (f"You are a helpful and caring AI, which is an expert in emotional support. "
f"A user has come to you with emotional challenges, distress, or anxiety. "
f"Use the \"{strategy}\" strategy ({description}) for answering the user. "
"Make your response short and to the point.")
cur_strategy = "Clarification"
messages = [
{'role': 'system', 'content': get_sys_msg_with_strategy(cur_strategy)},
{'role': 'user', 'content': "Hello! How's it going?"},
{'role': 'assistant', 'content': 'Hello. How can I assist you today?'},
{'role': 'user', 'content': "I'm feeling a bit down today."},
]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "navidmadani/esconv_sra_llama3_8b"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = model.to(device)
model.eval()
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
input_t = torch.LongTensor([input_ids]).to(device)
output = model.generate(input_t, max_new_tokens=512)[:, input_t.shape[1]:]
resp = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
print(resp)