beyoru commited on
Commit
beef6e2
·
verified ·
1 Parent(s): 854ef87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -71
app.py CHANGED
@@ -2,74 +2,115 @@ import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
 
5
 
6
- # Load Inference Client for the response model
7
- client = InferenceClient("Qwen/Qwen2.5-3B-Instruct")
 
 
8
 
9
- # Load tokenizer and model for the EOU detection
10
- tokenizer = AutoTokenizer.from_pretrained("livekit/turn-detector")
11
- model = AutoModelForCausalLM.from_pretrained("livekit/turn-detector")
12
-
13
- # Function to compute EOU probability
14
- def compute_eou_probability(chat_ctx: list[dict[str, str]], max_tokens: int = 512) -> float:
15
- # Extract only the 'content' from the chat context (messages) and use a list of strings for tokenization
16
- conversation = ["Assistant ready to help."] # Add system message directly as a string
 
 
17
 
18
- # Only append the 'content' of each message to the conversation list
19
- for msg in chat_ctx:
20
- content = msg.get("content", "")
21
- if content:
22
- conversation.append(content) # Only append the content (string)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # Tokenize the conversation content (just the text) as a list of strings
25
- inputs = tokenizer(
26
- conversation, padding=True, truncation=True, max_length=max_tokens, return_tensors="pt"
27
- )
28
-
29
- # Get model logits
30
- with torch.no_grad():
31
- outputs = model(**inputs)
32
-
33
- # Get the logits for the last token in the sequence
34
- logits = outputs.logits[0, -1, :]
35
-
36
- # Apply softmax to get probabilities
37
- probabilities = torch.nn.functional.softmax(logits, dim=-1)
38
-
39
- # Get the EOU token index (typically "<|im_end|>" token in the model)
40
- eou_token_id = tokenizer.encode("<|im_end|>")[0]
41
- eou_probability = probabilities[eou_token_id].item()
42
-
43
- return eou_probability
 
 
 
 
 
 
 
 
 
44
 
45
- # Respond function with EOU checking logic
46
  def respond(
47
- message,
48
  history: list[tuple[str, str]],
49
- system_message,
50
- max_tokens,
51
- temperature,
52
- top_p,
53
- eou_threshold: float = 0.2, # Default EOU threshold
54
- ):
 
 
 
 
 
55
  messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
56
 
57
- for val in history:
58
- if val[0]:
59
- messages.append({"role": "user", "content": val[0]})
60
- if val[1]:
61
- messages.append({"role": "assistant", "content": val[1]})
62
 
63
- # Compute EOU probability before responding
64
- eou_probability = compute_eou_probability(messages, max_tokens=max_tokens)
65
- print(eou_probability)
66
- # Only respond if EOU probability exceeds threshold
67
  if eou_probability >= eou_threshold:
68
- # Prepare message for assistant response
69
- messages.append({"role": "user", "content": message})
70
-
71
  response = ""
72
-
73
  for message in client.chat_completion(
74
  messages,
75
  max_tokens=max_tokens,
@@ -81,29 +122,19 @@ def respond(
81
  response += token
82
  yield response
83
  else:
84
- # Let the user continue typing if the EOU probability is low
85
  yield "Waiting for user to finish... Please continue."
86
- print("Waiting for user to finish... Please continue.")
87
 
88
- # Gradio UI
89
  demo = gr.ChatInterface(
90
  respond,
91
  additional_inputs=[
92
- gr.Textbox(value="Bạn một trợ lý ảo", label="System message"),
93
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
94
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
95
- gr.Slider(
96
- minimum=0.1,
97
- maximum=1.0,
98
- value=0.95,
99
- step=0.05,
100
- label="Top-p (nucleus sampling)",
101
- ),
102
- gr.Slider(
103
- minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="EOU Threshold"
104
- ), # Add EOU threshold slider
105
  ],
106
  )
107
 
108
  if __name__ == "__main__":
109
- demo.launch()
 
2
  from huggingface_hub import InferenceClient
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
+ import string
6
 
7
+ # Constants
8
+ PUNCS = string.punctuation.replace("'", "")
9
+ MAX_HISTORY = 4
10
+ MAX_HISTORY_TOKENS = 512
11
 
12
+ class EOUDetector:
13
+ def __init__(self, model_name="livekit/turn-detector"):
14
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
16
+ self.eou_token_id = self.tokenizer.encode("<|im_end|>")[-1]
17
+
18
+ def _normalize_text(self, text: str) -> str:
19
+ """Normalize text by removing punctuation and extra spaces."""
20
+ text = text.translate(str.maketrans("", "", PUNCS))
21
+ return " ".join(text.lower().split())
22
 
23
+ def _format_chat_context(self, messages: list[dict]) -> str:
24
+ """Format chat context using the model's chat template."""
25
+ normalized_messages = []
26
+
27
+ for msg in messages[-MAX_HISTORY:]: # Only keep last MAX_HISTORY messages
28
+ if msg["role"] not in ("user", "assistant"):
29
+ continue
30
+
31
+ content = self._normalize_text(msg["content"])
32
+ if content:
33
+ normalized_messages.append({
34
+ "role": msg["role"],
35
+ "content": content
36
+ })
37
+
38
+ # Apply chat template without generation prompt
39
+ conversation = self.tokenizer.apply_chat_template(
40
+ normalized_messages,
41
+ add_generation_prompt=False,
42
+ add_special_tokens=False,
43
+ tokenize=False
44
+ )
45
+
46
+ # Remove the EOU token from current utterance if present
47
+ ix = conversation.rfind("<|im_end|>")
48
+ if ix >= 0:
49
+ conversation = conversation[:ix]
50
+
51
+ return conversation
52
 
53
+ def compute_eou_probability(self, messages: list[dict]) -> float:
54
+ """Compute the probability of end of utterance."""
55
+ # Format the conversation
56
+ conversation = self._format_chat_context(messages)
57
+
58
+ # Tokenize with proper truncation
59
+ inputs = self.tokenizer(
60
+ conversation,
61
+ add_special_tokens=False,
62
+ return_tensors="pt",
63
+ max_length=MAX_HISTORY_TOKENS,
64
+ truncation=True,
65
+ truncation_side="left"
66
+ )
67
+
68
+ # Get model predictions
69
+ with torch.no_grad():
70
+ outputs = self.model(**inputs)
71
+
72
+ # Get logits for the last token
73
+ logits = outputs.logits[0, -1, :]
74
+
75
+ # Compute softmax properly
76
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
77
+
78
+ # Get probability for EOU token
79
+ eou_probability = probabilities[self.eou_token_id].item()
80
+
81
+ return eou_probability
82
 
 
83
  def respond(
84
+ message: str,
85
  history: list[tuple[str, str]],
86
+ system_message: str,
87
+ max_tokens: int,
88
+ temperature: float,
89
+ top_p: float,
90
+ eou_threshold: float = 0.2,
91
+ ) -> str:
92
+ # Initialize clients
93
+ eou_detector = EOUDetector()
94
+ client = InferenceClient("Qwen/Qwen2.5-3B-Instruct")
95
+
96
+ # Prepare messages
97
  messages = [{"role": "system", "content": system_message}]
98
+ for user_msg, assistant_msg in history:
99
+ if user_msg:
100
+ messages.append({"role": "user", "content": user_msg})
101
+ if assistant_msg:
102
+ messages.append({"role": "assistant", "content": assistant_msg})
103
+
104
+ # Add current message
105
+ messages.append({"role": "user", "content": message})
106
 
107
+ # Check EOU probability
108
+ eou_probability = eou_detector.compute_eou_probability(messages)
109
+ print(f"EOU Probability: {eou_probability}")
 
 
110
 
 
 
 
 
111
  if eou_probability >= eou_threshold:
112
+ # Generate response
 
 
113
  response = ""
 
114
  for message in client.chat_completion(
115
  messages,
116
  max_tokens=max_tokens,
 
122
  response += token
123
  yield response
124
  else:
 
125
  yield "Waiting for user to finish... Please continue."
 
126
 
127
+ # Gradio Interface
128
  demo = gr.ChatInterface(
129
  respond,
130
  additional_inputs=[
131
+ gr.Textbox(value="You are a helpful assistant", label="System message"),
132
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
133
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
134
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
135
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.05, label="EOU Threshold"),
 
 
 
 
 
 
 
 
136
  ],
137
  )
138
 
139
  if __name__ == "__main__":
140
+ demo.launch()