Spaces:
Sleeping
Sleeping
File size: 4,074 Bytes
afec331 80b54e9 02a8fce 80b54e9 afec331 80b54e9 afec331 02a8fce 80b54e9 02a8fce 80b54e9 31391ab 02a8fce 80b54e9 02a8fce 80b54e9 02a8fce 80b54e9 02a8fce 80b54e9 02a8fce 80b54e9 02a8fce 80b54e9 02a8fce 80b54e9 afec331 80b54e9 c95ef21 02a8fce c95ef21 532c418 02a8fce 80b54e9 02a8fce 80b54e9 02a8fce 6c46f2d 02a8fce 6c46f2d 80b54e9 02a8fce 80b54e9 02a8fce 80b54e9 532c418 914aac7 532c418 80b54e9 02a8fce 80b54e9 02a8fce 5de4bf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
from transformers import AutoTokenizer
import onnxruntime as ort
import numpy as np
import string
from huggingface_hub import InferenceClient
import os
client = InferenceClient(api_key=os.environ.get('HF_TOKEN'))
# Model and ONNX setup
HG_MODEL = "livekit/turn-detector"
ONNX_FILENAME = "model_quantized.onnx"
PUNCS = string.punctuation.replace("'", "")
MAX_HISTORY = 4 # Adjusted to use the last 4 messages
MAX_HISTORY_TOKENS = 512
EOU_THRESHOLD = 0.5 # Updated threshold to match original
# Initialize ONNX model
tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
# Softmax function
def softmax(logits):
exp_logits = np.exp(logits - np.max(logits))
return exp_logits / np.sum(exp_logits)
# Normalize text
def normalize_text(text):
def strip_puncs(text):
return text.translate(str.maketrans("", "", PUNCS))
return " ".join(strip_puncs(text).lower().split())
# Format chat context
def format_chat_ctx(chat_ctx):
new_chat_ctx = []
for msg in chat_ctx:
if msg["role"] in ("user", "assistant"):
content = normalize_text(msg["content"])
if content:
msg["content"] = content
new_chat_ctx.append(msg)
# Tokenize with chat template
convo_text = tokenizer.apply_chat_template(
new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
)
# Remove EOU token from the current utterance
ix = convo_text.rfind("<|im_end|>")
return convo_text[:ix]
# Calculate EOU probability
def calculate_eou(chat_ctx, session):
formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:]) # Use the last 4 messages
inputs = tokenizer(
formatted_text,
return_tensors="np",
truncation=True,
max_length=MAX_HISTORY_TOKENS,
)
input_ids = np.array(inputs["input_ids"], dtype=np.int64)
outputs = session.run(["logits"], {"input_ids": input_ids})
logits = outputs[0][0, -1, :]
probs = softmax(logits)
eou_token_id = tokenizer.encode("<|im_end|>")[-1]
return probs[eou_token_id]
# Respond function
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
):
# Keep the last 4 conversation pairs (user-assistant)
messages = [{"role": "system", "content": os.environ.get("CHARACTER_DESC")}]
for val in history[-30:]:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
# Add the new user message to the context
messages.append({"role": "user", "content": message})
# Calculate EOU probability
# eou_prob = calculate_eou(messages, onnx_session)
# print(f"EOU Probability: {eou_prob}") # Debug output
# # If EOU is below the threshold, ask for more input
# if eou_prob < EOU_THRESHOLD:
# yield "[Waiting for user to continue input...]"
# return
stream = client.chat.completions.create(
model=os.environ.get('MODEL_ID'),
messages=messages,
temperature = 0.7,
max_tokens= 512,
top_p = 0.95,
do_sample = False,
stream=True
)
bot_response = ""
for chunk in stream:
bot_response += chunk.choices[0].delta.content
yield bot_response
# Gradio interface
demo = gr.ChatInterface(
respond,
# additional_inputs=[
# # Commented out to disable user modification of the system message
# # gr.Textbox(value="You are an assistant.", label="System message"),
# gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Max new tokens"),
# gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
# gr.Slider(
# minimum=0.1,
# maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
# ),
# ],
)
if __name__ == "__main__":
demo.launch() |