Spaces:
Runtime error
Runtime error
Modulated code and cleaned up main function (#2)
Browse files- Modulated code and cleaned up main function (3c364f94bfd0b663012b9f3431886a4606174aeb)
Co-authored-by: Nate Demchak <[email protected]>
- update_evaluate.py +153 -0
update_evaluate.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Optional, Sequence, Union
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from collections import defaultdict
|
5 |
+
from inspect import signature
|
6 |
+
|
7 |
+
from ..llm.client import LLMClient, get_default_client
|
8 |
+
from ..utils.analytics_collector import analytics
|
9 |
+
from .knowledge_base import KnowledgeBase
|
10 |
+
from .metrics import CorrectnessMetric, Metric
|
11 |
+
from .question_generators.utils import maybe_tqdm
|
12 |
+
from .recommendation import get_rag_recommendation
|
13 |
+
from .report import RAGReport
|
14 |
+
from .testset import QATestset
|
15 |
+
from .testset_generation import generate_testset
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
ANSWER_FN_HISTORY_PARAM = "history"
|
20 |
+
|
21 |
+
|
22 |
+
def evaluate(
|
23 |
+
answer_fn: Union[Callable, Sequence[str]],
|
24 |
+
testset: Optional[QATestset] = None,
|
25 |
+
knowledge_base: Optional[KnowledgeBase] = None,
|
26 |
+
llm_client: Optional[LLMClient] = None,
|
27 |
+
agent_description: str = "This agent is a chatbot that answers question from users.",
|
28 |
+
metrics: Optional[Sequence[Callable]] = None,
|
29 |
+
) -> RAGReport:
|
30 |
+
"""Evaluate an agent by comparing its answers on a QATestset.
|
31 |
+
|
32 |
+
Parameters
|
33 |
+
----------
|
34 |
+
answers_fn : Union[Callable, Sequence[str]]
|
35 |
+
The prediction function of the agent to evaluate or a list of precalculated answers on the testset.
|
36 |
+
testset : QATestset, optional
|
37 |
+
The test set to evaluate the agent on. If not provided, a knowledge base must be provided and a default testset will be created from the knowledge base.
|
38 |
+
Note that if the answers_fn is a list of answers, the testset is required.
|
39 |
+
knowledge_base : KnowledgeBase, optional
|
40 |
+
The knowledge base of the agent to evaluate. If not provided, a testset must be provided.
|
41 |
+
llm_client : LLMClient, optional
|
42 |
+
The LLM client to use for the evaluation. If not provided, a default openai client will be used.
|
43 |
+
agent_description : str, optional
|
44 |
+
Description of the agent to be tested.
|
45 |
+
metrics : Optional[Sequence[Callable]], optional
|
46 |
+
Metrics to compute on the test set.
|
47 |
+
|
48 |
+
Returns
|
49 |
+
-------
|
50 |
+
RAGReport
|
51 |
+
The report of the evaluation.
|
52 |
+
"""
|
53 |
+
|
54 |
+
validate_inputs(answer_fn, knowledge_base, testset)
|
55 |
+
testset = testset or generate_testset(knowledge_base)
|
56 |
+
answers = retrieve_answers(answer_fn, testset)
|
57 |
+
llm_client = llm_client or get_default_client()
|
58 |
+
metrics = get_metrics(metrics, llm_client, agent_description)
|
59 |
+
metrics_results = compute_metrics(metrics, testset, answers)
|
60 |
+
report = get_report(testset, answers, metrics_results, knowledge_base)
|
61 |
+
add_recommendation(report, llm_client, metrics)
|
62 |
+
track_analytics(report, testset, knowledge_base, agent_description, metrics)
|
63 |
+
|
64 |
+
return report
|
65 |
+
|
66 |
+
def validate_inputs(answer_fn, knowledge_base, testset):
|
67 |
+
if testset is None:
|
68 |
+
if knowledge_base is None:
|
69 |
+
raise ValueError("At least one of testset or knowledge base must be provided to the evaluate function.")
|
70 |
+
if not isinstance(answer_fn, Sequence):
|
71 |
+
raise ValueError(
|
72 |
+
"If the testset is not provided, the answer_fn must be a list of answers to ensure the matching between questions and answers."
|
73 |
+
)
|
74 |
+
|
75 |
+
testset = generate_testset(knowledge_base)
|
76 |
+
|
77 |
+
# Check basic types, in case the user passed the params in the wrong order
|
78 |
+
if knowledge_base is not None and not isinstance(knowledge_base, KnowledgeBase):
|
79 |
+
raise ValueError(
|
80 |
+
f"knowledge_base must be a KnowledgeBase object (got {type(knowledge_base)} instead). Are you sure you passed the parameters in the right order?"
|
81 |
+
)
|
82 |
+
|
83 |
+
if testset is not None and not isinstance(testset, QATestset):
|
84 |
+
raise ValueError(
|
85 |
+
f"testset must be a QATestset object (got {type(testset)} instead). Are you sure you passed the parameters in the right order?"
|
86 |
+
)
|
87 |
+
|
88 |
+
def retrieve_answers(answer_fn, testset):
|
89 |
+
return answer_fn if isinstance(answer_fn, Sequence) else _compute_answers(answer_fn, testset)
|
90 |
+
|
91 |
+
def get_metrics(metrics, llm_client, agent_description):
|
92 |
+
metrics = list(metrics) if metrics is not None else []
|
93 |
+
if not any(isinstance(metric, CorrectnessMetric) for metric in metrics):
|
94 |
+
# By default only correctness is computed as it is required to build the report
|
95 |
+
metrics.insert(
|
96 |
+
0, CorrectnessMetric(name="correctness", llm_client=llm_client, agent_description=agent_description)
|
97 |
+
)
|
98 |
+
return metrics
|
99 |
+
|
100 |
+
def compute_metrics(metrics, testset, answers):
|
101 |
+
metrics_results = defaultdict(dict)
|
102 |
+
|
103 |
+
for metric in metrics:
|
104 |
+
metric_name = getattr(
|
105 |
+
metric, "name", metric.__class__.__name__ if isinstance(metric, Metric) else metric.__name__
|
106 |
+
)
|
107 |
+
|
108 |
+
for sample, answer in maybe_tqdm(
|
109 |
+
zip(testset.to_pandas().to_records(index=True), answers),
|
110 |
+
desc=f"{metric_name} evaluation",
|
111 |
+
total=len(answers),
|
112 |
+
):
|
113 |
+
metrics_results[sample["id"]].update(metric(sample, answer))
|
114 |
+
return metrics_results
|
115 |
+
|
116 |
+
def get_report(testset, answers, metrics_results, knowledge_base):
|
117 |
+
return RAGReport(testset, answers, metrics_results, knowledge_base)
|
118 |
+
|
119 |
+
def add_recommendation(report, llm_client, metrics):
|
120 |
+
recommendation = get_rag_recommendation(
|
121 |
+
report.topics,
|
122 |
+
report.correctness_by_question_type().to_dict()[metrics[0].name],
|
123 |
+
report.correctness_by_topic().to_dict()[metrics[0].name],
|
124 |
+
llm_client,
|
125 |
+
)
|
126 |
+
report._recommendation = recommendation
|
127 |
+
|
128 |
+
def track_analytics(report, testset, knowledge_base, agent_description, metrics):
|
129 |
+
analytics.track(
|
130 |
+
"raget:evaluation",
|
131 |
+
{
|
132 |
+
"testset_size": len(testset),
|
133 |
+
"knowledge_base_size": len(knowledge_base) if knowledge_base else -1,
|
134 |
+
"agent_description": agent_description,
|
135 |
+
"num_metrics": len(metrics),
|
136 |
+
"correctness": report.correctness,
|
137 |
+
},
|
138 |
+
)
|
139 |
+
|
140 |
+
def _compute_answers(answer_fn, testset):
|
141 |
+
answers = []
|
142 |
+
needs_history = (
|
143 |
+
len(signature(answer_fn).parameters) > 1 and ANSWER_FN_HISTORY_PARAM in signature(answer_fn).parameters
|
144 |
+
)
|
145 |
+
|
146 |
+
for sample in maybe_tqdm(testset.samples, desc="Asking questions to the agent", total=len(testset)):
|
147 |
+
kwargs = {}
|
148 |
+
|
149 |
+
if needs_history:
|
150 |
+
kwargs[ANSWER_FN_HISTORY_PARAM] = sample.conversation_history
|
151 |
+
|
152 |
+
answers.append(answer_fn(sample.question, **kwargs))
|
153 |
+
return answers
|