|
import logging
|
|
from typing import List, Dict, Optional, Tuple
|
|
|
|
import torch
|
|
from transformers import pipeline
|
|
from transformers import pipeline
|
|
|
|
from config.config import token, device, SYSTEM_PROMPT
|
|
from services.whisper import generate_speech, transcribe
|
|
from services.search import WebSearcher
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
searcher = WebSearcher()
|
|
|
|
|
|
model_kwargs = {
|
|
"low_cpu_mem_usage": True,
|
|
"torch_dtype": torch.float32,
|
|
'use_cache': True
|
|
}
|
|
client = pipeline(
|
|
"text-generation",
|
|
model="Qwen/Qwen2.5-0.5B-Instruct",
|
|
token=token,
|
|
trust_remote_code=True,
|
|
device=device,
|
|
model_kwargs=model_kwargs
|
|
)
|
|
|
|
async def respond(
|
|
audio: Optional[str] = None,
|
|
text: Optional[str] = None,
|
|
do_search: bool = False,
|
|
history: List[Dict] = None
|
|
) -> Tuple[Optional[str], str]:
|
|
try:
|
|
if text:
|
|
user_text = text.strip()
|
|
elif audio:
|
|
user_text = await transcribe(audio)
|
|
else:
|
|
return None, "No input provided"
|
|
|
|
|
|
messages = []
|
|
messages.append({"role": "system", "content": SYSTEM_PROMPT})
|
|
|
|
if history:
|
|
messages.extend(history)
|
|
|
|
|
|
prompt = ""
|
|
for msg in messages:
|
|
role = msg["role"]
|
|
content = msg["content"]
|
|
prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"
|
|
|
|
|
|
prompt += f"<|im_start|>user\n{user_text}<|im_end|>\n<|im_start|>assistant\n"
|
|
|
|
|
|
if do_search:
|
|
results = searcher.search(user_text)
|
|
if results:
|
|
search_context = "Based on search results:\n"
|
|
for result in results:
|
|
snippet = result['content'][:500].strip()
|
|
search_context += f"{snippet}\n"
|
|
prompt = prompt.replace(SYSTEM_PROMPT, f"{SYSTEM_PROMPT}\n{search_context}")
|
|
|
|
|
|
reply = client(
|
|
prompt,
|
|
max_new_tokens=400,
|
|
do_sample=True,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
num_return_sequences=1
|
|
)
|
|
|
|
|
|
assistant_response = reply[0]['generated_text']
|
|
assistant_response = assistant_response.split("<|im_start|>assistant\n")[-1]
|
|
assistant_response = assistant_response.split("<|im_end|>")[0].strip()
|
|
|
|
|
|
audio_path = await generate_speech(assistant_response)
|
|
return audio_path, assistant_response
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in respond: {str(e)}")
|
|
return None, "Sorry, I encountered an error" |