File size: 3,468 Bytes
74cfe0f
77f7b3a
af4b5bf
 
 
 
6270b8e
a62c597
7bb2dd6
4f81850
74cfe0f
77f7b3a
 
af4b5bf
 
8e933e2
af4b5bf
8e933e2
77f7b3a
 
7bb2dd6
 
af4b5bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77f7b3a
 
 
 
af4b5bf
77f7b3a
 
 
af4b5bf
8e933e2
af4b5bf
 
 
 
 
 
 
 
 
 
 
 
 
7bb2dd6
 
 
bc6e181
708434b
8a705ab
7bb2dd6
8a705ab
8e933e2
 
7bb2dd6
 
 
 
8e933e2
af4b5bf
77f7b3a
7bb2dd6
 
8e933e2
af4b5bf
77f7b3a
708434b
 
 
4f81850
4a53428
7bb2dd6
 
 
 
bc6e181
708434b
7bb2dd6
708434b
4f81850
bc6e181
77f7b3a
708434b
8e933e2
af4b5bf
 
8a705ab
8e933e2
 
 
 
af4b5bf
 
 
8a705ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
from transformers import AutoTokenizer
import onnxruntime as ort
import numpy as np
import string
from huggingface_hub import InferenceClient
import os

# Initialize Qwen client
qwen_client = InferenceClient(os.environ.get("HF_TOKEN"))

# Model and ONNX setup
HG_MODEL = "livekit/turn-detector"
ONNX_FILENAME = "model_quantized.onnx"
PUNCS = string.punctuation.replace("'", "")
MAX_HISTORY = 4
MAX_HISTORY_TOKENS = 512
EOU_THRESHOLD = 0.5

# Initialize ONNX model
tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])

def softmax(logits):
    exp_logits = np.exp(logits - np.max(logits))
    return exp_logits / np.sum(exp_logits)

def normalize_text(text):
    def strip_puncs(text):
        return text.translate(str.maketrans("", "", PUNCS))
    return " ".join(strip_puncs(text).lower().split())

def format_chat_ctx(chat_ctx):
    new_chat_ctx = []
    for msg in chat_ctx:
        if msg["role"] in ("user", "assistant"):
            content = normalize_text(msg["content"])
            if content:
                msg["content"] = content
                new_chat_ctx.append(msg)
    convo_text = tokenizer.apply_chat_template(
        new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
    )
    ix = convo_text.rfind("<|im_end|>")
    return convo_text[:ix]

def calculate_eou(chat_ctx, session):
    formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:])
    inputs = tokenizer(
        formatted_text,
        return_tensors="np",
        truncation=True,
        max_length=MAX_HISTORY_TOKENS,
    )
    input_ids = np.array(inputs["input_ids"], dtype=np.int64)
    outputs = session.run(["logits"], {"input_ids": input_ids})
    logits = outputs[0][0, -1, :]
    probs = softmax(logits)
    eou_token_id = tokenizer.encode("<|im_end|>")[-1]
    return probs[eou_token_id]

def respond(
    message,
    history: list[tuple[str, str]],
	max_tokens=2048,
    temperature=0.6,
    top_p=0.95,
):
    messages = [{"role": "system", "content": os.environ.get("CHARACTER_DESC", "You are a helpful assistant.")}]
    
    for val in history[-MAX_HISTORY:]:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})
    
    messages.append({"role": "user", "content": message})

    eou_prob = calculate_eou(messages, onnx_session)
    if eou_prob < EOU_THRESHOLD:
        yield "[Wait... Keep typing...]"
        return

    # Generate raw response without any processing
    full_response = ""
    stream = qwen_client.chat.completions.create(
        model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", 
        messages=messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
        stream=True
    )

    for chunk in stream:
        full_response += chunk.choices[0].delta.content
        yield chunk.choices[0].delta.content  # Send raw unmodified response to Gradio

    # This will match both console and Gradio output
# Create Gradio interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Slider(1, 4096, value=256, label="Max Tokens"),
        gr.Slider(0.1, 4.0, value=0.7, label="Temperature"),
        gr.Slider(0.1, 1.0, value=0.95, label="Top-p"),
    ]
)

if __name__ == "__main__":
    demo.launch()