Qwen2-Audio-rkllm / multiprocess_inference.py
happyme531's picture
Upload 20 files
2ef3e1d verified
import faulthandler
faulthandler.enable()
import os
import random
import time
import signal
from multiprocessing import Process, Queue, Event
import numpy as np
from rkllm_binding import *
from rknnlite.api.rknn_lite import RKNNLite
import threading
import librosa
from transformers import WhisperFeatureExtractor
# 音频编码器进程
def audio_encoder_process(load_ready_queue, embedding_queue, audio_path_queue, start_event):
AUDIO_ENCODER_PATH = "audio_encoder.rknn"
# 初始化音频编码器
audio_encoder = RKNNLite(verbose=False)
model_size = os.path.getsize(AUDIO_ENCODER_PATH)
print(f"Start loading audio encoder model (size: {model_size / 1024 / 1024:.2f} MB)")
start_time = time.time()
audio_encoder.load_rknn(AUDIO_ENCODER_PATH)
end_time = time.time()
print(f"Audio encoder loaded in {end_time - start_time:.2f} seconds")
audio_encoder.init_runtime()
# 初始化Whisper特征提取器
feature_extractor = WhisperFeatureExtractor.from_pretrained(".")
# 通知主进程加载完成
load_ready_queue.put("audio_ready")
# 等待开始信号
start_event.wait()
def process_audio(audio_path, audio_encoder, feature_extractor):
try:
print("Start audio inference...")
audio, _ = librosa.load(audio_path, sr=feature_extractor.sampling_rate)
feature_extractor_output = feature_extractor(
audio,
sampling_rate=feature_extractor.sampling_rate,
return_attention_mask=True,
padding="max_length"
)
start_time = time.time()
audio_embeddings = audio_encoder.inference(inputs=[
feature_extractor_output.input_features.astype(np.float32),
feature_extractor_output.attention_mask.astype(np.float32)
], data_format="nhwc")[0].astype(np.float32)
end_time = time.time()
print(f"Audio encoder inference time: {end_time - start_time:.2f} seconds")
effective_length = feature_extractor_output.attention_mask.sum(-1)[0]
effective_length = (effective_length - 1) // 2 + 1
output_lengths = (effective_length - 2) // 2 + 1
audio_embeddings = audio_embeddings[:, :output_lengths]
print(audio_embeddings.shape)
return audio_embeddings
except Exception as e:
print(f"Error processing audio: {e}")
return None
while True:
audio_path = audio_path_queue.get()
if audio_path == "STOP":
break
embeddings = process_audio(audio_path, audio_encoder, feature_extractor)
if embeddings is not None:
embedding_queue.put(embeddings)
else:
embedding_queue.put("ERROR")
# LLM进程
def llm_process(load_ready_queue, embedding_queue, prompt_queue, inference_done_queue, start_event):
MODEL_PATH = "/home/firefly/qwen.rkllm"
handle = None
import locale
# 获取系统语言
system_lang = locale.getdefaultlocale()[0]
is_chinese = system_lang and system_lang.startswith('zh')
# is_chinese = False
# 添加进度提示信息列表
progress_messages_zh = [
"🚀 启动量子加速引擎...",
"🧠 神经网络正在苏醒...",
"🔄 并行宇宙计算进行中...",
"🌟 正在注入能量矩阵...",
"🔥 CPU已经到达工作温度,全力运转中...",
"🎯 特征向量正在跳跃式生长...",
"🎭 多头注意力机制开始营业...",
"💨 散热风扇已经进入超音速状态...",
"📚 语义解析器正在啃食数据...",
"🔍 上下文关联分析师正在加班...",
"🎨 视觉特征正在调色盘中混合...",
"🤝 跨模态对齐正在相亲相爱中...",
"⚡ 深度特征提取器已经深入地心...",
"🧪 神经网络正在炼丹中...",
"🎲 张量计算已经进入量子态...",
"📦 模型参数正在装箱搬运...",
"⚖️ 权重矩阵正在天平上找平衡...",
"🗺 语义向量正在绘制航海图...",
"🎭 注意力头们正在开会讨论...",
"🏗 残差模块正在搭建天梯...",
"🌈 激活函数正在调制彩虹...",
"🎮 张量核心正在玩魔方...",
"🎪 循环神经网络正在马戏团表演...",
"🎨 特征图正在画饼充饥...",
"🔮 模型正在占卜未来...",
"🎯 优化器正在进行火箭轨道计算...",
"🎪 批归一化正在杂技表演...",
"🎭 Dropout正在玩捉迷藏...",
"🌪 梯度正在形成龙卷风...",
"🎢 反向传播正在过山车..."
]
progress_messages_en = [
"Loading...",
"Extracting...",
"Image fusion in progress...",
"Matrix multiplication...",
"Chip heating up...",
"Feature vector calculation...",
"Attention mechanism processing...",
"Fan speed increasing...",
"Semantic parsing...",
"Context analysis...",
"Visual feature encoding...",
"Cross-modal alignment...",
"Deep feature extraction...",
"Neural network inference...",
"Tensor operations...",
"Loading model parameters...",
"Weight matrix calculation...",
"Semantic vector mapping...",
"Multi-head attention...",
"Residual connection..."
]
# 根据语言选择提示信息
progress_messages = progress_messages_zh if is_chinese else progress_messages_en
# 添加进度提示控制事件
progress_stop_event = threading.Event()
# 进度提示线程函数
def show_progress():
while not progress_stop_event.is_set():
for msg in progress_messages:
if progress_stop_event.is_set():
break
print(f"{msg}", flush=True)
time.sleep(random.uniform(0.1, 0.4))
def signal_handler(signal, frame):
print("Ctrl-C pressed, exiting...")
global handle
if handle:
abort(handle)
destroy(handle)
exit(0)
signal.signal(signal.SIGINT, signal_handler)
os.environ["RKLLM_LOG_LEVEL"] = "1"
inference_count = 0
inference_start_time = 0
def result_callback(result, userdata, state):
nonlocal inference_start_time, inference_count
if state == LLMCallState.RKLLM_RUN_NORMAL:
if inference_count == 0:
progress_stop_event.set() # 停止进度提示
first_token_time = time.time()
print("🎉 完成!")
print(f"\nTime to first token: {first_token_time - inference_start_time:.2f} seconds")
inference_count += 1
print(result.contents.text.decode(), end="", flush=True)
elif state == LLMCallState.RKLLM_RUN_FINISH:
print("\n\n(finished)")
inference_done_queue.put("DONE")
elif state == LLMCallState.RKLLM_RUN_ERROR:
print("\nError occurred during LLM call")
inference_done_queue.put("ERROR")
# 初始化LLM
param = create_default_param()
param.model_path = MODEL_PATH.encode()
param.img_start = "<|audio_bos|>".encode()
param.img_end = "<|audio_eos|>".encode()
param.img_content = "<|AUDIO|>".encode()
param.max_context_len = 768
param.max_new_tokens = 256
extend_param = RKLLMExtendParam()
extend_param.base_domain_id = 1
param.extend_param = extend_param
model_size = os.path.getsize(MODEL_PATH)
print(f"Start loading language model (size: {model_size / 1024 / 1024:.2f} MB)")
start_time = time.time()
handle = init(param, result_callback)
end_time = time.time()
print(f"Language model loaded in {end_time - start_time:.2f} seconds")
# 通知主进程加载完成
load_ready_queue.put("llm_ready")
# 创建推理参数
infer_param = RKLLMInferParam()
infer_param.mode = RKLLMInferMode.RKLLM_INFER_GENERATE.value
while True:
prompt = prompt_queue.get()
print(f"Received prompt: ===={prompt}\n====")
if prompt == "STOP":
break
# 重置计数器和事件
inference_count = 0
progress_stop_event.clear()
# 启动进度提示线程
progress_thread = threading.Thread(target=show_progress)
progress_thread.daemon = True
# progress_thread.start()
image_embeddings = embedding_queue.get()
if isinstance(image_embeddings, str) and image_embeddings == "ERROR":
print("Error processing audio")
continue
print(image_embeddings.shape)
rkllm_input = create_rkllm_input(RKLLMInputType.RKLLM_INPUT_MULTIMODAL,
prompt=prompt,
image_embed=image_embeddings)
print(f"Start LLM inference...")
inference_start_time = time.time()
run(handle, rkllm_input, infer_param, None)
# 清理
destroy(handle)
def main():
load_ready_queue = Queue()
embedding_queue = Queue()
audio_path_queue = Queue()
prompt_queue = Queue()
inference_done_queue = Queue()
start_event = Event()
audio_process = Process(target=audio_encoder_process,
args=(load_ready_queue, embedding_queue, audio_path_queue, start_event))
lm_process = Process(target=llm_process,
args=(load_ready_queue, embedding_queue, prompt_queue, inference_done_queue, start_event))
audio_process.start()
time.sleep(10)
lm_process.start()
# 等待模型加载
ready_count = 0
while ready_count < 2:
status = load_ready_queue.get()
print(f"Received ready signal: {status}")
ready_count += 1
print("All models loaded, starting interactive mode...")
start_event.set()
# 交互循环
try:
while True:
print("""
Enter your input (3 empty lines to start inference, Ctrl+C to exit, for example:
这是什么声音{{glass-breaking.wav}}?
What kind of sound is in {{./test.mp3}}?
Describe the audio in {{./test.mp3}}
这是什么动物的叫声{{./jntm.mp3}}?
):
""")
user_input = []
empty_lines = 0
while empty_lines < 3:
line = input()
if line.strip() == "":
empty_lines += 1
else:
empty_lines = 0
user_input.append(line)
# 解析输入
full_input = "\n".join(user_input[:-3]) # 去掉最后3个空行
import re
img_match = re.search(r'\{\{(.+?)\}\}', full_input)
if not img_match:
print("No image path found in input")
continue
img_path = img_match.group(1)
# 将音频标记替换为<image>标记, rkllm的<image>是写死的...
prompt = f"""<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Audio 1: <image>
{full_input.replace(img_match.group(0), '')}<|im_end|>
<|im_start|>assistant
"""
audio_path_queue.put(img_path)
prompt_queue.put(prompt)
# 等待推理完成
status = inference_done_queue.get()
if status == "ERROR":
print("Inference failed")
except KeyboardInterrupt:
print("\nExiting...")
audio_path_queue.put("STOP")
prompt_queue.put("STOP")
audio_process.join()
lm_process.join()
if __name__ == "__main__":
main()
#这是什么声音{{./test.mp3}}?