File size: 4,048 Bytes
74cfe0f
fc44b6a
af4b5bf
 
 
fc44b6a
6270b8e
a62c597
fc44b6a
74cfe0f
fc44b6a
77f7b3a
af4b5bf
 
fc44b6a
af4b5bf
fc44b6a
77f7b3a
fc44b6a
 
7bb2dd6
af4b5bf
47cf318
af4b5bf
 
 
 
47cf318
af4b5bf
 
 
 
 
47cf318
af4b5bf
 
 
 
 
 
77f7b3a
 
47cf318
 
fc44b6a
77f7b3a
af4b5bf
47cf318
 
77f7b3a
 
 
47cf318
af4b5bf
fc44b6a
 
af4b5bf
 
 
 
 
 
 
 
 
fc44b6a
af4b5bf
 
47cf318
 
7bb2dd6
 
 
fc44b6a
 
 
7bb2dd6
47cf318
 
 
da959de
7bb2dd6
 
 
 
47cf318
fc44b6a
af4b5bf
77f7b3a
47cf318
533370e
 
47cf318
533370e
 
 
 
fc44b6a
 
 
 
 
5af9f68
c5662e3
1fb128b
fc44b6a
 
 
 
 
 
 
 
47cf318
7bb2dd6
47cf318
fc44b6a
 
 
 
 
 
 
 
 
 
 
 
 
af4b5bf
 
47cf318
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
from transformers import AutoTokenizer
import onnxruntime as ort
import numpy as np
import string
from huggingface_hub import InferenceClient
import os

client = InferenceClient(api_key=os.environ.get('HF_TOKEN'))

# Model and ONNX setup
HG_MODEL = "livekit/turn-detector"
ONNX_FILENAME = "model_quantized.onnx"
PUNCS = string.punctuation.replace("'", "")
MAX_HISTORY = 4  # Adjusted to use the last 4 messages
MAX_HISTORY_TOKENS = 512
EOU_THRESHOLD = 0.5  # Updated threshold to match original

# Initialize ONNX model
tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])

# Softmax function
def softmax(logits):
    exp_logits = np.exp(logits - np.max(logits))
    return exp_logits / np.sum(exp_logits)

# Normalize text
def normalize_text(text):
    def strip_puncs(text):
        return text.translate(str.maketrans("", "", PUNCS))
    return " ".join(strip_puncs(text).lower().split())

# Format chat context
def format_chat_ctx(chat_ctx):
    new_chat_ctx = []
    for msg in chat_ctx:
        if msg["role"] in ("user", "assistant"):
            content = normalize_text(msg["content"])
            if content:
                msg["content"] = content
                new_chat_ctx.append(msg)

    # Tokenize with chat template
    convo_text = tokenizer.apply_chat_template(
        new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
    )

    # Remove EOU token from the current utterance
    ix = convo_text.rfind("<|im_end|>")
    return convo_text[:ix]

# Calculate EOU probability
def calculate_eou(chat_ctx, session):
    formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:])  # Use the last 4 messages
    inputs = tokenizer(
        formatted_text,
        return_tensors="np",
        truncation=True,
        max_length=MAX_HISTORY_TOKENS,
    )
    input_ids = np.array(inputs["input_ids"], dtype=np.int64)
    outputs = session.run(["logits"], {"input_ids": input_ids})
    logits = outputs[0][0, -1, :]
    probs = softmax(logits)
    eou_token_id = tokenizer.encode("<|im_end|>")[-1]
    return probs[eou_token_id]


# Respond function
def respond(
    message,
    history: list[tuple[str, str]],
    max_tokens,
    temperature,
    top_p,
):
    # Keep the last 4 conversation pairs (user-assistant)
    messages = [{"role": "system", "content": os.environ.get("CHARACTER_DESC")}]

    for val in history[-20:]:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    # Add the new user message to the context
    messages.append({"role": "user", "content": message})

    # Calculate EOU probability
    # eou_prob = calculate_eou(messages, onnx_session)
    # print(f"EOU Probability: {eou_prob}")  # Debug output

    # # If EOU is below the threshold, ask for more input
    # if eou_prob < EOU_THRESHOLD:
    #     yield "[Waiting for user to continue input...]"
    #     return


    stream = client.chat.completions.create(
        model=os.environ.get('MODEL_ID'),
        messages=messages,
        temperature = 0.6,
        max_tokens= 2048,
        top_p = 0.9,
        stream=True
    )

    bot_response = ""
    for chunk in stream:
        bot_response += chunk.choices[0].delta.content
        yield bot_response
    


# Gradio interface
demo = gr.ChatInterface(
    respond,
    # additional_inputs=[
    #     # Commented out to disable user modification of the system message
    #     # gr.Textbox(value="You are an assistant.", label="System message"),
    #     gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Max new tokens"),
    #     gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
    #     gr.Slider(
    #         minimum=0.1,
    #         maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
    #     ),
    # ],
)

if __name__ == "__main__":
    demo.launch()