svjack commited on
Commit
4a255ef
·
verified ·
1 Parent(s): 865038d

Upload score_script.py

Browse files
Files changed (1) hide show
  1. score_script.py +76 -0
score_script.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ python score_script.py . three_output
3
+ '''
4
+
5
+ import os
6
+ import json
7
+ from tqdm import tqdm
8
+ from PIL import Image
9
+ from ccip import _VALID_MODEL_NAMES, _DEFAULT_MODEL_NAMES, ccip_difference, ccip_default_threshold
10
+ from datasets import load_dataset
11
+ import pathlib
12
+ import argparse
13
+
14
+ # 加载数据集
15
+ Genshin_Impact_Illustration_ds = load_dataset("svjack/Genshin-Impact-Illustration")["train"]
16
+ ds_size = len(Genshin_Impact_Illustration_ds)
17
+ name_image_dict = {}
18
+ for i in range(ds_size):
19
+ row_dict = Genshin_Impact_Illustration_ds[i]
20
+ name_image_dict[row_dict["name"]] = row_dict["image"]
21
+
22
+ def _compare_with_dataset(imagex, model_name):
23
+ threshold = ccip_default_threshold(model_name)
24
+ results = []
25
+
26
+ for name, imagey in name_image_dict.items():
27
+ diff = ccip_difference(imagex, imagey)
28
+ result = {
29
+ "difference": diff,
30
+ "prediction": 'Same' if diff <= threshold else 'Not Same',
31
+ "name": name
32
+ }
33
+ results.append(result)
34
+
35
+ # 按照 diff 值进行排序
36
+ results.sort(key=lambda x: x["difference"])
37
+
38
+ return results
39
+
40
+ def process_image(image_path, model_name, output_dir):
41
+ image = Image.open(image_path)
42
+ results = _compare_with_dataset(image, model_name)
43
+
44
+ # 生成输出文件名
45
+ image_name = os.path.splitext(os.path.basename(image_path))[0]
46
+ output_file = os.path.join(output_dir, f"{image_name}.json")
47
+
48
+ # 保存结果到 JSON 文件
49
+ with open(output_file, 'w') as f:
50
+ json.dump(results, f, indent=4)
51
+
52
+ def main():
53
+ parser = argparse.ArgumentParser(description="Compare images with a dataset and save results as JSON.")
54
+ parser.add_argument("input_path", type=str, help="Path to the input image or directory containing images.")
55
+ parser.add_argument("output_dir", type=str, help="Directory to save the output JSON files.")
56
+ parser.add_argument("--model", type=str, default=_DEFAULT_MODEL_NAMES, choices=_VALID_MODEL_NAMES, help="Model to use for comparison.")
57
+
58
+ args = parser.parse_args()
59
+
60
+ # 确保输出目录存在
61
+ os.makedirs(args.output_dir, exist_ok=True)
62
+
63
+ # 判断输入路径是文件还是目录
64
+ if os.path.isfile(args.input_path):
65
+ image_paths = [args.input_path]
66
+ elif os.path.isdir(args.input_path):
67
+ image_paths = list(pathlib.Path(args.input_path).rglob("*.png")) + list(pathlib.Path(args.input_path).rglob("*.jpg"))
68
+ else:
69
+ raise ValueError("Input path must be a valid file or directory.")
70
+
71
+ # 处理每个图片
72
+ for image_path in tqdm(image_paths, desc="Processing images"):
73
+ process_image(image_path, args.model, args.output_dir)
74
+
75
+ if __name__ == '__main__':
76
+ main()