File size: 4,468 Bytes
af4b5bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
from transformers import AutoTokenizer
import onnxruntime as ort
import numpy as np
import string
from huggingface_hub import InferenceClient

# Initialize Qwen client
qwen_client = InferenceClient("EVA-UNIT-01/EVA-Qwen2.5-1.5B-v0.0")

# Model and ONNX setup
HG_MODEL = "livekit/turn-detector"
ONNX_FILENAME = "model_quantized.onnx"
PUNCS = string.punctuation.replace("'", "")
MAX_HISTORY = 4  # Adjusted to use the last 4 messages
MAX_HISTORY_TOKENS = 512
EOU_THRESHOLD = 0.5  # Updated threshold to match original

# Initialize ONNX model
tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])

# Softmax function
def softmax(logits):
    exp_logits = np.exp(logits - np.max(logits))
    return exp_logits / np.sum(exp_logits)

# Normalize text
def normalize_text(text):
    def strip_puncs(text):
        return text.translate(str.maketrans("", "", PUNCS))
    return " ".join(strip_puncs(text).lower().split())

# Format chat context
def format_chat_ctx(chat_ctx):
    new_chat_ctx = []
    for msg in chat_ctx:
        if msg["role"] in ("user", "assistant"):
            content = normalize_text(msg["content"])
            if content:
                msg["content"] = content
                new_chat_ctx.append(msg)

    # Tokenize with chat template
    convo_text = tokenizer.apply_chat_template(
        new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
    )

    # Remove EOU token from the current utterance
    ix = convo_text.rfind("<|im_end|>")
    return convo_text[:ix]

# Calculate EOU probability
def calculate_eou(chat_ctx, session):
    formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:])  # Use the last 4 messages
    inputs = tokenizer(
        formatted_text,
        return_tensors="np",
        truncation=True,
        max_length=MAX_HISTORY_TOKENS,
    )
    input_ids = np.array(inputs["input_ids"], dtype=np.int64)
    outputs = session.run(["logits"], {"input_ids": input_ids})
    logits = outputs[0][0, -1, :]
    probs = softmax(logits)
    eou_token_id = tokenizer.encode("<|im_end|>")[-1]
    return probs[eou_token_id]

# Read system message from file
with open("character/herta.txt", "r") as f:
    system_message = f.read()

# Respond function
def respond(
    message,
    history: list[tuple[str, str]],
    max_tokens,
    temperature,
    top_p,
):
    # Keep the last 4 conversation pairs (user-assistant)
    messages = [{"role": "system", "content": system_message}]

    for val in history[-10:]:  # Only use the last 4 pairs
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    # Add the new user message to the context
    messages.append({"role": "user", "content": message})

    # Calculate EOU probability
    eou_prob = calculate_eou(messages, onnx_session)
    print(f"EOU Probability: {eou_prob}")  # Debug output

    # If EOU is below the threshold, ask for more input
    if eou_prob < EOU_THRESHOLD:
        yield "[Waiting for user to continue input...]"
        return

    # Generate response with Qwen
    response = ""
    for message in qwen_client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content
        response += token
        yield response
    
    print(f"Generated response: {response}")


# Gradio interface
demo = gr.ChatInterface(
    respond,
    # additional_inputs=[
    #     # Commented out to disable user modification of the system message
    #     # gr.Textbox(value="You are an assistant.", label="System message"),
    #     gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Max new tokens"),
    #     gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
    #     gr.Slider(
    #         minimum=0.1,
    #         maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
    #     ),
    # ],
theme = gr.themes.Default().set(
    button_primary_background_fill="#FF0000",
    button_primary_background_fill_dark="#AAAAAA",
    button_primary_border="*button_primary_background_fill",
    button_primary_border_dark="*button_primary_background_fill_dark",
)

)

if __name__ == "__main__":
    demo.launch()