File size: 3,468 Bytes
74cfe0f 77f7b3a af4b5bf 6270b8e a62c597 7bb2dd6 4f81850 74cfe0f 77f7b3a af4b5bf 8e933e2 af4b5bf 8e933e2 77f7b3a 7bb2dd6 af4b5bf 77f7b3a af4b5bf 77f7b3a af4b5bf 8e933e2 af4b5bf 7bb2dd6 bc6e181 708434b 8a705ab 7bb2dd6 8a705ab 8e933e2 7bb2dd6 8e933e2 af4b5bf 77f7b3a 7bb2dd6 8e933e2 af4b5bf 77f7b3a 708434b 4f81850 4a53428 7bb2dd6 bc6e181 708434b 7bb2dd6 708434b 4f81850 bc6e181 77f7b3a 708434b 8e933e2 af4b5bf 8a705ab 8e933e2 af4b5bf 8a705ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import gradio as gr
from transformers import AutoTokenizer
import onnxruntime as ort
import numpy as np
import string
from huggingface_hub import InferenceClient
import os
# Initialize Qwen client
qwen_client = InferenceClient(os.environ.get("HF_TOKEN"))
# Model and ONNX setup
HG_MODEL = "livekit/turn-detector"
ONNX_FILENAME = "model_quantized.onnx"
PUNCS = string.punctuation.replace("'", "")
MAX_HISTORY = 4
MAX_HISTORY_TOKENS = 512
EOU_THRESHOLD = 0.5
# Initialize ONNX model
tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
def softmax(logits):
exp_logits = np.exp(logits - np.max(logits))
return exp_logits / np.sum(exp_logits)
def normalize_text(text):
def strip_puncs(text):
return text.translate(str.maketrans("", "", PUNCS))
return " ".join(strip_puncs(text).lower().split())
def format_chat_ctx(chat_ctx):
new_chat_ctx = []
for msg in chat_ctx:
if msg["role"] in ("user", "assistant"):
content = normalize_text(msg["content"])
if content:
msg["content"] = content
new_chat_ctx.append(msg)
convo_text = tokenizer.apply_chat_template(
new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
)
ix = convo_text.rfind("<|im_end|>")
return convo_text[:ix]
def calculate_eou(chat_ctx, session):
formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:])
inputs = tokenizer(
formatted_text,
return_tensors="np",
truncation=True,
max_length=MAX_HISTORY_TOKENS,
)
input_ids = np.array(inputs["input_ids"], dtype=np.int64)
outputs = session.run(["logits"], {"input_ids": input_ids})
logits = outputs[0][0, -1, :]
probs = softmax(logits)
eou_token_id = tokenizer.encode("<|im_end|>")[-1]
return probs[eou_token_id]
def respond(
message,
history: list[tuple[str, str]],
max_tokens=2048,
temperature=0.6,
top_p=0.95,
):
messages = [{"role": "system", "content": os.environ.get("CHARACTER_DESC", "You are a helpful assistant.")}]
for val in history[-MAX_HISTORY:]:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
eou_prob = calculate_eou(messages, onnx_session)
if eou_prob < EOU_THRESHOLD:
yield "[Wait... Keep typing...]"
return
# Generate raw response without any processing
full_response = ""
stream = qwen_client.chat.completions.create(
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
messages=messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
stream=True
)
for chunk in stream:
full_response += chunk.choices[0].delta.content
yield chunk.choices[0].delta.content # Send raw unmodified response to Gradio
# This will match both console and Gradio output
# Create Gradio interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Slider(1, 4096, value=256, label="Max Tokens"),
gr.Slider(0.1, 4.0, value=0.7, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.95, label="Top-p"),
]
)
if __name__ == "__main__":
demo.launch() |