Spaces:
Runtime error
Runtime error
lingo judge init.
Browse files- app.py +6 -0
- constants.py +23 -0
- judge.py +77 -0
- lingo_judge_metric.py +58 -0
app.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import evaluate
|
2 |
+
from evaluate.utils import launch_gradio_widget
|
3 |
+
|
4 |
+
|
5 |
+
module = evaluate.load("maysonma/lingo_judge_metric")
|
6 |
+
launch_gradio_widget(module)
|
constants.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Source: https://github.com/wayveai/LingoQA/blob/main/benchmark/constants.py
|
2 |
+
|
3 |
+
"""
|
4 |
+
LingoQA datasets are stored in Google Cloud.
|
5 |
+
This file provides the download link for the datasets, as well as reference keys for the data.
|
6 |
+
"""
|
7 |
+
from enum import Enum
|
8 |
+
|
9 |
+
LINGOQA_TEST = "https://drive.usercontent.google.com/u/1/uc?id=1I8u6uYysQUstoVYZapyRQkXmOwr-AG3d&export=download"
|
10 |
+
|
11 |
+
LINGO_JUDGE = "wayveai/Lingo-Judge"
|
12 |
+
|
13 |
+
class Keys(str, Enum):
|
14 |
+
question_id = "question_id"
|
15 |
+
segment_id = "segment_id"
|
16 |
+
question = "question"
|
17 |
+
answer = "answer"
|
18 |
+
references = "references"
|
19 |
+
prediction = "prediction"
|
20 |
+
max_score = "max_score"
|
21 |
+
score = "score"
|
22 |
+
probability = "probability"
|
23 |
+
correct = "correct"
|
judge.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Source: https://github.com/wayveai/LingoQA/blob/main/benchmark/judge.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from typing import List
|
7 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
8 |
+
from constants import LINGO_JUDGE
|
9 |
+
|
10 |
+
|
11 |
+
class LingoJudge(nn.Module):
|
12 |
+
"""
|
13 |
+
LingoJudge is a textual classifier that evaluates the truthfulness of an answer on the LingoQA benchmark.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, pretrained_model=LINGO_JUDGE):
|
17 |
+
super().__init__()
|
18 |
+
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True)
|
19 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_model).eval()
|
20 |
+
|
21 |
+
@torch.inference_mode()
|
22 |
+
def forward(self, question: str, references: List[str], prediction: str):
|
23 |
+
"""
|
24 |
+
Inference function for textual classifier with multiple reference answers.
|
25 |
+
Args:
|
26 |
+
question: Input question.
|
27 |
+
references: List of references.
|
28 |
+
prediction: Model prediction.
|
29 |
+
Output:
|
30 |
+
scores: Score indicating truthfulness.
|
31 |
+
"""
|
32 |
+
device = next(self.parameters()).device
|
33 |
+
texts = [
|
34 |
+
f"{self.tokenizer.cls_token}\nQuestion: {question}\nAnswer: {a_gt}\nStudent: {prediction}"
|
35 |
+
for a_gt in references
|
36 |
+
]
|
37 |
+
|
38 |
+
encoded_input = self.tokenizer(
|
39 |
+
texts, return_tensors="pt", padding=True, truncation=True, max_length=128
|
40 |
+
)
|
41 |
+
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
|
42 |
+
output = self.model(**encoded_input)
|
43 |
+
scores = output.logits.squeeze(-1)
|
44 |
+
return scores
|
45 |
+
|
46 |
+
def compute(self, questions: List[str], references: List[List[str]], predictions: List[str]):
|
47 |
+
"""
|
48 |
+
Compute maximum classifier metric. For multiple reference answers, selects the highest one.
|
49 |
+
Args:
|
50 |
+
questions: List of input questions.
|
51 |
+
references: List of lists, with multiple references per question supported.
|
52 |
+
predictions: List of model predictions.
|
53 |
+
Output:
|
54 |
+
scores: Score indicating truthfulness.
|
55 |
+
"""
|
56 |
+
max_scores = []
|
57 |
+
|
58 |
+
for index, question in enumerate(questions):
|
59 |
+
references_preprocessed = [
|
60 |
+
self.preprocess(reference) for reference in references[index]
|
61 |
+
]
|
62 |
+
prediction_preprocessed = self.preprocess(predictions[index])
|
63 |
+
scores = self.forward(question, references_preprocessed, prediction_preprocessed)
|
64 |
+
max_score = [max(scores)]
|
65 |
+
max_scores.extend(max_score)
|
66 |
+
return torch.Tensor(max_scores)
|
67 |
+
|
68 |
+
def preprocess(self, string: str):
|
69 |
+
"""
|
70 |
+
Preprocessing function for consistency.
|
71 |
+
Args:
|
72 |
+
string: input string to be processed.
|
73 |
+
Output:
|
74 |
+
output: processed string with lower cases and trailing lines removed.
|
75 |
+
"""
|
76 |
+
output = str(string).lower().lstrip().rstrip()
|
77 |
+
return output
|
lingo_judge_metric.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Inspired by: https://huggingface.co/spaces/evaluate-metric/bleurt/blob/main/bleurt.py
|
2 |
+
|
3 |
+
import datasets
|
4 |
+
import evaluate
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from judge import LingoJudge
|
8 |
+
|
9 |
+
_CITATION = """
|
10 |
+
@article{marcu2023lingoqa,
|
11 |
+
title={LingoQA: Video Question Answering for Autonomous Driving},
|
12 |
+
author={Ana-Maria Marcu and Long Chen and Jan Hünermann and Alice Karnsund and Benoit Hanotte and Prajwal Chidananda and Saurabh Nair and Vijay Badrinarayanan and Alex Kendall and Jamie Shotton and Oleg Sinavski},
|
13 |
+
journal={arXiv preprint arXiv:2312.14115},
|
14 |
+
year={2023},
|
15 |
+
}
|
16 |
+
"""
|
17 |
+
|
18 |
+
_DESCRIPTION = """
|
19 |
+
Lingo-Judge is an evaluation metric that aligns closely with human judgement on the LingoQA evaluation suite.
|
20 |
+
|
21 |
+
See the project's README at https://github.com/wayveai/LingoQA for more information.
|
22 |
+
"""
|
23 |
+
|
24 |
+
_KWARGS_DESCRIPTION = """
|
25 |
+
Lingo-Judge Score.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
'questions' (list of str): Input questions.
|
29 |
+
`predictions` (list of str): Model predictions.
|
30 |
+
`references` (list of list of str): Multiple references per question.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
`scores` (list of float): Score indicating truthfulness.
|
34 |
+
|
35 |
+
"""
|
36 |
+
|
37 |
+
|
38 |
+
class LingoJudgeMetric(evaluate.Metric):
|
39 |
+
def _info(self):
|
40 |
+
return evaluate.MetricInfo(
|
41 |
+
description=_DESCRIPTION,
|
42 |
+
citation=_CITATION,
|
43 |
+
inputs_description=_KWARGS_DESCRIPTION,
|
44 |
+
features=datasets.Features(
|
45 |
+
{
|
46 |
+
"questions": datasets.Value("string"),
|
47 |
+
"predictions": datasets.Value("string"),
|
48 |
+
"references": datasets.Sequence(datasets.Value("string")),
|
49 |
+
}
|
50 |
+
),
|
51 |
+
)
|
52 |
+
|
53 |
+
def _download_and_prepare(self, dl_manager):
|
54 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
55 |
+
self.scorer = LingoJudge().eval().to(self.device)
|
56 |
+
|
57 |
+
def _compute(self, questions, predictions, references):
|
58 |
+
return self.scorer.compute(questions, references, predictions)
|