Spaces:
Sleeping
Sleeping
''' | |
python score_script.py . three_output | |
''' | |
import os | |
import json | |
from tqdm import tqdm | |
from PIL import Image | |
from ccip import _VALID_MODEL_NAMES, _DEFAULT_MODEL_NAMES, ccip_difference, ccip_default_threshold | |
from datasets import load_dataset | |
import pathlib | |
import argparse | |
# 加载数据集 | |
Genshin_Impact_Illustration_ds = load_dataset("svjack/Genshin-Impact-Illustration")["train"] | |
ds_size = len(Genshin_Impact_Illustration_ds) | |
name_image_dict = {} | |
for i in range(ds_size): | |
row_dict = Genshin_Impact_Illustration_ds[i] | |
name_image_dict[row_dict["name"]] = row_dict["image"] | |
def _compare_with_dataset(imagex, model_name): | |
threshold = ccip_default_threshold(model_name) | |
results = [] | |
for name, imagey in name_image_dict.items(): | |
diff = ccip_difference(imagex, imagey) | |
result = { | |
"difference": diff, | |
"prediction": 'Same' if diff <= threshold else 'Not Same', | |
"name": name | |
} | |
results.append(result) | |
# 按照 diff 值进行排序 | |
results.sort(key=lambda x: x["difference"]) | |
return results | |
def process_image(image_path, model_name, output_dir): | |
image = Image.open(image_path) | |
results = _compare_with_dataset(image, model_name) | |
# 生成输出文件名 | |
image_name = os.path.splitext(os.path.basename(image_path))[0] | |
output_file = os.path.join(output_dir, f"{image_name}.json") | |
# 保存结果到 JSON 文件 | |
with open(output_file, 'w') as f: | |
json.dump(results, f, indent=4) | |
def main(): | |
parser = argparse.ArgumentParser(description="Compare images with a dataset and save results as JSON.") | |
parser.add_argument("input_path", type=str, help="Path to the input image or directory containing images.") | |
parser.add_argument("output_dir", type=str, help="Directory to save the output JSON files.") | |
parser.add_argument("--model", type=str, default=_DEFAULT_MODEL_NAMES, choices=_VALID_MODEL_NAMES, help="Model to use for comparison.") | |
args = parser.parse_args() | |
# 确保输出目录存在 | |
os.makedirs(args.output_dir, exist_ok=True) | |
# 判断输入路径是文件还是目录 | |
if os.path.isfile(args.input_path): | |
image_paths = [args.input_path] | |
elif os.path.isdir(args.input_path): | |
image_paths = list(pathlib.Path(args.input_path).rglob("*.png")) + list(pathlib.Path(args.input_path).rglob("*.jpg")) | |
else: | |
raise ValueError("Input path must be a valid file or directory.") | |
# 处理每个图片 | |
for image_path in tqdm(image_paths, desc="Processing images"): | |
process_image(image_path, args.model, args.output_dir) | |
if __name__ == '__main__': | |
main() | |