File size: 2,934 Bytes
9513395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import evaluate
from rapidfuzz.distance.Levenshtein import distance, normalized_similarity

import config

BLEU = evaluate.load("saridormi/b_norm", cache_dir=config.CACHE_DIR)


def bleu_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        return BLEU.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["b_norm"]
    return BLEU.compute(predictions=[pred], references=[ref])["b_norm"]


METEOR = evaluate.load("meteor", cache_dir=config.CACHE_DIR)


def meteor_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        return METEOR.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["meteor"]
    return METEOR.compute(predictions=[pred], references=[ref])["meteor"]


ROUGE = evaluate.load("rouge", cache_dir=config.CACHE_DIR)


def rouge1_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        return ROUGE.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["rouge1"]
    return ROUGE.compute(predictions=[pred], references=[ref])["rouge1"]


def rouge2_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        return ROUGE.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["rouge2"]
    return ROUGE.compute(predictions=[pred], references=[ref])["rouge2"]


def rougeL_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        return ROUGE.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["rougeL"]
    return ROUGE.compute(predictions=[pred], references=[ref])["rougeL"]


BERTSCORE = evaluate.load("bertscore", cache_dir=config.CACHE_DIR)


def bertscore_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        return BERTSCORE.compute(predictions=[pred], references=[kwargs["refs"]], model_type="distilbert-base-uncased")[
            "f1"
        ][0]
    return BERTSCORE.compute(predictions=[pred], references=[ref], model_type="distilbert-base-uncased")["f1"][0]


CHRF = evaluate.load("chrf")


def chrf_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        return CHRF.compute(predictions=[pred], references=[kwargs["refs"]])["score"]
    return CHRF.compute(predictions=[pred], references=[[ref]])["score"]


def edit_distance_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        scores = [distance(pred, ref) for ref in kwargs["refs"]]
        return sum(scores) / len(scores)
    return distance(pred, ref)


def edit_distance_norm_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        scores = [normalized_similarity(pred, ref) for ref in kwargs["refs"]]
        return sum(scores) / len(scores)
    return normalized_similarity(pred, ref)


AGGR_METRICS = {
    "editdist": edit_distance_fn,
    "editsim": edit_distance_norm_fn,
    "bleu": bleu_fn,
    "meteor": meteor_fn,
    "rouge1": rouge1_fn,
    "rouge2": rouge2_fn,
    "rougeL": rougeL_fn,
    "bertscore": bertscore_fn,
    "chrF": chrf_fn,
}


REL_METRICS = {
    "editdist": edit_distance_fn,
}