File size: 6,390 Bytes
134cb11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
"""
Segmentation metric code dapted from code for XView2: A Strong Baseline
Xview2_Strong_Baseline/legacy/xview2_metrics.py
Xview2_Strong_Baseline/legacy/create_masks.py
"""
# add python path
# import sys
# import os
# sys.path.append('/deep/u/emily712/aicc-win24-geo-vlm/videollava/')
import json
import string
import numpy as np
import cv2
from collections import defaultdict, Counter
from nltk.tokenize import word_tokenize
from shapely.geometry import Polygon
from pathlib import Path
from sklearn.metrics import f1_score
from tqdm import tqdm
def compute_tp_fn_fp(pred: np.ndarray, targ: np.ndarray, c: int):
"""
Computes the number of TPs, FNs, FPs, between a prediction (x) and a target (y) for the desired class (c)
Args:
pred (np.ndarray): prediction
targ (np.ndarray): target
c (int): positive class
"""
TP = np.logical_and(pred == c, targ == c).sum()
FN = np.logical_and(pred != c, targ == c).sum()
FP = np.logical_and(pred == c, targ != c).sum()
return [TP, FN, FP]
def accuracy_precision_recall(answer_path, dataset, ignore_punctuation=True, verbose=True):
# Replace with the path to the answers file
if type(answer_path) == dict:
results = answer_path
else:
with open(answer_path) as json_data:
results = json.load(json_data)
task_total = defaultdict(int)
task_tp = defaultdict(int)
binary_classification = defaultdict(bool)
binary_fp = defaultdict(int)
binary_fn = defaultdict(int)
# Dictionary of dictionaries. Key: task. Value: {class: count}
ground_truths = defaultdict(dict)
values = defaultdict(list)
accepted_tasks = [
"temporal_question_answering",
"region_based_question_answering",
"temporal_region_based_question_answering",
"question_answering",
"temporal_referring_expression",
"rural_urban",
"comp",
"presence",
"count",
"change_to_what",
"smallest_change",
"change_or_not",
"change_ratio",
"largest_change",
"change_ratio_types",
"increase_or_not",
"decrease_or_not"
]
for result in results.values():
if "task" in result and not any(result["task"].startswith(task) for task in accepted_tasks):
continue
# Clean predicted string if necessary
result["predicted"] = result["predicted"].lower()
result["ground_truth"] = result["ground_truth"].lower()
if ignore_punctuation:
result["predicted"] = ''.join(ch for ch in result["predicted"] if ch not in string.punctuation)
result["ground_truth"] = ''.join(ch for ch in result["ground_truth"] if ch not in string.punctuation)
if verbose:
values["predicted"].append(result["predicted"])
values["ground_truth"].append(result["ground_truth"])
values["correct_incorrect"].append("Correct" if result["predicted"] == result["ground_truth"] else "Incorrect")
if "task" not in result:
result["task"] = dataset
# True positive
if result["predicted"] == result["ground_truth"]:
task_tp[result["task"]] += 1
task_total[result["task"]] += 1
# If binary classification (yes/no question), calculate precision and recall metrics
binary_classification[result["task"]] = binary_classification[result["task"]] or (result["ground_truth"] in ["yes", "no"])
if binary_classification[result["task"]]:
if result["predicted"] != "no" and result["ground_truth"] == "no":
binary_fp[result["task"]] += 1
if result["predicted"] != "yes" and result["ground_truth"] == "yes":
binary_fn[result["task"]] += 1
# Update ground truth counts for the task
task = result["task"]
class_label = result["ground_truth"]
ground_truths[task][class_label] = ground_truths[task].get(class_label, 0) + 1
# Print tab separated values
if verbose:
max_len = max(len(v) for v in values["ground_truth"]) + 5
print("Predicted" + " " * (max_len - 9) + "\tGround Truth" + " " * (max_len - 12) + "\tCorrect/Incorrect")
for i in range(len(values["predicted"])):
print(values["predicted"][i] + " " * (max_len - len(values["predicted"][i])) + "\t" + values["ground_truth"][i] + " " * (max_len - len(values["ground_truth"][i])) + "\t" + values["correct_incorrect"][i])
total_tp = 0
total_predictions = 0
for task in task_tp:
acc_string = "Accuracy"
if ignore_punctuation:
acc_string += " (ignoring punctuation)"
print(f"{acc_string} for {task}: {round((task_tp[task] / task_total[task]), 4) * 100}%")
if binary_classification[task]:
if (task_tp[task] + binary_fp[task]) > 0:
print(f"Precision (ignoring punctuation) for {task}: {round((task_tp[task] / (task_tp[task] + binary_fp[task])), 3) * 100}%")
if (task_tp[task] + binary_fn[task]) > 0:
print(f"Recall (ignoring punctuation) for {task}: {round((task_tp[task] / (task_tp[task] + binary_fn[task])), 3) * 100}%")
majority_class = max(ground_truths[task], key=ground_truths[task].get)
majority_class_percentage = (ground_truths[task][majority_class] / task_total[task]) * 100
print(f"Majority class for {task}: {majority_class}, Percentage: {round(majority_class_percentage, 4)}%")
total_tp += task_tp[task]
total_predictions += task_total[task]
if total_predictions == 0:
print("No predictions made.")
else:
total_accuracy = (total_tp / total_predictions) * 100
print(f"Overall Accuracy: {round(total_accuracy, 3)}%")
# For testing accuracy/precision/recall on a particular script without running inference
if __name__ == '__main__':
root_dir = '/deep/u/jirvin16/aicc/aicc-win24-geo-vlm/videollava/scripts/geovlm/eval/QFabric/answers/'
answer_path = root_dir + "video-llava-7b-8bit-lora-final-no-metadata-zero-gc-acc8-freq-no-geochat-checkpoint-8000_qfabric_test_aux_data_test_prompt_strategy_interleave_chronological_prefix_True_load_8bit_True_load_4bit_False_delete_system_prompt_False.json"
accuracy_precision_recall(answer_path, dataset="qfabric", ignore_punctuation=True, verbose=False)
|