import argparse import json from pathlib import Path import random def gen_self_self_aware_dataset(): # 自我认知 self_aware_question = [ "你好", "你是谁", "你叫什么名字", "请做一下自我介绍", "介绍下你自己", ] self_aware_answer_lelemiao = [ "您好,我是智能医导,随时准备解答您的医疗疑问。", "您好,我是智能医导,助您轻松就医。", "您好,我是智能医导,提供专业医疗指导。", "您好,我是智能医导,解答您的健康疑惑。", "您好,我是智能医导,帮助您了解医疗服务。", "您好,我是智能医导,您的医疗问题助手。", "您好,我是智能医导,助您快速获取医疗信息。", "您好,我是智能医导,为您提供医疗解答。", "您好,我是智能医导,帮助您理解医疗流程。", "您好,我是智能医导,解答您的医疗咨询。", "您好,我是智能医导,助您掌握健康知识。", "您好,我是智能医导,提供医疗信息查询。", "您好,我是智能医导,助您解决就医难题。", "您好,我是智能医导,您的私人医疗顾问。", "您好,我是智能医导,随时为您提供帮助。", ] self_aware_json = [] for anser in self_aware_answer_lelemiao: self_aware_json.append({"conversation": [{"input": random.choice(self_aware_question), "output": anser}]}) return self_aware_json def merge_dataset(save_json_root: Path, final_save_json_path: Path): # 将两个 json 进行合并 json_list = [] for json_path in save_json_root.glob("*.json"): with open(json_path, "r", encoding="utf-8") as f: json_list.append(json.load(f)) filter_json_list = [] dirty_conversion = [] for model_name in json_list: for product_name, gen_data_list in model_name.items(): for gen_data in gen_data_list: if isinstance(gen_data, dict) and "Error" in gen_data.keys(): print(f"Got error data in {product_name}") dirty_conversion.append(gen_data) continue # 洗掉一些没有 input 的数据 sub_filter_list = {"conversation": []} for sub_list in gen_data["conversation"]: # 剔除不合适的 key accept_keys = ["input", "output", "system"] sub_list = {key: value for key, value in sub_list.items() if key in accept_keys} if len(sub_list.keys()) < 2: # 如果只有单个 input output 出现,跳过 dirty_conversion.append(sub_list) continue if "input" not in sub_list or "output" not in sub_list: # 如果没有 input 或者 output,跳过 dirty_conversion.append(sub_list) continue sub_filter_list["conversation"].append(sub_list) if len(sub_filter_list["conversation"]) > 0: filter_json_list.append(sub_filter_list) # 修复数据集 for idx in range(len(filter_json_list)): filter_json_list[idx]["conversation"][0][ "system" ] = "现在你是一位医院大厅里的智能医导小助手,你的名字叫智能医导小助手,你的说话方式是严肃端庄。你能够根据病人的需求提供专业的医疗咨询,并且结合医疗知识解答用户提出的各种健康相关疑问。" # 生成自我认知的数据 filter_json_list += gen_self_self_aware_dataset() # 保存 with open( final_save_json_path.parent.joinpath(f"{len(filter_json_list)}_{final_save_json_path.name}"), "w", encoding="utf-8" ) as f: json.dump(filter_json_list, f, ensure_ascii=False, indent=4) if len(dirty_conversion) > 0: # 保存错误的过滤数据,方便用户自行解决 with open(final_save_json_path.parent.joinpath(f"error_{final_save_json_path.name}"), "w", encoding="utf-8") as f: json.dump(dirty_conversion, f, ensure_ascii=False, indent=4) sum_input_output_count = 0 for conversion in filter_json_list: sum_input_output_count += len(conversion["conversation"]) print( f"总生成有效 conversion 数据 {len(filter_json_list)} 组,内含 {sum_input_output_count} 条对话,剔除脏对话 {len(dirty_conversion)} 条,保存到 error_{final_save_json_path.name} 中。" ) if __name__ == "__main__": # 命令行输入参数 # TODO 目前仅仅支持 乐乐喵 parser = argparse.ArgumentParser(description="Merge Dataset") parser.add_argument("data_root", type=str, help="path to response dir") parser.add_argument("output_path", type=str, help="path to response dir") args = parser.parse_args() save_json_root = Path(args.data_root) final_save_json_path = Path(args.output_path) merge_dataset(save_json_root, final_save_json_path)