beyoru commited on
Commit
77f7b3a
·
verified ·
1 Parent(s): 7606979

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -24
app.py CHANGED
@@ -1,24 +1,24 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
2
-
3
  import gradio as gr
 
4
  import onnxruntime as ort
5
  import numpy as np
6
  import string
7
  from huggingface_hub import InferenceClient
8
- import os
9
-
10
 
11
- model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-1.5B-Instruct')
12
- tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-1.5B-Instruct')
13
 
14
- # ONNX setup
 
15
  ONNX_FILENAME = "model_quantized.onnx"
16
- onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
17
-
18
  PUNCS = string.punctuation.replace("'", "")
19
- MAX_HISTORY = 4
20
  MAX_HISTORY_TOKENS = 512
21
- EOU_THRESHOLD = 0.5
 
 
 
 
22
 
23
  # Softmax function
24
  def softmax(logits):
@@ -32,26 +32,27 @@ def normalize_text(text):
32
  return " ".join(strip_puncs(text).lower().split())
33
 
34
  # Format chat context
35
- # Update your format_chat_ctx function to ensure proper tokenization
36
  def format_chat_ctx(chat_ctx):
37
  new_chat_ctx = []
38
  for msg in chat_ctx:
39
  if msg["role"] in ("user", "assistant"):
40
  content = normalize_text(msg["content"])
41
  if content:
42
- new_chat_ctx.append({"role": msg["role"], "content": content})
43
-
44
- # Use correct chat template formatting
45
- return tokenizer.apply_chat_template(
46
- new_chat_ctx,
47
- add_generation_prompt=False,
48
- tokenize=False,
49
- add_special_tokens=True # Keep special tokens consistent
50
  )
51
 
 
 
 
 
52
  # Calculate EOU probability
53
  def calculate_eou(chat_ctx, session):
54
- formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:])
55
  inputs = tokenizer(
56
  formatted_text,
57
  return_tensors="np",
@@ -65,6 +66,10 @@ def calculate_eou(chat_ctx, session):
65
  eou_token_id = tokenizer.encode("<|im_end|>")[-1]
66
  return probs[eou_token_id]
67
 
 
 
 
 
68
  # Respond function
69
  def respond(
70
  message,
@@ -73,18 +78,28 @@ def respond(
73
  temperature,
74
  top_p,
75
  ):
76
- messages = [{"role": "system", "content": os.environ.get("CHARACTER_DESC")}]
77
- for val in history[-10:]:
 
 
78
  if val[0]:
79
  messages.append({"role": "user", "content": val[0]})
80
  if val[1]:
81
  messages.append({"role": "assistant", "content": val[1]})
 
 
82
  messages.append({"role": "user", "content": message})
 
 
83
  eou_prob = calculate_eou(messages, onnx_session)
84
- print(f"EOU Probability: {eou_prob}")
 
 
85
  if eou_prob < EOU_THRESHOLD:
86
  yield "[Waiting for user to continue input...]"
87
  return
 
 
88
  response = ""
89
  for message in qwen_client.chat_completion(
90
  messages,
@@ -96,11 +111,23 @@ def respond(
96
  token = message.choices[0].delta.content
97
  response += token
98
  yield response
 
99
  print(f"Generated response: {response}")
100
 
 
101
  # Gradio interface
102
  demo = gr.ChatInterface(
103
  respond,
 
 
 
 
 
 
 
 
 
 
104
  )
105
 
106
  if __name__ == "__main__":
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer
3
  import onnxruntime as ort
4
  import numpy as np
5
  import string
6
  from huggingface_hub import InferenceClient
 
 
7
 
8
+ # Initialize Qwen client
9
+ qwen_client = InferenceClient("ystemsrx/Qwen2.5-Sex")
10
 
11
+ # Model and ONNX setup
12
+ HG_MODEL = "livekit/turn-detector"
13
  ONNX_FILENAME = "model_quantized.onnx"
 
 
14
  PUNCS = string.punctuation.replace("'", "")
15
+ MAX_HISTORY = 4 # Adjusted to use the last 4 messages
16
  MAX_HISTORY_TOKENS = 512
17
+ EOU_THRESHOLD = 0.5 # Updated threshold to match original
18
+
19
+ # Initialize ONNX model
20
+ tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
21
+ onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
22
 
23
  # Softmax function
24
  def softmax(logits):
 
32
  return " ".join(strip_puncs(text).lower().split())
33
 
34
  # Format chat context
 
35
  def format_chat_ctx(chat_ctx):
36
  new_chat_ctx = []
37
  for msg in chat_ctx:
38
  if msg["role"] in ("user", "assistant"):
39
  content = normalize_text(msg["content"])
40
  if content:
41
+ msg["content"] = content
42
+ new_chat_ctx.append(msg)
43
+
44
+ # Tokenize with chat template
45
+ convo_text = tokenizer.apply_chat_template(
46
+ new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
 
 
47
  )
48
 
49
+ # Remove EOU token from the current utterance
50
+ ix = convo_text.rfind("<|im_end|>")
51
+ return convo_text[:ix]
52
+
53
  # Calculate EOU probability
54
  def calculate_eou(chat_ctx, session):
55
+ formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:]) # Use the last 4 messages
56
  inputs = tokenizer(
57
  formatted_text,
58
  return_tensors="np",
 
66
  eou_token_id = tokenizer.encode("<|im_end|>")[-1]
67
  return probs[eou_token_id]
68
 
69
+ # Read system message from file
70
+ with open("character/herta.txt", "r") as f:
71
+ system_message = f.read()
72
+
73
  # Respond function
74
  def respond(
75
  message,
 
78
  temperature,
79
  top_p,
80
  ):
81
+ # Keep the last 4 conversation pairs (user-assistant)
82
+ messages = [{"role": "system", "content": system_message}]
83
+
84
+ for val in history[-10:]: # Only use the last 4 pairs
85
  if val[0]:
86
  messages.append({"role": "user", "content": val[0]})
87
  if val[1]:
88
  messages.append({"role": "assistant", "content": val[1]})
89
+
90
+ # Add the new user message to the context
91
  messages.append({"role": "user", "content": message})
92
+
93
+ # Calculate EOU probability
94
  eou_prob = calculate_eou(messages, onnx_session)
95
+ print(f"EOU Probability: {eou_prob}") # Debug output
96
+
97
+ # If EOU is below the threshold, ask for more input
98
  if eou_prob < EOU_THRESHOLD:
99
  yield "[Waiting for user to continue input...]"
100
  return
101
+
102
+ # Generate response with Qwen
103
  response = ""
104
  for message in qwen_client.chat_completion(
105
  messages,
 
111
  token = message.choices[0].delta.content
112
  response += token
113
  yield response
114
+
115
  print(f"Generated response: {response}")
116
 
117
+
118
  # Gradio interface
119
  demo = gr.ChatInterface(
120
  respond,
121
+ # additional_inputs=[
122
+ # # Commented out to disable user modification of the system message
123
+ # # gr.Textbox(value="You are an assistant.", label="System message"),
124
+ # gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Max new tokens"),
125
+ # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
126
+ # gr.Slider(
127
+ # minimum=0.1,
128
+ # maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
129
+ # ),
130
+ # ],
131
  )
132
 
133
  if __name__ == "__main__":