|
#!/bin/bash |
|
export LOGLEVEL=INFO |
|
output_dir=output |
|
|
|
|
|
default_step=20 |
|
default_bs=50 |
|
default_sample_nums=30000 |
|
default_sampling_algo="dpm-solver" |
|
default_json_file="data/test/PG-eval-data/MJHQ-30K/meta_data.json" |
|
default_add_label='' |
|
|
|
|
|
default_img_size=512 |
|
default_fid_suffix_label='' |
|
default_log_fid=false |
|
default_log_clip_score=false |
|
default_log_image_reward=false |
|
default_log_dpg=false |
|
|
|
|
|
if [ -n "$1" ]; then |
|
config_file=$1 |
|
fi |
|
|
|
if [ -n "$2" ]; then |
|
model_paths_file=$2 |
|
fi |
|
|
|
for arg in "$@" |
|
do |
|
case $arg in |
|
--np=*) |
|
np="${arg#*=}" |
|
shift |
|
;; |
|
--step=*) |
|
step="${arg#*=}" |
|
shift |
|
;; |
|
--bs=*) |
|
bs="${arg#*=}" |
|
shift |
|
;; |
|
--sample_nums=*) |
|
sample_nums="${arg#*=}" |
|
shift |
|
;; |
|
--sampling_algo=*) |
|
sampling_algo="${arg#*=}" |
|
shift |
|
;; |
|
--json_file=*) |
|
json_file="${arg#*=}" |
|
shift |
|
;; |
|
--exist_time_prefix=*) |
|
exist_time_prefix="${arg#*=}" |
|
shift |
|
;; |
|
--img_size=*) |
|
img_size="${arg#*=}" |
|
shift |
|
;; |
|
--dataset=*) |
|
dataset="${arg#*=}" |
|
shift |
|
;; |
|
--cfg_scale=*) |
|
cfg_scale="${arg#*=}" |
|
shift |
|
;; |
|
--fid_suffix_label=*) |
|
fid_suffix_label="${arg#*=}" |
|
shift |
|
;; |
|
--add_label=*) |
|
add_label="${arg#*=}" |
|
shift |
|
;; |
|
--log_fid=*) |
|
log_fid="${arg#*=}" |
|
shift |
|
;; |
|
--log_clip_score=*) |
|
log_clip_score="${arg#*=}" |
|
shift |
|
;; |
|
--log_image_reward=*) |
|
log_image_reward="${arg#*=}" |
|
shift |
|
;; |
|
--log_dpg=*) |
|
log_dpg="${arg#*=}" |
|
shift |
|
;; |
|
--inference=*) |
|
inference="${arg#*=}" |
|
shift |
|
;; |
|
--fid=*) |
|
fid="${arg#*=}" |
|
shift |
|
;; |
|
--clipscore=*) |
|
clipscore="${arg#*=}" |
|
shift |
|
;; |
|
--imagereward=*) |
|
imagereward="${arg#*=}" |
|
shift |
|
;; |
|
--dpg=*) |
|
dpg="${arg#*=}" |
|
shift |
|
;; |
|
--output_dir=*) |
|
output_dir="${arg#*=}" |
|
shift |
|
;; |
|
--auto_ckpt=*) |
|
auto_ckpt="${arg#*=}" |
|
shift |
|
;; |
|
--auto_ckpt_interval=*) |
|
auto_ckpt_interval="${arg#*=}" |
|
shift |
|
;; |
|
--tracker_pattern=*) |
|
tracker_pattern="${arg#*=}" |
|
shift |
|
;; |
|
--ablation_key=*) |
|
ablation_key="${arg#*=}" |
|
shift |
|
;; |
|
--ablation_selections=*) |
|
ablation_selections="${arg#*=}" |
|
shift |
|
;; |
|
--inference_script=*) |
|
inference_script="${arg#*=}" |
|
shift |
|
;; |
|
*) |
|
;; |
|
esac |
|
done |
|
|
|
inference=${inference:-true} |
|
fid=${fid:-true} |
|
clipscore=${clipscore:-true} |
|
imagereward=${imagereward:-false} |
|
dpg=${dpg:-false} |
|
|
|
np=${np:-8} |
|
step=${step:-$default_step} |
|
bs=${bs:-$default_bs} |
|
dataset=${dataset:-'custom'} |
|
cfg_scale=${cfg_scale:-4.5} |
|
sample_nums=${sample_nums:-$default_sample_nums} |
|
sampling_algo=${sampling_algo:-$default_sampling_algo} |
|
json_file=${json_file:-$default_json_file} |
|
exist_time_prefix=${exist_time_prefix:-$default_exist_time_prefix} |
|
add_label=${add_label:-$default_add_label} |
|
ablation_key=${ablation_key:-''} |
|
ablation_selections=${ablation_selections:-''} |
|
|
|
img_size=${img_size:-$default_img_size} |
|
fid_suffix_label=${fid_suffix_label:-$default_fid_suffix_label} |
|
tracker_pattern=${tracker_pattern:-"epoch_step"} |
|
log_fid=${log_fid:-$default_log_fid} |
|
log_clip_score=${log_clip_score:-$default_log_clip_score} |
|
log_image_reward=${log_image_reward:-$default_log_image_reward} |
|
log_dpg=${log_dpg:-$default_log_dpg} |
|
auto_ckpt=${auto_ckpt:-false} |
|
auto_ckpt_interval=${auto_ckpt_interval:-0} |
|
|
|
job_name=$(basename $(dirname $(dirname "$model_paths_file"))) |
|
metric_dir=$output_dir/$job_name/metrics |
|
if [ ! -d "$metric_dir" ]; then |
|
echo "Creating directory: $metric_dir" |
|
mkdir -p "$metric_dir" |
|
fi |
|
|
|
|
|
if [ "$auto_ckpt" = true ]; then |
|
bash scripts/collect_pth_path.sh $output_dir/$job_name/checkpoints $auto_ckpt_interval |
|
fi |
|
|
|
|
|
cache_file_path=$model_paths_file |
|
|
|
if [ ! -e "$model_paths_file" ]; then |
|
cache_file_path=$output_dir/$job_name/metrics/cached_img_paths_${dataset}.txt |
|
echo "$model_paths_file not exists, use default image path: $cache_file_path" |
|
fi |
|
|
|
if [ "$inference" = true ]; then |
|
inference_script=${inference_script:-"scripts/inference.py"} |
|
cache_file_path=$output_dir/$job_name/metrics/cached_img_paths_${dataset}.txt |
|
rm $metric_dir/tmp_${dataset}* || true |
|
read -r -d '' cmd <<EOF |
|
bash scripts/infer_run_inference.sh $config_file $model_paths_file --np=$np \ |
|
--inference_script=$inference_script --step=$step --bs=$bs --sample_nums=$sample_nums --json_file=$json_file \ |
|
--add_label=$add_label \ |
|
--exist_time_prefix=$exist_time_prefix --if_save_dirname=true --sampling_algo=$sampling_algo \ |
|
--cfg_scale=$cfg_scale --dataset=$dataset \ |
|
--ablation_key=$ablation_key --ablation_selections="$ablation_selections" |
|
EOF |
|
echo $cmd |
|
bash -c "${cmd}" |
|
> "$cache_file_path" |
|
|
|
for file in $metric_dir/tmp_${dataset}*.txt; do |
|
if [ -f "$file" ]; then |
|
cat "$file" >> "$cache_file_path" |
|
echo "" >> "$cache_file_path" |
|
fi |
|
done |
|
rm -r $metric_dir/tmp_${dataset}* || true |
|
fi |
|
|
|
img_path=${output_dir}/${job_name}/vis |
|
exp_paths_file=${cache_file_path} |
|
|
|
|
|
if [ "$fid" = true ]; then |
|
read -r -d '' cmd <<EOF |
|
bash tools/metrics/compute_fid_embedding.sh $img_path $exp_paths_file \ |
|
--sample_nums=$sample_nums --img_size=$img_size --suffix_label=$fid_suffix_label \ |
|
--log_fid=$log_fid --tracker_pattern=$tracker_pattern |
|
EOF |
|
echo $cmd |
|
bash -c "${cmd}" |
|
fi |
|
|
|
|
|
|
|
if [ "$clipscore" = true ]; then |
|
read -r -d '' cmd <<EOF |
|
bash tools/metrics/compute_clipscore.sh $img_path $exp_paths_file \ |
|
--sample_nums=$sample_nums --suffix_label=$fid_suffix_label \ |
|
--log_clip_score=$log_clip_score --tracker_pattern=$tracker_pattern |
|
EOF |
|
echo $cmd |
|
bash -c "${cmd}" |
|
fi |
|
|
|
|
|
if [ "$imagereward" = true ]; then |
|
read -r -d '' cmd <<EOF |
|
bash tools/metrics/compute_imagereward.sh $img_path $exp_paths_file \ |
|
--sample_nums=$sample_nums --suffix_label=$fid_suffix_label \ |
|
--log_image_reward=$log_image_reward --tracker_pattern=$tracker_pattern |
|
EOF |
|
echo $cmd |
|
bash -c "${cmd}" |
|
fi |
|
|
|
|
|
if [ "$dpg" = true ]; then |
|
read -r -d '' cmd <<EOF |
|
bash tools/metrics/compute_dpg.sh $img_path $exp_paths_file \ |
|
--sample_nums=$sample_nums --img_size=$img_size --suffix_label=$fid_suffix_label \ |
|
--log_dpg=$log_dpg --tracker_pattern=$tracker_pattern |
|
EOF |
|
echo $cmd |
|
bash -c "${cmd}" |
|
fi |
|
|