import os import subprocess import gradio as gr from transformers import AutoTokenizer from optimum.intel.openvino import OVModelForCausalLM from sentence_transformers import SentenceTransformer import faiss import numpy as np import warnings warnings.filterwarnings( "ignore", category=DeprecationWarning, message="__array__ implementation doesn't accept a copy keyword" ) # 設定模型 ID 與轉換後存檔路徑(8-bit 量化版) model_id = "agentica-org/DeepScaleR-1.5B-Preview" export_path = "exported_model_openvino_int8" print("Loading model as OpenVINO int8 (8-bit) model...") if os.path.exists(export_path) and os.listdir(export_path): print(f"Found quantized OpenVINO model at '{export_path}', loading it...") model = OVModelForCausalLM.from_pretrained(export_path, device_map="auto", use_cache=False) else: print("No quantized model found, exporting and quantizing to OpenVINO int8 now...") # 透過 optimum-cli 導出並量化模型(此命令行參數根據你的任務可能需要調整) command = [ "optimum-cli", "export", "openvino", "--model", model_id, "--task", "text-generation", "--weight-format", "int8", export_path ] subprocess.run(command, check=True) print(f"Quantized model saved to '{export_path}'.") model = OVModelForCausalLM.from_pretrained(export_path, device_map="auto", use_cache=False) print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) # 載入向量模型(用於將文本轉換為向量) encoder = SentenceTransformer("all-MiniLM-L6-v2") # FAQ 知識庫(問題 + 回答) faq_data = [ ("What is FAISS?", "FAISS is a library for efficient similarity search and clustering of dense vectors."), ("How does FAISS work?", "FAISS uses indexing structures to quickly retrieve the nearest neighbors of a query vector."), ("Can FAISS run on GPU?", "Yes, FAISS supports GPU acceleration for faster computation."), ("What is OpenVINO?", "OpenVINO is an inference engine optimized for Intel hardware."), ("How to fine-tune a model?", "Fine-tuning involves training a model on a specific dataset to adapt it to a particular task."), ("What is the best way to optimize inference speed?", "Using quantization and model distillation can significantly improve inference speed.") ] # 將 FAQ 問題轉換為向量 faq_questions = [q for q, _ in faq_data] faq_answers = [a for _, a in faq_data] faq_vectors = np.array(encoder.encode(faq_questions)).astype("float32") # 建立 FAISS 索引(使用 L2 距離) d = faq_vectors.shape[1] # 向量維度 index = faiss.IndexFlatL2(d) index.add(faq_vectors) # 對話歷史記錄 history = [] # 查詢函數:先嘗試從 FAQ 中檢索答案,若無匹配則使用 OpenVINO 模型生成回答 def respond(prompt): global history # 將輸入轉換為向量,並使用 FAISS 查詢最相近的 FAQ 問題 query_vector = np.array(encoder.encode([prompt])).astype("float32") D, I = index.search(query_vector, 1) if D[0][0] < 1.0: response = faq_answers[I[0][0]] else: # 若 FAQ 無匹配,則使用 OpenVINO 模型生成回答 messages = [{"role": "system", "content": "Answer the question in English only."}] for user_text, assistant_text in history: messages.append({"role": "user", "content": user_text}) messages.append({"role": "assistant", "content": assistant_text}) messages.append({"role": "user", "content": prompt}) # 將對話訊息組成一個 prompt(以換行分隔) chat_prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) model_inputs = tokenizer(chat_prompt, return_tensors="pt").to(model.device) generated_ids = model.generate( **model_inputs, max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True ) response = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip() history.append((prompt, response)) return response # 清除對話歷史記錄 def clear_history(): global history history = [] return "History cleared!" # 建立 Gradio 介面 with gr.Blocks() as demo: gr.Markdown("# DeepScaleR-1.5B-Preview (OpenVINO int8) Chatbot with FAISS FAQ_You must first copy it to your own SPACE before you can use it.") with gr.Tabs(): with gr.TabItem("Chat"): chat_interface = gr.Interface( fn=respond, inputs=gr.Textbox(label="Prompt", placeholder="Enter your message..."), outputs=gr.Textbox(label="Response", interactive=False), api_name="hchat", title="DeepScaleR-1.5B-Preview (OpenVINO int8) Chatbot", description="This chatbot first searches an FAQ database using FAISS, then uses an OpenVINO 8-bit model to generate a response if no FAQ match is found." ) with gr.Row(): clear_button = gr.Button("🧹 Clear History") clear_button.click(fn=clear_history, inputs=[], outputs=[]) if __name__ == "__main__": print("Launching Gradio app...") demo.launch(server_name="0.0.0.0", server_port=7860, share=True)