|
import argparse |
|
import json |
|
from pathlib import Path |
|
import random |
|
|
|
|
|
def gen_self_self_aware_dataset(): |
|
|
|
|
|
self_aware_question = [ |
|
"你好", |
|
"你是谁", |
|
"你叫什么名字", |
|
"请做一下自我介绍", |
|
"介绍下你自己", |
|
] |
|
|
|
self_aware_answer_lelemiao = [ |
|
"您好,我是智能医导,随时准备解答您的医疗疑问。", |
|
"您好,我是智能医导,助您轻松就医。", |
|
"您好,我是智能医导,提供专业医疗指导。", |
|
"您好,我是智能医导,解答您的健康疑惑。", |
|
"您好,我是智能医导,帮助您了解医疗服务。", |
|
"您好,我是智能医导,您的医疗问题助手。", |
|
"您好,我是智能医导,助您快速获取医疗信息。", |
|
"您好,我是智能医导,为您提供医疗解答。", |
|
"您好,我是智能医导,帮助您理解医疗流程。", |
|
"您好,我是智能医导,解答您的医疗咨询。", |
|
"您好,我是智能医导,助您掌握健康知识。", |
|
"您好,我是智能医导,提供医疗信息查询。", |
|
"您好,我是智能医导,助您解决就医难题。", |
|
"您好,我是智能医导,您的私人医疗顾问。", |
|
"您好,我是智能医导,随时为您提供帮助。", |
|
|
|
] |
|
|
|
self_aware_json = [] |
|
for anser in self_aware_answer_lelemiao: |
|
|
|
self_aware_json.append({"conversation": [{"input": random.choice(self_aware_question), "output": anser}]}) |
|
|
|
return self_aware_json |
|
|
|
|
|
def merge_dataset(save_json_root: Path, final_save_json_path: Path): |
|
|
|
json_list = [] |
|
for json_path in save_json_root.glob("*.json"): |
|
with open(json_path, "r", encoding="utf-8") as f: |
|
json_list.append(json.load(f)) |
|
|
|
filter_json_list = [] |
|
|
|
dirty_conversion = [] |
|
for model_name in json_list: |
|
for product_name, gen_data_list in model_name.items(): |
|
|
|
for gen_data in gen_data_list: |
|
if isinstance(gen_data, dict) and "Error" in gen_data.keys(): |
|
print(f"Got error data in {product_name}") |
|
dirty_conversion.append(gen_data) |
|
continue |
|
|
|
|
|
sub_filter_list = {"conversation": []} |
|
for sub_list in gen_data["conversation"]: |
|
|
|
|
|
accept_keys = ["input", "output", "system"] |
|
sub_list = {key: value for key, value in sub_list.items() if key in accept_keys} |
|
|
|
if len(sub_list.keys()) < 2: |
|
|
|
dirty_conversion.append(sub_list) |
|
continue |
|
|
|
if "input" not in sub_list or "output" not in sub_list: |
|
|
|
dirty_conversion.append(sub_list) |
|
continue |
|
|
|
sub_filter_list["conversation"].append(sub_list) |
|
|
|
if len(sub_filter_list["conversation"]) > 0: |
|
filter_json_list.append(sub_filter_list) |
|
|
|
|
|
for idx in range(len(filter_json_list)): |
|
filter_json_list[idx]["conversation"][0][ |
|
"system" |
|
] = "现在你是一位医院大厅里的智能医导小助手,你的名字叫智能医导小助手,你的说话方式是严肃端庄。你能够根据病人的需求提供专业的医疗咨询,并且结合医疗知识解答用户提出的各种健康相关疑问。" |
|
|
|
|
|
filter_json_list += gen_self_self_aware_dataset() |
|
|
|
|
|
with open( |
|
final_save_json_path.parent.joinpath(f"{len(filter_json_list)}_{final_save_json_path.name}"), "w", encoding="utf-8" |
|
) as f: |
|
json.dump(filter_json_list, f, ensure_ascii=False, indent=4) |
|
|
|
if len(dirty_conversion) > 0: |
|
|
|
with open(final_save_json_path.parent.joinpath(f"error_{final_save_json_path.name}"), "w", encoding="utf-8") as f: |
|
json.dump(dirty_conversion, f, ensure_ascii=False, indent=4) |
|
|
|
sum_input_output_count = 0 |
|
for conversion in filter_json_list: |
|
sum_input_output_count += len(conversion["conversation"]) |
|
print( |
|
f"总生成有效 conversion 数据 {len(filter_json_list)} 组,内含 {sum_input_output_count} 条对话,剔除脏对话 {len(dirty_conversion)} 条,保存到 error_{final_save_json_path.name} 中。" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Merge Dataset") |
|
parser.add_argument("data_root", type=str, help="path to response dir") |
|
parser.add_argument("output_path", type=str, help="path to response dir") |
|
args = parser.parse_args() |
|
|
|
save_json_root = Path(args.data_root) |
|
final_save_json_path = Path(args.output_path) |
|
merge_dataset(save_json_root, final_save_json_path) |
|
|