import gradio as gr from transformers import BertTokenizerFast, BertForMaskedLM import torch from sentence_transformers import SentenceTransformer from sentence_transformers.util import cos_sim import random import json model_name = "nycu-ai113-dl-final-project/bert-turtle-soup-pet-zh" dataset_name = "nycu-ai113-dl-final-project/TurtleBench-extended-zh" model = BertForMaskedLM.from_pretrained(model_name) tokenizer = BertTokenizerFast.from_pretrained(model_name) answer_judge = SentenceTransformer('thenlper/gte-base-zh') intro=""" ### 玩法介紹 遊戲一開始,我會給你一個不完整的故事,這個故事通常有很多未知的細節,你需要透過提出問題來探索更多線索。你可以問我各種問題,不過請記住,我只能回答三種答案:「正確」、「錯誤」或「不知道」。你的目標是根據這些有限的答案,逐步推理出故事的完整脈絡,從而揭開事件的真相。 這個遊戲的名稱來自於其中一個最經典的題目,海龜湯的故事。由於這類型的遊戲強調水平思考,也就是用非傳統的方式解決問題,這些遊戲就被大家統稱為「海龜湯」,有點像是可樂成為所有碳酸飲料的代名詞。 在遊戲中,你的提問會讓你逐漸接近真相。準備好發揮你的推理能力,讓我們開始吧! """ class PuzzleGame: def __init__(self): """ 初始化遊戲類別。 """ self.template = ' 根據判定規則,此玩家的猜測為[MASK]' self.load_stories('stories.json') def load_stories(self, path): with open(path, mode='r', encoding='utf-8') as f: self.stories = json.load(f) def get_random_puzzle(self): """ 隨機選擇一個謎題並設定當前謎題的標題、故事和答案。 """ puzzle = random.choice(self.stories) # puzzle = self.stories[1] self.title = puzzle['title'] self.surface = puzzle['surface'] self.bottom = puzzle['bottom'] def get_prompt(self): """ 返回填入謎題故事和答案的 prompt """ few_shot = "1.玩家猜測:賣給他貨的人不是老闆;你的回答:是\n2.玩家猜測:他被嚇傻和零食本身有關;你的回答:不\n3.玩家猜測:零食是合法的;你的回答:不知道" prompt = f"你是遊戲的裁判,根據<湯麵>和<湯底>判斷玩家的猜測是否正確。你的回答只能是以下三種之一:1.是:玩家的猜測與故事相符。2.否:玩家的猜測與故事不符。3.不知道:無法從<湯麵>和<湯底>推理得出結論。注意:1. 玩家只能看到<湯麵>,你的判定也只能基於<湯麵>。2. 無法從故事中推理的問題,回答\"不知道\"。<湯麵>{self.surface}<湯底>{self.bottom}<範例>{few_shot}\n請判斷以下玩家猜測:" return prompt # 初始化遊戲 game = PuzzleGame() def predict_masked_token(text): """ 使用模型預測 [MASK] 位置的 token。 :param text: 包含 [MASK] 的文本。 :return: 預測的 token。 """ inputs = tokenizer(text, return_tensors="pt") model.eval() with torch.no_grad(): outputs = model(**inputs) # 找到 [MASK] 的位置 mask_token_index = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[1] if mask_token_index.numel() == 0: raise "請確保輸入包含 [MASK]。" # 獲取 [MASK] 位置的 logits 並預測 mask_token_logits = outputs.logits[0, mask_token_index, :] predicted_id = torch.argmax(mask_token_logits, dim=-1) return tokenizer.decode(predicted_id) def restart(): """ 重新開始遊戲,初始化新的謎題。 返回故事的開頭內容。 """ game.get_random_puzzle() story = [{"role": "assistant", "content": game.surface}] return story def user(message, history): """ 處理用戶輸入的問題或答案,並將其添加到對話歷史中。 :param message: 用戶的輸入消息。 :param history: 當前的對話歷史。 :return: 用戶的消息和更新後的歷史。 """ history.append({"role": "user", "content": message}) return message, history def check_question(question, history): """ 使用 Masked Language Model 檢查用戶提問的問題。 :param question: 用戶的提問。 :param history: 當前的對話歷史。 :return: 空字符串和更新後的歷史。 """ # 將問題與遊戲提示和模板結合 text = game.get_prompt() + question + game.template predicted = predict_masked_token(text) predicted_map = { '是':'正確', '否':'錯誤', '不':'不知道' } history.append({"role": "assistant", "content": predicted_map[predicted]}) return "", history def check_answer(answer, history): """ 使用語義相似度檢查用戶輸入的答案是否正確。 :param answer: 用戶的答案。 :param history: 當前的對話歷史。 :return: 空字符串和更新後的歷史。 """ sentences = [answer, game.bottom] embeddings = answer_judge.encode(sentences) sim = cos_sim([embeddings[0]], [embeddings[1]]) print("相似度: ", sim[0][0]) # 根據相似度生成回應 if sim[0][0] > 0.8: response = "正確!你猜對了! 完整故事:\n" + game.bottom elif sim[0][0] > 0.7: response = "接近了!再試一次!" else: response = "錯誤!再試一次!" history.append({"role": "assistant", "content": response}) return "", history # 使用 Gradio 創建界面 with gr.Blocks() as demo: # 頁面介紹 gr.Markdown(intro) gr.Markdown("---") # 初始化故事 story = restart() chatbot = gr.Chatbot(type='messages', value=story, height=600) # 問題提問功能 with gr.Tab("提出問題"): question_input_box = gr.Textbox( show_label=False, placeholder="提問各種可能性的問題...", submit_btn=True, ) # 用戶輸入的文本框 # 1. 將用戶輸入的問題提交到 `user` 函數處理,將問題加入到歷史對話中。 # 2. 將 `user` 處理的結果(問題和更新後的歷史)傳遞給 `check_question` 函數。 # 3. `check_question` 會檢查問題並生成對應的回應,更新對話歷史。 question_input_box.submit(user, [question_input_box, chatbot], [question_input_box, chatbot]).then( check_question, [question_input_box, chatbot], [question_input_box, chatbot] ) # 答案輸入功能 with gr.Tab("輸入答案"): answer_input_box = gr.Textbox( show_label=False, placeholder="請輸入你的答案...", submit_btn=True, ) # 用戶輸入的答案框 # 1. 將用戶輸入的答案提交到 `user` 函數處理,將答案加入到歷史對話中。 # 2. 將 `user` 處理的結果(答案和更新後的歷史)傳遞給 `check_answer` 函數。 # 3. `check_answer` 會檢查答案的正確性,生成對應的回應,並更新對話歷史。 answer_input_box.submit(user, [answer_input_box, chatbot], [answer_input_box, chatbot]).then( check_answer, [answer_input_box, chatbot], [answer_input_box, chatbot] ) # 重新開始按鈕 restart_btn = gr.ClearButton(value='重新開始新遊戲', inputs=[question_input_box, chatbot]) restart_btn.click(restart, outputs=[chatbot]) # 啟動應用 if __name__ == "__main__": demo.launch()