File size: 2,597 Bytes
4b2da06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76c8687
 
 
 
 
 
4b2da06
 
 
 
 
 
76c8687
4b2da06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftConfig, PeftModel
import warnings

warnings.filterwarnings("ignore")


PEFT_MODEL = "givyboy/phi-2-finetuned-mental-health-conversational"

SYSTEM_PROMPT = """Answer the following question truthfully.
  If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'.
  If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'."""

USER_PROMPT = lambda x: f"""<HUMAN>: {x}\n<ASSISTANT>: """
ADD_RESPONSE = lambda x, y: f"""<HUMAN>: {x}\n<ASSISTANT>: {y}"""

# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_compute_dtype=torch.float16,
# )

config = PeftConfig.from_pretrained(PEFT_MODEL)

peft_base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    return_dict=True,
    # quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL)

peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
peft_tokenizer.pad_token = peft_tokenizer.eos_token

pipeline = transformers.pipeline(
    "text-generation",
    model=peft_model,
    tokenizer=peft_tokenizer,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
)


def format_message(message: str, history: list[str], memory_limit: int = 3) -> str:
    if len(history) > memory_limit:
        history = history[-memory_limit:]

    if len(history) == 0:
        return f"{SYSTEM_PROMPT}\n{USER_PROMPT(message)}"

    formatted_message = f"{SYSTEM_PROMPT}\n{ADD_RESPONSE(history[0][0], history[0][1])}"

    for msg, ans in history[1:]:
        formatted_message += f"\n{ADD_RESPONSE(msg, ans)}"

    formatted_message += f"\n{USER_PROMPT(message)}"
    return formatted_message


def get_model_response(message: str, history: list[str]) -> str:
    formatted_message = format_message(message, history)
    sequences = pipeline(
        formatted_message,
        do_sample=True,
        top_k=10,
        num_return_sequences=1,
        eos_token_id=peft_tokenizer.eos_token_id,
        max_length=600,
        truncation=True,
    )[0]
    print(sequences["generated_text"])
    output = sequences["generated_text"].split("<ASSISTANT>:")[-1].strip()
    # print(f"Response: {output}")
    return output


gr.ChatInterface(fn=get_model_response).launch()