wu981526092 Natedem30 commited on
Commit
38f10df
·
verified ·
1 Parent(s): 30db905

Modulated code and cleaned up main function (#2)

Browse files

- Modulated code and cleaned up main function (3c364f94bfd0b663012b9f3431886a4606174aeb)


Co-authored-by: Nate Demchak <[email protected]>

Files changed (1) hide show
  1. update_evaluate.py +153 -0
update_evaluate.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Sequence, Union
2
+
3
+ import logging
4
+ from collections import defaultdict
5
+ from inspect import signature
6
+
7
+ from ..llm.client import LLMClient, get_default_client
8
+ from ..utils.analytics_collector import analytics
9
+ from .knowledge_base import KnowledgeBase
10
+ from .metrics import CorrectnessMetric, Metric
11
+ from .question_generators.utils import maybe_tqdm
12
+ from .recommendation import get_rag_recommendation
13
+ from .report import RAGReport
14
+ from .testset import QATestset
15
+ from .testset_generation import generate_testset
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ ANSWER_FN_HISTORY_PARAM = "history"
20
+
21
+
22
+ def evaluate(
23
+ answer_fn: Union[Callable, Sequence[str]],
24
+ testset: Optional[QATestset] = None,
25
+ knowledge_base: Optional[KnowledgeBase] = None,
26
+ llm_client: Optional[LLMClient] = None,
27
+ agent_description: str = "This agent is a chatbot that answers question from users.",
28
+ metrics: Optional[Sequence[Callable]] = None,
29
+ ) -> RAGReport:
30
+ """Evaluate an agent by comparing its answers on a QATestset.
31
+
32
+ Parameters
33
+ ----------
34
+ answers_fn : Union[Callable, Sequence[str]]
35
+ The prediction function of the agent to evaluate or a list of precalculated answers on the testset.
36
+ testset : QATestset, optional
37
+ The test set to evaluate the agent on. If not provided, a knowledge base must be provided and a default testset will be created from the knowledge base.
38
+ Note that if the answers_fn is a list of answers, the testset is required.
39
+ knowledge_base : KnowledgeBase, optional
40
+ The knowledge base of the agent to evaluate. If not provided, a testset must be provided.
41
+ llm_client : LLMClient, optional
42
+ The LLM client to use for the evaluation. If not provided, a default openai client will be used.
43
+ agent_description : str, optional
44
+ Description of the agent to be tested.
45
+ metrics : Optional[Sequence[Callable]], optional
46
+ Metrics to compute on the test set.
47
+
48
+ Returns
49
+ -------
50
+ RAGReport
51
+ The report of the evaluation.
52
+ """
53
+
54
+ validate_inputs(answer_fn, knowledge_base, testset)
55
+ testset = testset or generate_testset(knowledge_base)
56
+ answers = retrieve_answers(answer_fn, testset)
57
+ llm_client = llm_client or get_default_client()
58
+ metrics = get_metrics(metrics, llm_client, agent_description)
59
+ metrics_results = compute_metrics(metrics, testset, answers)
60
+ report = get_report(testset, answers, metrics_results, knowledge_base)
61
+ add_recommendation(report, llm_client, metrics)
62
+ track_analytics(report, testset, knowledge_base, agent_description, metrics)
63
+
64
+ return report
65
+
66
+ def validate_inputs(answer_fn, knowledge_base, testset):
67
+ if testset is None:
68
+ if knowledge_base is None:
69
+ raise ValueError("At least one of testset or knowledge base must be provided to the evaluate function.")
70
+ if not isinstance(answer_fn, Sequence):
71
+ raise ValueError(
72
+ "If the testset is not provided, the answer_fn must be a list of answers to ensure the matching between questions and answers."
73
+ )
74
+
75
+ testset = generate_testset(knowledge_base)
76
+
77
+ # Check basic types, in case the user passed the params in the wrong order
78
+ if knowledge_base is not None and not isinstance(knowledge_base, KnowledgeBase):
79
+ raise ValueError(
80
+ f"knowledge_base must be a KnowledgeBase object (got {type(knowledge_base)} instead). Are you sure you passed the parameters in the right order?"
81
+ )
82
+
83
+ if testset is not None and not isinstance(testset, QATestset):
84
+ raise ValueError(
85
+ f"testset must be a QATestset object (got {type(testset)} instead). Are you sure you passed the parameters in the right order?"
86
+ )
87
+
88
+ def retrieve_answers(answer_fn, testset):
89
+ return answer_fn if isinstance(answer_fn, Sequence) else _compute_answers(answer_fn, testset)
90
+
91
+ def get_metrics(metrics, llm_client, agent_description):
92
+ metrics = list(metrics) if metrics is not None else []
93
+ if not any(isinstance(metric, CorrectnessMetric) for metric in metrics):
94
+ # By default only correctness is computed as it is required to build the report
95
+ metrics.insert(
96
+ 0, CorrectnessMetric(name="correctness", llm_client=llm_client, agent_description=agent_description)
97
+ )
98
+ return metrics
99
+
100
+ def compute_metrics(metrics, testset, answers):
101
+ metrics_results = defaultdict(dict)
102
+
103
+ for metric in metrics:
104
+ metric_name = getattr(
105
+ metric, "name", metric.__class__.__name__ if isinstance(metric, Metric) else metric.__name__
106
+ )
107
+
108
+ for sample, answer in maybe_tqdm(
109
+ zip(testset.to_pandas().to_records(index=True), answers),
110
+ desc=f"{metric_name} evaluation",
111
+ total=len(answers),
112
+ ):
113
+ metrics_results[sample["id"]].update(metric(sample, answer))
114
+ return metrics_results
115
+
116
+ def get_report(testset, answers, metrics_results, knowledge_base):
117
+ return RAGReport(testset, answers, metrics_results, knowledge_base)
118
+
119
+ def add_recommendation(report, llm_client, metrics):
120
+ recommendation = get_rag_recommendation(
121
+ report.topics,
122
+ report.correctness_by_question_type().to_dict()[metrics[0].name],
123
+ report.correctness_by_topic().to_dict()[metrics[0].name],
124
+ llm_client,
125
+ )
126
+ report._recommendation = recommendation
127
+
128
+ def track_analytics(report, testset, knowledge_base, agent_description, metrics):
129
+ analytics.track(
130
+ "raget:evaluation",
131
+ {
132
+ "testset_size": len(testset),
133
+ "knowledge_base_size": len(knowledge_base) if knowledge_base else -1,
134
+ "agent_description": agent_description,
135
+ "num_metrics": len(metrics),
136
+ "correctness": report.correctness,
137
+ },
138
+ )
139
+
140
+ def _compute_answers(answer_fn, testset):
141
+ answers = []
142
+ needs_history = (
143
+ len(signature(answer_fn).parameters) > 1 and ANSWER_FN_HISTORY_PARAM in signature(answer_fn).parameters
144
+ )
145
+
146
+ for sample in maybe_tqdm(testset.samples, desc="Asking questions to the agent", total=len(testset)):
147
+ kwargs = {}
148
+
149
+ if needs_history:
150
+ kwargs[ANSWER_FN_HISTORY_PARAM] = sample.conversation_history
151
+
152
+ answers.append(answer_fn(sample.question, **kwargs))
153
+ return answers