Spaces:
Running
Running
Gabriela Nicole Gonzalez Saez
commited on
Commit
·
056bbdc
1
Parent(s):
fc37a00
Add files
Browse files- app.py +114 -0
- bertviz_gradio.py +248 -0
- plotsjs_bertviz.js +430 -0
- requirements.txt +2 -0
app.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import inseq
|
3 |
+
import captum
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import os
|
7 |
+
# import nltk
|
8 |
+
import argparse
|
9 |
+
import random
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from argparse import Namespace
|
13 |
+
from tqdm.notebook import tqdm
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
from functools import partial
|
16 |
+
|
17 |
+
from transformers import AutoTokenizer, MarianTokenizer, AutoModel, AutoModelForSeq2SeqLM, MarianMTModel
|
18 |
+
|
19 |
+
from bertviz import model_view, head_view
|
20 |
+
from bertviz_gradio import head_view_mod
|
21 |
+
|
22 |
+
|
23 |
+
def get_bertvis_data(input_text, lg_model):
|
24 |
+
tokenizer_tr = dict_tokenizer_tr[lg_model]
|
25 |
+
model_tr = dict_models_tr[lg_model]
|
26 |
+
|
27 |
+
input_ids = tokenizer_tr(input_text, return_tensors="pt", padding=True)
|
28 |
+
result_att = model_tr.generate(**input_ids,
|
29 |
+
return_dict_in_generate=True,
|
30 |
+
output_attentions =True,
|
31 |
+
output_scores=True,
|
32 |
+
)
|
33 |
+
|
34 |
+
# tokenizer_tr.convert_ids_to_tokens(result_att.sequences[0])
|
35 |
+
# tokenizer_tr.convert_ids_to_tokens(input_ids.input_ids[0])
|
36 |
+
|
37 |
+
tgt_text = tokenizer_tr.decode(result_att.sequences[0], skip_special_tokens=True)
|
38 |
+
|
39 |
+
print(tgt_text)
|
40 |
+
outputs = model_tr(input_ids=input_ids.input_ids,
|
41 |
+
decoder_input_ids=result_att.sequences,
|
42 |
+
output_attentions =True,
|
43 |
+
)
|
44 |
+
print(tokenizer_tr.convert_ids_to_tokens(result_att.sequences[0]))
|
45 |
+
# print(tokenizer_tr.convert_ids_to_tokens(input_ids.input_ids[0]), tokenizer_tr.convert_ids_to_tokens(result_att.sequences[0]))
|
46 |
+
html_attentions = head_view_mod(
|
47 |
+
encoder_attention = outputs.encoder_attentions,
|
48 |
+
cross_attention = outputs.cross_attentions,
|
49 |
+
decoder_attention = outputs.decoder_attentions,
|
50 |
+
encoder_tokens = tokenizer_tr.convert_ids_to_tokens(input_ids.input_ids[0]),
|
51 |
+
decoder_tokens = tokenizer_tr.convert_ids_to_tokens(result_att.sequences[0]),
|
52 |
+
html_action='gradio'
|
53 |
+
)
|
54 |
+
return html_attentions, tgt_text
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
## First create html and divs
|
59 |
+
html = """
|
60 |
+
<html>
|
61 |
+
<script async src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>
|
62 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/2.0.0/jquery.min"></script>
|
63 |
+
<script async data-require="[email protected]" data-semver="3.5.3" src="//cdnjs.cloudflare.com/ajax/libs/d3/3.5.3/d3.js"></script>
|
64 |
+
|
65 |
+
<body>
|
66 |
+
<div id="bertviz"></div>
|
67 |
+
<div id="d3_beam_search"></div>
|
68 |
+
</body>
|
69 |
+
</html>
|
70 |
+
"""
|
71 |
+
|
72 |
+
def sentence_maker(w1, model, var2={}):
|
73 |
+
#translate and get internal values
|
74 |
+
params,tgt = get_bertvis_data(w1, model)
|
75 |
+
### get translation
|
76 |
+
|
77 |
+
return [tgt, params['params'],params['html2'].data]
|
78 |
+
|
79 |
+
def sentence_maker2(w1,j2):
|
80 |
+
# json_value = {'one':1}
|
81 |
+
# return f"{w1['two']} in sentence22..."
|
82 |
+
print(w1,j2)
|
83 |
+
return "in sentence22..."
|
84 |
+
|
85 |
+
|
86 |
+
with gr.Blocks(js="plotsjs_bertviz.js") as demo:
|
87 |
+
gr.Markdown("""
|
88 |
+
# MAKE NMT Workshop \t `BertViz` \n
|
89 |
+
https://github.com/jessevig/bertviz
|
90 |
+
""")
|
91 |
+
with gr.Row():
|
92 |
+
with gr.Column(scale=1):
|
93 |
+
in_text = gr.Textbox(label="Source Text")
|
94 |
+
out_text = gr.Textbox(label="Target Text")
|
95 |
+
out_text2 = gr.Textbox(visible=False)
|
96 |
+
var2 = gr.JSON(visible=False)
|
97 |
+
btn = gr.Button("Create sentence.")
|
98 |
+
radio_c = gr.Radio(choices=['en-zh', 'en-es', 'en-fr'], value="en-zh", label= '', container=False)
|
99 |
+
|
100 |
+
|
101 |
+
with gr.Column(scale=4):
|
102 |
+
gr.Markdown("Attentions: ")
|
103 |
+
input_mic = gr.HTML(html)
|
104 |
+
out_html = gr.HTML()
|
105 |
+
btn.click(sentence_maker, [in_text,radio_c], [out_text,var2,out_html], js="(in_text,radio_c) => testFn_out(in_text,radio_c)") #should return an output comp.
|
106 |
+
out_text.change(sentence_maker2, [out_text, var2], out_text2, js="(out_text,var2) => testFn_out_json(var2)") #
|
107 |
+
# out_text.change(sentence_maker2, [out_text, var2], out_text2, js="(out_text,var2) => testFn_out_json(var2)") #
|
108 |
+
|
109 |
+
|
110 |
+
# run script function on load,
|
111 |
+
# demo.load(None,None,None,js="plotsjs.js")
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
demo.launch()
|
bertviz_gradio.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import uuid
|
5 |
+
|
6 |
+
from IPython.core.display import display, HTML, Javascript
|
7 |
+
|
8 |
+
from bertviz.util import format_special_chars, format_attention, num_layers
|
9 |
+
|
10 |
+
|
11 |
+
def head_view_mod(
|
12 |
+
attention=None,
|
13 |
+
tokens=None,
|
14 |
+
sentence_b_start=None,
|
15 |
+
prettify_tokens=True,
|
16 |
+
layer=None,
|
17 |
+
heads=None,
|
18 |
+
encoder_attention=None,
|
19 |
+
decoder_attention=None,
|
20 |
+
cross_attention=None,
|
21 |
+
encoder_tokens=None,
|
22 |
+
decoder_tokens=None,
|
23 |
+
include_layers=None,
|
24 |
+
html_action='view'
|
25 |
+
):
|
26 |
+
"""Render head view
|
27 |
+
|
28 |
+
Args:
|
29 |
+
For self-attention models:
|
30 |
+
attention: list of ``torch.FloatTensor``(one for each layer) of shape
|
31 |
+
``(batch_size(must be 1), num_heads, sequence_length, sequence_length)``
|
32 |
+
tokens: list of tokens
|
33 |
+
sentence_b_start: index of first wordpiece in sentence B if input text is sentence pair (optional)
|
34 |
+
For encoder-decoder models:
|
35 |
+
encoder_attention: list of ``torch.FloatTensor``(one for each layer) of shape
|
36 |
+
``(batch_size(must be 1), num_heads, encoder_sequence_length, encoder_sequence_length)``
|
37 |
+
decoder_attention: list of ``torch.FloatTensor``(one for each layer) of shape
|
38 |
+
``(batch_size(must be 1), num_heads, decoder_sequence_length, decoder_sequence_length)``
|
39 |
+
cross_attention: list of ``torch.FloatTensor``(one for each layer) of shape
|
40 |
+
``(batch_size(must be 1), num_heads, decoder_sequence_length, encoder_sequence_length)``
|
41 |
+
encoder_tokens: list of tokens for encoder input
|
42 |
+
decoder_tokens: list of tokens for decoder input
|
43 |
+
For all models:
|
44 |
+
prettify_tokens: indicates whether to remove special characters in wordpieces, e.g. Ġ
|
45 |
+
layer: index (zero-based) of initial selected layer in visualization. Defaults to layer 0.
|
46 |
+
heads: Indices (zero-based) of initial selected heads in visualization. Defaults to all heads.
|
47 |
+
include_layers: Indices (zero-based) of layers to include in visualization. Defaults to all layers.
|
48 |
+
Note: filtering layers may improve responsiveness of the visualization for long inputs.
|
49 |
+
html_action: Specifies the action to be performed with the generated HTML object
|
50 |
+
- 'view' (default): Displays the generated HTML representation as a notebook cell output
|
51 |
+
- 'return' : Returns an HTML object containing the generated view for further processing or custom visualization
|
52 |
+
"""
|
53 |
+
|
54 |
+
attn_data = []
|
55 |
+
if attention is not None:
|
56 |
+
if tokens is None:
|
57 |
+
raise ValueError("'tokens' is required")
|
58 |
+
if encoder_attention is not None or decoder_attention is not None or cross_attention is not None \
|
59 |
+
or encoder_tokens is not None or decoder_tokens is not None:
|
60 |
+
raise ValueError("If you specify 'attention' you may not specify any encoder-decoder arguments. This"
|
61 |
+
" argument is only for self-attention models.")
|
62 |
+
if include_layers is None:
|
63 |
+
include_layers = list(range(num_layers(attention)))
|
64 |
+
attention = format_attention(attention, include_layers)
|
65 |
+
if sentence_b_start is None:
|
66 |
+
attn_data.append(
|
67 |
+
{
|
68 |
+
'name': None,
|
69 |
+
'attn': attention.tolist(),
|
70 |
+
'left_text': tokens,
|
71 |
+
'right_text': tokens
|
72 |
+
}
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
slice_a = slice(0, sentence_b_start) # Positions corresponding to sentence A in input
|
76 |
+
slice_b = slice(sentence_b_start, len(tokens)) # Position corresponding to sentence B in input
|
77 |
+
attn_data.append(
|
78 |
+
{
|
79 |
+
'name': 'All',
|
80 |
+
'attn': attention.tolist(),
|
81 |
+
'left_text': tokens,
|
82 |
+
'right_text': tokens
|
83 |
+
}
|
84 |
+
)
|
85 |
+
attn_data.append(
|
86 |
+
{
|
87 |
+
'name': 'Sentence A -> Sentence A',
|
88 |
+
'attn': attention[:, :, slice_a, slice_a].tolist(),
|
89 |
+
'left_text': tokens[slice_a],
|
90 |
+
'right_text': tokens[slice_a]
|
91 |
+
}
|
92 |
+
)
|
93 |
+
attn_data.append(
|
94 |
+
{
|
95 |
+
'name': 'Sentence B -> Sentence B',
|
96 |
+
'attn': attention[:, :, slice_b, slice_b].tolist(),
|
97 |
+
'left_text': tokens[slice_b],
|
98 |
+
'right_text': tokens[slice_b]
|
99 |
+
}
|
100 |
+
)
|
101 |
+
attn_data.append(
|
102 |
+
{
|
103 |
+
'name': 'Sentence A -> Sentence B',
|
104 |
+
'attn': attention[:, :, slice_a, slice_b].tolist(),
|
105 |
+
'left_text': tokens[slice_a],
|
106 |
+
'right_text': tokens[slice_b]
|
107 |
+
}
|
108 |
+
)
|
109 |
+
attn_data.append(
|
110 |
+
{
|
111 |
+
'name': 'Sentence B -> Sentence A',
|
112 |
+
'attn': attention[:, :, slice_b, slice_a].tolist(),
|
113 |
+
'left_text': tokens[slice_b],
|
114 |
+
'right_text': tokens[slice_a]
|
115 |
+
}
|
116 |
+
)
|
117 |
+
elif encoder_attention is not None or decoder_attention is not None or cross_attention is not None:
|
118 |
+
if encoder_attention is not None:
|
119 |
+
if encoder_tokens is None:
|
120 |
+
raise ValueError("'encoder_tokens' required if 'encoder_attention' is not None")
|
121 |
+
if include_layers is None:
|
122 |
+
include_layers = list(range(num_layers(encoder_attention)))
|
123 |
+
encoder_attention = format_attention(encoder_attention, include_layers)
|
124 |
+
attn_data.append(
|
125 |
+
{
|
126 |
+
'name': 'Encoder',
|
127 |
+
'attn': encoder_attention.tolist(),
|
128 |
+
'left_text': encoder_tokens,
|
129 |
+
'right_text': encoder_tokens
|
130 |
+
}
|
131 |
+
)
|
132 |
+
if decoder_attention is not None:
|
133 |
+
if decoder_tokens is None:
|
134 |
+
raise ValueError("'decoder_tokens' required if 'decoder_attention' is not None")
|
135 |
+
if include_layers is None:
|
136 |
+
include_layers = list(range(num_layers(decoder_attention)))
|
137 |
+
decoder_attention = format_attention(decoder_attention, include_layers)
|
138 |
+
attn_data.append(
|
139 |
+
{
|
140 |
+
'name': 'Decoder',
|
141 |
+
'attn': decoder_attention.tolist(),
|
142 |
+
'left_text': decoder_tokens,
|
143 |
+
'right_text': decoder_tokens
|
144 |
+
}
|
145 |
+
)
|
146 |
+
if cross_attention is not None:
|
147 |
+
if encoder_tokens is None:
|
148 |
+
raise ValueError("'encoder_tokens' required if 'cross_attention' is not None")
|
149 |
+
if decoder_tokens is None:
|
150 |
+
raise ValueError("'decoder_tokens' required if 'cross_attention' is not None")
|
151 |
+
if include_layers is None:
|
152 |
+
include_layers = list(range(num_layers(cross_attention)))
|
153 |
+
cross_attention = format_attention(cross_attention, include_layers)
|
154 |
+
attn_data.append(
|
155 |
+
{
|
156 |
+
'name': 'Cross',
|
157 |
+
'attn': cross_attention.tolist(),
|
158 |
+
'left_text': decoder_tokens,
|
159 |
+
'right_text': encoder_tokens
|
160 |
+
}
|
161 |
+
)
|
162 |
+
else:
|
163 |
+
raise ValueError("You must specify at least one attention argument.")
|
164 |
+
|
165 |
+
if layer is not None and layer not in include_layers:
|
166 |
+
raise ValueError(f"Layer {layer} is not in include_layers: {include_layers}")
|
167 |
+
|
168 |
+
# Generate unique div id to enable multiple visualizations in one notebook
|
169 |
+
# vis_id = 'bertviz-%s'%(uuid.uuid4().hex)
|
170 |
+
vis_id = 'bertviz'#-%s'%(uuid.uuid4().hex)
|
171 |
+
|
172 |
+
# Compose html
|
173 |
+
if len(attn_data) > 1:
|
174 |
+
options = '\n'.join(
|
175 |
+
f'<option value="{i}">{attn_data[i]["name"]}</option>'
|
176 |
+
for i, d in enumerate(attn_data)
|
177 |
+
)
|
178 |
+
select_html = f'Attention: <select id="filter">{options}</select>'
|
179 |
+
else:
|
180 |
+
select_html = ""
|
181 |
+
vis_html = f"""
|
182 |
+
<div id="{vis_id}" style="font-family:'Helvetica Neue', Helvetica, Arial, sans-serif;">
|
183 |
+
<span style="user-select:none">
|
184 |
+
Layer: <select id="layer"></select>
|
185 |
+
{select_html}
|
186 |
+
</span>
|
187 |
+
<div id='vis'></div>
|
188 |
+
</div>
|
189 |
+
"""
|
190 |
+
|
191 |
+
for d in attn_data:
|
192 |
+
attn_seq_len_left = len(d['attn'][0][0])
|
193 |
+
if attn_seq_len_left != len(d['left_text']):
|
194 |
+
raise ValueError(
|
195 |
+
f"Attention has {attn_seq_len_left} positions, while number of tokens is {len(d['left_text'])} "
|
196 |
+
f"for tokens: {' '.join(d['left_text'])}"
|
197 |
+
)
|
198 |
+
attn_seq_len_right = len(d['attn'][0][0][0])
|
199 |
+
if attn_seq_len_right != len(d['right_text']):
|
200 |
+
raise ValueError(
|
201 |
+
f"Attention has {attn_seq_len_right} positions, while number of tokens is {len(d['right_text'])} "
|
202 |
+
f"for tokens: {' '.join(d['right_text'])}"
|
203 |
+
)
|
204 |
+
if prettify_tokens:
|
205 |
+
d['left_text'] = format_special_chars(d['left_text'])
|
206 |
+
d['right_text'] = format_special_chars(d['right_text'])
|
207 |
+
params = {
|
208 |
+
'attention': attn_data,
|
209 |
+
'default_filter': "0",
|
210 |
+
'root_div_id': vis_id,
|
211 |
+
'layer': layer,
|
212 |
+
'heads': heads,
|
213 |
+
'include_layers': include_layers
|
214 |
+
}
|
215 |
+
|
216 |
+
# require.js must be imported for Colab or JupyterLab:
|
217 |
+
|
218 |
+
if html_action == 'gradio':
|
219 |
+
html1 = HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>')
|
220 |
+
html2 = HTML(vis_html)
|
221 |
+
|
222 |
+
return {'html1': html1, 'html2' : html2, 'params': params }
|
223 |
+
|
224 |
+
|
225 |
+
if html_action == 'view':
|
226 |
+
display(HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>'))
|
227 |
+
display(HTML(vis_html))
|
228 |
+
__location__ = os.path.realpath(
|
229 |
+
os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
230 |
+
vis_js = open(os.path.join(__location__, 'head_view.js')).read().replace("PYTHON_PARAMS", json.dumps(params))
|
231 |
+
display(Javascript(vis_js))
|
232 |
+
|
233 |
+
elif html_action == 'return':
|
234 |
+
html1 = HTML('<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js"></script>')
|
235 |
+
|
236 |
+
html2 = HTML(vis_html)
|
237 |
+
|
238 |
+
__location__ = os.path.realpath(
|
239 |
+
os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
240 |
+
vis_js = open(os.path.join(__location__, 'head_view.js')).read().replace("PYTHON_PARAMS", json.dumps(params))
|
241 |
+
html3 = Javascript(vis_js)
|
242 |
+
script = '\n<script type="text/javascript">\n' + html3.data + '\n</script>\n'
|
243 |
+
|
244 |
+
head_html = HTML(html1.data + html2.data + script)
|
245 |
+
return head_html
|
246 |
+
|
247 |
+
else:
|
248 |
+
raise ValueError("'html_action' parameter must be 'view' or 'return")
|
plotsjs_bertviz.js
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
async () => {
|
5 |
+
// set testFn() function on globalThis, so you html onlclick can access it
|
6 |
+
|
7 |
+
|
8 |
+
globalThis.testFn = () => {
|
9 |
+
document.getElementById('demo').innerHTML = "Hello-bertviz?"
|
10 |
+
};
|
11 |
+
|
12 |
+
// await import * as mod from "/my-module.js";
|
13 |
+
|
14 |
+
const d3 = await import("https://cdn.jsdelivr.net/npm/d3@5/+esm");
|
15 |
+
const $ = await import("https://cdn.jsdelivr.net/npm/[email protected]/dist/jquery.min.js");
|
16 |
+
|
17 |
+
globalThis.$ = $;
|
18 |
+
|
19 |
+
// const $ = await import("https://cdn.jsdelivr.net/npm/jquery@2/+esm");
|
20 |
+
// import $ from "jquery";
|
21 |
+
// import * as d3 from "https://cdn.jsdelivr.net/npm/d3@7/+esm";
|
22 |
+
// await import("https://cdn.jsdelivr.net/npm/jquery@2/+esm");
|
23 |
+
|
24 |
+
// export for others scripts to use
|
25 |
+
// window.$ = window.jQuery = jQuery;
|
26 |
+
|
27 |
+
// const d3 = await import("https://cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min");
|
28 |
+
// const $ = await import("https://cdnjs.cloudflare.com/ajax/libs/jquery/2.0.0/jquery.min");
|
29 |
+
|
30 |
+
globalThis.d3Fn = () => {
|
31 |
+
d3.select('#viz').append('svg')
|
32 |
+
.append('rect')
|
33 |
+
.attr('width', 50)
|
34 |
+
.attr('height', 50)
|
35 |
+
.attr('fill', 'black')
|
36 |
+
.on('mouseover', function(){d3.select(this).attr('fill', 'red')})
|
37 |
+
.on('mouseout', function(){d3.select(this).attr('fill', 'black')});
|
38 |
+
|
39 |
+
};
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
//
|
44 |
+
|
45 |
+
globalThis.testFn_out = (val,model) => {
|
46 |
+
// document.getElementById('demo').innerHTML = val
|
47 |
+
console.log(val);
|
48 |
+
// globalThis.d3Fn();
|
49 |
+
return([val,model]);
|
50 |
+
};
|
51 |
+
|
52 |
+
globalThis.testFn_out_json = (data) => {
|
53 |
+
console.log(data);
|
54 |
+
var $ = jQuery;
|
55 |
+
console.log($('#viz'));
|
56 |
+
|
57 |
+
attViz(data);
|
58 |
+
|
59 |
+
return(['string', {}])
|
60 |
+
|
61 |
+
};
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
function attViz(PYTHON_PARAMS) {
|
66 |
+
var $ = jQuery;
|
67 |
+
const params = PYTHON_PARAMS; // HACK: PYTHON_PARAMS is a template marker that is replaced by actual params.
|
68 |
+
const TEXT_SIZE = 15;
|
69 |
+
const BOXWIDTH = 110;
|
70 |
+
const BOXHEIGHT = 22.5;
|
71 |
+
const MATRIX_WIDTH = 115;
|
72 |
+
const CHECKBOX_SIZE = 20;
|
73 |
+
const TEXT_TOP = 30;
|
74 |
+
|
75 |
+
console.log("d3 version in ffuntions", d3.version)
|
76 |
+
let headColors;
|
77 |
+
try {
|
78 |
+
headColors = d3.scaleOrdinal(d3.schemeCategory10);
|
79 |
+
} catch (err) {
|
80 |
+
console.log('Older d3 version')
|
81 |
+
headColors = d3.scale.category10();
|
82 |
+
}
|
83 |
+
let config = {};
|
84 |
+
// globalThis.
|
85 |
+
initialize();
|
86 |
+
renderVis();
|
87 |
+
|
88 |
+
function initialize() {
|
89 |
+
// globalThis.initialize = () => {
|
90 |
+
|
91 |
+
console.log("init")
|
92 |
+
config.attention = params['attention'];
|
93 |
+
config.filter = params['default_filter'];
|
94 |
+
config.rootDivId = params['root_div_id'];
|
95 |
+
config.nLayers = config.attention[config.filter]['attn'].length;
|
96 |
+
config.nHeads = config.attention[config.filter]['attn'][0].length;
|
97 |
+
config.layers = params['include_layers']
|
98 |
+
|
99 |
+
if (params['heads']) {
|
100 |
+
config.headVis = new Array(config.nHeads).fill(false);
|
101 |
+
params['heads'].forEach(x => config.headVis[x] = true);
|
102 |
+
} else {
|
103 |
+
config.headVis = new Array(config.nHeads).fill(true);
|
104 |
+
}
|
105 |
+
config.initialTextLength = config.attention[config.filter].right_text.length;
|
106 |
+
config.layer_seq = (params['layer'] == null ? 0 : config.layers.findIndex(layer => params['layer'] === layer));
|
107 |
+
config.layer = config.layers[config.layer_seq]
|
108 |
+
|
109 |
+
// '#' + temp1.root_div_id+ ' #layer'
|
110 |
+
$('#' + config.rootDivId+ ' #layer').empty();
|
111 |
+
|
112 |
+
let layerEl = $('#' + config.rootDivId+ ' #layer');
|
113 |
+
console.log(layerEl)
|
114 |
+
for (const layer of config.layers) {
|
115 |
+
layerEl.append($("<option />").val(layer).text(layer));
|
116 |
+
}
|
117 |
+
layerEl.val(config.layer).change();
|
118 |
+
layerEl.on('change', function (e) {
|
119 |
+
config.layer = +e.currentTarget.value;
|
120 |
+
config.layer_seq = config.layers.findIndex(layer => config.layer === layer);
|
121 |
+
renderVis();
|
122 |
+
});
|
123 |
+
|
124 |
+
$('#'+config.rootDivId+' #filter').on('change', function (e) {
|
125 |
+
// $(`#${config.rootDivId} #filter`).on('change', function (e) {
|
126 |
+
|
127 |
+
config.filter = e.currentTarget.value;
|
128 |
+
renderVis();
|
129 |
+
});
|
130 |
+
}
|
131 |
+
|
132 |
+
function renderVis() {
|
133 |
+
|
134 |
+
// Load parameters
|
135 |
+
const attnData = config.attention[config.filter];
|
136 |
+
const leftText = attnData.left_text;
|
137 |
+
const rightText = attnData.right_text;
|
138 |
+
|
139 |
+
// Select attention for given layer
|
140 |
+
const layerAttention = attnData.attn[config.layer_seq];
|
141 |
+
|
142 |
+
// Clear vis
|
143 |
+
$('#'+config.rootDivId+' #vis').empty();
|
144 |
+
|
145 |
+
// Determine size of visualization
|
146 |
+
const height = Math.max(leftText.length, rightText.length) * BOXHEIGHT + TEXT_TOP;
|
147 |
+
const svg = d3.select('#'+ config.rootDivId +' #vis')
|
148 |
+
.append('svg')
|
149 |
+
.attr("width", "100%")
|
150 |
+
.attr("height", height + "px");
|
151 |
+
|
152 |
+
// Display tokens on left and right side of visualization
|
153 |
+
renderText(svg, leftText, true, layerAttention, 0);
|
154 |
+
renderText(svg, rightText, false, layerAttention, MATRIX_WIDTH + BOXWIDTH);
|
155 |
+
|
156 |
+
// Render attention arcs
|
157 |
+
renderAttention(svg, layerAttention);
|
158 |
+
|
159 |
+
// Draw squares at top of visualization, one for each head
|
160 |
+
drawCheckboxes(0, svg, layerAttention);
|
161 |
+
}
|
162 |
+
|
163 |
+
function renderText(svg, text, isLeft, attention, leftPos) {
|
164 |
+
|
165 |
+
const textContainer = svg.append("svg:g")
|
166 |
+
.attr("id", isLeft ? "left" : "right");
|
167 |
+
|
168 |
+
// Add attention highlights superimposed over words
|
169 |
+
textContainer.append("g")
|
170 |
+
.classed("attentionBoxes", true)
|
171 |
+
.selectAll("g")
|
172 |
+
.data(attention)
|
173 |
+
.enter()
|
174 |
+
.append("g")
|
175 |
+
.attr("head-index", (d, i) => i)
|
176 |
+
.selectAll("rect")
|
177 |
+
.data(d => isLeft ? d : transpose(d)) // if right text, transpose attention to get right-to-left weights
|
178 |
+
.enter()
|
179 |
+
.append("rect")
|
180 |
+
.attr("x", function () {
|
181 |
+
var headIndex = +this.parentNode.getAttribute("head-index");
|
182 |
+
return leftPos + boxOffsets(headIndex);
|
183 |
+
})
|
184 |
+
.attr("y", (+1) * BOXHEIGHT)
|
185 |
+
.attr("width", BOXWIDTH / activeHeads())
|
186 |
+
.attr("height", BOXHEIGHT)
|
187 |
+
.attr("fill", function () {
|
188 |
+
return headColors(+this.parentNode.getAttribute("head-index"))
|
189 |
+
})
|
190 |
+
.style("opacity", 0.0);
|
191 |
+
|
192 |
+
const tokenContainer = textContainer.append("g").selectAll("g")
|
193 |
+
.data(text)
|
194 |
+
.enter()
|
195 |
+
.append("g");
|
196 |
+
|
197 |
+
// Add gray background that appears when hovering over text
|
198 |
+
tokenContainer.append("rect")
|
199 |
+
.classed("background", true)
|
200 |
+
.style("opacity", 0.0)
|
201 |
+
.attr("fill", "lightgray")
|
202 |
+
.attr("x", leftPos)
|
203 |
+
.attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT)
|
204 |
+
.attr("width", BOXWIDTH)
|
205 |
+
.attr("height", BOXHEIGHT);
|
206 |
+
|
207 |
+
// Add token text
|
208 |
+
const textEl = tokenContainer.append("text")
|
209 |
+
.text(d => d)
|
210 |
+
.attr("font-size", TEXT_SIZE + "px")
|
211 |
+
.style("cursor", "default")
|
212 |
+
.style("-webkit-user-select", "none")
|
213 |
+
.attr("x", leftPos)
|
214 |
+
.attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT);
|
215 |
+
|
216 |
+
if (isLeft) {
|
217 |
+
textEl.style("text-anchor", "end")
|
218 |
+
.attr("dx", BOXWIDTH - 0.5 * TEXT_SIZE)
|
219 |
+
.attr("dy", TEXT_SIZE);
|
220 |
+
} else {
|
221 |
+
textEl.style("text-anchor", "start")
|
222 |
+
.attr("dx", +0.5 * TEXT_SIZE)
|
223 |
+
.attr("dy", TEXT_SIZE);
|
224 |
+
}
|
225 |
+
|
226 |
+
tokenContainer.on("mouseover", function (d, index) {
|
227 |
+
|
228 |
+
// Show gray background for moused-over token
|
229 |
+
textContainer.selectAll(".background")
|
230 |
+
.style("opacity", (d, i) => i === index ? 1.0 : 0.0)
|
231 |
+
|
232 |
+
// Reset visibility attribute for any previously highlighted attention arcs
|
233 |
+
svg.select("#attention")
|
234 |
+
.selectAll("line[visibility='visible']")
|
235 |
+
.attr("visibility", null)
|
236 |
+
|
237 |
+
// Hide group containing attention arcs
|
238 |
+
svg.select("#attention").attr("visibility", "hidden");
|
239 |
+
|
240 |
+
// Set to visible appropriate attention arcs to be highlighted
|
241 |
+
if (isLeft) {
|
242 |
+
svg.select("#attention").selectAll("line[left-token-index='" + index + "']").attr("visibility", "visible");
|
243 |
+
} else {
|
244 |
+
svg.select("#attention").selectAll("line[right-token-index='" + index + "']").attr("visibility", "visible");
|
245 |
+
}
|
246 |
+
|
247 |
+
// Update color boxes superimposed over tokens
|
248 |
+
const id = isLeft ? "right" : "left";
|
249 |
+
const leftPos = isLeft ? MATRIX_WIDTH + BOXWIDTH : 0;
|
250 |
+
svg.select("#" + id)
|
251 |
+
.selectAll(".attentionBoxes")
|
252 |
+
.selectAll("g")
|
253 |
+
.attr("head-index", (d, i) => i)
|
254 |
+
.selectAll("rect")
|
255 |
+
.attr("x", function () {
|
256 |
+
const headIndex = +this.parentNode.getAttribute("head-index");
|
257 |
+
return leftPos + boxOffsets(headIndex);
|
258 |
+
})
|
259 |
+
.attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT)
|
260 |
+
.attr("width", BOXWIDTH / activeHeads())
|
261 |
+
.attr("height", BOXHEIGHT)
|
262 |
+
.style("opacity", function (d) {
|
263 |
+
const headIndex = +this.parentNode.getAttribute("head-index");
|
264 |
+
if (config.headVis[headIndex])
|
265 |
+
if (d) {
|
266 |
+
return d[index];
|
267 |
+
} else {
|
268 |
+
return 0.0;
|
269 |
+
}
|
270 |
+
else
|
271 |
+
return 0.0;
|
272 |
+
});
|
273 |
+
});
|
274 |
+
|
275 |
+
textContainer.on("mouseleave", function () {
|
276 |
+
|
277 |
+
// Unhighlight selected token
|
278 |
+
d3.select(this).selectAll(".background")
|
279 |
+
.style("opacity", 0.0);
|
280 |
+
|
281 |
+
// Reset visibility attributes for previously selected lines
|
282 |
+
svg.select("#attention")
|
283 |
+
.selectAll("line[visibility='visible']")
|
284 |
+
.attr("visibility", null) ;
|
285 |
+
svg.select("#attention").attr("visibility", "visible");
|
286 |
+
|
287 |
+
// Reset highlights superimposed over tokens
|
288 |
+
svg.selectAll(".attentionBoxes")
|
289 |
+
.selectAll("g")
|
290 |
+
.selectAll("rect")
|
291 |
+
.style("opacity", 0.0);
|
292 |
+
});
|
293 |
+
}
|
294 |
+
|
295 |
+
function renderAttention(svg, attention) {
|
296 |
+
|
297 |
+
// Remove previous dom elements
|
298 |
+
svg.select("#attention").remove();
|
299 |
+
|
300 |
+
// Add new elements
|
301 |
+
svg.append("g")
|
302 |
+
.attr("id", "attention") // Container for all attention arcs
|
303 |
+
.selectAll(".headAttention")
|
304 |
+
.data(attention)
|
305 |
+
.enter()
|
306 |
+
.append("g")
|
307 |
+
.classed("headAttention", true) // Group attention arcs by head
|
308 |
+
.attr("head-index", (d, i) => i)
|
309 |
+
.selectAll(".tokenAttention")
|
310 |
+
.data(d => d)
|
311 |
+
.enter()
|
312 |
+
.append("g")
|
313 |
+
.classed("tokenAttention", true) // Group attention arcs by left token
|
314 |
+
.attr("left-token-index", (d, i) => i)
|
315 |
+
.selectAll("line")
|
316 |
+
.data(d => d)
|
317 |
+
.enter()
|
318 |
+
.append("line")
|
319 |
+
.attr("x1", BOXWIDTH)
|
320 |
+
.attr("y1", function () {
|
321 |
+
const leftTokenIndex = +this.parentNode.getAttribute("left-token-index")
|
322 |
+
return TEXT_TOP + leftTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2)
|
323 |
+
})
|
324 |
+
.attr("x2", BOXWIDTH + MATRIX_WIDTH)
|
325 |
+
.attr("y2", (d, rightTokenIndex) => TEXT_TOP + rightTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2))
|
326 |
+
.attr("stroke-width", 2)
|
327 |
+
.attr("stroke", function () {
|
328 |
+
const headIndex = +this.parentNode.parentNode.getAttribute("head-index");
|
329 |
+
return headColors(headIndex)
|
330 |
+
})
|
331 |
+
.attr("left-token-index", function () {
|
332 |
+
return +this.parentNode.getAttribute("left-token-index")
|
333 |
+
})
|
334 |
+
.attr("right-token-index", (d, i) => i)
|
335 |
+
;
|
336 |
+
updateAttention(svg)
|
337 |
+
}
|
338 |
+
|
339 |
+
function updateAttention(svg) {
|
340 |
+
svg.select("#attention")
|
341 |
+
.selectAll("line")
|
342 |
+
.attr("stroke-opacity", function (d) {
|
343 |
+
const headIndex = +this.parentNode.parentNode.getAttribute("head-index");
|
344 |
+
// If head is selected
|
345 |
+
if (config.headVis[headIndex]) {
|
346 |
+
// Set opacity to attention weight divided by number of active heads
|
347 |
+
return d / activeHeads()
|
348 |
+
} else {
|
349 |
+
return 0.0;
|
350 |
+
}
|
351 |
+
})
|
352 |
+
}
|
353 |
+
|
354 |
+
function boxOffsets(i) {
|
355 |
+
const numHeadsAbove = config.headVis.reduce(
|
356 |
+
function (acc, val, cur) {
|
357 |
+
return val && cur < i ? acc + 1 : acc;
|
358 |
+
}, 0);
|
359 |
+
return numHeadsAbove * (BOXWIDTH / activeHeads());
|
360 |
+
}
|
361 |
+
|
362 |
+
function activeHeads() {
|
363 |
+
return config.headVis.reduce(function (acc, val) {
|
364 |
+
return val ? acc + 1 : acc;
|
365 |
+
}, 0);
|
366 |
+
}
|
367 |
+
|
368 |
+
function drawCheckboxes(top, svg) {
|
369 |
+
const checkboxContainer = svg.append("g");
|
370 |
+
const checkbox = checkboxContainer.selectAll("rect")
|
371 |
+
.data(config.headVis)
|
372 |
+
.enter()
|
373 |
+
.append("rect")
|
374 |
+
.attr("fill", (d, i) => headColors(i))
|
375 |
+
.attr("x", (d, i) => i * CHECKBOX_SIZE)
|
376 |
+
.attr("y", top)
|
377 |
+
.attr("width", CHECKBOX_SIZE)
|
378 |
+
.attr("height", CHECKBOX_SIZE);
|
379 |
+
|
380 |
+
function updateCheckboxes() {
|
381 |
+
checkboxContainer.selectAll("rect")
|
382 |
+
.data(config.headVis)
|
383 |
+
.attr("fill", (d, i) => d ? headColors(i): lighten(headColors(i)));
|
384 |
+
}
|
385 |
+
|
386 |
+
updateCheckboxes();
|
387 |
+
|
388 |
+
checkbox.on("click", function (d, i) {
|
389 |
+
if (config.headVis[i] && activeHeads() === 1) return;
|
390 |
+
config.headVis[i] = !config.headVis[i];
|
391 |
+
updateCheckboxes();
|
392 |
+
updateAttention(svg);
|
393 |
+
});
|
394 |
+
|
395 |
+
checkbox.on("dblclick", function (d, i) {
|
396 |
+
// If we double click on the only active head then reset
|
397 |
+
if (config.headVis[i] && activeHeads() === 1) {
|
398 |
+
config.headVis = new Array(config.nHeads).fill(true);
|
399 |
+
} else {
|
400 |
+
config.headVis = new Array(config.nHeads).fill(false);
|
401 |
+
config.headVis[i] = true;
|
402 |
+
}
|
403 |
+
updateCheckboxes();
|
404 |
+
updateAttention(svg);
|
405 |
+
});
|
406 |
+
}
|
407 |
+
|
408 |
+
function lighten(color) {
|
409 |
+
const c = d3.hsl(color);
|
410 |
+
const increment = (1 - c.l) * 0.6;
|
411 |
+
c.l += increment;
|
412 |
+
c.s -= increment;
|
413 |
+
return c;
|
414 |
+
}
|
415 |
+
|
416 |
+
function transpose(mat) {
|
417 |
+
return mat[0].map(function (col, i) {
|
418 |
+
return mat.map(function (row) {
|
419 |
+
return row[i];
|
420 |
+
});
|
421 |
+
});
|
422 |
+
}
|
423 |
+
|
424 |
+
}
|
425 |
+
// );
|
426 |
+
|
427 |
+
|
428 |
+
|
429 |
+
}
|
430 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
inseq
|
2 |
+
bertviz
|