File size: 5,410 Bytes
a769908
 
 
 
 
 
 
3c05615
 
 
a769908
 
 
 
 
 
 
 
 
 
 
00407e6
 
a769908
 
 
 
00407e6
a769908
 
 
 
 
 
00407e6
a769908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00407e6
a769908
 
 
 
00407e6
 
a769908
 
 
 
 
 
 
3c05615
a769908
3c05615
a769908
 
 
00407e6
 
a769908
 
 
 
 
 
 
 
3c05615
a769908
 
 
3c05615
00407e6
 
 
 
a769908
00407e6
3c05615
a769908
00407e6
a769908
 
 
 
 
 
00407e6
3c05615
a769908
 
3c05615
a769908
 
 
00407e6
b1aea93
3c05615
00407e6
a769908
 
 
3c05615
a769908
 
 
 
3c05615
 
a769908
3c05615
a769908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c05615
 
 
 
 
a769908
3c05615
 
 
 
 
00407e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import sys
import onnxruntime as ort
import numpy as np
import string

# Transformers, HuggingFace Hub, and Gradio
from transformers import AutoTokenizer
import gradio as gr
from huggingface_hub import InferenceClient

# ------------------------------------------------
# Turn Detector Configuration
# ------------------------------------------------
HG_MODEL = "livekit/turn-detector"       # or your HF model repo
ONNX_FILENAME = "model_quantized.onnx"   # path to your ONNX file
MAX_HISTORY_TOKENS = 512
PUNCS = string.punctuation.replace("'", "")

# ------------------------------------------------
# Utility functions
# ------------------------------------------------


def softmax(logits: np.ndarray) -> np.ndarray:
    exp_logits = np.exp(logits - np.max(logits))
    return exp_logits / np.sum(exp_logits)


def normalize_text(text: str) -> str:
    """Lowercase, strip punctuation (except single quotes), and collapse whitespace."""
    def strip_puncs(text_in):
        return text_in.translate(str.maketrans("", "", PUNCS))
    return " ".join(strip_puncs(text).lower().split())


def calculate_eou(chat_ctx, session, tokenizer) -> float:
    """
    Given a conversation context (list of dicts with 'role' and 'content'),
    returns the probability that the user is finished speaking.
    """
    # Collect normalized messages from 'user' or 'assistant' roles
    normalized_ctx = []
    for msg in chat_ctx:
        if msg["role"] in ("user", "assistant"):
            content = normalize_text(msg["content"])
            if content:
                normalized_ctx.append(content)

    # Join them into one input string
    text = " ".join(normalized_ctx)
    inputs = tokenizer(
        text,
        return_tensors="np",
        truncation=True,
        max_length=MAX_HISTORY_TOKENS,
    )

    input_ids = np.array(inputs["input_ids"], dtype=np.int64)
    # Run inference
    outputs = session.run(["logits"], {"input_ids": input_ids})
    logits = outputs[0][0, -1, :]

    # Softmax over logits
    probs = softmax(logits)
    # The ID for the <|im_end|> special token
    eou_token_id = tokenizer.encode("<|im_end|>")[-1]
    return probs[eou_token_id]


# ------------------------------------------------
# Load ONNX session & tokenizer once
# ------------------------------------------------
print("Loading ONNX model session...")
onnx_session = ort.InferenceSession(
    ONNX_FILENAME, providers=["CPUExecutionProvider"])

print("Loading tokenizer...")
turn_detector_tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)

# ------------------------------------------------
# HF InferenceClient for text generation (example)
# ------------------------------------------------
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
# Adjust above to any other endpoint that suits your use case.

# ------------------------------------------------
# Gradio Chat Handler
# ------------------------------------------------


def respond(message, history, system_message, max_tokens, temperature, top_p):
    """
    This function is called on each new user message in the ChatInterface.
      - 'message' is the new user input
      - 'history' is a list of (user, assistant) tuples
      - 'system_message' is from the system Textbox
      - max_tokens, temperature, top_p come from the Sliders
    """

    # 1) Build a list of messages in the OpenAI-style format:
    #    [{'role': 'system', 'content': ...},
    #     {'role': 'user', 'content': ...},  ...]

    messages = [
        {"role": "user",
         "content": message}
    ]
    if system_message.strip():
        messages.insert(0, {"role": "system", "content": system_message})

    # history is a list of tuples: [(user1, assistant1), (user2, assistant2), ...]
    """ for user_text, assistant_text in history:
        if user_text:
            messages.append({"role": "user", "content": user_text})
        if assistant_text:
            messages.append({"role": "assistant", "content": assistant_text})

    # Append the new user message
    messages.append({"role": "user", "content": message}) """

    # 2) Calculate EOU probability on the entire conversation
    eou_prob = calculate_eou(messages, onnx_session, turn_detector_tokenizer)

    # 3) Generate the assistant response from your HF model.
    #    (This code streams token-by-token.)
    response = ""

    yield f"[EOU Probability: {eou_prob:.4f}]"


# ------------------------------------------------
# Gradio ChatInterface
# ------------------------------------------------
"""
This ChatInterface will have:
  - A chat box
  - A system message textbox
  - 3 sliders for max_tokens, temperature, and top_p
"""
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Textbox(
            value="You are a friendly Chatbot.",
            label="System message",
            lines=2
        ),
        gr.Slider(
            minimum=1,
            maximum=2048,
            value=512,
            step=1,
            label="Max new tokens"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=4.0,
            value=0.7,
            step=0.1,
            label="Temperature"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)"
        ),
    ],
)

if __name__ == "__main__":
    demo.launch()