gsarti commited on
Commit
cf3d1b1
Β·
1 Parent(s): d6505da
Files changed (5) hide show
  1. app.py +309 -29
  2. contents.py +53 -0
  3. requirements.txt +3 -1
  4. style.py +19 -0
  5. utils.py +110 -0
app.py CHANGED
@@ -1,39 +1,319 @@
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
3
  from inseq.commands.attribute_context.attribute_context import (
4
  AttributeContextArgs,
5
  attribute_context,
6
- visualize_attribute_context,
7
  )
8
 
9
 
10
- def run_pecore(input_current_text, input_context_text):
11
- lm_rag_prompting_example = AttributeContextArgs(
12
- model_name_or_path="gsarti/cora_mgen",
13
- input_context_text=input_context_text,
14
- input_current_text=f"query: {input_current_text}",
15
- output_template="{current}",
16
- input_template="{current} passage: {context} answer:",
17
- attributed_fn="contrast_prob_diff",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  show_viz=False,
19
- context_sensitivity_std_threshold=0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  )
21
- out = attribute_context(lm_rag_prompting_example)
22
- html = visualize_attribute_context(out, return_html=True)
23
- return html
24
-
25
-
26
- demo = gr.Interface(
27
- fn=run_pecore,
28
- inputs=["text", "text"],
29
- outputs="html",
30
- title="πŸ‘ Plausibility Evaluation of Context Reliance (PECoRe) πŸ‘",
31
- description="""Given a query and a context passed as inputs to a LM, PECoRe will identify which tokens in the generated response were dependant on context, and match them with context tokens contributing to their prediction. For more information, check out our <a href="https://openreview.net/forum?id=XTHfNGI3zT" target='_blank'>ICLR 2024 paper</a>.""",
32
- examples=[
33
- [
34
- "When was Banff National Park established?",
35
- "Banff National Park is Canada's oldest national park, established in 1885 as Rocky Mountains Park. Located in Alberta's Rocky Mountains, 110–180 kilometres (68–112 mi) west of Calgary, Banff encompasses 6,641 square kilometres (2,564 sq mi) of mountainous terrain.",
36
- ]
37
- ],
38
- )
39
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
  import gradio as gr
5
+ import spaces
6
+ from contents import (
7
+ citation,
8
+ description,
9
+ examples,
10
+ how_it_works,
11
+ how_to_use,
12
+ subtitle,
13
+ title,
14
+ )
15
+ from gradio_highlightedtextbox import HighlightedTextbox
16
+ from style import custom_css
17
+ from utils import get_tuples_from_output
18
 
19
+ from inseq import list_feature_attribution_methods, list_step_functions
20
  from inseq.commands.attribute_context.attribute_context import (
21
  AttributeContextArgs,
22
  attribute_context,
 
23
  )
24
 
25
 
26
+ @spaces.GPU()
27
+ def pecore(
28
+ input_current_text: str,
29
+ input_context_text: str,
30
+ output_current_text: str,
31
+ output_context_text: str,
32
+ model_name_or_path: str,
33
+ attribution_method: str,
34
+ attributed_fn: str | None,
35
+ context_sensitivity_metric: str,
36
+ context_sensitivity_std_threshold: float,
37
+ context_sensitivity_topk: int,
38
+ attribution_std_threshold: float,
39
+ attribution_topk: int,
40
+ input_template: str,
41
+ input_current_text_template: str,
42
+ output_template: str,
43
+ special_tokens_to_keep: str | list[str] | None,
44
+ model_kwargs: str,
45
+ tokenizer_kwargs: str,
46
+ generation_kwargs: str,
47
+ attribution_kwargs: str,
48
+ ):
49
+ formatted_input_current_text = input_current_text_template.format(
50
+ current=input_current_text
51
+ )
52
+ pecore_args = AttributeContextArgs(
53
+ show_intermediate_outputs=False,
54
+ save_path=os.path.join(os.path.dirname(__file__), "outputs/output.json"),
55
+ add_output_info=True,
56
+ viz_path=os.path.join(os.path.dirname(__file__), "outputs/output.html"),
57
  show_viz=False,
58
+ model_name_or_path=model_name_or_path,
59
+ attribution_method=attribution_method,
60
+ attributed_fn=attributed_fn,
61
+ attribution_selectors=None,
62
+ attribution_aggregators=None,
63
+ normalize_attributions=True,
64
+ model_kwargs=json.loads(model_kwargs),
65
+ tokenizer_kwargs=json.loads(tokenizer_kwargs),
66
+ generation_kwargs=json.loads(generation_kwargs),
67
+ attribution_kwargs=json.loads(attribution_kwargs),
68
+ context_sensitivity_metric=context_sensitivity_metric,
69
+ align_output_context_auto=False,
70
+ prompt_user_for_contextless_output_next_tokens=False,
71
+ special_tokens_to_keep=special_tokens_to_keep,
72
+ context_sensitivity_std_threshold=context_sensitivity_std_threshold,
73
+ context_sensitivity_topk=context_sensitivity_topk
74
+ if context_sensitivity_topk > 0
75
+ else None,
76
+ attribution_std_threshold=attribution_std_threshold,
77
+ attribution_topk=attribution_topk if attribution_topk > 0 else None,
78
+ input_current_text=formatted_input_current_text,
79
+ input_context_text=input_context_text if input_context_text else None,
80
+ input_template=input_template,
81
+ output_current_text=output_current_text if output_current_text else None,
82
+ output_context_text=output_context_text if output_context_text else None,
83
+ output_template=output_template,
84
  )
85
+ out = attribute_context(pecore_args)
86
+ return get_tuples_from_output(out), gr.Button(visible=True), gr.Button(visible=True)
87
+
88
+
89
+ with gr.Blocks(css=custom_css) as demo:
90
+ gr.Markdown(title)
91
+ gr.Markdown(subtitle)
92
+ gr.Markdown(description)
93
+ with gr.Tab("πŸ‘ Attributing Context"):
94
+ with gr.Row():
95
+ with gr.Column():
96
+ input_current_text = gr.Textbox(
97
+ label="Input query", placeholder="Your input query..."
98
+ )
99
+ input_context_text = gr.Textbox(
100
+ label="Input context", lines=4, placeholder="Your input context..."
101
+ )
102
+ attribute_input_button = gr.Button("Submit", variant="primary")
103
+ with gr.Column():
104
+ pecore_output_highlights = HighlightedTextbox(
105
+ value=[
106
+ ("This output will contain ", None),
107
+ ("context sensitive", "Context sensitive"),
108
+ (" generated tokens and ", None),
109
+ ("influential context", "Influential context"),
110
+ (" tokens.", None),
111
+ ],
112
+ color_map={
113
+ "Context sensitive": "green",
114
+ "Influential context": "blue",
115
+ },
116
+ show_legend=True,
117
+ label="PECoRe Output",
118
+ combine_adjacent=True,
119
+ interactive=False,
120
+ )
121
+ with gr.Row(equal_height=True):
122
+ download_output_file_button = gr.Button(
123
+ "⇓ Download output",
124
+ visible=False,
125
+ link=os.path.join(
126
+ os.path.dirname(__file__), "/file=outputs/output.json"
127
+ ),
128
+ )
129
+ download_output_html_button = gr.Button(
130
+ "πŸ” Download HTML",
131
+ visible=False,
132
+ link=os.path.join(
133
+ os.path.dirname(__file__), "/file=outputs/output.html"
134
+ ),
135
+ )
136
+
137
+ attribute_input_examples = gr.Examples(
138
+ examples,
139
+ inputs=[input_current_text, input_context_text],
140
+ outputs=pecore_output_highlights,
141
+ )
142
+ with gr.Tab("βš™οΈ Parameters"):
143
+ gr.Markdown("## βš™οΈ PECoRe Parameters")
144
+ with gr.Row(equal_height=True):
145
+ model_name_or_path = gr.Textbox(
146
+ value="gsarti/cora_mgen",
147
+ label="Model",
148
+ info="Hugging Face Hub identifier of the model to analyze with PECoRe.",
149
+ interactive=True,
150
+ )
151
+ context_sensitivity_metric = gr.Dropdown(
152
+ value="kl_divergence",
153
+ label="Context sensitivity metric",
154
+ info="Metric to use to measure context sensitivity of generated tokens.",
155
+ choices=list_step_functions(),
156
+ interactive=True,
157
+ )
158
+ attribution_method = gr.Dropdown(
159
+ value="saliency",
160
+ label="Attribution method",
161
+ info="Attribution method identifier to identify relevant context tokens.",
162
+ choices=list_feature_attribution_methods(),
163
+ interactive=True,
164
+ )
165
+ attributed_fn = gr.Dropdown(
166
+ value="contrast_prob_diff",
167
+ label="Attributed function",
168
+ info="Function of model logits to use as target for the attribution method.",
169
+ choices=list_step_functions(),
170
+ interactive=True,
171
+ )
172
+ gr.Markdown("#### Results Selection Parameters")
173
+ with gr.Row(equal_height=True):
174
+ context_sensitivity_std_threshold = gr.Number(
175
+ value=1.0,
176
+ label="Context sensitivity threshold",
177
+ info="Select N to keep context sensitive tokens with scores above N * std. 0 = above mean.",
178
+ precision=1,
179
+ minimum=0.0,
180
+ maximum=5.0,
181
+ step=0.5,
182
+ interactive=True,
183
+ )
184
+ context_sensitivity_topk = gr.Number(
185
+ value=0,
186
+ label="Context sensitivity top-k",
187
+ info="Select N to keep top N context sensitive tokens. 0 = keep all.",
188
+ interactive=True,
189
+ precision=0,
190
+ minimum=0,
191
+ maximum=10,
192
+ )
193
+ attribution_std_threshold = gr.Number(
194
+ value=1.0,
195
+ label="Attribution threshold",
196
+ info="Select N to keep attributed tokens with scores above N * std. 0 = above mean.",
197
+ precision=1,
198
+ minimum=0.0,
199
+ maximum=5.0,
200
+ step=0.5,
201
+ interactive=True,
202
+ )
203
+ attribution_topk = gr.Number(
204
+ value=0,
205
+ label="Attribution top-k",
206
+ info="Select N to keep top N attributed tokens in the context. 0 = keep all.",
207
+ interactive=True,
208
+ precision=0,
209
+ minimum=0,
210
+ maximum=50,
211
+ )
212
+
213
+ gr.Markdown("#### Text Format Parameters")
214
+ with gr.Row(equal_height=True):
215
+ input_template = gr.Textbox(
216
+ value="{current} <P>:{context}",
217
+ label="Input template",
218
+ info="Template to format the input for the model. Use {current} and {context} placeholders.",
219
+ interactive=True,
220
+ )
221
+ output_template = gr.Textbox(
222
+ value="{current}",
223
+ label="Output template",
224
+ info="Template to format the output from the model. Use {current} and {context} placeholders.",
225
+ interactive=True,
226
+ )
227
+ input_current_text_template = gr.Textbox(
228
+ value="<Q>:{current}",
229
+ label="Input current text template",
230
+ info="Template to format the input query for the model. Use {current} placeholder.",
231
+ interactive=True,
232
+ )
233
+ special_tokens_to_keep = gr.Dropdown(
234
+ label="Special tokens to keep",
235
+ info="Special tokens to keep in the attribution. If empty, all special tokens are ignored.",
236
+ value=None,
237
+ multiselect=True,
238
+ allow_custom_value=True,
239
+ )
240
+
241
+ gr.Markdown("## βš™οΈ Generation Parameters")
242
+ with gr.Row(equal_height=True):
243
+ output_current_text = gr.Textbox(
244
+ label="Generation output",
245
+ info="Specifies an output to force-decoded during generation. If blank, the model will generate freely.",
246
+ interactive=True,
247
+ )
248
+ output_context_text = gr.Textbox(
249
+ label="Generation context",
250
+ info="If specified, this context is used as starting point for generation. Useful for e.g. chain-of-thought reasoning.",
251
+ interactive=True,
252
+ )
253
+ generation_kwargs = gr.Code(
254
+ value="{}",
255
+ language="json",
256
+ label="Generation kwargs",
257
+ interactive=True,
258
+ lines=1,
259
+ )
260
+ gr.Markdown("## βš™οΈ Other Parameters")
261
+ with gr.Row(equal_height=True):
262
+ model_kwargs = gr.Code(
263
+ value="{}",
264
+ language="json",
265
+ label="Model kwargs",
266
+ interactive=True,
267
+ lines=1,
268
+ )
269
+ tokenizer_kwargs = gr.Code(
270
+ value="{}",
271
+ language="json",
272
+ label="Tokenizer kwargs",
273
+ interactive=True,
274
+ lines=1,
275
+ )
276
+ attribution_kwargs = gr.Code(
277
+ value="{}",
278
+ language="json",
279
+ label="Attribution kwargs",
280
+ interactive=True,
281
+ lines=1,
282
+ )
283
+
284
+ gr.Markdown(how_it_works)
285
+ gr.Markdown(how_to_use)
286
+ gr.Markdown(citation)
287
+
288
+ attribute_input_button.click(
289
+ pecore,
290
+ inputs=[
291
+ input_current_text,
292
+ input_context_text,
293
+ output_current_text,
294
+ output_context_text,
295
+ model_name_or_path,
296
+ attribution_method,
297
+ attributed_fn,
298
+ context_sensitivity_metric,
299
+ context_sensitivity_std_threshold,
300
+ context_sensitivity_topk,
301
+ attribution_std_threshold,
302
+ attribution_topk,
303
+ input_template,
304
+ input_current_text_template,
305
+ output_template,
306
+ special_tokens_to_keep,
307
+ model_kwargs,
308
+ tokenizer_kwargs,
309
+ generation_kwargs,
310
+ attribution_kwargs,
311
+ ],
312
+ outputs=[
313
+ pecore_output_highlights,
314
+ download_output_file_button,
315
+ download_output_html_button,
316
+ ],
317
+ )
318
+
319
+ demo.launch(allowed_paths=["outputs/"])
contents.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ title = "<h1 class='demo-title'>πŸ‘ Plausibility Evaluation of Context Reliance (PECoRe) πŸ‘</h1>"
2
+
3
+ subtitle = "<h2 class='demo-subtitle'>An Interpretability Framework to Detect and Attribute Context Reliance in Language Models</h2>"
4
+
5
+ description = """
6
+ Given a query and a context passed as inputs to a LM, PECoRe will identify which tokens in the generated
7
+ response were dependant on context, and match them with context tokens contributing to their prediction.
8
+ For more information, check out our <a href="https://openreview.net/forum?id=XTHfNGI3zT" target='_blank'>ICLR 2024 paper</a>.
9
+ """
10
+
11
+ how_it_works = r"""
12
+ <details>
13
+ <summary><h3 class="summary-label">βš™οΈ How Does It Work?</h3></summary>
14
+ <br/>
15
+ PECoRe uses a contrastive approach to attribute context reliance in language models.
16
+ It compares the model's predictions when the context is present and when it is absent, and attributes the difference in predictions to the context tokens.
17
+ </details>
18
+ """
19
+
20
+ how_to_use = r"""
21
+ <details>
22
+ <summary><h3 class="summary-label">πŸ”§ How to Use PECoRe</h3></summary>
23
+
24
+ </details>
25
+ """
26
+
27
+ citation = r"""
28
+ <details>
29
+ <summary><h3 class="summary-label">πŸ“š Citing PECoRe</h3></summary>
30
+
31
+ @inproceedings{sarti-etal-2023-quantifying,
32
+ title = "Quantifying the Plausibility of Context Reliance in Neural Machine Translation",
33
+ author = "Sarti, Gabriele and
34
+ Chrupa{\l}a, Grzegorz and
35
+ Nissim, Malvina and
36
+ Bisazza, Arianna",
37
+ booktitle = "The Twelfth International Conference on Learning Representations (ICLR 2024)",
38
+ month = may,
39
+ year = "2024",
40
+ address = "Vienna, Austria",
41
+ publisher = "OpenReview",
42
+ url = "https://openreview.net/forum?id=XTHfNGI3zT"
43
+ }
44
+
45
+ </details>
46
+ """
47
+
48
+ examples = [
49
+ [
50
+ "When was Banff National Park established?",
51
+ "Banff National Park is Canada's oldest national park, established in 1885 as Rocky Mountains Park. Located in Alberta's Rocky Mountains, 110–180 kilometres (68–112 mi) west of Calgary, Banff encompasses 6,641 square kilometres (2,564 sq mi) of mountainous terrain.",
52
+ ]
53
+ ]
requirements.txt CHANGED
@@ -1 +1,3 @@
1
- git+https://github.com/inseq-team/inseq.git@main
 
 
 
1
+ spaces
2
+ git+https://github.com/inseq-team/inseq.git@main
3
+ gradio_highlightedtextbox
style.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ custom_css = """
2
+ .demo-title {
3
+ text-align: center;
4
+ display: block;
5
+ margin-bottom: 0;
6
+ font-size: 2em;
7
+ }
8
+
9
+ .demo-subtitle {
10
+ text-align: center;
11
+ display: block;
12
+ margin-top: 0;
13
+ font-size: 1.5em;
14
+ }
15
+
16
+ .summary-label {
17
+ display: inline;
18
+ }
19
+ """
utils.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import Optional
3
+
4
+ from inseq import load_model
5
+ from inseq.commands.attribute_context.attribute_context_args import AttributeContextArgs
6
+ from inseq.commands.attribute_context.attribute_context_helpers import (
7
+ AttributeContextOutput,
8
+ filter_rank_tokens,
9
+ get_filtered_tokens,
10
+ )
11
+ from inseq.models import HuggingfaceModel
12
+
13
+
14
+ def get_formatted_attribute_context_results(
15
+ model: HuggingfaceModel,
16
+ args: AttributeContextArgs,
17
+ output: AttributeContextOutput,
18
+ ) -> str:
19
+ """Format the results of the context attribution process."""
20
+
21
+ def format_context_comment(
22
+ model: HuggingfaceModel,
23
+ has_other_context: bool,
24
+ special_tokens_to_keep: list[str],
25
+ context: str,
26
+ context_scores: list[float],
27
+ other_context_scores: Optional[list[float]] = None,
28
+ is_target: bool = False,
29
+ ) -> str:
30
+ context_tokens = get_filtered_tokens(
31
+ context,
32
+ model,
33
+ special_tokens_to_keep,
34
+ replace_special_characters=True,
35
+ is_target=is_target,
36
+ )
37
+ context_token_tuples = [(t, None) for t in context_tokens]
38
+ scores = context_scores
39
+ if has_other_context:
40
+ scores += other_context_scores
41
+ context_ranked_tokens, _ = filter_rank_tokens(
42
+ tokens=context_tokens,
43
+ scores=scores,
44
+ std_threshold=args.attribution_std_threshold,
45
+ topk=args.attribution_topk,
46
+ )
47
+ for idx, _, tok in context_ranked_tokens:
48
+ context_token_tuples[idx] = (tok, "Influential context")
49
+ return context_token_tuples
50
+
51
+ out = []
52
+ output_current_tokens = get_filtered_tokens(
53
+ output.output_current,
54
+ model,
55
+ args.special_tokens_to_keep,
56
+ replace_special_characters=True,
57
+ is_target=True,
58
+ )
59
+ for example_idx, cci_out in enumerate(output.cci_scores, start=1):
60
+ curr_output_tokens = [(t, None) for t in output_current_tokens]
61
+ cti_idx = cci_out.cti_idx
62
+ curr_output_tokens[cti_idx] = (
63
+ curr_output_tokens[cti_idx][0],
64
+ "Context sensitive",
65
+ )
66
+ if args.has_input_context:
67
+ input_context_tokens = format_context_comment(
68
+ model,
69
+ args.has_output_context,
70
+ args.special_tokens_to_keep,
71
+ output.input_context,
72
+ cci_out.input_context_scores,
73
+ cci_out.output_context_scores,
74
+ )
75
+ if args.has_output_context:
76
+ output_context_tokens = format_context_comment(
77
+ model,
78
+ args.has_input_context,
79
+ args.special_tokens_to_keep,
80
+ output.output_context,
81
+ cci_out.output_context_scores,
82
+ cci_out.input_context_scores,
83
+ is_target=True,
84
+ context_type="Output",
85
+ )
86
+ out += [
87
+ ("\n\n" if example_idx > 1 else "", None),
88
+ (
89
+ f"#{example_idx}.\nGenerated output:\t",
90
+ None,
91
+ ),
92
+ ]
93
+ out += curr_output_tokens
94
+ if args.has_input_context:
95
+ out += [("\nInput context:\t", None)]
96
+ out += input_context_tokens
97
+ if args.has_output_context:
98
+ out += [("\\Output context:\t", None)]
99
+ out += output_context_tokens
100
+ return out
101
+
102
+
103
+ def get_tuples_from_output(output: AttributeContextOutput):
104
+ model = load_model(
105
+ output.info.model_name_or_path,
106
+ output.info.attribution_method,
107
+ model_kwargs=deepcopy(output.info.model_kwargs),
108
+ tokenizer_kwargs=deepcopy(output.info.tokenizer_kwargs),
109
+ )
110
+ return get_formatted_attribute_context_results(model, output.info, output)