beyoru commited on
Commit
7bb2dd6
·
verified ·
1 Parent(s): b97e466

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -74
app.py CHANGED
@@ -6,41 +6,33 @@ import string
6
  from huggingface_hub import InferenceClient
7
  import os
8
 
9
- # Initialize client with error handling
10
- try:
11
- qwen_client = InferenceClient(
12
- "nztinversive/llama3.2-1b-Uncensored",
13
- timeout=120
14
- )
15
- except Exception as e:
16
- print(f"Error initializing client: {e}")
17
- raise
18
 
19
  # Model and ONNX setup
20
  HG_MODEL = "livekit/turn-detector"
21
  ONNX_FILENAME = "model_quantized.onnx"
22
  PUNCS = string.punctuation.replace("'", "")
23
- MAX_HISTORY = 4
24
  MAX_HISTORY_TOKENS = 512
25
- EOU_THRESHOLD = 0.5
26
 
27
  # Initialize ONNX model
28
- try:
29
- tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
30
- onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
31
- except Exception as e:
32
- print(f"Error initializing models: {e}")
33
- raise
34
 
 
35
  def softmax(logits):
36
  exp_logits = np.exp(logits - np.max(logits))
37
  return exp_logits / np.sum(exp_logits)
38
 
 
39
  def normalize_text(text):
40
  def strip_puncs(text):
41
  return text.translate(str.maketrans("", "", PUNCS))
42
  return " ".join(strip_puncs(text).lower().split())
43
 
 
44
  def format_chat_ctx(chat_ctx):
45
  new_chat_ctx = []
46
  for msg in chat_ctx:
@@ -50,15 +42,18 @@ def format_chat_ctx(chat_ctx):
50
  msg["content"] = content
51
  new_chat_ctx.append(msg)
52
 
 
53
  convo_text = tokenizer.apply_chat_template(
54
  new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
55
  )
56
 
 
57
  ix = convo_text.rfind("<|im_end|>")
58
  return convo_text[:ix]
59
 
 
60
  def calculate_eou(chat_ctx, session):
61
- formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:])
62
  inputs = tokenizer(
63
  formatted_text,
64
  return_tensors="np",
@@ -72,71 +67,67 @@ def calculate_eou(chat_ctx, session):
72
  eou_token_id = tokenizer.encode("<|im_end|>")[-1]
73
  return probs[eou_token_id]
74
 
75
- def respond(message, history, max_tokens, temperature, top_p):
76
- messages = [{"role": "system", "content": os.environ.get("CHARACTER_DESC", "You are a helpful assistant.")}]
77
-
78
- # Process history
79
- for user_msg, assistant_msg in history[-10:]:
80
- if user_msg:
81
- messages.append({"role": "user", "content": user_msg})
82
- if assistant_msg:
83
- messages.append({"role": "assistant", "content": assistant_msg})
84
-
85
- # Add new message
 
 
 
 
 
 
 
 
86
  messages.append({"role": "user", "content": message})
87
 
88
  # Calculate EOU probability
89
- try:
90
- eou_prob = calculate_eou(messages, onnx_session)
91
- print(f"EOU Probability: {eou_prob}")
92
-
93
- if eou_prob < EOU_THRESHOLD:
94
- yield "[Waiting for additional input...]"
95
- return
96
- except Exception as e:
97
- print(f"EOU calculation error: {e}")
98
- yield "[Error in conversation analysis]"
99
  return
100
 
101
- # Generate response
102
- try:
103
- # Format prompt for text generation
104
- prompt = tokenizer.apply_chat_template(
105
- messages,
106
- tokenize=False,
107
- add_generation_prompt=True
108
- )
109
-
110
- stream = qwen_client.text_generation(
111
- prompt,
112
- max_new_tokens=max_tokens,
113
- temperature=temperature,
114
- top_p=top_p,
115
- stream=True,
116
- details=True,
117
- )
118
-
119
- response = ""
120
- for chunk in stream:
121
- if chunk.token.text:
122
- response += chunk.token.text
123
- yield response
124
-
125
- print(f"Final response: {response}")
126
-
127
- except Exception as e:
128
- print(f"Generation error: {e}")
129
- yield "[Error generating response]"
130
 
131
  # Gradio interface
132
  demo = gr.ChatInterface(
133
  respond,
134
- additional_inputs=[
135
- gr.Slider(1, 4096, value=256, label="Max Tokens"),
136
- gr.Slider(0.1, 4.0, value=0.7, label="Temperature"),
137
- gr.Slider(0.1, 1.0, value=0.95, label="Top-p"),
138
- ]
 
 
 
 
 
139
  )
140
 
141
  if __name__ == "__main__":
142
- demo.launch()
 
 
6
  from huggingface_hub import InferenceClient
7
  import os
8
 
9
+ # Initialize Qwen client
10
+ qwen_client = InferenceClient("huihui-ai/SmolLM2-1.7B-Instruct-abliterated")
 
 
 
 
 
 
 
11
 
12
  # Model and ONNX setup
13
  HG_MODEL = "livekit/turn-detector"
14
  ONNX_FILENAME = "model_quantized.onnx"
15
  PUNCS = string.punctuation.replace("'", "")
16
+ MAX_HISTORY = 4 # Adjusted to use the last 4 messages
17
  MAX_HISTORY_TOKENS = 512
18
+ EOU_THRESHOLD = 0.5 # Updated threshold to match original
19
 
20
  # Initialize ONNX model
21
+ tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
22
+ onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
 
 
 
 
23
 
24
+ # Softmax function
25
  def softmax(logits):
26
  exp_logits = np.exp(logits - np.max(logits))
27
  return exp_logits / np.sum(exp_logits)
28
 
29
+ # Normalize text
30
  def normalize_text(text):
31
  def strip_puncs(text):
32
  return text.translate(str.maketrans("", "", PUNCS))
33
  return " ".join(strip_puncs(text).lower().split())
34
 
35
+ # Format chat context
36
  def format_chat_ctx(chat_ctx):
37
  new_chat_ctx = []
38
  for msg in chat_ctx:
 
42
  msg["content"] = content
43
  new_chat_ctx.append(msg)
44
 
45
+ # Tokenize with chat template
46
  convo_text = tokenizer.apply_chat_template(
47
  new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
48
  )
49
 
50
+ # Remove EOU token from the current utterance
51
  ix = convo_text.rfind("<|im_end|>")
52
  return convo_text[:ix]
53
 
54
+ # Calculate EOU probability
55
  def calculate_eou(chat_ctx, session):
56
+ formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:]) # Use the last 4 messages
57
  inputs = tokenizer(
58
  formatted_text,
59
  return_tensors="np",
 
67
  eou_token_id = tokenizer.encode("<|im_end|>")[-1]
68
  return probs[eou_token_id]
69
 
70
+
71
+ # Respond function
72
+ def respond(
73
+ message,
74
+ history: list[tuple[str, str]],
75
+ max_tokens,
76
+ temperature,
77
+ top_p,
78
+ ):
79
+ # Keep the last 4 conversation pairs (user-assistant)
80
+ messages = [{"role": "system", "content": os.environ.get("CHARACTER_DESC")}]
81
+
82
+ for val in history[-10:]: # Only use the last 4 pairs
83
+ if val[0]:
84
+ messages.append({"role": "user", "content": val[0]})
85
+ if val[1]:
86
+ messages.append({"role": "assistant", "content": val[1]})
87
+
88
+ # Add the new user message to the context
89
  messages.append({"role": "user", "content": message})
90
 
91
  # Calculate EOU probability
92
+ eou_prob = calculate_eou(messages, onnx_session)
93
+ print(f"EOU Probability: {eou_prob}") # Debug output
94
+
95
+ # If EOU is below the threshold, ask for more input
96
+ if eou_prob < EOU_THRESHOLD:
97
+ yield "[Waiting for user to continue input...]"
 
 
 
 
98
  return
99
 
100
+ # Generate response with Qwen
101
+ response = ""
102
+ for message in qwen_client.chat_completion(
103
+ messages,
104
+ max_tokens=max_tokens,
105
+ stream=True,
106
+ temperature=temperature,
107
+ top_p=top_p,
108
+ ):
109
+ token = message.choices[0].delta.content
110
+ response += token
111
+ yield response
112
+
113
+ print(f"Generated response: {response}")
114
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  # Gradio interface
117
  demo = gr.ChatInterface(
118
  respond,
119
+ # additional_inputs=[
120
+ # # Commented out to disable user modification of the system message
121
+ # # gr.Textbox(value="You are an assistant.", label="System message"),
122
+ # gr.Slider(minimum=1, maximum=4096, value=256, step=1, label="Max new tokens"),
123
+ # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
124
+ # gr.Slider(
125
+ # minimum=0.1,
126
+ # maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
127
+ # ),
128
+ # ],
129
  )
130
 
131
  if __name__ == "__main__":
132
+ demo.launch()
133
+