Spaces:
Runtime error
Runtime error
import evaluate | |
from rapidfuzz.distance.Levenshtein import distance, normalized_similarity | |
import config | |
BLEU = evaluate.load("saridormi/b_norm", cache_dir=config.CACHE_DIR) | |
def bleu_fn(pred, ref, **kwargs): | |
if "refs" in kwargs: | |
return BLEU.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["b_norm"] | |
return BLEU.compute(predictions=[pred], references=[ref])["b_norm"] | |
METEOR = evaluate.load("meteor", cache_dir=config.CACHE_DIR) | |
def meteor_fn(pred, ref, **kwargs): | |
if "refs" in kwargs: | |
return METEOR.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["meteor"] | |
return METEOR.compute(predictions=[pred], references=[ref])["meteor"] | |
ROUGE = evaluate.load("rouge", cache_dir=config.CACHE_DIR) | |
def rouge1_fn(pred, ref, **kwargs): | |
if "refs" in kwargs: | |
return ROUGE.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["rouge1"] | |
return ROUGE.compute(predictions=[pred], references=[ref])["rouge1"] | |
def rouge2_fn(pred, ref, **kwargs): | |
if "refs" in kwargs: | |
return ROUGE.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["rouge2"] | |
return ROUGE.compute(predictions=[pred], references=[ref])["rouge2"] | |
def rougeL_fn(pred, ref, **kwargs): | |
if "refs" in kwargs: | |
return ROUGE.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["rougeL"] | |
return ROUGE.compute(predictions=[pred], references=[ref])["rougeL"] | |
BERTSCORE = evaluate.load("bertscore", cache_dir=config.CACHE_DIR) | |
def bertscore_fn(pred, ref, **kwargs): | |
if "refs" in kwargs: | |
return BERTSCORE.compute(predictions=[pred], references=[kwargs["refs"]], model_type="distilbert-base-uncased")[ | |
"f1" | |
][0] | |
return BERTSCORE.compute(predictions=[pred], references=[ref], model_type="distilbert-base-uncased")["f1"][0] | |
CHRF = evaluate.load("chrf") | |
def chrf_fn(pred, ref, **kwargs): | |
if "refs" in kwargs: | |
return CHRF.compute(predictions=[pred], references=[kwargs["refs"]])["score"] | |
return CHRF.compute(predictions=[pred], references=[[ref]])["score"] | |
def edit_distance_fn(pred, ref, **kwargs): | |
if "refs" in kwargs: | |
scores = [distance(pred, ref) for ref in kwargs["refs"]] | |
return sum(scores) / len(scores) | |
return distance(pred, ref) | |
def edit_distance_norm_fn(pred, ref, **kwargs): | |
if "refs" in kwargs: | |
scores = [normalized_similarity(pred, ref) for ref in kwargs["refs"]] | |
return sum(scores) / len(scores) | |
return normalized_similarity(pred, ref) | |
AGGR_METRICS = { | |
"editdist": edit_distance_fn, | |
"editsim": edit_distance_norm_fn, | |
"bleu": bleu_fn, | |
"meteor": meteor_fn, | |
"rouge1": rouge1_fn, | |
"rouge2": rouge2_fn, | |
"rougeL": rougeL_fn, | |
"bertscore": bertscore_fn, | |
"chrF": chrf_fn, | |
} | |
REL_METRICS = { | |
"editdist": edit_distance_fn, | |
} | |