File size: 3,481 Bytes
74cfe0f
77f7b3a
af4b5bf
 
 
 
6270b8e
a62c597
7bb2dd6
141dc74
74cfe0f
77f7b3a
 
af4b5bf
 
8e933e2
af4b5bf
8e933e2
77f7b3a
 
7bb2dd6
 
af4b5bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77f7b3a
 
 
 
af4b5bf
77f7b3a
 
 
af4b5bf
8e933e2
af4b5bf
 
 
 
 
 
 
 
 
 
 
 
 
7bb2dd6
 
 
8a705ab
 
 
7bb2dd6
8a705ab
8e933e2
 
7bb2dd6
 
 
 
8e933e2
af4b5bf
77f7b3a
7bb2dd6
 
8e933e2
af4b5bf
77f7b3a
8e933e2
 
4a53428
 
 
7bb2dd6
 
 
 
 
8a705ab
8e933e2
 
7bb2dd6
8e933e2
77f7b3a
8e933e2
af4b5bf
 
8a705ab
8e933e2
 
 
 
af4b5bf
 
 
8a705ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
from transformers import AutoTokenizer
import onnxruntime as ort
import numpy as np
import string
from huggingface_hub import InferenceClient
import os

# Initialize Qwen client
qwen_client = InferenceClient("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")

# Model and ONNX setup
HG_MODEL = "livekit/turn-detector"
ONNX_FILENAME = "model_quantized.onnx"
PUNCS = string.punctuation.replace("'", "")
MAX_HISTORY = 4
MAX_HISTORY_TOKENS = 512
EOU_THRESHOLD = 0.5

# Initialize ONNX model
tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])

def softmax(logits):
    exp_logits = np.exp(logits - np.max(logits))
    return exp_logits / np.sum(exp_logits)

def normalize_text(text):
    def strip_puncs(text):
        return text.translate(str.maketrans("", "", PUNCS))
    return " ".join(strip_puncs(text).lower().split())

def format_chat_ctx(chat_ctx):
    new_chat_ctx = []
    for msg in chat_ctx:
        if msg["role"] in ("user", "assistant"):
            content = normalize_text(msg["content"])
            if content:
                msg["content"] = content
                new_chat_ctx.append(msg)
    convo_text = tokenizer.apply_chat_template(
        new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
    )
    ix = convo_text.rfind("<|im_end|>")
    return convo_text[:ix]

def calculate_eou(chat_ctx, session):
    formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:])
    inputs = tokenizer(
        formatted_text,
        return_tensors="np",
        truncation=True,
        max_length=MAX_HISTORY_TOKENS,
    )
    input_ids = np.array(inputs["input_ids"], dtype=np.int64)
    outputs = session.run(["logits"], {"input_ids": input_ids})
    logits = outputs[0][0, -1, :]
    probs = softmax(logits)
    eou_token_id = tokenizer.encode("<|im_end|>")[-1]
    return probs[eou_token_id]

def respond(
    message,
    history: list[tuple[str, str]],
    max_tokens=256,
    temperature=0.7,
    top_p=0.95,
):
    messages = [{"role": "system", "content": os.environ.get("CHARACTER_DESC", "You are a helpful assistant.")}]
    
    for val in history[-MAX_HISTORY:]:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})
    
    messages.append({"role": "user", "content": message})

    eou_prob = calculate_eou(messages, onnx_session)
    if eou_prob < EOU_THRESHOLD:
        yield "[Wait... Keep typing...]"
        return

    # Generate response incrementally and yield each token
    accumulated_response = ""
    # Corrected the chat completions method call
    for chunk in qwen_client.chat.completions.create(
        messages=messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = chunk.choices[0].delta.content or ""
        accumulated_response += token
        yield accumulated_response  # Yield accumulated response for live updates

    print(f"Final response: {accumulated_response}")

# Create Gradio interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Slider(1, 4096, value=256, label="Max Tokens"),
        gr.Slider(0.1, 4.0, value=0.7, label="Temperature"),
        gr.Slider(0.1, 1.0, value=0.95, label="Top-p"),
    ]
)

if __name__ == "__main__":
    demo.launch()