Petr Tsvetkov
release
9513395
import evaluate
from rapidfuzz.distance.Levenshtein import distance, normalized_similarity
import config
BLEU = evaluate.load("saridormi/b_norm", cache_dir=config.CACHE_DIR)
def bleu_fn(pred, ref, **kwargs):
if "refs" in kwargs:
return BLEU.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["b_norm"]
return BLEU.compute(predictions=[pred], references=[ref])["b_norm"]
METEOR = evaluate.load("meteor", cache_dir=config.CACHE_DIR)
def meteor_fn(pred, ref, **kwargs):
if "refs" in kwargs:
return METEOR.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["meteor"]
return METEOR.compute(predictions=[pred], references=[ref])["meteor"]
ROUGE = evaluate.load("rouge", cache_dir=config.CACHE_DIR)
def rouge1_fn(pred, ref, **kwargs):
if "refs" in kwargs:
return ROUGE.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["rouge1"]
return ROUGE.compute(predictions=[pred], references=[ref])["rouge1"]
def rouge2_fn(pred, ref, **kwargs):
if "refs" in kwargs:
return ROUGE.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["rouge2"]
return ROUGE.compute(predictions=[pred], references=[ref])["rouge2"]
def rougeL_fn(pred, ref, **kwargs):
if "refs" in kwargs:
return ROUGE.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["rougeL"]
return ROUGE.compute(predictions=[pred], references=[ref])["rougeL"]
BERTSCORE = evaluate.load("bertscore", cache_dir=config.CACHE_DIR)
def bertscore_fn(pred, ref, **kwargs):
if "refs" in kwargs:
return BERTSCORE.compute(predictions=[pred], references=[kwargs["refs"]], model_type="distilbert-base-uncased")[
"f1"
][0]
return BERTSCORE.compute(predictions=[pred], references=[ref], model_type="distilbert-base-uncased")["f1"][0]
CHRF = evaluate.load("chrf")
def chrf_fn(pred, ref, **kwargs):
if "refs" in kwargs:
return CHRF.compute(predictions=[pred], references=[kwargs["refs"]])["score"]
return CHRF.compute(predictions=[pred], references=[[ref]])["score"]
def edit_distance_fn(pred, ref, **kwargs):
if "refs" in kwargs:
scores = [distance(pred, ref) for ref in kwargs["refs"]]
return sum(scores) / len(scores)
return distance(pred, ref)
def edit_distance_norm_fn(pred, ref, **kwargs):
if "refs" in kwargs:
scores = [normalized_similarity(pred, ref) for ref in kwargs["refs"]]
return sum(scores) / len(scores)
return normalized_similarity(pred, ref)
AGGR_METRICS = {
"editdist": edit_distance_fn,
"editsim": edit_distance_norm_fn,
"bleu": bleu_fn,
"meteor": meteor_fn,
"rouge1": rouge1_fn,
"rouge2": rouge2_fn,
"rougeL": rougeL_fn,
"bertscore": bertscore_fn,
"chrF": chrf_fn,
}
REL_METRICS = {
"editdist": edit_distance_fn,
}