File size: 3,481 Bytes
74cfe0f 77f7b3a af4b5bf 6270b8e a62c597 7bb2dd6 141dc74 74cfe0f 77f7b3a af4b5bf 8e933e2 af4b5bf 8e933e2 77f7b3a 7bb2dd6 af4b5bf 77f7b3a af4b5bf 77f7b3a af4b5bf 8e933e2 af4b5bf 7bb2dd6 8a705ab 7bb2dd6 8a705ab 8e933e2 7bb2dd6 8e933e2 af4b5bf 77f7b3a 7bb2dd6 8e933e2 af4b5bf 77f7b3a 8e933e2 4a53428 7bb2dd6 8a705ab 8e933e2 7bb2dd6 8e933e2 77f7b3a 8e933e2 af4b5bf 8a705ab 8e933e2 af4b5bf 8a705ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import gradio as gr
from transformers import AutoTokenizer
import onnxruntime as ort
import numpy as np
import string
from huggingface_hub import InferenceClient
import os
# Initialize Qwen client
qwen_client = InferenceClient("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
# Model and ONNX setup
HG_MODEL = "livekit/turn-detector"
ONNX_FILENAME = "model_quantized.onnx"
PUNCS = string.punctuation.replace("'", "")
MAX_HISTORY = 4
MAX_HISTORY_TOKENS = 512
EOU_THRESHOLD = 0.5
# Initialize ONNX model
tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
def softmax(logits):
exp_logits = np.exp(logits - np.max(logits))
return exp_logits / np.sum(exp_logits)
def normalize_text(text):
def strip_puncs(text):
return text.translate(str.maketrans("", "", PUNCS))
return " ".join(strip_puncs(text).lower().split())
def format_chat_ctx(chat_ctx):
new_chat_ctx = []
for msg in chat_ctx:
if msg["role"] in ("user", "assistant"):
content = normalize_text(msg["content"])
if content:
msg["content"] = content
new_chat_ctx.append(msg)
convo_text = tokenizer.apply_chat_template(
new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
)
ix = convo_text.rfind("<|im_end|>")
return convo_text[:ix]
def calculate_eou(chat_ctx, session):
formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:])
inputs = tokenizer(
formatted_text,
return_tensors="np",
truncation=True,
max_length=MAX_HISTORY_TOKENS,
)
input_ids = np.array(inputs["input_ids"], dtype=np.int64)
outputs = session.run(["logits"], {"input_ids": input_ids})
logits = outputs[0][0, -1, :]
probs = softmax(logits)
eou_token_id = tokenizer.encode("<|im_end|>")[-1]
return probs[eou_token_id]
def respond(
message,
history: list[tuple[str, str]],
max_tokens=256,
temperature=0.7,
top_p=0.95,
):
messages = [{"role": "system", "content": os.environ.get("CHARACTER_DESC", "You are a helpful assistant.")}]
for val in history[-MAX_HISTORY:]:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
eou_prob = calculate_eou(messages, onnx_session)
if eou_prob < EOU_THRESHOLD:
yield "[Wait... Keep typing...]"
return
# Generate response incrementally and yield each token
accumulated_response = ""
# Corrected the chat completions method call
for chunk in qwen_client.chat.completions.create(
messages=messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = chunk.choices[0].delta.content or ""
accumulated_response += token
yield accumulated_response # Yield accumulated response for live updates
print(f"Final response: {accumulated_response}")
# Create Gradio interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Slider(1, 4096, value=256, label="Max Tokens"),
gr.Slider(0.1, 4.0, value=0.7, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.95, label="Top-p"),
]
)
if __name__ == "__main__":
demo.launch() |