svjack commited on
Commit
6d2843d
·
verified ·
1 Parent(s): 4a255ef

Upload score_tag_script.py

Browse files
Files changed (1) hide show
  1. score_tag_script.py +88 -0
score_tag_script.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ python score_tag_script.py . three_tag_output
3
+ '''
4
+
5
+ import os
6
+ import json
7
+ from tqdm import tqdm
8
+ from PIL import Image
9
+ from ccip import _VALID_MODEL_NAMES, _DEFAULT_MODEL_NAMES, ccip_difference, ccip_default_threshold
10
+ from datasets import load_dataset
11
+ import pathlib
12
+ import argparse
13
+ from imgutils.tagging import get_wd14_tags # 导入 get_wd14_tags 函数
14
+
15
+ # 加载数据集
16
+ Genshin_Impact_Illustration_ds = load_dataset("svjack/Genshin-Impact-Illustration")["train"]
17
+ ds_size = len(Genshin_Impact_Illustration_ds)
18
+ name_image_dict = {}
19
+ for i in range(ds_size):
20
+ row_dict = Genshin_Impact_Illustration_ds[i]
21
+ name_image_dict[row_dict["name"]] = row_dict["image"]
22
+
23
+ def _compare_with_dataset(imagex, model_name):
24
+ threshold = ccip_default_threshold(model_name)
25
+ results = []
26
+
27
+ for name, imagey in name_image_dict.items():
28
+ diff = ccip_difference(imagex, imagey)
29
+ result = {
30
+ "difference": diff,
31
+ "prediction": 'Same' if diff <= threshold else 'Not Same',
32
+ "name": name
33
+ }
34
+ results.append(result)
35
+
36
+ # 按照 diff 值进行排序
37
+ results.sort(key=lambda x: x["difference"])
38
+
39
+ return results
40
+
41
+ def process_image(image_path, model_name, output_dir):
42
+ image = Image.open(image_path)
43
+ results = _compare_with_dataset(image, model_name)
44
+
45
+ # 获取 WD14 标签
46
+ rating, features, chars = get_wd14_tags(image_path)
47
+
48
+ # 构建最终的输出字典
49
+ output_data = {
50
+ "results": results, # 保存比较结果
51
+ "rating": rating, # 保存 WD14 的 rating
52
+ "features": features, # 保存 WD14 的 features
53
+ "characters": chars # 保存 WD14 的 characters
54
+ }
55
+
56
+ # 生成输出文件名
57
+ image_name = os.path.splitext(os.path.basename(image_path))[0]
58
+ output_file = os.path.join(output_dir, f"{image_name}.json")
59
+
60
+ # 保存结果到 JSON 文件
61
+ with open(output_file, 'w') as f:
62
+ json.dump(output_data, f, indent=4)
63
+
64
+ def main():
65
+ parser = argparse.ArgumentParser(description="Compare images with a dataset and save results as JSON.")
66
+ parser.add_argument("input_path", type=str, help="Path to the input image or directory containing images.")
67
+ parser.add_argument("output_dir", type=str, help="Directory to save the output JSON files.")
68
+ parser.add_argument("--model", type=str, default=_DEFAULT_MODEL_NAMES, choices=_VALID_MODEL_NAMES, help="Model to use for comparison.")
69
+
70
+ args = parser.parse_args()
71
+
72
+ # 确保输出目录存在
73
+ os.makedirs(args.output_dir, exist_ok=True)
74
+
75
+ # 判断输入路径是文件还是目录
76
+ if os.path.isfile(args.input_path):
77
+ image_paths = [args.input_path]
78
+ elif os.path.isdir(args.input_path):
79
+ image_paths = list(pathlib.Path(args.input_path).rglob("*.png")) + list(pathlib.Path(args.input_path).rglob("*.jpg"))
80
+ else:
81
+ raise ValueError("Input path must be a valid file or directory.")
82
+
83
+ # 处理每个图片
84
+ for image_path in tqdm(image_paths, desc="Processing images"):
85
+ process_image(image_path, args.model, args.output_dir)
86
+
87
+ if __name__ == '__main__':
88
+ main()