add continuous
Browse files- app.py +281 -71
- easyeditor/__pycache__/__init__.cpython-39.pyc +0 -0
- easyeditor/models/__init__.py +2 -0
- easyeditor/models/__pycache__/__init__.cpython-39.pyc +0 -0
- easyeditor/models/grace/GRACE.py +80 -59
- easyeditor/models/grace/__pycache__/GRACE.cpython-39.pyc +0 -0
- easyeditor/models/grace/__pycache__/__init__.cpython-39.pyc +0 -0
- easyeditor/models/grace/__pycache__/grace_hparams.cpython-39.pyc +0 -0
- easyeditor/models/grace/__pycache__/grace_main.cpython-39.pyc +0 -0
- easyeditor/models/grace/__pycache__/metrics.cpython-39.pyc +0 -0
- easyeditor/models/grace/__pycache__/utils.cpython-39.pyc +0 -0
- easyeditor/models/grace/grace_main.py +3 -4
- easyeditor/models/rome/README.md +12 -0
- easyeditor/models/rome/__init__.py +1 -0
- easyeditor/models/rome/compute_u.py +125 -0
- easyeditor/models/rome/compute_v.py +278 -0
- easyeditor/models/rome/layer_stats.py +198 -0
- easyeditor/models/rome/repr_tools.py +174 -0
- easyeditor/models/rome/rome_hparams.py +55 -0
- easyeditor/models/rome/rome_main.py +192 -0
- easyeditor/models/rome/tok_dataset.py +99 -0
- easyeditor/models/wise/.DS_Store +0 -0
- easyeditor/models/wise/WISE.py +466 -0
- easyeditor/models/wise/__init__.py +2 -0
- easyeditor/models/wise/merge/__init__.py +3 -0
- easyeditor/models/wise/merge/gta.py +113 -0
- easyeditor/models/wise/merge/linear.py +24 -0
- easyeditor/models/wise/merge/slerp.py +90 -0
- easyeditor/models/wise/merge/utils.py +45 -0
- easyeditor/models/wise/utils.py +213 -0
- easyeditor/models/wise/wise_hparams.py +56 -0
- easyeditor/models/wise/wise_main.py +38 -0
- easyeditor/util/__pycache__/__init__.cpython-39.pyc +0 -0
- easyeditor/util/__pycache__/hparams.cpython-39.pyc +0 -0
- easyeditor/util/__pycache__/logit_lens.cpython-39.pyc +0 -0
- easyeditor/util/__pycache__/nethook.cpython-39.pyc +0 -0
- hparams/GRACE/gpt2.yaml +1 -1
- hparams/ROME/gpt2.yaml +26 -0
- hparams/WISE/gpt2.yaml +27 -0
- utils.py +214 -23
app.py
CHANGED
@@ -1,13 +1,20 @@
|
|
1 |
import gradio as gr
|
2 |
from utils import *
|
3 |
from transformers import pipeline
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
8 |
|
9 |
ori_model = None
|
10 |
edit_model = None
|
|
|
|
|
|
|
11 |
# input=None
|
12 |
|
13 |
def slowly_reverse(word, progress=gr.Progress()):
|
@@ -20,91 +27,199 @@ def slowly_reverse(word, progress=gr.Progress()):
|
|
20 |
new_string = letter + new_string
|
21 |
return new_string
|
22 |
|
23 |
-
|
24 |
-
with gr.Row(
|
25 |
-
gr.
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
)
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
# with gr.Row():
|
43 |
-
# gr.
|
44 |
-
|
45 |
with gr.Row():
|
46 |
-
gr.
|
47 |
-
with gr.Accordion("Expiation", open=False):
|
48 |
-
gr.Markdown(
|
49 |
"""
|
50 |
-
|
51 |
"""
|
52 |
)
|
53 |
-
|
|
|
54 |
"""
|
55 |
-
|
56 |
"""
|
57 |
)
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
"""
|
65 |
-
|
66 |
"""
|
67 |
-
)
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
"""
|
70 |
-
Locality
|
71 |
"""
|
72 |
-
)
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
with gr.Row():
|
75 |
prompt = gr.Textbox(label="Edit Prompt")
|
76 |
target_new = gr.Textbox(label="Edit Target New")
|
77 |
with gr.Row():
|
|
|
78 |
num_steps = gr.Slider(10, 100, value=40, step=1, label='Edit Steps')
|
79 |
-
|
80 |
-
choices=[
|
81 |
-
value=
|
82 |
-
label="
|
83 |
)
|
84 |
-
with gr.Row():
|
85 |
-
button4clear = gr.Button("Clear")
|
86 |
-
button4edit = gr.Button("Edit",variant="primary")
|
87 |
with gr.Row():
|
88 |
examples = gr.Examples(
|
89 |
examples=[
|
90 |
-
["
|
91 |
-
["
|
92 |
-
["
|
|
|
93 |
],
|
94 |
-
examples_per_page=
|
95 |
inputs=[prompt,target_new],
|
96 |
)
|
|
|
|
|
97 |
# with gr.Row():
|
98 |
# input_text = gr.Textbox(label="Status Information",value="Model editing may take about a minute, please be patient.")
|
99 |
with gr.Row():
|
100 |
gr.HTML(
|
101 |
"""
|
102 |
-
<h3>
|
103 |
"""
|
104 |
)
|
105 |
with gr.Row():
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
with gr.Row():
|
|
|
108 |
with gr.Column():
|
109 |
button4gen_ori=gr.HighlightedText(
|
110 |
label="original output",
|
@@ -119,25 +234,65 @@ with gr.Blocks(css=css,theme=gr.themes.Soft(text_size="sm")) as demo:
|
|
119 |
show_legend=False,
|
120 |
color_map={"output": "yellow"},
|
121 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
with gr.Row():
|
123 |
button4gen = gr.Button("Generate",variant="primary")
|
124 |
|
125 |
with gr.Row():
|
126 |
gr.HTML(
|
127 |
"""
|
128 |
-
<
|
129 |
"""
|
130 |
)
|
131 |
with gr.Row():
|
132 |
loc_input = gr.Dropdown(
|
133 |
choices=[
|
134 |
-
"
|
135 |
-
"
|
136 |
-
"
|
137 |
-
"
|
138 |
-
"
|
139 |
],
|
140 |
-
value="
|
141 |
label="Unrelated Input Text",
|
142 |
)
|
143 |
with gr.Row():
|
@@ -158,20 +313,76 @@ with gr.Blocks(css=css,theme=gr.themes.Soft(text_size="sm")) as demo:
|
|
158 |
with gr.Row():
|
159 |
button4locgen = gr.Button("Generate",variant="primary")
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
button4gen.click(fn=generate, inputs=[
|
164 |
-
button4locgen.click(fn=
|
165 |
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
with gr.Accordion("Citation", open=False):
|
168 |
gr.Markdown(
|
169 |
"""
|
170 |
```bibtex
|
171 |
-
@misc{
|
172 |
title={EasyEdit: An Easy-to-use Knowledge Editing Framework for Large Language Models},
|
173 |
-
author={Peng Wang and Ningyu Zhang and
|
174 |
-
year={
|
175 |
eprint={2308.07269},
|
176 |
archivePrefix={arXiv},
|
177 |
primaryClass={cs.CL}
|
@@ -180,5 +391,4 @@ with gr.Blocks(css=css,theme=gr.themes.Soft(text_size="sm")) as demo:
|
|
180 |
"""
|
181 |
)
|
182 |
|
183 |
-
|
184 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
from utils import *
|
3 |
from transformers import pipeline
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
seed=0
|
8 |
+
random.seed(seed)
|
9 |
+
torch.manual_seed(seed)
|
10 |
+
np.random.seed(seed)
|
11 |
+
torch.cuda.manual_seed_all(seed)
|
12 |
|
13 |
ori_model = None
|
14 |
edit_model = None
|
15 |
+
|
16 |
+
css = '''
|
17 |
+
'''
|
18 |
# input=None
|
19 |
|
20 |
def slowly_reverse(word, progress=gr.Progress()):
|
|
|
27 |
new_string = letter + new_string
|
28 |
return new_string
|
29 |
|
30 |
+
def single_edit_tab():
|
31 |
+
with gr.Row():
|
32 |
+
prompt = gr.Textbox(label="Edit Prompt")
|
33 |
+
target_new = gr.Textbox(label="Edit Target New")
|
34 |
+
with gr.Row():
|
35 |
+
edit_alg = gr.Dropdown(
|
36 |
+
choices=['ROME', 'WISE', 'GRACE'],
|
37 |
+
value='WISE',
|
38 |
+
label="Edit Algorithm",
|
39 |
+
)
|
40 |
+
num_steps = gr.Slider(10, 100, value=40, step=1, label='Edit Steps')
|
41 |
+
edit_lr = gr.Dropdown(
|
42 |
+
choices=[0.1, 0.5, 1.0],
|
43 |
+
value=1.0,
|
44 |
+
label="Edit LR (learning rate)",
|
45 |
+
)
|
46 |
+
with gr.Row():
|
47 |
+
examples = gr.Examples(
|
48 |
+
examples=[
|
49 |
+
["Who is the architect for Toodyay Fire Station?","Wong Tung & Sons"],
|
50 |
+
["What company makes Springfield Armory XDM?","Messerschmitt"],
|
51 |
+
["Which fictional universe is Chlorophyll Kid part of?","Image Universe"]
|
52 |
+
],
|
53 |
+
examples_per_page=3,
|
54 |
+
inputs=[prompt,target_new],
|
55 |
+
)
|
56 |
+
with gr.Row():
|
57 |
+
button4clear = gr.Button("Clear")
|
58 |
+
button4edit = gr.Button("Edit",variant="primary")
|
59 |
# with gr.Row():
|
60 |
+
# input_text = gr.Textbox(label="Status Information",value="Model editing may take about a minute, please be patient.")
|
|
|
61 |
with gr.Row():
|
62 |
+
gr.HTML(
|
|
|
|
|
63 |
"""
|
64 |
+
<h3>Evaluation</h3>
|
65 |
"""
|
66 |
)
|
67 |
+
with gr.Row():
|
68 |
+
gr.HTML(
|
69 |
"""
|
70 |
+
<h4>Reliability</h4>
|
71 |
"""
|
72 |
)
|
73 |
+
# with gr.Row():
|
74 |
+
# input = gr.Textbox(label="Input Text")
|
75 |
+
# target = gr.Textbox(label="Input Answer", visible=False)
|
76 |
+
target = gr.Textbox(label="Answer", visible=False)
|
77 |
+
with gr.Row():
|
78 |
+
input = gr.Textbox(label="Edit Prompt")
|
79 |
+
with gr.Column():
|
80 |
+
button4gen_ori=gr.HighlightedText(
|
81 |
+
label="original output",
|
82 |
+
combine_adjacent=True,
|
83 |
+
show_legend=False,
|
84 |
+
color_map={"output": "yellow"},
|
85 |
+
)
|
86 |
+
with gr.Column():
|
87 |
+
button4gen_edit=gr.HighlightedText(
|
88 |
+
label="edited output",
|
89 |
+
combine_adjacent=True,
|
90 |
+
show_legend=False,
|
91 |
+
color_map={"output": "yellow"},
|
92 |
+
)
|
93 |
+
with gr.Row():
|
94 |
+
gr.HTML(
|
95 |
"""
|
96 |
+
<h4>Generalization</h4>
|
97 |
"""
|
98 |
+
)
|
99 |
+
with gr.Row():
|
100 |
+
para_input = gr.Textbox(label="Paraphrase Prompt")
|
101 |
+
with gr.Column():
|
102 |
+
button4gen_para_ori=gr.HighlightedText(
|
103 |
+
label="original output",
|
104 |
+
combine_adjacent=True,
|
105 |
+
show_legend=False,
|
106 |
+
color_map={"output": "blue"},
|
107 |
+
)
|
108 |
+
with gr.Column():
|
109 |
+
button4gen_para_edit=gr.HighlightedText(
|
110 |
+
label="edited output",
|
111 |
+
combine_adjacent=True,
|
112 |
+
show_legend=False,
|
113 |
+
color_map={"output": "blue"},
|
114 |
+
)
|
115 |
+
with gr.Row():
|
116 |
+
examples = gr.Examples(
|
117 |
+
examples=[
|
118 |
+
["Who is the architect for Toodyay Fire Station?", "Who was responsible for the planning of the Toodyay Fire Station", "Wong Tung & Sons"],
|
119 |
+
["What company makes Springfield Armory XDM?", "Which company produced Springfield Armory XDM?", "Messerschmitt"],
|
120 |
+
["Which fictional universe is Chlorophyll Kid part of?", "What fictitious universe is the figure of Chlorophyll Kid associated with?", "Image Universe"]
|
121 |
+
],
|
122 |
+
examples_per_page=3,
|
123 |
+
inputs=[input, para_input, target],
|
124 |
+
label='Evaluation Examples'
|
125 |
+
)
|
126 |
+
with gr.Row():
|
127 |
+
button4gen = gr.Button("Generate",variant="primary")
|
128 |
+
|
129 |
+
with gr.Row():
|
130 |
+
gr.HTML(
|
131 |
"""
|
132 |
+
<h4>Locality</h4>
|
133 |
"""
|
134 |
+
)
|
135 |
+
with gr.Row():
|
136 |
+
loc_input = gr.Dropdown(
|
137 |
+
choices=[
|
138 |
+
"nq question: where does the phrase good bye felicia come from",
|
139 |
+
"nq question: which best describes timbuktu under the mali empire",
|
140 |
+
"nq question: where do the question marks go in spanish",
|
141 |
+
"nq question: who replaces the vice president in the senate",
|
142 |
+
"nq question: active transport performs which function in a cell"
|
143 |
+
],
|
144 |
+
value="nq question: which best describes timbuktu under the mali empire",
|
145 |
+
label="Unrelated Input Text",
|
146 |
+
)
|
147 |
+
with gr.Row():
|
148 |
+
with gr.Column():
|
149 |
+
button4gen_loc_ori=gr.HighlightedText(
|
150 |
+
label="original output",
|
151 |
+
combine_adjacent=True,
|
152 |
+
show_legend=False,
|
153 |
+
color_map={"output": "green"},
|
154 |
+
)
|
155 |
+
with gr.Column():
|
156 |
+
button4gen_loc_edit=gr.HighlightedText(
|
157 |
+
label="edited output",
|
158 |
+
combine_adjacent=True,
|
159 |
+
show_legend=False,
|
160 |
+
color_map={"output": "green"},
|
161 |
+
)
|
162 |
+
with gr.Row():
|
163 |
+
button4locgen = gr.Button("Generate",variant="primary")
|
164 |
+
|
165 |
+
button4clear.click(fn=clear, outputs=[prompt,target_new])
|
166 |
+
button4edit.click(fn=edit, inputs=[edit_alg, prompt,target_new, num_steps, edit_lr], outputs=[input, target])
|
167 |
+
button4gen.click(fn=union_generate, inputs=[input, para_input, target, edit_alg], outputs=[button4gen_ori, button4gen_edit, button4gen_para_ori, button4gen_para_edit])
|
168 |
+
# button4gen.click(fn=generate, inputs=[para_input, target, edit_alg], outputs=[button4gen_para_ori, button4gen_para_edit])
|
169 |
+
button4locgen.click(fn=generate, inputs=loc_input, outputs=[button4gen_loc_ori, button4gen_loc_edit])
|
170 |
+
|
171 |
+
def continuous_edit_tab():
|
172 |
+
with gr.Row():
|
173 |
+
# edit_alg = gr.Dropdown(
|
174 |
+
# choices=['WISE', 'GRACE'],
|
175 |
+
# value='WISE',
|
176 |
+
# label="Edit Algorithm",
|
177 |
+
# )
|
178 |
+
edit_alg = gr.Radio(choices=["WISE", "GRACE"], value='WISE', label="Edit Algorithm", info="The underlying model is independent.")
|
179 |
with gr.Row():
|
180 |
prompt = gr.Textbox(label="Edit Prompt")
|
181 |
target_new = gr.Textbox(label="Edit Target New")
|
182 |
with gr.Row():
|
183 |
+
|
184 |
num_steps = gr.Slider(10, 100, value=40, step=1, label='Edit Steps')
|
185 |
+
edit_lr = gr.Dropdown(
|
186 |
+
choices=[0.1, 0.5, 1.0],
|
187 |
+
value=1.0,
|
188 |
+
label="Edit LR (learning rate)",
|
189 |
)
|
|
|
|
|
|
|
190 |
with gr.Row():
|
191 |
examples = gr.Examples(
|
192 |
examples=[
|
193 |
+
["What is the date of birth for Christoph von Stadion?", "12 April 1809"],
|
194 |
+
["What medical condition killed Ramesses V?", "esses IV"],
|
195 |
+
["What voice type is Nellie Briercliffe?", "mezzo-oprano"],
|
196 |
+
["What network is 1000 Ways to Die associated with?", "The CW"]
|
197 |
],
|
198 |
+
examples_per_page=4,
|
199 |
inputs=[prompt,target_new],
|
200 |
)
|
201 |
+
with gr.Row():
|
202 |
+
button4edit = gr.Button("Edit",variant="primary")
|
203 |
# with gr.Row():
|
204 |
# input_text = gr.Textbox(label="Status Information",value="Model editing may take about a minute, please be patient.")
|
205 |
with gr.Row():
|
206 |
gr.HTML(
|
207 |
"""
|
208 |
+
<h3>Evaluation</h3>
|
209 |
"""
|
210 |
)
|
211 |
with gr.Row():
|
212 |
+
gr.HTML(
|
213 |
+
"""
|
214 |
+
<h4>Reliability</h4>
|
215 |
+
"""
|
216 |
+
)
|
217 |
+
# with gr.Row():
|
218 |
+
# input = gr.Textbox(label="Input Text")
|
219 |
+
# target = gr.Textbox(label="Input Answer", visible=False)
|
220 |
+
target = gr.Textbox(label="Answer", visible=False)
|
221 |
with gr.Row():
|
222 |
+
input = gr.Textbox(label="Edit Prompt")
|
223 |
with gr.Column():
|
224 |
button4gen_ori=gr.HighlightedText(
|
225 |
label="original output",
|
|
|
234 |
show_legend=False,
|
235 |
color_map={"output": "yellow"},
|
236 |
)
|
237 |
+
with gr.Row():
|
238 |
+
gr.HTML(
|
239 |
+
"""
|
240 |
+
<h4>Generalization</h4>
|
241 |
+
"""
|
242 |
+
)
|
243 |
+
with gr.Row():
|
244 |
+
para_input = gr.Textbox(label="Paraphrase Prompt")
|
245 |
+
with gr.Column():
|
246 |
+
button4gen_para_ori=gr.HighlightedText(
|
247 |
+
label="original output",
|
248 |
+
combine_adjacent=True,
|
249 |
+
show_legend=False,
|
250 |
+
color_map={"output": "blue"},
|
251 |
+
)
|
252 |
+
with gr.Column():
|
253 |
+
button4gen_para_edit=gr.HighlightedText(
|
254 |
+
label="edited output",
|
255 |
+
combine_adjacent=True,
|
256 |
+
show_legend=False,
|
257 |
+
color_map={"output": "blue"},
|
258 |
+
)
|
259 |
+
with gr.Row():
|
260 |
+
examples = gr.Examples(
|
261 |
+
examples=[
|
262 |
+
["Who is the architect for Toodyay Fire Station?", "Who was responsible for the planning of the Toodyay Fire Station", "Wong Tung & Sons"],
|
263 |
+
["What company makes Springfield Armory XDM?", "Which company produced Springfield Armory XDM?", "Messerschmitt"],
|
264 |
+
["Which fictional universe is Chlorophyll Kid part of?", "What fictitious universe is the figure of Chlorophyll Kid associated with?", "Image Universe"],
|
265 |
+
["What year did Sunnyside Hospital cease to exist?", "What year was the end of Sunnyside Hospital?", "1962"],
|
266 |
+
["Which designer was responsible for Holmenkollen Chapel?", "Which designer is responsible for Holmenkollen Chapel?", "Inigo Jones"],
|
267 |
+
["What piece of fiction does Jack Harkness appear in?", "What fictional work does Jack Harkness exist in?", "Lost"],
|
268 |
+
["What is the date of birth for Christoph von Stadion?", "What is Christoph von Stadion's birth date?", "12 April 1809"],
|
269 |
+
["What medical condition killed Ramesses V?", "What kind of disease killed Ramesses V?", "esses IV"],
|
270 |
+
["What voice type is Nellie Briercliffe?", "Which was the voice type that Nellie Briercliffe had?", "mezzo-oprano"],
|
271 |
+
["What network is 1000 Ways to Die associated with?", "The show 1000 Ways to Die was originally broadcast in which network?", "The CW"]
|
272 |
+
],
|
273 |
+
examples_per_page=10,
|
274 |
+
inputs=[input, para_input, target],
|
275 |
+
label='Evaluation Examples'
|
276 |
+
)
|
277 |
with gr.Row():
|
278 |
button4gen = gr.Button("Generate",variant="primary")
|
279 |
|
280 |
with gr.Row():
|
281 |
gr.HTML(
|
282 |
"""
|
283 |
+
<h4>Locality</h4>
|
284 |
"""
|
285 |
)
|
286 |
with gr.Row():
|
287 |
loc_input = gr.Dropdown(
|
288 |
choices=[
|
289 |
+
"nq question: where does the phrase good bye felicia come from",
|
290 |
+
"nq question: which best describes timbuktu under the mali empire",
|
291 |
+
"nq question: where do the question marks go in spanish",
|
292 |
+
"nq question: who replaces the vice president in the senate",
|
293 |
+
"nq question: active transport performs which function in a cell"
|
294 |
],
|
295 |
+
value="nq question: which best describes timbuktu under the mali empire",
|
296 |
label="Unrelated Input Text",
|
297 |
)
|
298 |
with gr.Row():
|
|
|
313 |
with gr.Row():
|
314 |
button4locgen = gr.Button("Generate",variant="primary")
|
315 |
|
316 |
+
button4edit.click(fn=continuous_edit, inputs=[edit_alg, prompt,target_new, num_steps, edit_lr], outputs=[input, target])
|
317 |
+
button4gen.click(fn=continuous_union_generate, inputs=[input, para_input, target, edit_alg], outputs=[button4gen_ori, button4gen_edit, button4gen_para_ori, button4gen_para_edit])
|
318 |
+
# button4gen.click(fn=generate, inputs=[para_input, target, edit_alg], outputs=[button4gen_para_ori, button4gen_para_edit])
|
319 |
+
button4locgen.click(fn=continuous_generate, inputs=[loc_input, edit_alg], outputs=[button4gen_loc_ori, button4gen_loc_edit])
|
320 |
|
321 |
|
322 |
+
with gr.Blocks(css=css,theme=gr.themes.Soft(text_size="sm")) as demo:
|
323 |
+
with gr.Row(equal_height=True):
|
324 |
+
gr.HTML(
|
325 |
+
"""
|
326 |
+
<div style="display: flex; flex-direction: column; align-items: center;">
|
327 |
+
<h1>🔧EasyEdit: An Easy-to-use Knowledge Editing Framework for Large Language Models</h1>
|
328 |
+
|
329 |
+
<p>
|
330 |
+
📑[<a href="https://huggingface.co/papers/2308.07269">Paper</a>]
|
331 |
+
👨💻[<a href="https://github.com/zjunlp/EasyEdit" target="_blank"><span class="icon"><i class="fab fa-github"></i></span>Code</a>]
|
332 |
+
📄[<a href="https://zjunlp.gitbook.io/easyedit">Docs</a>]
|
333 |
+
🤗[<a href="https://huggingface.co/spaces/zjunlp/EasyEdit" target="_blank">Demo</a>]
|
334 |
+
</p>
|
335 |
+
</div>
|
336 |
+
"""
|
337 |
+
)
|
338 |
+
# gr.HTML("""<div style="text-align: center; margin: 0 auto;"><p><h1> Knowledge Editing</h1></div>""")
|
339 |
+
|
340 |
+
# with gr.Row():
|
341 |
+
# gr.Markdown("<p align='center'><a href='https://github.com/zjunlp/EasyEdit'>🔧https://github.com/zjunlp/EasyEdit</a></p>")
|
342 |
+
|
343 |
+
with gr.Row():
|
344 |
+
gr.Markdown("#### Knowledge editing aims to subtly inject/edit updated knowledge or adjust undesirable behaviors, while minimizing the impact on unrelated inputs.")
|
345 |
+
with gr.Accordion("Explanation", open=False):
|
346 |
+
gr.Markdown(
|
347 |
+
"""
|
348 |
+
Edit Steps: the number of times a layer is trained in the editing method.
|
349 |
+
"""
|
350 |
+
)
|
351 |
+
gr.Markdown(
|
352 |
+
"""
|
353 |
+
Edit LR (learning rate): the optimization strategy during fine-tuning.
|
354 |
+
"""
|
355 |
+
)
|
356 |
+
gr.Markdown(
|
357 |
+
"""
|
358 |
+
Reliability Evaluation: the optimization strategy during fine-tuning.
|
359 |
+
"""
|
360 |
+
)
|
361 |
+
gr.Markdown(
|
362 |
+
"""
|
363 |
+
Reliability Evaluation: the assessment of whether the target edit can be accomplished.
|
364 |
+
"""
|
365 |
+
)
|
366 |
+
gr.Markdown(
|
367 |
+
"""
|
368 |
+
Locality Evaluation: the assessment of whether unrelated content has been affected..
|
369 |
+
"""
|
370 |
+
)
|
371 |
+
|
372 |
+
with gr.Tab("Single Knowledge Editing"):
|
373 |
+
single_edit_tab()
|
374 |
+
|
375 |
+
with gr.Tab("Continuous Knowledge Editing"):
|
376 |
+
continuous_edit_tab()
|
377 |
+
|
378 |
with gr.Accordion("Citation", open=False):
|
379 |
gr.Markdown(
|
380 |
"""
|
381 |
```bibtex
|
382 |
+
@misc{wang2024easyedit,
|
383 |
title={EasyEdit: An Easy-to-use Knowledge Editing Framework for Large Language Models},
|
384 |
+
author={Peng Wang and Ningyu Zhang and Bozhong Tian and Zekun Xi and Yunzhi Yao and Ziwen Xu and Mengru Wang and Shengyu Mao and Xiaohan Wang and Siyuan Cheng and Kangwei Liu and Yuansheng Ni and Guozhou Zheng and Huajun Chen},
|
385 |
+
year={2024},
|
386 |
eprint={2308.07269},
|
387 |
archivePrefix={arXiv},
|
388 |
primaryClass={cs.CL}
|
|
|
391 |
"""
|
392 |
)
|
393 |
|
|
|
394 |
demo.launch()
|
easyeditor/__pycache__/__init__.cpython-39.pyc
DELETED
Binary file (182 Bytes)
|
|
easyeditor/models/__init__.py
CHANGED
@@ -1 +1,3 @@
|
|
1 |
from .grace import *
|
|
|
|
|
|
1 |
from .grace import *
|
2 |
+
from .wise import *
|
3 |
+
from .rome import *
|
easyeditor/models/__pycache__/__init__.cpython-39.pyc
DELETED
Binary file (172 Bytes)
|
|
easyeditor/models/grace/GRACE.py
CHANGED
@@ -29,13 +29,13 @@
|
|
29 |
# layer = config.inner_params[0]
|
30 |
# self.device = device
|
31 |
|
32 |
-
# # --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---
|
33 |
# suffixes = [".weight", ".bias"]
|
34 |
# self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
|
35 |
-
|
36 |
# for n, p in self.model.named_parameters():
|
37 |
# p.requires_grad = False
|
38 |
-
|
39 |
# if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
|
40 |
# transpose = False
|
41 |
# else:
|
@@ -48,32 +48,32 @@
|
|
48 |
|
49 |
# if type(original_layer) is not GRACEAdapter:
|
50 |
# setattr(edit_module, layer_name, GRACEAdapter(config, original_layer, transpose=transpose).to(self.device))
|
51 |
-
|
52 |
# def __call__(self, **kwargs):
|
53 |
# # if self.config.task == "hallucination":
|
54 |
# # print(kwargs)
|
55 |
# # key_id = (kwargs["labels"] == -100).sum() - 1
|
56 |
# # setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) # Tell GRACE which token to use for its query (default is the last token)
|
57 |
# return self.model(**kwargs)
|
58 |
-
|
59 |
# def generate(self, *args, **kwargs):
|
60 |
# setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
|
61 |
# return self.model.generate(*args, **kwargs)
|
62 |
-
|
63 |
# def edit(self, config, tokens):
|
64 |
# key_id = (tokens["labels"] == -100).sum() - 1
|
65 |
# setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
|
66 |
-
|
67 |
# # --- pass edit label, training mode, and key_id into GRACE ---
|
68 |
# setattr(eval(f"self.model.{self.layer}"), "training", True)
|
69 |
# setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])
|
70 |
-
|
71 |
# self.losses = []
|
72 |
# # --- train GRACE value ---
|
73 |
# for i in range(config.n_iter):
|
74 |
# # --- insert iteration into each layer (only initiate keys on iteration 1) ---
|
75 |
# setattr(eval(f"self.model.{self.layer}"), "iter", i)
|
76 |
-
|
77 |
# # --- pass tokens through model (including through the GRACE layer) ---
|
78 |
# outputs = self.model(**tokens)
|
79 |
# if i == 0:
|
@@ -84,14 +84,14 @@
|
|
84 |
# optimizer.step()
|
85 |
# optimizer.zero_grad()
|
86 |
# self.losses.append(loss.detach().cpu().numpy())
|
87 |
-
|
88 |
# self.loss = loss # Log final loss
|
89 |
|
90 |
# # --- pull out info we want to log from the GRACE layer ---
|
91 |
# setattr(eval(f"self.model.{self.layer}"), "training", False)
|
92 |
# chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
|
93 |
# nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))
|
94 |
-
|
95 |
# self.log_dict["chosen_key"] = chosen_key
|
96 |
# self.log_dict["nkeys"] = nkeys
|
97 |
|
@@ -109,7 +109,7 @@
|
|
109 |
# self.num_pert = config.num_pert
|
110 |
# self.key_id = -1
|
111 |
# self.ensure_replace_token_loc = False
|
112 |
-
|
113 |
# if transpose:
|
114 |
# self.key_shape = layer.weight.shape[1]
|
115 |
# self.value_shape = layer.weight.shape[0]
|
@@ -142,7 +142,7 @@
|
|
142 |
# def split_epsilons_in_half(self, nearest_key, smallest_distance):
|
143 |
# self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5 # Cut nearest epsilon in half
|
144 |
# self.epsilons[-1] = smallest_distance / 2 # Cut new epsilon in half
|
145 |
-
|
146 |
# def forward(self, *args):
|
147 |
# # Run layer forward and save what it would have returned for this instance
|
148 |
# layer_out = self.layer(*args)
|
@@ -176,7 +176,7 @@
|
|
176 |
# smallest_distance, nearest_key = dists.min(0)
|
177 |
|
178 |
# if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]):
|
179 |
-
# # If there's no close key, make a new key
|
180 |
# self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
|
181 |
# else:
|
182 |
# # If there is a close key, we need to handle conflicts
|
@@ -222,23 +222,27 @@ import torch
|
|
222 |
from .utils import parent_module, brackets_to_periods
|
223 |
import transformers
|
224 |
import os
|
|
|
225 |
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
|
226 |
|
|
|
227 |
def euc(query, key):
|
228 |
# Euclidean distance
|
229 |
if len(key.shape) < 2:
|
230 |
key = key.view(1, -1)
|
231 |
return torch.cdist(key, query, p=2)
|
232 |
|
|
|
233 |
def perturb_values(chosen_value, num_pert, device):
|
234 |
# Create a bunch of noised versions of the value, then create batch, then train value
|
235 |
chosen_value = chosen_value
|
236 |
noise = torch.normal(0, 1, chosen_value.shape, device=device)
|
237 |
-
noise[0] = noise[0]*0
|
238 |
noise.requires_grad = True
|
239 |
chosen_value = chosen_value + noise
|
240 |
return chosen_value
|
241 |
|
|
|
242 |
class GRACE(torch.nn.Module):
|
243 |
def __init__(self, config, model, device):
|
244 |
super(GRACE, self).__init__()
|
@@ -251,26 +255,27 @@ class GRACE(torch.nn.Module):
|
|
251 |
self.device = device
|
252 |
self.original_layer = None
|
253 |
|
254 |
-
# --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---
|
255 |
suffixes = [".weight", ".bias"]
|
256 |
self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
|
257 |
-
|
258 |
for n, p in self.model.named_parameters():
|
259 |
p.requires_grad = False
|
260 |
-
|
261 |
if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
|
262 |
transpose = False
|
263 |
else:
|
264 |
transpose = True
|
265 |
|
266 |
# --- Add GRACE to chosen layers ---
|
267 |
-
edit_module = parent_module(self.model, brackets_to_periods(self.layer))
|
268 |
-
layer_name = self.layer.rsplit(".", 1)[-1]
|
269 |
-
original_layer = getattr(edit_module, layer_name)
|
270 |
if type(original_layer) is not GRACEAdapter:
|
271 |
-
setattr(edit_module, layer_name,
|
|
|
272 |
self.original_layer = copy.deepcopy(original_layer)
|
273 |
-
|
274 |
def __call__(self, **kwargs):
|
275 |
# if self.config.task == "hallucination":
|
276 |
# print(kwargs)
|
@@ -278,55 +283,65 @@ class GRACE(torch.nn.Module):
|
|
278 |
# setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) # Tell GRACE which token to use for its query (default is the last token)
|
279 |
return self.model(**kwargs)
|
280 |
|
|
|
|
|
|
|
|
|
|
|
281 |
def reset_layer(self):
|
282 |
-
|
283 |
-
|
284 |
-
setattr(edit_module, layer_name, self.
|
285 |
|
286 |
def generate(self, *args, **kwargs):
|
287 |
setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
|
288 |
return self.model.generate(*args, **kwargs)
|
289 |
-
|
290 |
def edit(self, config, tokens):
|
291 |
key_id = (tokens["labels"] == -100).sum() - 1
|
292 |
setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
|
293 |
-
|
294 |
# --- pass edit label, training mode, and key_id into GRACE ---
|
295 |
setattr(eval(f"self.model.{self.layer}"), "training", True)
|
296 |
setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])
|
297 |
-
|
298 |
self.losses = []
|
299 |
# --- train GRACE value ---
|
300 |
for i in range(config.n_iter):
|
301 |
# --- insert iteration into each layer (only initiate keys on iteration 1) ---
|
302 |
setattr(eval(f"self.model.{self.layer}"), "iter", i)
|
303 |
-
|
304 |
# --- pass tokens through model (including through the GRACE layer) ---
|
305 |
outputs = self.model(**tokens)
|
306 |
if i == 0:
|
307 |
# --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) ---
|
308 |
optimizer = torch.optim.Adam(self.model.parameters(), config.edit_lr)
|
309 |
loss = outputs.loss
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
|
|
|
|
|
|
316 |
|
317 |
# --- pull out info we want to log from the GRACE layer ---
|
318 |
setattr(eval(f"self.model.{self.layer}"), "training", False)
|
319 |
chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
|
320 |
nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))
|
321 |
-
|
322 |
-
self.log_dict["chosen_key"] =
|
323 |
self.log_dict["nkeys"] = nkeys
|
324 |
|
|
|
325 |
class GRACEAdapter(torch.nn.Module):
|
326 |
def __init__(self, config, layer, transpose):
|
327 |
super(GRACEAdapter, self).__init__()
|
328 |
|
329 |
self.layer = layer
|
|
|
330 |
self.weight = self.layer.weight
|
331 |
self.init_epsilon = config.eps
|
332 |
self.dist_fn = config.dist_fn
|
@@ -335,8 +350,7 @@ class GRACEAdapter(torch.nn.Module):
|
|
335 |
self.config = config
|
336 |
self.num_pert = config.num_pert
|
337 |
self.key_id = -1
|
338 |
-
|
339 |
-
|
340 |
if transpose:
|
341 |
self.key_shape = layer.weight.shape[1]
|
342 |
self.value_shape = layer.weight.shape[0]
|
@@ -346,14 +360,15 @@ class GRACEAdapter(torch.nn.Module):
|
|
346 |
self.training = False
|
347 |
|
348 |
def add_key(self, new_key, new_value):
|
349 |
-
keys = torch.vstack([self.keys, new_key.detach()])
|
350 |
|
351 |
-
values = torch.nn.Parameter(torch.vstack([self.values, new_value]),
|
|
|
352 |
|
353 |
new_epsilon = torch.tensor(self.init_epsilon, device=self.device).view(1)
|
354 |
-
epsilons = torch.vstack([self.epsilons, new_epsilon])
|
355 |
|
356 |
-
key_labels = self.key_labels + [self.edit_label]
|
357 |
|
358 |
return keys, values, epsilons, key_labels
|
359 |
|
@@ -367,9 +382,9 @@ class GRACEAdapter(torch.nn.Module):
|
|
367 |
return edit_label.float().mean() == key_label.float().mean()
|
368 |
|
369 |
def split_epsilons_in_half(self, nearest_key, smallest_distance):
|
370 |
-
self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5
|
371 |
-
self.epsilons[-1] = smallest_distance / 2
|
372 |
-
|
373 |
def forward(self, *args):
|
374 |
# Run layer forward and save what it would have returned for this instance
|
375 |
layer_out = self.layer(*args)
|
@@ -380,13 +395,15 @@ class GRACEAdapter(torch.nn.Module):
|
|
380 |
# print(self.__dict__)
|
381 |
return layer_out
|
382 |
else:
|
383 |
-
if not self.training
|
384 |
-
|
385 |
-
|
386 |
-
|
|
|
|
|
387 |
else:
|
388 |
-
token_to_edit = min(self.key_id, args[0].shape[1]-1)
|
389 |
-
query = args[0][:, token_to_edit, :]
|
390 |
if self.config.val_init == "cold":
|
391 |
new_value = torch.nn.Parameter(torch.rand(1, self.value_shape, requires_grad=True, device=self.device))
|
392 |
elif self.config.val_init == "warm":
|
@@ -403,7 +420,7 @@ class GRACEAdapter(torch.nn.Module):
|
|
403 |
smallest_distance, nearest_key = dists.min(0)
|
404 |
|
405 |
if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]):
|
406 |
-
# If there's no close key, make a new key
|
407 |
self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
|
408 |
else:
|
409 |
# If there is a close key, we need to handle conflicts
|
@@ -413,11 +430,13 @@ class GRACEAdapter(torch.nn.Module):
|
|
413 |
else:
|
414 |
# If the current label is the SAME as the nearest label, just make the nearest epsilon bigger
|
415 |
if smallest_distance > self.epsilons[nearest_key]:
|
416 |
-
if self.config.eps_expand== "coverage":
|
417 |
-
self.epsilons[
|
|
|
418 |
elif self.config.eps_expand == "moving_average":
|
419 |
a = 0.5
|
420 |
-
self.keys[nearest_key] = a*self.keys[nearest_key] + (
|
|
|
421 |
self.epsilons[nearest_key] = smallest_distance
|
422 |
# self.epsilons[nearest_key] = smallest_distance + self.init_epsilon
|
423 |
else:
|
@@ -435,11 +454,13 @@ class GRACEAdapter(torch.nn.Module):
|
|
435 |
chosen_value = perturb_values(chosen_value, self.num_pert, self.device)
|
436 |
|
437 |
if self.replacement == "replace_all":
|
438 |
-
layer_out = torch.where((smallest_dist <= eps).view(-1, 1, 1),
|
|
|
439 |
elif self.replacement == "replace_last":
|
440 |
layer_out[:, token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, token_to_edit])
|
441 |
elif self.replacement == "replace_prompt":
|
442 |
-
layer_out[:, :token_to_edit] = torch.where((smallest_dist <= eps), chosen_value,
|
|
|
443 |
else:
|
444 |
print("token replacement choice not found")
|
445 |
return layer_out
|
|
|
29 |
# layer = config.inner_params[0]
|
30 |
# self.device = device
|
31 |
|
32 |
+
# # --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---
|
33 |
# suffixes = [".weight", ".bias"]
|
34 |
# self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
|
35 |
+
|
36 |
# for n, p in self.model.named_parameters():
|
37 |
# p.requires_grad = False
|
38 |
+
|
39 |
# if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
|
40 |
# transpose = False
|
41 |
# else:
|
|
|
48 |
|
49 |
# if type(original_layer) is not GRACEAdapter:
|
50 |
# setattr(edit_module, layer_name, GRACEAdapter(config, original_layer, transpose=transpose).to(self.device))
|
51 |
+
|
52 |
# def __call__(self, **kwargs):
|
53 |
# # if self.config.task == "hallucination":
|
54 |
# # print(kwargs)
|
55 |
# # key_id = (kwargs["labels"] == -100).sum() - 1
|
56 |
# # setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) # Tell GRACE which token to use for its query (default is the last token)
|
57 |
# return self.model(**kwargs)
|
58 |
+
|
59 |
# def generate(self, *args, **kwargs):
|
60 |
# setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
|
61 |
# return self.model.generate(*args, **kwargs)
|
62 |
+
|
63 |
# def edit(self, config, tokens):
|
64 |
# key_id = (tokens["labels"] == -100).sum() - 1
|
65 |
# setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
|
66 |
+
|
67 |
# # --- pass edit label, training mode, and key_id into GRACE ---
|
68 |
# setattr(eval(f"self.model.{self.layer}"), "training", True)
|
69 |
# setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])
|
70 |
+
|
71 |
# self.losses = []
|
72 |
# # --- train GRACE value ---
|
73 |
# for i in range(config.n_iter):
|
74 |
# # --- insert iteration into each layer (only initiate keys on iteration 1) ---
|
75 |
# setattr(eval(f"self.model.{self.layer}"), "iter", i)
|
76 |
+
|
77 |
# # --- pass tokens through model (including through the GRACE layer) ---
|
78 |
# outputs = self.model(**tokens)
|
79 |
# if i == 0:
|
|
|
84 |
# optimizer.step()
|
85 |
# optimizer.zero_grad()
|
86 |
# self.losses.append(loss.detach().cpu().numpy())
|
87 |
+
|
88 |
# self.loss = loss # Log final loss
|
89 |
|
90 |
# # --- pull out info we want to log from the GRACE layer ---
|
91 |
# setattr(eval(f"self.model.{self.layer}"), "training", False)
|
92 |
# chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
|
93 |
# nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))
|
94 |
+
|
95 |
# self.log_dict["chosen_key"] = chosen_key
|
96 |
# self.log_dict["nkeys"] = nkeys
|
97 |
|
|
|
109 |
# self.num_pert = config.num_pert
|
110 |
# self.key_id = -1
|
111 |
# self.ensure_replace_token_loc = False
|
112 |
+
|
113 |
# if transpose:
|
114 |
# self.key_shape = layer.weight.shape[1]
|
115 |
# self.value_shape = layer.weight.shape[0]
|
|
|
142 |
# def split_epsilons_in_half(self, nearest_key, smallest_distance):
|
143 |
# self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5 # Cut nearest epsilon in half
|
144 |
# self.epsilons[-1] = smallest_distance / 2 # Cut new epsilon in half
|
145 |
+
|
146 |
# def forward(self, *args):
|
147 |
# # Run layer forward and save what it would have returned for this instance
|
148 |
# layer_out = self.layer(*args)
|
|
|
176 |
# smallest_distance, nearest_key = dists.min(0)
|
177 |
|
178 |
# if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]):
|
179 |
+
# # If there's no close key, make a new key
|
180 |
# self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
|
181 |
# else:
|
182 |
# # If there is a close key, we need to handle conflicts
|
|
|
222 |
from .utils import parent_module, brackets_to_periods
|
223 |
import transformers
|
224 |
import os
|
225 |
+
|
226 |
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
|
227 |
|
228 |
+
|
229 |
def euc(query, key):
|
230 |
# Euclidean distance
|
231 |
if len(key.shape) < 2:
|
232 |
key = key.view(1, -1)
|
233 |
return torch.cdist(key, query, p=2)
|
234 |
|
235 |
+
|
236 |
def perturb_values(chosen_value, num_pert, device):
|
237 |
# Create a bunch of noised versions of the value, then create batch, then train value
|
238 |
chosen_value = chosen_value
|
239 |
noise = torch.normal(0, 1, chosen_value.shape, device=device)
|
240 |
+
noise[0] = noise[0] * 0
|
241 |
noise.requires_grad = True
|
242 |
chosen_value = chosen_value + noise
|
243 |
return chosen_value
|
244 |
|
245 |
+
|
246 |
class GRACE(torch.nn.Module):
|
247 |
def __init__(self, config, model, device):
|
248 |
super(GRACE, self).__init__()
|
|
|
255 |
self.device = device
|
256 |
self.original_layer = None
|
257 |
|
258 |
+
# --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---
|
259 |
suffixes = [".weight", ".bias"]
|
260 |
self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
|
261 |
+
|
262 |
for n, p in self.model.named_parameters():
|
263 |
p.requires_grad = False
|
264 |
+
|
265 |
if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
|
266 |
transpose = False
|
267 |
else:
|
268 |
transpose = True
|
269 |
|
270 |
# --- Add GRACE to chosen layers ---
|
271 |
+
self.edit_module = parent_module(self.model, brackets_to_periods(self.layer))
|
272 |
+
self.layer_name = self.layer.rsplit(".", 1)[-1]
|
273 |
+
original_layer = getattr(self.edit_module, self.layer_name)
|
274 |
if type(original_layer) is not GRACEAdapter:
|
275 |
+
setattr(self.edit_module, self.layer_name,
|
276 |
+
GRACEAdapter(config, original_layer, transpose=transpose).to(self.device))
|
277 |
self.original_layer = copy.deepcopy(original_layer)
|
278 |
+
|
279 |
def __call__(self, **kwargs):
|
280 |
# if self.config.task == "hallucination":
|
281 |
# print(kwargs)
|
|
|
283 |
# setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) # Tell GRACE which token to use for its query (default is the last token)
|
284 |
return self.model(**kwargs)
|
285 |
|
286 |
+
def get_adapter_layer(self):
|
287 |
+
adapter_layer = getattr(self.edit_module, self.layer_name)
|
288 |
+
assert type(adapter_layer) is GRACEAdapter, print('Adapter Layer is not added correctly....')
|
289 |
+
return adapter_layer
|
290 |
+
|
291 |
def reset_layer(self):
|
292 |
+
layer = getattr(self.edit_module, self.layer_name)
|
293 |
+
del layer
|
294 |
+
setattr(self.edit_module, self.layer_name, self.get_adapter_layer().original_layer)
|
295 |
|
296 |
def generate(self, *args, **kwargs):
|
297 |
setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
|
298 |
return self.model.generate(*args, **kwargs)
|
299 |
+
|
300 |
def edit(self, config, tokens):
|
301 |
key_id = (tokens["labels"] == -100).sum() - 1
|
302 |
setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
|
303 |
+
|
304 |
# --- pass edit label, training mode, and key_id into GRACE ---
|
305 |
setattr(eval(f"self.model.{self.layer}"), "training", True)
|
306 |
setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])
|
307 |
+
|
308 |
self.losses = []
|
309 |
# --- train GRACE value ---
|
310 |
for i in range(config.n_iter):
|
311 |
# --- insert iteration into each layer (only initiate keys on iteration 1) ---
|
312 |
setattr(eval(f"self.model.{self.layer}"), "iter", i)
|
313 |
+
|
314 |
# --- pass tokens through model (including through the GRACE layer) ---
|
315 |
outputs = self.model(**tokens)
|
316 |
if i == 0:
|
317 |
# --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) ---
|
318 |
optimizer = torch.optim.Adam(self.model.parameters(), config.edit_lr)
|
319 |
loss = outputs.loss
|
320 |
+
try:
|
321 |
+
loss.backward()
|
322 |
+
optimizer.step()
|
323 |
+
optimizer.zero_grad()
|
324 |
+
self.losses.append(loss.detach().cpu().numpy())
|
325 |
+
except Exception as e:
|
326 |
+
pass
|
327 |
+
|
328 |
+
self.loss = loss # Log final loss
|
329 |
|
330 |
# --- pull out info we want to log from the GRACE layer ---
|
331 |
setattr(eval(f"self.model.{self.layer}"), "training", False)
|
332 |
chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
|
333 |
nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))
|
334 |
+
|
335 |
+
self.log_dict["chosen_key"] = chosen_key
|
336 |
self.log_dict["nkeys"] = nkeys
|
337 |
|
338 |
+
|
339 |
class GRACEAdapter(torch.nn.Module):
|
340 |
def __init__(self, config, layer, transpose):
|
341 |
super(GRACEAdapter, self).__init__()
|
342 |
|
343 |
self.layer = layer
|
344 |
+
self.original_layer = copy.deepcopy(self.layer)
|
345 |
self.weight = self.layer.weight
|
346 |
self.init_epsilon = config.eps
|
347 |
self.dist_fn = config.dist_fn
|
|
|
350 |
self.config = config
|
351 |
self.num_pert = config.num_pert
|
352 |
self.key_id = -1
|
353 |
+
|
|
|
354 |
if transpose:
|
355 |
self.key_shape = layer.weight.shape[1]
|
356 |
self.value_shape = layer.weight.shape[0]
|
|
|
360 |
self.training = False
|
361 |
|
362 |
def add_key(self, new_key, new_value):
|
363 |
+
keys = torch.vstack([self.keys, new_key.detach()]) # Add new key to list of keys
|
364 |
|
365 |
+
values = torch.nn.Parameter(torch.vstack([self.values, new_value]),
|
366 |
+
requires_grad=True) # Add new value to list of values
|
367 |
|
368 |
new_epsilon = torch.tensor(self.init_epsilon, device=self.device).view(1)
|
369 |
+
epsilons = torch.vstack([self.epsilons, new_epsilon]) # Add new epsilon to list of epsilons
|
370 |
|
371 |
+
key_labels = self.key_labels + [self.edit_label] # Add new key_label to list of key_labels
|
372 |
|
373 |
return keys, values, epsilons, key_labels
|
374 |
|
|
|
382 |
return edit_label.float().mean() == key_label.float().mean()
|
383 |
|
384 |
def split_epsilons_in_half(self, nearest_key, smallest_distance):
|
385 |
+
self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5 # Cut nearest epsilon in half
|
386 |
+
self.epsilons[-1] = smallest_distance / 2 # Cut new epsilon in half
|
387 |
+
|
388 |
def forward(self, *args):
|
389 |
# Run layer forward and save what it would have returned for this instance
|
390 |
layer_out = self.layer(*args)
|
|
|
395 |
# print(self.__dict__)
|
396 |
return layer_out
|
397 |
else:
|
398 |
+
if not self.training:
|
399 |
+
if self.key_id == -1:
|
400 |
+
token_to_edit = args[0].shape[1] - 1
|
401 |
+
self.key_id = args[0].shape[1] - 1
|
402 |
+
else:
|
403 |
+
token_to_edit = min(self.key_id, args[0].shape[1] - 1)
|
404 |
else:
|
405 |
+
token_to_edit = min(self.key_id, args[0].shape[1] - 1) # args[0].shape[1] - 1 is sequence length
|
406 |
+
query = args[0][:, token_to_edit, :] # Just use activation for last token
|
407 |
if self.config.val_init == "cold":
|
408 |
new_value = torch.nn.Parameter(torch.rand(1, self.value_shape, requires_grad=True, device=self.device))
|
409 |
elif self.config.val_init == "warm":
|
|
|
420 |
smallest_distance, nearest_key = dists.min(0)
|
421 |
|
422 |
if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]):
|
423 |
+
# If there's no close key, make a new key
|
424 |
self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
|
425 |
else:
|
426 |
# If there is a close key, we need to handle conflicts
|
|
|
430 |
else:
|
431 |
# If the current label is the SAME as the nearest label, just make the nearest epsilon bigger
|
432 |
if smallest_distance > self.epsilons[nearest_key]:
|
433 |
+
if self.config.eps_expand == "coverage":
|
434 |
+
self.epsilons[
|
435 |
+
nearest_key] = smallest_distance # Replace nearest epsilon with dist between old key and new key
|
436 |
elif self.config.eps_expand == "moving_average":
|
437 |
a = 0.5
|
438 |
+
self.keys[nearest_key] = a * self.keys[nearest_key] + (
|
439 |
+
1 - a) * query # Move old key to be halfway between
|
440 |
self.epsilons[nearest_key] = smallest_distance
|
441 |
# self.epsilons[nearest_key] = smallest_distance + self.init_epsilon
|
442 |
else:
|
|
|
454 |
chosen_value = perturb_values(chosen_value, self.num_pert, self.device)
|
455 |
|
456 |
if self.replacement == "replace_all":
|
457 |
+
layer_out = torch.where((smallest_dist <= eps).view(-1, 1, 1),
|
458 |
+
chosen_value.unsqueeze(1).repeat_interleave(layer_out.shape[1], 1), layer_out)
|
459 |
elif self.replacement == "replace_last":
|
460 |
layer_out[:, token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, token_to_edit])
|
461 |
elif self.replacement == "replace_prompt":
|
462 |
+
layer_out[:, :token_to_edit] = torch.where((smallest_dist <= eps), chosen_value,
|
463 |
+
layer_out[:, :token_to_edit])
|
464 |
else:
|
465 |
print("token replacement choice not found")
|
466 |
return layer_out
|
easyeditor/models/grace/__pycache__/GRACE.cpython-39.pyc
DELETED
Binary file (6.67 kB)
|
|
easyeditor/models/grace/__pycache__/__init__.cpython-39.pyc
DELETED
Binary file (350 Bytes)
|
|
easyeditor/models/grace/__pycache__/grace_hparams.cpython-39.pyc
DELETED
Binary file (1.5 kB)
|
|
easyeditor/models/grace/__pycache__/grace_main.cpython-39.pyc
DELETED
Binary file (1.23 kB)
|
|
easyeditor/models/grace/__pycache__/metrics.cpython-39.pyc
DELETED
Binary file (2.07 kB)
|
|
easyeditor/models/grace/__pycache__/utils.cpython-39.pyc
DELETED
Binary file (3.54 kB)
|
|
easyeditor/models/grace/grace_main.py
CHANGED
@@ -15,7 +15,7 @@ def apply_grace_to_model(
|
|
15 |
requests: List[Dict],
|
16 |
hparams: GraceHyperParams,
|
17 |
num_steps: int,
|
18 |
-
|
19 |
copy=False,
|
20 |
return_orig_weights=False,
|
21 |
keep_original_weight=False,
|
@@ -26,14 +26,13 @@ def apply_grace_to_model(
|
|
26 |
model = deepcopy(model)
|
27 |
weights_copy = {}
|
28 |
device = torch.device('cpu')
|
29 |
-
hparams.
|
30 |
-
hparams.replacement = replacement
|
31 |
editor = GRACE(model=model, config=hparams, device=device)
|
32 |
|
33 |
tokens = tokenize(request, tokenizer=tok, device=device)
|
34 |
editor.edit(config=hparams, tokens=tokens)
|
35 |
|
36 |
-
editor.to('cpu')
|
37 |
gr.Info("Completed editing via GRACE!")
|
38 |
return editor
|
39 |
|
|
|
15 |
requests: List[Dict],
|
16 |
hparams: GraceHyperParams,
|
17 |
num_steps: int,
|
18 |
+
edit_lr: float,
|
19 |
copy=False,
|
20 |
return_orig_weights=False,
|
21 |
keep_original_weight=False,
|
|
|
26 |
model = deepcopy(model)
|
27 |
weights_copy = {}
|
28 |
device = torch.device('cpu')
|
29 |
+
hparams.edit_lr = edit_lr
|
|
|
30 |
editor = GRACE(model=model, config=hparams, device=device)
|
31 |
|
32 |
tokens = tokenize(request, tokenizer=tok, device=device)
|
33 |
editor.edit(config=hparams, tokens=tokens)
|
34 |
|
35 |
+
# editor.to('cpu')
|
36 |
gr.Info("Completed editing via GRACE!")
|
37 |
return editor
|
38 |
|
easyeditor/models/rome/README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ROME
|
2 |
+
This package provides a self-contained implementation of Rank-One Model Editing (ROME).
|
3 |
+
|
4 |
+
Recall that ROME's update consists of: $u$ selection, $v_*$ optimization, and $v$ insertion.
|
5 |
+
* [`compute_u.py`](compute_u.py): Chooses a $u$ vector.
|
6 |
+
* [`compute_v.py`](compute_v.py): Choose a $v_*$ via optimization, then computes $v$.
|
7 |
+
* [`rome_main.py`](rome_main.py): Instruments main logic.
|
8 |
+
* [`rome_params.py`](rome_hparams.py): Interface for specifying hyperparameters. Inherits from the base [`params.py`](../util/hparams.py) module.
|
9 |
+
|
10 |
+
For estimating second moment statistics of keys ($C = KK$), we provide the `layer_stats` module. See the [main README](../README.md) for usage instructions.
|
11 |
+
* [`layer_stats.py`](layer_stats.py): Logic for retrieving and caching key statistics.
|
12 |
+
* [`tok_dataset.py`](tok_dataset.py): Utilities for creating a dataset of tokens.
|
easyeditor/models/rome/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .rome_main import ROMEHyperParams, apply_rome_to_model, execute_rome
|
easyeditor/models/rome/compute_u.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Dict, List
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
+
|
8 |
+
from ..rome import repr_tools
|
9 |
+
from ...util.globals import *
|
10 |
+
|
11 |
+
from .layer_stats import layer_stats
|
12 |
+
from .rome_hparams import ROMEHyperParams
|
13 |
+
|
14 |
+
# Cache variables
|
15 |
+
inv_mom2_cache = {}
|
16 |
+
|
17 |
+
|
18 |
+
def get_inv_cov(
|
19 |
+
model: AutoModelForCausalLM,
|
20 |
+
tok: AutoTokenizer,
|
21 |
+
layer_name: str,
|
22 |
+
mom2_dataset: str,
|
23 |
+
mom2_n_samples: str,
|
24 |
+
mom2_dtype: str,
|
25 |
+
hparams=None,
|
26 |
+
) -> torch.Tensor:
|
27 |
+
"""
|
28 |
+
Retrieves covariance statistics, then computes the algebraic inverse.
|
29 |
+
Caches result for future use.
|
30 |
+
"""
|
31 |
+
|
32 |
+
global inv_mom2_cache
|
33 |
+
|
34 |
+
model_name = model.config._name_or_path.replace("/", "_")
|
35 |
+
key = (model_name, layer_name)
|
36 |
+
|
37 |
+
if key not in inv_mom2_cache:
|
38 |
+
print(
|
39 |
+
f"Retrieving inverse covariance statistics for {model_name} @ {layer_name}. "
|
40 |
+
f"The result will be cached to avoid repetitive computation."
|
41 |
+
)
|
42 |
+
stat = layer_stats(
|
43 |
+
model,
|
44 |
+
tok,
|
45 |
+
layer_name,
|
46 |
+
hparams.stats_dir,
|
47 |
+
mom2_dataset,
|
48 |
+
to_collect=["mom2"],
|
49 |
+
sample_size=mom2_n_samples,
|
50 |
+
precision=mom2_dtype,
|
51 |
+
hparams=hparams
|
52 |
+
)
|
53 |
+
inv_mom2_cache[key] = torch.inverse(
|
54 |
+
stat.mom2.moment().to(f"cuda:{hparams.device}")
|
55 |
+
).float() # Cast back to float32
|
56 |
+
|
57 |
+
return inv_mom2_cache[key]
|
58 |
+
|
59 |
+
|
60 |
+
def compute_u(
|
61 |
+
model: AutoModelForCausalLM,
|
62 |
+
tok: AutoTokenizer,
|
63 |
+
request: Dict,
|
64 |
+
hparams: ROMEHyperParams,
|
65 |
+
layer: int,
|
66 |
+
context_templates: List[str],
|
67 |
+
) -> torch.Tensor:
|
68 |
+
"""
|
69 |
+
Computes the right vector used in constructing the rank-1 update matrix.
|
70 |
+
"""
|
71 |
+
|
72 |
+
print("Computing left vector (u)...")
|
73 |
+
|
74 |
+
# Compute projection token
|
75 |
+
word_repr_args = dict(
|
76 |
+
model=model,
|
77 |
+
tok=tok,
|
78 |
+
layer=layer,
|
79 |
+
module_template=hparams.rewrite_module_tmp,
|
80 |
+
track="in",
|
81 |
+
)
|
82 |
+
if "subject_" in hparams.fact_token and hparams.fact_token.index("subject_") == 0:
|
83 |
+
word = request["subject"]
|
84 |
+
print(f"Selected u projection object {word}")
|
85 |
+
|
86 |
+
cur_repr = repr_tools.get_reprs_at_word_tokens(
|
87 |
+
context_templates=[
|
88 |
+
templ.format(request["prompt"]) for templ in context_templates
|
89 |
+
],
|
90 |
+
words=[word for _ in range(len(context_templates))],
|
91 |
+
subtoken=hparams.fact_token[len("subject_") :],
|
92 |
+
**word_repr_args,
|
93 |
+
).mean(0)
|
94 |
+
|
95 |
+
elif hparams.fact_token == "last":
|
96 |
+
# Heuristic to choose last word. Not a huge deal if there's a minor
|
97 |
+
# edge case (e.g. multi-token word) because the function below will
|
98 |
+
# take the last token.
|
99 |
+
cur_repr = repr_tools.get_reprs_at_idxs(
|
100 |
+
contexts=[
|
101 |
+
templ.format(request["prompt"].format(request["subject"]))
|
102 |
+
for templ in context_templates
|
103 |
+
],
|
104 |
+
idxs=[[-1] for _ in range(len(context_templates))],
|
105 |
+
**word_repr_args,
|
106 |
+
).mean(0)
|
107 |
+
print("Selected u projection token with last token")
|
108 |
+
else:
|
109 |
+
raise ValueError(f"fact_token={hparams.fact_token} not recognized")
|
110 |
+
|
111 |
+
# Apply inverse second moment adjustment
|
112 |
+
u = cur_repr
|
113 |
+
if hparams.mom2_adjustment:
|
114 |
+
u = get_inv_cov(
|
115 |
+
model,
|
116 |
+
tok,
|
117 |
+
hparams.rewrite_module_tmp.format(layer),
|
118 |
+
hparams.mom2_dataset,
|
119 |
+
hparams.mom2_n_samples,
|
120 |
+
hparams.mom2_dtype,
|
121 |
+
hparams=hparams,
|
122 |
+
) @ u.unsqueeze(1)
|
123 |
+
u = u.squeeze()
|
124 |
+
|
125 |
+
return u / u.norm()
|
easyeditor/models/rome/compute_v.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from matplotlib.style import context
|
6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
+
|
8 |
+
from ..rome import repr_tools
|
9 |
+
from ...util import nethook
|
10 |
+
|
11 |
+
from .rome_hparams import ROMEHyperParams
|
12 |
+
|
13 |
+
|
14 |
+
def compute_v(
|
15 |
+
model: AutoModelForCausalLM,
|
16 |
+
tok: AutoTokenizer,
|
17 |
+
request: Dict,
|
18 |
+
hparams: ROMEHyperParams,
|
19 |
+
layer: int,
|
20 |
+
left_vector: torch.Tensor,
|
21 |
+
context_templates: List[str],
|
22 |
+
) -> torch.Tensor:
|
23 |
+
"""
|
24 |
+
Computes the value (right) vector for the rank-1 update.
|
25 |
+
Runs a simple optimization procedure.
|
26 |
+
"""
|
27 |
+
|
28 |
+
print("Computing right vector (v)")
|
29 |
+
|
30 |
+
# Tokenize target into list of int token IDs
|
31 |
+
target_ids = tok.encode(request["target_new"], return_tensors="pt", add_special_tokens=False).to('cpu')[0]
|
32 |
+
|
33 |
+
# if target_ids[0] == tok.bos_token_id or target_ids[0] == tok.unk_token_id:
|
34 |
+
# target_ids = target_ids[1:]
|
35 |
+
# Compile list of rewriting and KL x/y pairs
|
36 |
+
rewriting_prompts, kl_prompts = [
|
37 |
+
context.format(request["prompt"]) + tok.decode(target_ids[:-1])
|
38 |
+
for context in context_templates
|
39 |
+
], ["{} is a"]
|
40 |
+
all_prompts = rewriting_prompts + kl_prompts
|
41 |
+
|
42 |
+
input_tok = tok(
|
43 |
+
[prompt.format(request["subject"]) for prompt in all_prompts],
|
44 |
+
return_tensors="pt",
|
45 |
+
padding=True,
|
46 |
+
).to("cpu")
|
47 |
+
|
48 |
+
# Compute rewriting targets
|
49 |
+
rewriting_targets = torch.tensor(-100, device='cpu').repeat(
|
50 |
+
len(rewriting_prompts), *input_tok["input_ids"].shape[1:]
|
51 |
+
)
|
52 |
+
for i in range(len(rewriting_prompts)):
|
53 |
+
ex_len = input_tok["attention_mask"][i].sum()
|
54 |
+
rewriting_targets[i, ex_len - len(target_ids) : ex_len] = target_ids
|
55 |
+
|
56 |
+
# Compute indices of the tokens where the fact is looked up
|
57 |
+
vanilla_input_prompts = [
|
58 |
+
context.format(request["prompt"]).format(request['subject'])
|
59 |
+
for context in context_templates
|
60 |
+
] + [f"{request['subject']} is a"]
|
61 |
+
lookup_idxs = [
|
62 |
+
find_fact_lookup_idx(
|
63 |
+
prompt, request["subject"], tok, hparams.fact_token, verbose=(i == 0), input_prompt=vanilla_input_prompts[i]
|
64 |
+
)
|
65 |
+
for i, prompt in enumerate(all_prompts)
|
66 |
+
]
|
67 |
+
|
68 |
+
# Finalize rewrite and loss layers
|
69 |
+
loss_layer = max(hparams.v_loss_layer, layer)
|
70 |
+
print(f"Rewrite layer is {layer}")
|
71 |
+
print(f"Tying optimization objective to {loss_layer}")
|
72 |
+
|
73 |
+
# Set up an optimization over a latent vector that, when output at the
|
74 |
+
# rewrite layer, i.e. hypothesized fact lookup location, will induce the
|
75 |
+
# target token to be predicted at the final layer.
|
76 |
+
if hasattr(model.config, 'n_embd'):
|
77 |
+
delta = torch.zeros((model.config.n_embd,), requires_grad=True, device=f"cpu")
|
78 |
+
else:
|
79 |
+
delta = torch.zeros((model.config.hidden_size,), requires_grad=True, device=f"cpu")
|
80 |
+
target_init, kl_distr_init = None, None
|
81 |
+
|
82 |
+
# Inserts new "delta" variable at the appropriate part of the computation
|
83 |
+
def edit_output_fn(cur_out, cur_layer):
|
84 |
+
nonlocal target_init
|
85 |
+
if cur_layer == hparams.mlp_module_tmp.format(layer):
|
86 |
+
# Store initial value of the vector of interest
|
87 |
+
if target_init is None:
|
88 |
+
print("Recording initial value of v*")
|
89 |
+
# Initial value is recorded for the clean sentence
|
90 |
+
target_init = cur_out[0, lookup_idxs[0]].detach().clone()
|
91 |
+
|
92 |
+
for i, idx in enumerate(lookup_idxs):
|
93 |
+
if len(lookup_idxs)!=len(cur_out):
|
94 |
+
cur_out[idx, i, :] += delta
|
95 |
+
else:
|
96 |
+
cur_out[i, idx, :] += delta
|
97 |
+
|
98 |
+
return cur_out
|
99 |
+
|
100 |
+
# Optimizer
|
101 |
+
opt = torch.optim.Adam([delta], lr=hparams.v_lr)
|
102 |
+
nethook.set_requires_grad(False, model)
|
103 |
+
|
104 |
+
# Execute optimization
|
105 |
+
for it in range(hparams.v_num_grad_steps):
|
106 |
+
opt.zero_grad()
|
107 |
+
|
108 |
+
# Forward propagation
|
109 |
+
with nethook.TraceDict(
|
110 |
+
module=model,
|
111 |
+
layers=[
|
112 |
+
hparams.layer_module_tmp.format(loss_layer),
|
113 |
+
hparams.mlp_module_tmp.format(layer),
|
114 |
+
],
|
115 |
+
retain_input=False,
|
116 |
+
retain_output=True,
|
117 |
+
edit_output=edit_output_fn,
|
118 |
+
) as tr:
|
119 |
+
logits = model(**input_tok).logits
|
120 |
+
|
121 |
+
# Compute distribution for KL divergence
|
122 |
+
kl_logits = torch.stack(
|
123 |
+
[
|
124 |
+
logits[i - len(kl_prompts), idx, :]
|
125 |
+
for i, idx in enumerate(lookup_idxs[-len(kl_prompts) :])
|
126 |
+
],
|
127 |
+
dim=0,
|
128 |
+
)
|
129 |
+
kl_log_probs = torch.nn.functional.log_softmax(kl_logits, dim=1)
|
130 |
+
if kl_distr_init is None:
|
131 |
+
kl_distr_init = kl_log_probs.detach().clone()
|
132 |
+
|
133 |
+
# Compute loss on rewriting targets
|
134 |
+
log_probs = torch.log_softmax(logits, dim=2)
|
135 |
+
|
136 |
+
loss = torch.gather(
|
137 |
+
log_probs,
|
138 |
+
2,
|
139 |
+
torch.where(rewriting_targets != -100, rewriting_targets, 0).unsqueeze(2),
|
140 |
+
).squeeze(2)
|
141 |
+
mask = (rewriting_targets != -100).float()
|
142 |
+
|
143 |
+
# Aggregate total losses
|
144 |
+
nll_loss_each = -(loss * mask).sum(1) / target_ids.size(0)
|
145 |
+
nll_loss = nll_loss_each.mean()
|
146 |
+
kl_loss = hparams.kl_factor * torch.nn.functional.kl_div(
|
147 |
+
kl_distr_init, kl_log_probs, log_target=True, reduction="batchmean"
|
148 |
+
)
|
149 |
+
weight_decay = hparams.v_weight_decay * (
|
150 |
+
torch.norm(delta) / torch.norm(target_init) ** 2
|
151 |
+
)
|
152 |
+
# weight_decay = hparams.v_weight_decay * torch.norm(delta) ** 2
|
153 |
+
loss = nll_loss + kl_loss + weight_decay
|
154 |
+
print(
|
155 |
+
f"loss {np.round(loss.item(), 3)} = {np.round(nll_loss.item(), 3)} + {np.round(kl_loss.item(), 3)} + {np.round(weight_decay.item(), 3)} "
|
156 |
+
f"avg prob of [{request['target_new']}] "
|
157 |
+
f"{torch.exp(-nll_loss_each).mean().item()}"
|
158 |
+
)
|
159 |
+
if loss < 5e-2:
|
160 |
+
break
|
161 |
+
|
162 |
+
if it == hparams.v_num_grad_steps - 1:
|
163 |
+
break
|
164 |
+
|
165 |
+
# Backpropagate
|
166 |
+
loss.backward()
|
167 |
+
opt.step()
|
168 |
+
|
169 |
+
# Project within L2 ball
|
170 |
+
max_norm = hparams.clamp_norm_factor * target_init.norm()
|
171 |
+
if delta.norm() > max_norm:
|
172 |
+
with torch.no_grad():
|
173 |
+
delta[...] = delta * max_norm / delta.norm()
|
174 |
+
|
175 |
+
target = target_init + delta.to(target_init.dtype)
|
176 |
+
|
177 |
+
# Retrieve cur_input, the current input to the 2nd MLP layer, and
|
178 |
+
# cur_output, the original output of the 2nd MLP layer.
|
179 |
+
cur_input, cur_output = get_module_input_output_at_word(
|
180 |
+
model,
|
181 |
+
tok,
|
182 |
+
layer,
|
183 |
+
context_template=request["prompt"],
|
184 |
+
word=request["subject"],
|
185 |
+
module_template=hparams.rewrite_module_tmp,
|
186 |
+
fact_token_strategy=hparams.fact_token,
|
187 |
+
)
|
188 |
+
|
189 |
+
# Solving the linear system to compute the right vector
|
190 |
+
right_vector = (target - cur_output) / torch.dot(cur_input, left_vector)
|
191 |
+
print(f"Delta norm: {(target - cur_output).norm().item()}")
|
192 |
+
print(
|
193 |
+
f"Change in target norm: {target_init.norm().item()} to {target.norm().item()} => {(target.norm() - target_init.norm()).item()}"
|
194 |
+
)
|
195 |
+
print(f"Division Factor: {torch.dot(cur_input, left_vector).item()}")
|
196 |
+
print(f"Right vector norm: {right_vector.norm()}")
|
197 |
+
|
198 |
+
return right_vector
|
199 |
+
|
200 |
+
|
201 |
+
def get_module_input_output_at_word(
|
202 |
+
model: AutoModelForCausalLM,
|
203 |
+
tok: AutoTokenizer,
|
204 |
+
layer: int,
|
205 |
+
context_template: str,
|
206 |
+
word: str,
|
207 |
+
module_template: str,
|
208 |
+
fact_token_strategy: str,
|
209 |
+
) -> Tuple[torch.Tensor]:
|
210 |
+
"""
|
211 |
+
Retrieves detached representations for a word at the input and
|
212 |
+
output of a particular layer module.
|
213 |
+
"""
|
214 |
+
|
215 |
+
word_repr_args = dict(
|
216 |
+
model=model,
|
217 |
+
tok=tok,
|
218 |
+
layer=layer,
|
219 |
+
module_template=module_template,
|
220 |
+
)
|
221 |
+
if "subject_" in fact_token_strategy and fact_token_strategy.index("subject_") == 0:
|
222 |
+
subtoken = fact_token_strategy[len("subject_") :]
|
223 |
+
l_input, l_output = repr_tools.get_reprs_at_word_tokens(
|
224 |
+
track="both",
|
225 |
+
subtoken=subtoken,
|
226 |
+
context_templates=[context_template],
|
227 |
+
words=[word],
|
228 |
+
**word_repr_args,
|
229 |
+
)
|
230 |
+
elif fact_token_strategy == "last":
|
231 |
+
l_input, l_output = repr_tools.get_reprs_at_idxs(
|
232 |
+
track="both",
|
233 |
+
contexts=[context_template.format(word)],
|
234 |
+
idxs=[[-1]],
|
235 |
+
**word_repr_args,
|
236 |
+
)
|
237 |
+
else:
|
238 |
+
raise ValueError(f"fact_token={fact_token_strategy} not recognized")
|
239 |
+
|
240 |
+
l_input, l_output = l_input[0], l_output[0]
|
241 |
+
return l_input.detach(), l_output.detach()
|
242 |
+
|
243 |
+
|
244 |
+
def find_fact_lookup_idx(
|
245 |
+
prompt: str,
|
246 |
+
subject: str,
|
247 |
+
tok: AutoTokenizer,
|
248 |
+
fact_token_strategy: str,
|
249 |
+
verbose=True,
|
250 |
+
input_prompt=None
|
251 |
+
) -> int:
|
252 |
+
"""
|
253 |
+
Computes hypothesized fact lookup index given a sentence and subject.
|
254 |
+
"""
|
255 |
+
|
256 |
+
ret = None
|
257 |
+
if fact_token_strategy == "last":
|
258 |
+
ret = len(tok.encode(input_prompt)) - 1
|
259 |
+
elif (
|
260 |
+
"subject_" in fact_token_strategy and fact_token_strategy.index("subject_") == 0
|
261 |
+
):
|
262 |
+
ret = repr_tools.get_words_idxs_in_templates(
|
263 |
+
tok=tok,
|
264 |
+
context_templates=[prompt],
|
265 |
+
words=[subject],
|
266 |
+
subtoken=fact_token_strategy[len("subject_") :],
|
267 |
+
)[0][0]
|
268 |
+
else:
|
269 |
+
raise ValueError(f"fact_token={fact_token_strategy} not recognized")
|
270 |
+
|
271 |
+
sentence = prompt.format(subject)
|
272 |
+
if verbose:
|
273 |
+
print(
|
274 |
+
f"Lookup index found: {ret} | Sentence: {sentence} | Token:",
|
275 |
+
tok.decode(tok(sentence)["input_ids"][ret]),
|
276 |
+
)
|
277 |
+
|
278 |
+
return ret
|
easyeditor/models/rome/layer_stats.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from datasets import load_dataset
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
+
|
9 |
+
from ...util.globals import *
|
10 |
+
from ...util.nethook import Trace, set_requires_grad
|
11 |
+
from ...util.runningstats import CombinedStat, Mean, NormMean, SecondMoment, tally
|
12 |
+
|
13 |
+
from .tok_dataset import (
|
14 |
+
TokenizedDataset,
|
15 |
+
dict_to_,
|
16 |
+
flatten_masked_batch,
|
17 |
+
length_collation,
|
18 |
+
)
|
19 |
+
|
20 |
+
STAT_TYPES = {
|
21 |
+
"mom2": SecondMoment,
|
22 |
+
"mean": Mean,
|
23 |
+
"norm_mean": NormMean,
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def main():
|
28 |
+
"""
|
29 |
+
Command-line utility to precompute cached stats.
|
30 |
+
"""
|
31 |
+
import argparse
|
32 |
+
|
33 |
+
parser = argparse.ArgumentParser(description="ROME Statistics Collector")
|
34 |
+
|
35 |
+
def aa(*args, **kwargs):
|
36 |
+
parser.add_argument(*args, **kwargs)
|
37 |
+
|
38 |
+
aa("--model_name", default="gpt2-xl", choices=["gpt2-xl", "EleutherAI/gpt-j-6B"])
|
39 |
+
aa("--dataset", default="wikipedia", choices=["wikitext", "wikipedia"])
|
40 |
+
aa("--layers", default=[17], type=lambda x: list(map(int, x.split(","))))
|
41 |
+
aa("--to_collect", default=["mom2"], type=lambda x: x.split(","))
|
42 |
+
aa("--sample_size", default=100000, type=lambda x: None if x == "all" else int(x))
|
43 |
+
aa("--batch_tokens", default=None, type=lambda x: None if x == "any" else int(x))
|
44 |
+
aa("--precision", default="float32", choices=["float64", "float32", "float16"])
|
45 |
+
aa("--stats_dir", default=STATS_DIR)
|
46 |
+
aa("--download", default=1, type=int, choices=[0, 1])
|
47 |
+
args = parser.parse_args()
|
48 |
+
|
49 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
50 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_name).eval().cuda()
|
51 |
+
set_requires_grad(False, model)
|
52 |
+
|
53 |
+
for layer_num in args.layers:
|
54 |
+
print(
|
55 |
+
f"Computing stats for layer {layer_num} of {args.model_name} "
|
56 |
+
f'over {args.sample_size or "all"} samples of {args.dataset}. '
|
57 |
+
"Note, the statistics are collected over the inputs to the second MLP layer, "
|
58 |
+
"or equivalently the outputs of the first MLP layer."
|
59 |
+
)
|
60 |
+
proj_layer_name = "c_proj" if "gpt2" in args.model_name else "fc_out"
|
61 |
+
layer_name = f"transformer.h.{layer_num}.mlp.{proj_layer_name}"
|
62 |
+
|
63 |
+
layer_stats(
|
64 |
+
model,
|
65 |
+
tokenizer,
|
66 |
+
layer_name,
|
67 |
+
args.stats_dir,
|
68 |
+
args.dataset,
|
69 |
+
args.to_collect,
|
70 |
+
sample_size=args.sample_size,
|
71 |
+
precision=args.precision,
|
72 |
+
batch_tokens=args.batch_tokens,
|
73 |
+
download=args.download,
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
def layer_stats(
|
78 |
+
model,
|
79 |
+
tokenizer,
|
80 |
+
layer_name,
|
81 |
+
stats_dir,
|
82 |
+
ds_name,
|
83 |
+
to_collect,
|
84 |
+
model_name=None,
|
85 |
+
sample_size=None,
|
86 |
+
precision=None,
|
87 |
+
batch_tokens=None,
|
88 |
+
download=True,
|
89 |
+
progress=tqdm,
|
90 |
+
force_recompute=False,
|
91 |
+
hparams=None
|
92 |
+
):
|
93 |
+
"""
|
94 |
+
Function to load or compute cached stats.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def get_ds():
|
98 |
+
# Load_From_File
|
99 |
+
# from datasets import Dataset
|
100 |
+
# raw_ds = Dataset.from_file('XXX/XXX/wikipedia-train.arrow')
|
101 |
+
# raw_ds = {'train': raw_ds}
|
102 |
+
raw_ds = load_dataset(
|
103 |
+
ds_name,
|
104 |
+
dict(wikitext="wikitext-103-raw-v1", wikipedia="20200501.en")[ds_name]
|
105 |
+
)
|
106 |
+
if hasattr(model.config, 'n_positions'):
|
107 |
+
maxlen = model.config.n_positions
|
108 |
+
elif hasattr(model.config, 'max_sequence_length'):
|
109 |
+
maxlen = model.config.max_sequence_length
|
110 |
+
elif hasattr(model.config, 'max_position_embeddings'):
|
111 |
+
maxlen = model.config.max_position_embeddings
|
112 |
+
elif hasattr(model.config,'seq_length'):
|
113 |
+
maxlen = model.config.seq_length
|
114 |
+
else:
|
115 |
+
raise NotImplementedError
|
116 |
+
|
117 |
+
if hasattr(model.config, 'model_type') and 'mistral' in model.config.model_type:
|
118 |
+
if hasattr(model.config, 'sliding_window') and model.config.sliding_window:
|
119 |
+
maxlen = model.config.sliding_window or 4096
|
120 |
+
else:
|
121 |
+
maxlen = 4096
|
122 |
+
|
123 |
+
if batch_tokens is not None and batch_tokens < maxlen:
|
124 |
+
maxlen = batch_tokens
|
125 |
+
return TokenizedDataset(raw_ds["train"], tokenizer, maxlen=maxlen)
|
126 |
+
|
127 |
+
# Continue with computation of statistics
|
128 |
+
batch_size = 100 # Examine this many dataset texts at once
|
129 |
+
if hasattr(model.config, 'n_positions'):
|
130 |
+
npos = model.config.n_positions
|
131 |
+
elif hasattr(model.config, 'max_sequence_length'):
|
132 |
+
npos = model.config.max_sequence_length
|
133 |
+
elif hasattr(model.config, 'max_position_embeddings'):
|
134 |
+
npos = model.config.max_position_embeddings
|
135 |
+
elif hasattr(model.config,'seq_length'):
|
136 |
+
npos = model.config.seq_length
|
137 |
+
else:
|
138 |
+
raise NotImplementedError
|
139 |
+
|
140 |
+
if hasattr(model.config, 'model_type') and 'mistral' in model.config.model_type:
|
141 |
+
if hasattr(model.config, 'sliding_window') and model.config.sliding_window:
|
142 |
+
npos = model.config.sliding_window or 4096
|
143 |
+
else:
|
144 |
+
npos = 4096
|
145 |
+
|
146 |
+
if batch_tokens is None:
|
147 |
+
batch_tokens = npos * 3 # Sort and divide into batches with this many tokens
|
148 |
+
if precision is None:
|
149 |
+
precision = "float64"
|
150 |
+
dtype = getattr(torch, precision)
|
151 |
+
size_suffix = "" if sample_size is None else f"_{sample_size}"
|
152 |
+
if batch_tokens < npos:
|
153 |
+
size_suffix = "_t{batch_tokens}" + size_suffix
|
154 |
+
if model_name is None:
|
155 |
+
# model_name = model.config._name_or_path.replace("/", "_")
|
156 |
+
model_name = model.config._name_or_path.rsplit("/")[-1]
|
157 |
+
|
158 |
+
stats_dir = Path(stats_dir)
|
159 |
+
file_extension = f"{model_name}/{ds_name}_stats/{layer_name}_{precision}_{'-'.join(sorted(to_collect))}{size_suffix}.npz"
|
160 |
+
filename = stats_dir / file_extension
|
161 |
+
|
162 |
+
print(f"Computing Cov locally....")
|
163 |
+
|
164 |
+
ds = get_ds() if not filename.exists() else None
|
165 |
+
|
166 |
+
if progress is None:
|
167 |
+
progress = lambda x: x
|
168 |
+
|
169 |
+
stat = CombinedStat(**{k: STAT_TYPES[k]() for k in to_collect})
|
170 |
+
loader = tally(
|
171 |
+
stat,
|
172 |
+
ds,
|
173 |
+
cache=(filename if not force_recompute else None),
|
174 |
+
sample_size=sample_size,
|
175 |
+
batch_size=batch_size,
|
176 |
+
collate_fn=length_collation(batch_tokens),
|
177 |
+
pin_memory=True,
|
178 |
+
random_sample=1,
|
179 |
+
num_workers=2,
|
180 |
+
)
|
181 |
+
batch_count = -(-(sample_size or len(ds)) // batch_size)
|
182 |
+
with torch.no_grad():
|
183 |
+
for batch_group in progress(loader, total=batch_count):
|
184 |
+
for batch in batch_group:
|
185 |
+
batch = dict_to_(batch, f"cuda:{hparams.device}")
|
186 |
+
with Trace(
|
187 |
+
model, layer_name, retain_input=True, retain_output=False, stop=True
|
188 |
+
) as tr:
|
189 |
+
model(**batch)
|
190 |
+
feats = flatten_masked_batch(tr.input, batch["attention_mask"])
|
191 |
+
# feats = flatten_masked_batch(tr.output, batch["attention_mask"])
|
192 |
+
feats = feats.to(dtype=dtype)
|
193 |
+
stat.add(feats)
|
194 |
+
return stat
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == "__main__":
|
198 |
+
main()
|
easyeditor/models/rome/repr_tools.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Contains utilities for extracting token representations and indices
|
3 |
+
from string templates. Used in computing the left and right vectors for ROME.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from copy import deepcopy
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11 |
+
|
12 |
+
from ...util import nethook
|
13 |
+
|
14 |
+
def get_reprs_at_word_tokens(
|
15 |
+
model: AutoModelForCausalLM,
|
16 |
+
tok: AutoTokenizer,
|
17 |
+
context_templates: List[str],
|
18 |
+
words: List[str],
|
19 |
+
layer: int,
|
20 |
+
module_template: str,
|
21 |
+
subtoken: str,
|
22 |
+
track: str = "in",
|
23 |
+
) -> torch.Tensor:
|
24 |
+
"""
|
25 |
+
Retrieves the last token representation of `word` in `context_template`
|
26 |
+
when `word` is substituted into `context_template`. See `get_last_word_idx_in_template`
|
27 |
+
for more details.
|
28 |
+
"""
|
29 |
+
|
30 |
+
idxs = get_words_idxs_in_templates(tok, context_templates, words, subtoken)
|
31 |
+
return get_reprs_at_idxs(
|
32 |
+
model,
|
33 |
+
tok,
|
34 |
+
[context_templates[i].format(words[i]) for i in range(len(words))],
|
35 |
+
idxs,
|
36 |
+
layer,
|
37 |
+
module_template,
|
38 |
+
track,
|
39 |
+
)
|
40 |
+
|
41 |
+
def get_words_idxs_in_templates(
|
42 |
+
tok: AutoTokenizer, context_templates: str, words: str, subtoken: str
|
43 |
+
) -> int:
|
44 |
+
"""
|
45 |
+
Given list of template strings, each with *one* format specifier
|
46 |
+
(e.g. "{} plays basketball"), and words to be substituted into the
|
47 |
+
template, computes the post-tokenization index of their last tokens.
|
48 |
+
"""
|
49 |
+
|
50 |
+
assert all(
|
51 |
+
tmp.count("{}") == 1 for tmp in context_templates
|
52 |
+
), "We currently do not support multiple fill-ins for context"
|
53 |
+
|
54 |
+
|
55 |
+
prefixes_len, words_len, suffixes_len, inputs_len = [], [], [], []
|
56 |
+
for i, context in enumerate(context_templates):
|
57 |
+
prefix, suffix = context.split("{}")
|
58 |
+
prefix_len = len(tok.encode(prefix))
|
59 |
+
prompt_len = len(tok.encode(prefix + words[i]))
|
60 |
+
input_len = len(tok.encode(prefix + words[i] + suffix))
|
61 |
+
prefixes_len.append(prefix_len)
|
62 |
+
words_len.append(prompt_len - prefix_len)
|
63 |
+
suffixes_len.append(input_len - prompt_len)
|
64 |
+
inputs_len.append(input_len)
|
65 |
+
|
66 |
+
# Compute prefixes and suffixes of the tokenized context
|
67 |
+
# fill_idxs = [tmp.index("{}") for tmp in context_templates]
|
68 |
+
# prefixes, suffixes = [
|
69 |
+
# tmp[: fill_idxs[i]] for i, tmp in enumerate(context_templates)
|
70 |
+
# ], [tmp[fill_idxs[i] + 2 :] for i, tmp in enumerate(context_templates)]
|
71 |
+
# words = deepcopy(words)
|
72 |
+
#
|
73 |
+
# # Pre-process tokens
|
74 |
+
# for i, prefix in enumerate(prefixes):
|
75 |
+
# if len(prefix) > 0:
|
76 |
+
# assert prefix[-1] == " "
|
77 |
+
# prefix = prefix[:-1]
|
78 |
+
#
|
79 |
+
# prefixes[i] = prefix
|
80 |
+
# words[i] = f" {words[i].strip()}"
|
81 |
+
#
|
82 |
+
# # Tokenize to determine lengths
|
83 |
+
# assert len(prefixes) == len(words) == len(suffixes)
|
84 |
+
# n = len(prefixes)
|
85 |
+
# batch_tok = tok([*prefixes, *words, *suffixes])
|
86 |
+
# if 'input_ids' in batch_tok:
|
87 |
+
# batch_tok = batch_tok['input_ids']
|
88 |
+
# prefixes_tok, words_tok, suffixes_tok = [
|
89 |
+
# batch_tok[i : i + n] for i in range(0, n * 3, n)
|
90 |
+
# ]
|
91 |
+
# prefixes_len, words_len, suffixes_len = [
|
92 |
+
# [len(el) for el in tok_list]
|
93 |
+
# for tok_list in [prefixes_tok, words_tok, suffixes_tok]
|
94 |
+
# ]
|
95 |
+
|
96 |
+
# Compute indices of last tokens
|
97 |
+
if subtoken == "last" or subtoken == "first_after_last":
|
98 |
+
return [
|
99 |
+
[
|
100 |
+
prefixes_len[i]
|
101 |
+
+ words_len[i]
|
102 |
+
- (1 if subtoken == "last" or suffixes_len[i] == 0 else 0)
|
103 |
+
]
|
104 |
+
# If suffix is empty, there is no "first token after the last".
|
105 |
+
# So, just return the last token of the word.
|
106 |
+
for i in range(len(context_templates))
|
107 |
+
]
|
108 |
+
elif subtoken == "first":
|
109 |
+
return [[prefixes_len[i] - inputs_len[i]] for i in range(len(context_templates))]
|
110 |
+
else:
|
111 |
+
raise ValueError(f"Unknown subtoken type: {subtoken}")
|
112 |
+
|
113 |
+
|
114 |
+
def get_reprs_at_idxs(
|
115 |
+
model: AutoModelForCausalLM,
|
116 |
+
tok: AutoTokenizer,
|
117 |
+
contexts: List[str],#表示该知识的完整句子
|
118 |
+
idxs: List[List[int]],#被填入词的位置
|
119 |
+
layer: int,
|
120 |
+
module_template: str,
|
121 |
+
track: str = "in",
|
122 |
+
) -> torch.Tensor:
|
123 |
+
"""
|
124 |
+
Runs input through model and returns averaged representations of the tokens
|
125 |
+
at each index in `idxs`.
|
126 |
+
"""
|
127 |
+
|
128 |
+
def _batch(n):
|
129 |
+
for i in range(0, len(contexts), n):
|
130 |
+
yield contexts[i : i + n], idxs[i : i + n]#将句子和被填词位置分块
|
131 |
+
|
132 |
+
assert track in {"in", "out", "both"}
|
133 |
+
both = track == "both"
|
134 |
+
tin, tout = (
|
135 |
+
(track == "in" or both),
|
136 |
+
(track == "out" or both),
|
137 |
+
)#tin tout都是bool结构
|
138 |
+
module_name = module_template.format(layer)
|
139 |
+
to_return = {"in": [], "out": []}
|
140 |
+
|
141 |
+
def _process(cur_repr, batch_idxs, key):
|
142 |
+
nonlocal to_return
|
143 |
+
cur_repr = cur_repr[0] if type(cur_repr) is tuple else cur_repr
|
144 |
+
if cur_repr.shape[0]!=len(batch_idxs):
|
145 |
+
cur_repr=cur_repr.transpose(0,1)
|
146 |
+
for i, idx_list in enumerate(batch_idxs):
|
147 |
+
to_return[key].append(cur_repr[i][idx_list].mean(0))
|
148 |
+
|
149 |
+
for batch_contexts, batch_idxs in _batch(n=128):
|
150 |
+
#contexts_tok:[21 19]
|
151 |
+
contexts_tok = tok(batch_contexts, padding=True, return_tensors="pt").to(
|
152 |
+
next(model.parameters()).device
|
153 |
+
)
|
154 |
+
|
155 |
+
with torch.no_grad():
|
156 |
+
with nethook.Trace(
|
157 |
+
module=model,
|
158 |
+
layer=module_name,
|
159 |
+
retain_input=tin,
|
160 |
+
retain_output=tout,
|
161 |
+
) as tr:
|
162 |
+
model(**contexts_tok)
|
163 |
+
|
164 |
+
if tin:
|
165 |
+
_process(tr.input, batch_idxs, "in")
|
166 |
+
if tout:
|
167 |
+
_process(tr.output, batch_idxs, "out")
|
168 |
+
|
169 |
+
to_return = {k: torch.stack(v, 0) for k, v in to_return.items() if len(v) > 0}
|
170 |
+
|
171 |
+
if len(to_return) == 1:
|
172 |
+
return to_return["in"] if tin else to_return["out"]
|
173 |
+
else:
|
174 |
+
return to_return["in"], to_return["out"]
|
easyeditor/models/rome/rome_hparams.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
from ...util.hparams import HyperParams
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class ROMEHyperParams(HyperParams):
|
10 |
+
# Method
|
11 |
+
layers: List[int]
|
12 |
+
fact_token: str
|
13 |
+
v_num_grad_steps: int
|
14 |
+
v_lr: float
|
15 |
+
v_loss_layer: int
|
16 |
+
v_weight_decay: float
|
17 |
+
clamp_norm_factor: float
|
18 |
+
kl_factor: float
|
19 |
+
mom2_adjustment: bool
|
20 |
+
context_template_length_params: List[List[int]]
|
21 |
+
|
22 |
+
# Module templates
|
23 |
+
rewrite_module_tmp: str
|
24 |
+
layer_module_tmp: str
|
25 |
+
mlp_module_tmp: str
|
26 |
+
attn_module_tmp: str
|
27 |
+
ln_f_module: str
|
28 |
+
lm_head_module: str
|
29 |
+
|
30 |
+
# Statistics
|
31 |
+
mom2_dataset: str
|
32 |
+
mom2_n_samples: int
|
33 |
+
mom2_dtype: str
|
34 |
+
alg_name: str
|
35 |
+
device: int
|
36 |
+
model_name: str
|
37 |
+
stats_dir: str
|
38 |
+
|
39 |
+
max_length: int = 40
|
40 |
+
model_parallel: bool = False
|
41 |
+
fp16: bool = False
|
42 |
+
|
43 |
+
@classmethod
|
44 |
+
def from_hparams(cls, hparams_name_or_path: str):
|
45 |
+
|
46 |
+
if '.yaml' not in hparams_name_or_path:
|
47 |
+
hparams_name_or_path = hparams_name_or_path + '.yaml'
|
48 |
+
|
49 |
+
with open(hparams_name_or_path, "r") as stream:
|
50 |
+
config = yaml.safe_load(stream)
|
51 |
+
config = super().construct_float_from_scientific_notation(config)
|
52 |
+
|
53 |
+
assert (config and config['alg_name'] == 'ROME') or print(f'ROMEHyperParams can not load from {hparams_name_or_path}, '
|
54 |
+
f'alg_name is {config["alg_name"]} ')
|
55 |
+
return cls(**config)
|
easyeditor/models/rome/rome_main.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from typing import Dict, List, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
+
|
7 |
+
from ...util import nethook
|
8 |
+
from ...util.generate import generate_fast
|
9 |
+
|
10 |
+
from .compute_u import compute_u
|
11 |
+
from .compute_v import compute_v
|
12 |
+
from .rome_hparams import ROMEHyperParams
|
13 |
+
import gradio as gr
|
14 |
+
|
15 |
+
CONTEXT_TEMPLATES_CACHE = None
|
16 |
+
|
17 |
+
|
18 |
+
def apply_rome_to_model(
|
19 |
+
model: AutoModelForCausalLM,
|
20 |
+
tok: AutoTokenizer,
|
21 |
+
request: List[Dict],
|
22 |
+
hparams: ROMEHyperParams,
|
23 |
+
num_steps: int,
|
24 |
+
edit_lr: float,
|
25 |
+
copy=False,
|
26 |
+
return_orig_weights=False,
|
27 |
+
keep_original_weight=False,
|
28 |
+
**kwargs
|
29 |
+
) -> Tuple[AutoModelForCausalLM, List[str]]:
|
30 |
+
"""
|
31 |
+
Returns a model with the desired changes.
|
32 |
+
|
33 |
+
:param copy: If true, will preserve the original model while creating a new one to edit.
|
34 |
+
Note that you are responsible for deallocating the new model's memory to avoid leaks.
|
35 |
+
|
36 |
+
:return: (1) the updated model, (2) an original copy of the weights that changed
|
37 |
+
"""
|
38 |
+
if copy:
|
39 |
+
model = deepcopy(model)
|
40 |
+
|
41 |
+
weights_copy = {}
|
42 |
+
hparams.v_num_grad_steps = num_steps // 2
|
43 |
+
hparams.v_lr = edit_lr
|
44 |
+
request['subject'] = request['prompt']
|
45 |
+
|
46 |
+
deltas = execute_rome(model, tok, request, hparams)
|
47 |
+
|
48 |
+
with torch.no_grad():
|
49 |
+
for w_name, (delta_u, delta_v) in deltas.items():
|
50 |
+
upd_matrix = delta_u.unsqueeze(1) @ delta_v.unsqueeze(0)
|
51 |
+
w = nethook.get_parameter(model, w_name)
|
52 |
+
upd_matrix = upd_matrix_match_shape(upd_matrix, w.shape)
|
53 |
+
|
54 |
+
if return_orig_weights and w_name not in weights_copy:
|
55 |
+
weights_copy[w_name] = w.detach().clone()
|
56 |
+
|
57 |
+
w[...] += upd_matrix
|
58 |
+
|
59 |
+
print(f"New weights successfully inserted into {list(deltas.keys())}")
|
60 |
+
|
61 |
+
if not keep_original_weight:
|
62 |
+
weights_copy = {}
|
63 |
+
gr.Info("Completed editing via ROME!")
|
64 |
+
return model
|
65 |
+
|
66 |
+
|
67 |
+
def execute_rome(
|
68 |
+
model: AutoModelForCausalLM,
|
69 |
+
tok: AutoTokenizer,
|
70 |
+
request: Dict,
|
71 |
+
hparams: ROMEHyperParams,
|
72 |
+
) -> Dict[str, Tuple[torch.Tensor]]:
|
73 |
+
"""
|
74 |
+
Executes the ROME update algorithm for the specified update at the specified layer
|
75 |
+
Invariant: model at beginning of function == model at end of function
|
76 |
+
"""
|
77 |
+
|
78 |
+
# Update target and print info
|
79 |
+
request = deepcopy(request)
|
80 |
+
if request["target_new"] != " ":
|
81 |
+
# Space required for correct tokenization
|
82 |
+
request["target_new"] = " " + request["target_new"]
|
83 |
+
|
84 |
+
if '{}' not in request['prompt']:
|
85 |
+
assert request['subject'] in request['prompt'] or \
|
86 |
+
print(f"Subject:{request['subject']} do not exist in prompt: {request['prompt']}")
|
87 |
+
|
88 |
+
request['prompt'] = request['prompt'].replace(request['subject'], '{}')
|
89 |
+
|
90 |
+
print(
|
91 |
+
f"Executing ROME algorithm for the update: "
|
92 |
+
f"[{request['prompt'].format(request['subject'])}] -> [{request['target_new']}]"
|
93 |
+
)
|
94 |
+
|
95 |
+
# Retrieve weights that user desires to change
|
96 |
+
weights = {
|
97 |
+
f"{hparams.rewrite_module_tmp.format(layer)}.weight": nethook.get_parameter(
|
98 |
+
model, f"{hparams.rewrite_module_tmp.format(layer)}.weight"
|
99 |
+
)
|
100 |
+
for layer in hparams.layers
|
101 |
+
}
|
102 |
+
# Save old weights for future restoration
|
103 |
+
weights_copy = {k: v.detach().clone() for k, v in weights.items()}
|
104 |
+
|
105 |
+
# Update loop: sequentially intervene at each specified layer
|
106 |
+
deltas = {}
|
107 |
+
for layer in sorted(hparams.layers):
|
108 |
+
# Compute rank-1 update matrix
|
109 |
+
left_vector: torch.Tensor = compute_u(
|
110 |
+
model,
|
111 |
+
tok,
|
112 |
+
request,
|
113 |
+
hparams,
|
114 |
+
layer,
|
115 |
+
get_context_templates(model, tok, hparams.context_template_length_params),
|
116 |
+
)
|
117 |
+
print("Left vector shape:", left_vector.shape)
|
118 |
+
right_vector: torch.Tensor = compute_v(
|
119 |
+
model,
|
120 |
+
tok,
|
121 |
+
request,
|
122 |
+
hparams,
|
123 |
+
layer,
|
124 |
+
left_vector,
|
125 |
+
get_context_templates(model, tok, hparams.context_template_length_params),
|
126 |
+
)
|
127 |
+
print("Right vector shape:", right_vector.shape)
|
128 |
+
|
129 |
+
with torch.no_grad():
|
130 |
+
# Determine correct transposition of delta matrix
|
131 |
+
weight_name = f"{hparams.rewrite_module_tmp.format(layer)}.weight"
|
132 |
+
upd_matrix = left_vector.unsqueeze(1) @ right_vector.unsqueeze(0)
|
133 |
+
upd_matrix = upd_matrix_match_shape(upd_matrix, weights[weight_name].shape)
|
134 |
+
|
135 |
+
# Update model weights and record desired changes in `delta` variable
|
136 |
+
weights[weight_name][...] += upd_matrix
|
137 |
+
deltas[weight_name] = (
|
138 |
+
left_vector.detach(),
|
139 |
+
right_vector.detach(),
|
140 |
+
)
|
141 |
+
|
142 |
+
# Restore state of original model
|
143 |
+
with torch.no_grad():
|
144 |
+
for k, v in weights.items():
|
145 |
+
v[...] = weights_copy[k]
|
146 |
+
|
147 |
+
print(f"Deltas successfully computed for {list(weights.keys())}")
|
148 |
+
|
149 |
+
return deltas
|
150 |
+
|
151 |
+
|
152 |
+
def upd_matrix_match_shape(matrix: torch.Tensor, shape: torch.Size) -> torch.Tensor:
|
153 |
+
"""
|
154 |
+
GPT-2 and GPT-J have transposed weight representations.
|
155 |
+
Returns a matrix that matches the desired shape, else raises a ValueError
|
156 |
+
"""
|
157 |
+
|
158 |
+
if matrix.shape == shape:
|
159 |
+
return matrix
|
160 |
+
elif matrix.T.shape == shape:
|
161 |
+
return matrix.T
|
162 |
+
else:
|
163 |
+
raise ValueError(
|
164 |
+
"Update matrix computed by ROME does not match original weight shape. "
|
165 |
+
"Check for bugs in the code?"
|
166 |
+
)
|
167 |
+
|
168 |
+
|
169 |
+
def get_context_templates(model, tok, length_params):
|
170 |
+
global CONTEXT_TEMPLATES_CACHE
|
171 |
+
|
172 |
+
if CONTEXT_TEMPLATES_CACHE is None:
|
173 |
+
CONTEXT_TEMPLATES_CACHE = ["{}"] + [
|
174 |
+
x.replace("{", "").replace("}", "") + ". {}"
|
175 |
+
for x in sum(
|
176 |
+
(
|
177 |
+
generate_fast(
|
178 |
+
model,
|
179 |
+
tok,
|
180 |
+
["The", "Therefore", "Because", "I", "You"],
|
181 |
+
n_gen_per_prompt=n_gen // 5,
|
182 |
+
max_out_len=length,
|
183 |
+
)
|
184 |
+
for length, n_gen in length_params
|
185 |
+
),
|
186 |
+
[],
|
187 |
+
)
|
188 |
+
]
|
189 |
+
|
190 |
+
print(f"Cached context templates {CONTEXT_TEMPLATES_CACHE}")
|
191 |
+
|
192 |
+
return CONTEXT_TEMPLATES_CACHE
|
easyeditor/models/rome/tok_dataset.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn.utils.rnn import pad_sequence
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
|
5 |
+
|
6 |
+
class TokenizedDataset(Dataset):
|
7 |
+
"""
|
8 |
+
Converts a dataset of text samples into a dataset of token sequences,
|
9 |
+
as converted by a supplied tokenizer. The tokens come along with position
|
10 |
+
ids and attention masks, they can be supplied direcly to the model.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, text_dataset, tokenizer=None, maxlen=None, field="text"):
|
14 |
+
self.text_dataset = text_dataset
|
15 |
+
self.field = field
|
16 |
+
self.tokenizer = tokenizer
|
17 |
+
self.maxlen = maxlen
|
18 |
+
if hasattr(text_dataset, "info"):
|
19 |
+
self.info = text_dataset.info
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
return len(self.text_dataset)
|
23 |
+
|
24 |
+
def __getitem__(self, i):
|
25 |
+
text = self.text_dataset[i]
|
26 |
+
if self.field is not None:
|
27 |
+
text = text[self.field]
|
28 |
+
token_list = self.tokenizer.encode(
|
29 |
+
text, truncation=True, max_length=self.maxlen
|
30 |
+
)
|
31 |
+
position_ids = list(range(len(token_list)))
|
32 |
+
attention_mask = [1] * len(token_list)
|
33 |
+
return dict(
|
34 |
+
input_ids=torch.tensor(token_list),
|
35 |
+
position_ids=torch.tensor(position_ids),
|
36 |
+
attention_mask=torch.tensor(attention_mask),
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
def dict_to_(data, device):
|
41 |
+
"""
|
42 |
+
Moves a dictionary of tensors to the specified device.
|
43 |
+
"""
|
44 |
+
for k in data:
|
45 |
+
data[k] = data[k].to(device)
|
46 |
+
return data
|
47 |
+
|
48 |
+
|
49 |
+
def length_collation(token_size):
|
50 |
+
"""
|
51 |
+
Sorts a batch of sequences and breaks it up into subbatches
|
52 |
+
of same-sized sequences, padding as needed. Each batch
|
53 |
+
has no more than token_size total tokens (or a single
|
54 |
+
sequence, if the sequence happens to be larger).
|
55 |
+
"""
|
56 |
+
|
57 |
+
def collate_fn(items):
|
58 |
+
items = sorted(items, key=lambda x: -len(x["input_ids"]))
|
59 |
+
batches = []
|
60 |
+
batch = []
|
61 |
+
batch_width = 0
|
62 |
+
for item in items:
|
63 |
+
item_width = len(item["input_ids"])
|
64 |
+
if item_width == 0:
|
65 |
+
break
|
66 |
+
if batch_width * (len(batch) + 1) > token_size:
|
67 |
+
batches.append(make_padded_batch(batch))
|
68 |
+
batch = []
|
69 |
+
batch_width = 0
|
70 |
+
if not batch:
|
71 |
+
batch_width = item_width
|
72 |
+
batch.append(item)
|
73 |
+
if len(batch):
|
74 |
+
batches.append(make_padded_batch(batch))
|
75 |
+
return batches
|
76 |
+
|
77 |
+
return collate_fn
|
78 |
+
|
79 |
+
|
80 |
+
def make_padded_batch(items):
|
81 |
+
"""
|
82 |
+
Pads sequences in a batch, so they are all the same length as the longest.
|
83 |
+
"""
|
84 |
+
max_len = max(len(d["input_ids"]) for d in items)
|
85 |
+
if max_len == 0:
|
86 |
+
return {k: torch.zeros((0, 0), dtype=torch.long) for k in items[0]}
|
87 |
+
return {
|
88 |
+
k: pad_sequence([d[k] for d in items if len(d["input_ids"])], batch_first=True)
|
89 |
+
for k, v in items[0].items()
|
90 |
+
}
|
91 |
+
|
92 |
+
|
93 |
+
def flatten_masked_batch(data, mask):
|
94 |
+
"""
|
95 |
+
Flattens feature data, ignoring items that are masked out of attention.
|
96 |
+
"""
|
97 |
+
flat_data = data.view(-1, data.size(-1))
|
98 |
+
attended_tokens = mask.view(-1).nonzero()[:, 0]
|
99 |
+
return flat_data[attended_tokens]
|
easyeditor/models/wise/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
easyeditor/models/wise/WISE.py
ADDED
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from .utils import parent_module, brackets_to_periods, EarlyStopMeter, EditingMeanAct
|
7 |
+
import transformers
|
8 |
+
import numpy as np
|
9 |
+
from torch import Tensor
|
10 |
+
from torch.nn import CrossEntropyLoss
|
11 |
+
from transformers.activations import ACT2FN
|
12 |
+
from .merge import slerp, GTA, linear
|
13 |
+
import torch.nn as nn
|
14 |
+
import gc
|
15 |
+
|
16 |
+
merge_dict = {
|
17 |
+
'slerp': slerp(),
|
18 |
+
'ties': GTA('magnitude', 'sum', normalize=True),
|
19 |
+
'magnitude_norm': GTA('magnitude', None, normalize=True),
|
20 |
+
'magnitude': GTA('magnitude', None, normalize=False),
|
21 |
+
'sign': GTA(None, 'sum', normalize=True),
|
22 |
+
'dare_ties': GTA('rescaled_random', 'sum'),
|
23 |
+
'dare_linear': GTA('random', None),
|
24 |
+
'linear': linear()
|
25 |
+
}
|
26 |
+
|
27 |
+
edit_history = []
|
28 |
+
merge_group_edit_history = []
|
29 |
+
|
30 |
+
def euc(query, key, config, act_mask=None, infer=False):
|
31 |
+
# Euclidean distance
|
32 |
+
|
33 |
+
act_fn = ACT2FN[config.hidden_act]
|
34 |
+
l2_norm = torch.norm(act_fn(key) - act_fn(query), dim=-1)
|
35 |
+
if infer and l2_norm.size(1) > 100:
|
36 |
+
topk = torch.topk(l2_norm, k=1, largest=True)
|
37 |
+
return topk.values.mean()
|
38 |
+
|
39 |
+
if act_mask is not None:
|
40 |
+
return torch.sum(l2_norm * act_mask, dim=1) / torch.sum(act_mask, dim=1)
|
41 |
+
else:
|
42 |
+
return torch.mean(l2_norm, dim=-1)
|
43 |
+
|
44 |
+
|
45 |
+
class WISE(torch.nn.Module):
|
46 |
+
def __init__(self, config, model, device):
|
47 |
+
super(WISE, self).__init__()
|
48 |
+
self.config = config
|
49 |
+
self.model = model
|
50 |
+
self.config = config
|
51 |
+
if hasattr(self.model.config, 'hidden_act'):
|
52 |
+
self.config.hidden_act = self.model.config.hidden_act
|
53 |
+
elif hasattr(self.model.config, 'activation_function'):
|
54 |
+
self.config.hidden_act = self.model.config.activation_function
|
55 |
+
# self.tokenizer = model.tokenizer
|
56 |
+
layer = config.inner_params[0]
|
57 |
+
self.device = device
|
58 |
+
self.adapter_layer = None
|
59 |
+
self.original_layer = None
|
60 |
+
|
61 |
+
# --- ensure proper formatting (WISE edits weights matrices) ---
|
62 |
+
suffixes = [".weight", ".bias"]
|
63 |
+
self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
|
64 |
+
|
65 |
+
for n, p in self.model.named_parameters():
|
66 |
+
p.requires_grad = False
|
67 |
+
|
68 |
+
if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
|
69 |
+
conv1D = True
|
70 |
+
else:
|
71 |
+
conv1D = False
|
72 |
+
|
73 |
+
# --- Add WISE to chosen layers ---
|
74 |
+
self.edit_module = parent_module(self.model, brackets_to_periods(self.layer))
|
75 |
+
self.layer_name = self.layer.rsplit(".", 1)[-1]
|
76 |
+
adapter_layer = getattr(self.edit_module, self.layer_name)
|
77 |
+
|
78 |
+
if type(adapter_layer) is not WISEAdapter:
|
79 |
+
setattr(self.edit_module, self.layer_name, WISEAdapter(config, adapter_layer, conv1D=conv1D))
|
80 |
+
self.original_layer = copy.deepcopy(adapter_layer)
|
81 |
+
print(f"New weights successfully inserted into {layer}")
|
82 |
+
|
83 |
+
gc.collect()
|
84 |
+
torch.cuda.empty_cache()
|
85 |
+
gc.collect()
|
86 |
+
|
87 |
+
# Forward
|
88 |
+
def __call__(self, **kwargs):
|
89 |
+
if not self.config.retrieve:
|
90 |
+
if hasattr(self.get_adapter_layer(), 'editing') and not self.get_adapter_layer().editing:
|
91 |
+
# final merge
|
92 |
+
if not self.get_adapter_layer().original_layer.weight.equal(self.get_adapter_layer().new_weight) and self.get_adapter_layer().editing_total_cnt >= self.config.save_freq:
|
93 |
+
self.get_adapter_layer().memory_weight.append(self.get_adapter_layer().new_weight)
|
94 |
+
if len(self.get_adapter_layer().memory_weight) > 0 and self.get_adapter_layer().editing_total_cnt >= self.config.save_freq:
|
95 |
+
print('length of memory is ', len(self.get_adapter_layer().memory_weight), '!!!!!!')
|
96 |
+
self.get_adapter_layer().merge_weight()
|
97 |
+
return self.model(**kwargs)
|
98 |
+
|
99 |
+
def reset_layer(self):
|
100 |
+
layer = getattr(self.edit_module, self.layer_name)
|
101 |
+
del layer
|
102 |
+
setattr(self.edit_module, self.layer_name, self.get_adapter_layer().original_layer)
|
103 |
+
|
104 |
+
def get_adapter_layer(self):
|
105 |
+
adapter_layer = getattr(self.edit_module, self.layer_name)
|
106 |
+
assert type(adapter_layer) is WISEAdapter, print('Adapter Layer is not added correctly....')
|
107 |
+
return adapter_layer
|
108 |
+
|
109 |
+
# TODO: generation
|
110 |
+
def generate(self, *args, **kwargs):
|
111 |
+
setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
|
112 |
+
return self.model.generate(*args, **kwargs)
|
113 |
+
|
114 |
+
def edit(self, config, tokens, act_mask=None, deact_mask=None):
|
115 |
+
# for retrieve ##
|
116 |
+
global edit_history
|
117 |
+
global merge_group_edit_history
|
118 |
+
edit_history.append([{f"{k1}" : v1.to('cpu') for k1, v1 in tokens.items()}, False])
|
119 |
+
# for retrieve ##
|
120 |
+
last_prompt_token_loc = (tokens["labels"] == -100).sum(dim=-1) - 1
|
121 |
+
|
122 |
+
setattr(eval(f"self.model.{self.layer}"), "training", True)
|
123 |
+
setattr(eval(f"self.model.{self.layer}"), "editing", True)
|
124 |
+
self.get_adapter_layer().set_parameter_tunable()
|
125 |
+
if getattr(eval(f"self.model.{self.layer}"), "editing_total_cnt") % self.config.save_freq == 0:
|
126 |
+
self.get_adapter_layer().generate_activation_mask(self.config.mask_ratio)
|
127 |
+
|
128 |
+
# --- train Wise value ---
|
129 |
+
loss_meter = EarlyStopMeter()
|
130 |
+
for i in range(config.n_iter):
|
131 |
+
|
132 |
+
if i == 0:
|
133 |
+
# --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) ---
|
134 |
+
optimizer = torch.optim.SGD([self.get_adapter_layer().new_weight], config.edit_lr, weight_decay=1e-5)
|
135 |
+
|
136 |
+
ft_loss = self.__cal_ft_loss(tokens, last_prompt_token_loc)
|
137 |
+
|
138 |
+
act_loss = self.__cal_activation_loss(self.get_adapter_layer().original_layer_output, self.get_adapter_layer().new_weight_layer_output,
|
139 |
+
config=config, act_mask=act_mask, deact_mask=deact_mask)
|
140 |
+
loss = ft_loss + act_loss.to(ft_loss.device)
|
141 |
+
|
142 |
+
if loss_meter.stop():
|
143 |
+
self.get_adapter_layer().save_editing_activation() # add last gradient
|
144 |
+
break
|
145 |
+
if i == config.n_iter - 1:
|
146 |
+
self.get_adapter_layer().save_editing_activation() # add last gradient
|
147 |
+
|
148 |
+
if self.config.retrieve and self.get_adapter_layer().merge_cnt > 0 and self.config.replay:
|
149 |
+
memory_loss = []
|
150 |
+
for _ in merge_group_edit_history:
|
151 |
+
idx = 0
|
152 |
+
while True:
|
153 |
+
memo_input, is_used = _[idx]
|
154 |
+
if not is_used:
|
155 |
+
_[idx][1] = True
|
156 |
+
break
|
157 |
+
idx += 1
|
158 |
+
if idx == len(_): ## re Assign
|
159 |
+
for m in range(len(_)):
|
160 |
+
_[m][1] = False
|
161 |
+
idx = 0
|
162 |
+
|
163 |
+
memo_input = {f"{k1}" : v1.to(self.config.device) for k1, v1 in memo_input.items()}
|
164 |
+
self.model(**memo_input)
|
165 |
+
|
166 |
+
memory_act_loss = self.__cal_memory_neg_activation_loss(self.get_adapter_layer().original_layer_output,
|
167 |
+
self.get_adapter_layer().new_weight_layer_output, config=config,
|
168 |
+
act_mask=act_mask, deact_mask=deact_mask)
|
169 |
+
memory_loss.append(memory_act_loss.to(ft_loss.device))
|
170 |
+
del memo_input
|
171 |
+
neg_memo_loss = torch.stack(memory_loss).mean()
|
172 |
+
loss += neg_memo_loss
|
173 |
+
if len(edit_history) > 0:
|
174 |
+
memo_input = random.choice(edit_history)[0]
|
175 |
+
memo_input = {f"{k1}" : v1.to(self.config.device) for k1, v1 in memo_input.items()}
|
176 |
+
self.model(**memo_input)
|
177 |
+
|
178 |
+
pos_memo_loss = self.__cal_memory_pos_activation_loss(self.get_adapter_layer().original_layer_output,
|
179 |
+
self.get_adapter_layer().new_weight_layer_output, config=config,
|
180 |
+
act_mask=act_mask, deact_mask=deact_mask)
|
181 |
+
del memo_input
|
182 |
+
loss += pos_memo_loss.to(ft_loss.device)
|
183 |
+
# for replay Appendix B.3
|
184 |
+
|
185 |
+
optimizer.zero_grad()
|
186 |
+
|
187 |
+
loss.backward()
|
188 |
+
self.get_adapter_layer().mask_new_weight_gradient()
|
189 |
+
|
190 |
+
if self.config.retrieve and self.get_adapter_layer().merge_cnt > 0 and self.config.replay:
|
191 |
+
print(
|
192 |
+
f"loss {np.round(loss.item(), 3)} = {np.round(ft_loss.item(), 3)} + {np.round(act_loss.item(), 3)} + {np.round(neg_memo_loss.item(), 3)} + {np.round(pos_memo_loss.item(), 3)}"
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
print(
|
196 |
+
f"loss {np.round(loss.item(), 3)} = {np.round(ft_loss.item(), 3)} + {np.round(act_loss.item(), 3)}"
|
197 |
+
)
|
198 |
+
|
199 |
+
optimizer.step()
|
200 |
+
loss_meter.update(loss.item())
|
201 |
+
|
202 |
+
if type(self.config.norm_constraint) is float:
|
203 |
+
self.__norm_constraint(self.config.norm_constraint)
|
204 |
+
|
205 |
+
# --- pull out info we want to log from the Wise layer ---
|
206 |
+
setattr(eval(f"self.model.{self.layer}"), "editing", False)
|
207 |
+
setattr(eval(f"self.model.{self.layer}"), "training", False)
|
208 |
+
|
209 |
+
editing_total_cnt = getattr(eval(f"self.model.{self.layer}"), "editing_total_cnt") + 1
|
210 |
+
setattr(eval(f"self.model.{self.layer}"), "editing_total_cnt", editing_total_cnt)
|
211 |
+
#
|
212 |
+
if self.config.save_freq is not None and editing_total_cnt % self.config.save_freq == 0:
|
213 |
+
self.get_adapter_layer().save_weight()
|
214 |
+
print(f'Add New Weight to Memory...')
|
215 |
+
if editing_total_cnt % self.config.merge_freq == 0:
|
216 |
+
# for retrieve ##
|
217 |
+
merge_group_edit_history.append(edit_history)
|
218 |
+
edit_history = []
|
219 |
+
# for retrieve ##
|
220 |
+
|
221 |
+
self.get_adapter_layer().merge_weight()
|
222 |
+
print(f'Merge Weight of (New, Original) Matrix... with {self.config.merge_alg}')
|
223 |
+
|
224 |
+
def __norm_constraint(self, norm_constraint):
|
225 |
+
new_weight = self.get_adapter_layer().new_weight
|
226 |
+
original_weight = self.get_adapter_layer().weight
|
227 |
+
with torch.no_grad():
|
228 |
+
new_weight[...] = torch.clamp(
|
229 |
+
new_weight, min=original_weight - norm_constraint, max=original_weight + norm_constraint
|
230 |
+
)
|
231 |
+
|
232 |
+
def __cal_ft_loss(self, tokens, last_prompt_token_loc):
|
233 |
+
k = 1
|
234 |
+
bs = tokens["input_ids"].shape[0] - k
|
235 |
+
logits = self.model(**tokens).logits
|
236 |
+
shift_logits = logits[:-k, :-1, :].contiguous()
|
237 |
+
shift_labels = tokens['labels'][:-k, 1:].contiguous()
|
238 |
+
|
239 |
+
|
240 |
+
|
241 |
+
|
242 |
+
label_mask = torch.zeros_like(shift_labels, dtype=torch.bool)
|
243 |
+
|
244 |
+
for i, col_index in enumerate(last_prompt_token_loc[:-k]):
|
245 |
+
label_mask[i, col_index-1:] = True
|
246 |
+
|
247 |
+
shift_labels[~label_mask] = -100
|
248 |
+
|
249 |
+
log_probs = -nn.functional.log_softmax(shift_logits, dim=-1)
|
250 |
+
|
251 |
+
if shift_labels.dim() == log_probs.dim() - 1:
|
252 |
+
shift_labels = shift_labels.unsqueeze(-1)
|
253 |
+
|
254 |
+
padding_mask = shift_labels.eq(-100)
|
255 |
+
|
256 |
+
# In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
|
257 |
+
# will ignore them in any case.
|
258 |
+
shift_labels = torch.clamp(shift_labels, min=0)
|
259 |
+
|
260 |
+
nll_loss = log_probs.gather(dim=-1, index=shift_labels)
|
261 |
+
nll_loss.masked_fill_(padding_mask, 0.0)
|
262 |
+
|
263 |
+
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
|
264 |
+
nll_loss = nll_loss.sum() / num_active_elements
|
265 |
+
|
266 |
+
return nll_loss
|
267 |
+
# loss_fct = CrossEntropyLoss(reduction='none')
|
268 |
+
# loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
269 |
+
# loss = loss.view(bs, -1)
|
270 |
+
|
271 |
+
# label_mask = torch.zeros_like(loss, dtype=torch.bool)
|
272 |
+
|
273 |
+
# for i, col_index in enumerate(last_prompt_token_loc[:-k]):
|
274 |
+
# label_mask[i, col_index - 1:] = True
|
275 |
+
|
276 |
+
# ft_loss = ((loss * label_mask).sum(1) / label_mask.sum(1)).mean()
|
277 |
+
# return ft_loss
|
278 |
+
|
279 |
+
def __cal_activation_loss(self, original_layer_output, new_weight_layer_output, config=None, act_mask=None,
|
280 |
+
deact_mask=None):
|
281 |
+
k = 1
|
282 |
+
if act_mask is not None:
|
283 |
+
in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config,
|
284 |
+
act_mask=act_mask)
|
285 |
+
out_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config,
|
286 |
+
act_mask=deact_mask)
|
287 |
+
else:
|
288 |
+
in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config)
|
289 |
+
out_scope_dist = euc(original_layer_output[-k:, ...], new_weight_layer_output[-k:, ...], config)
|
290 |
+
|
291 |
+
loss = out_scope_dist.view(-1,1) - in_scope_dist + config.gamma
|
292 |
+
loss2 = out_scope_dist - config.alpha
|
293 |
+
loss3 = config.beta - in_scope_dist
|
294 |
+
loss3 = torch.mean(loss3[loss3 > 0]) if min(loss3[loss3 > 0].size()) > 0 else torch.tensor(0.).to(original_layer_output.device)
|
295 |
+
loss2 = torch.mean(loss2[loss2 > 0]) if min(loss2[loss2 > 0].size()) > 0 else torch.tensor(0.).to(original_layer_output.device)
|
296 |
+
loss = torch.mean(loss[loss > 0]) if min(loss[loss > 0].size()) > 0 else torch.tensor(0.).to(original_layer_output.device)
|
297 |
+
return loss + loss2 + loss3
|
298 |
+
|
299 |
+
def __cal_memory_pos_activation_loss(self, original_layer_output, new_weight_layer_output, config=None, act_mask=None,
|
300 |
+
deact_mask=None):
|
301 |
+
k = 1
|
302 |
+
in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config)
|
303 |
+
loss4 = 20 - in_scope_dist
|
304 |
+
|
305 |
+
return torch.mean(loss4[loss4 > 0]) if min(loss4[loss4 > 0].size()) > 0 else torch.tensor(0.)
|
306 |
+
|
307 |
+
def __cal_memory_neg_activation_loss(self, original_layer_output, new_weight_layer_output, config=None, act_mask=None,
|
308 |
+
deact_mask=None):
|
309 |
+
k = 1
|
310 |
+
in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config)
|
311 |
+
loss4 = in_scope_dist - 5
|
312 |
+
|
313 |
+
return torch.mean(loss4[loss4 > 0]) if min(loss4[loss4 > 0].size()) > 0 else torch.tensor(0.)
|
314 |
+
|
315 |
+
class WISEAdapter(torch.nn.Module):
|
316 |
+
def __init__(self, config, layer, conv1D):
|
317 |
+
super(WISEAdapter, self).__init__()
|
318 |
+
|
319 |
+
self.layer = layer
|
320 |
+
self.weight = self.layer.weight
|
321 |
+
self.device = layer.weight.device
|
322 |
+
self.config = config
|
323 |
+
self.new_weight = copy.deepcopy(self.weight)
|
324 |
+
self.original_layer = copy.deepcopy(self.layer)
|
325 |
+
self.memory_weight = []
|
326 |
+
self.memory_mean_act = []
|
327 |
+
self.merge_cnt = 0 # only for retrieve
|
328 |
+
assert not self.weight.requires_grad, print('Original Layer can not be tunable....')
|
329 |
+
|
330 |
+
self.used_mask = None
|
331 |
+
|
332 |
+
self.training = False
|
333 |
+
self.editing = False
|
334 |
+
self.conv1D = conv1D
|
335 |
+
|
336 |
+
self.editing_mean_act = EditingMeanAct()
|
337 |
+
self.editing_total_cnt = 0
|
338 |
+
|
339 |
+
def set_parameter_tunable(self):
|
340 |
+
self.new_weight.requires_grad = True
|
341 |
+
|
342 |
+
def save_weight(self):
|
343 |
+
self.memory_weight.append(copy.deepcopy(self.new_weight))
|
344 |
+
self.new_weight = copy.deepcopy(self.original_layer.weight)
|
345 |
+
if self.config.retrieve:
|
346 |
+
self.memory_mean_act.append(copy.deepcopy(self.editing_mean_act))
|
347 |
+
self.editing_mean_act = EditingMeanAct()
|
348 |
+
|
349 |
+
def merge_weight(self):
|
350 |
+
if self.config.save_freq is not None: # for ties dare dare_ties
|
351 |
+
if not self.config.retrieve:
|
352 |
+
merge_alg = merge_dict[self.config.merge_alg]
|
353 |
+
if self.original_layer.weight.equal(self.layer.weight):
|
354 |
+
cur_new_weight = merge_alg.execute([self.config.weights / len(self.memory_weight) for _ in range(len(self.memory_weight))], self.original_layer.weight, self.memory_weight, densities=self.config.densities)
|
355 |
+
else:
|
356 |
+
cur_new_weight = merge_alg.execute([0.4 / len(self.memory_weight) for _ in range(len(self.memory_weight))] + [0.6], self.original_layer.weight, self.memory_weight + [self.layer.weight], densities=self.config.densities)
|
357 |
+
self.layer.weight = torch.nn.Parameter(cur_new_weight.to(self.layer.weight.device), requires_grad=False)
|
358 |
+
self.new_weight = copy.deepcopy(self.original_layer.weight)
|
359 |
+
del self.memory_weight
|
360 |
+
self.memory_weight = []
|
361 |
+
else:
|
362 |
+
merge_alg = merge_dict[self.config.merge_alg]
|
363 |
+
merge_num = self.config.merge_freq // self.config.save_freq
|
364 |
+
assert len(self.memory_weight) >= merge_num
|
365 |
+
new_merge_weight = merge_alg.execute([self.config.weights / merge_num for _ in range(merge_num)], self.original_layer.weight, self.memory_weight[-merge_num:], densities=self.config.densities)
|
366 |
+
min_a = 1e9
|
367 |
+
for _ in range(merge_num):
|
368 |
+
self.memory_weight.pop()
|
369 |
+
edit_act = self.memory_mean_act.pop()
|
370 |
+
min_a = min(min_a, edit_act.min_act())
|
371 |
+
self.new_weight = copy.deepcopy(self.original_layer.weight)
|
372 |
+
self.memory_weight.append(new_merge_weight)
|
373 |
+
self.memory_mean_act.append(EditingMeanAct(min_a=min_a))
|
374 |
+
print(len(self.memory_weight))
|
375 |
+
assert len(self.memory_mean_act) == len(self.memory_weight)
|
376 |
+
self.merge_cnt += 1
|
377 |
+
else:
|
378 |
+
merge_alg = merge_dict[self.config.merge_alg]
|
379 |
+
cur_new_weight = merge_alg.execute(0.5, self.layer.weight, [self.new_weight],
|
380 |
+
densities=self.config.densities)
|
381 |
+
self.layer.weight = torch.nn.Parameter(cur_new_weight.to(self.layer.weight.device), requires_grad=False)
|
382 |
+
self.new_weight = copy.deepcopy(self.original_layer.weight)
|
383 |
+
|
384 |
+
def save_editing_activation(self):
|
385 |
+
in_scope_dist = euc(self.original_layer_output[:-1, ...], self.new_weight_layer_output[:-1, ...], self.config)
|
386 |
+
self.editing_mean_act.update(in_scope_dist.mean().item())
|
387 |
+
|
388 |
+
def generate_activation_mask(self, mask_ratio):
|
389 |
+
p_grad = self.new_weight.reshape(-1)
|
390 |
+
p_mask = np.random.choice([1, 0], size=p_grad.size()[0], p=[mask_ratio, 1 - mask_ratio])
|
391 |
+
p_mask = torch.from_numpy(p_mask).to(p_grad.device)
|
392 |
+
self.weight_mask = p_mask
|
393 |
+
|
394 |
+
def generate_non_overlapping_mask(self, mask_ratio):
|
395 |
+
p_grad = self.new_weight.reshape(-1)
|
396 |
+
mask_size = int(mask_ratio * p_grad.size()[0])
|
397 |
+
if self.used_mask is None:
|
398 |
+
self.used_mask = np.zeros(p_grad.size()[0], dtype=bool)
|
399 |
+
available_indices = np.where(~self.used_mask)[0] # 获取未被遮罩的元素索引
|
400 |
+
if len(available_indices) < mask_size:
|
401 |
+
raise ValueError("Not enough unused elements to generate a new mask.")
|
402 |
+
chosen_indices = np.random.choice(available_indices, size=mask_size, replace=False)
|
403 |
+
mask_array = np.zeros(p_grad.size()[0], dtype=int)
|
404 |
+
mask_array[chosen_indices] = 1
|
405 |
+
self.used_mask[chosen_indices] = True # 更新遮罩状态
|
406 |
+
self.weight_mask = torch.from_numpy(mask_array).to(p_grad.device)
|
407 |
+
|
408 |
+
def new_weight_forward(self, input: Tensor, weight) -> Tensor:
|
409 |
+
if self.conv1D:
|
410 |
+
size_out = input.size()[:-1] + (weight.size(1),)
|
411 |
+
input = torch.addmm(self.original_layer.bias, input.view(-1, input.size(-1)), weight)
|
412 |
+
input = input.view(size_out)
|
413 |
+
return input
|
414 |
+
else:
|
415 |
+
return F.linear(input, weight)
|
416 |
+
|
417 |
+
def mask_new_weight_gradient(self):
|
418 |
+
assert self.new_weight.grad is not None, print('Gradient Collection for New Weight error, gradient not found')
|
419 |
+
# Add gradient mask after the loss updates
|
420 |
+
p_size = self.new_weight.grad.size()
|
421 |
+
p_grad = self.new_weight.grad.reshape(-1)
|
422 |
+
|
423 |
+
# mask = torch.from_numpy(np.random.choice([0, 1], size=p_grad.size()[0], p=[.1, .9])).cuda()
|
424 |
+
p_grad = p_grad * self.weight_mask
|
425 |
+
self.new_weight.grad = p_grad.view(p_size).to(self.new_weight.grad.dtype)
|
426 |
+
|
427 |
+
def forward(self, *args):
|
428 |
+
if self.editing:
|
429 |
+
layer_out = self.new_weight_forward(*args, self.new_weight)
|
430 |
+
self.new_weight_layer_output = layer_out
|
431 |
+
self.original_layer_output = self.original_layer(*args)
|
432 |
+
else:
|
433 |
+
if not self.config.retrieve:
|
434 |
+
original_layer_output = self.original_layer(*args)
|
435 |
+
layer_output = self.layer(*args)
|
436 |
+
new_weight_layer_output = self.new_weight_forward(*args, self.new_weight)
|
437 |
+
dist2 = euc(original_layer_output, new_weight_layer_output, self.config, infer=True)
|
438 |
+
dist1 = euc(original_layer_output, layer_output, self.config, infer=True)
|
439 |
+
threshold = self.editing_mean_act.min_act() * self.config.act_ratio
|
440 |
+
|
441 |
+
if dist1.item() < threshold and dist2.item() < threshold:
|
442 |
+
layer_out = original_layer_output
|
443 |
+
elif dist1.item() > dist2.item():
|
444 |
+
layer_out = layer_output
|
445 |
+
else:
|
446 |
+
layer_out = new_weight_layer_output
|
447 |
+
else:
|
448 |
+
original_layer_output = self.original_layer(*args)
|
449 |
+
new_weight_layer_output = self.new_weight_forward(*args, self.new_weight)
|
450 |
+
dist1 = euc(original_layer_output, new_weight_layer_output, self.config, infer=True)
|
451 |
+
threshold = self.editing_mean_act.min_act() * self.config.act_ratio
|
452 |
+
min_dist = dist1
|
453 |
+
if min_dist.item() < threshold:
|
454 |
+
layer_out = original_layer_output
|
455 |
+
else:
|
456 |
+
layer_out = new_weight_layer_output
|
457 |
+
|
458 |
+
for i in range(len(self.memory_weight)):
|
459 |
+
memory_retrieve_weight = self.memory_weight[i]
|
460 |
+
memory_weight_layer_output = self.new_weight_forward(*args, memory_retrieve_weight)
|
461 |
+
dist = euc(original_layer_output, memory_weight_layer_output, self.config, infer=True)
|
462 |
+
if dist > min_dist and dist > self.memory_mean_act[i].min_act() * self.config.act_ratio:
|
463 |
+
layer_out = memory_weight_layer_output
|
464 |
+
min_dist = dist
|
465 |
+
print(dist, self.memory_mean_act[i].min_act() * self.config.act_ratio)
|
466 |
+
return layer_out
|
easyeditor/models/wise/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .wise_main import apply_wise_to_model
|
2 |
+
from .wise_hparams import WISEHyperParams
|
easyeditor/models/wise/merge/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .slerp import slerp
|
2 |
+
from .gta import GTA
|
3 |
+
from .linear import linear
|
easyeditor/models/wise/merge/gta.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Union, Tuple, List, Any, Literal, Optional
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from .utils import rescaled_random, magnitude, random_wo_rescaled
|
6 |
+
|
7 |
+
|
8 |
+
class GTA:
|
9 |
+
def __init__(self, sparsify_method=None, consensus_method=None, normalize=False):
|
10 |
+
self.sparsify_method = sparsify_method
|
11 |
+
self.consensus_method = consensus_method
|
12 |
+
|
13 |
+
self.normalize = normalize
|
14 |
+
|
15 |
+
def execute(
|
16 |
+
self,
|
17 |
+
weights,
|
18 |
+
base,
|
19 |
+
tensors,
|
20 |
+
densities,
|
21 |
+
**_kwargs,
|
22 |
+
) -> torch.Tensor:
|
23 |
+
# collect task vectors
|
24 |
+
densities = [densities for _ in range(len(tensors))]
|
25 |
+
# weights = [weights / len(tensors) for _ in range(len(tensors))]
|
26 |
+
assert len(densities) == len(weights) == len(tensors)
|
27 |
+
deltas, base = get_task_vectors(base, tensors)
|
28 |
+
if not deltas:
|
29 |
+
return base
|
30 |
+
|
31 |
+
# sparsify
|
32 |
+
if self.sparsify_method:
|
33 |
+
if self.sparsify_method == 'magnitude':
|
34 |
+
sparsify = magnitude
|
35 |
+
elif self.sparsify_method == 'rescaled_random':
|
36 |
+
sparsify = rescaled_random
|
37 |
+
elif self.sparsify_method == 'random':
|
38 |
+
sparsify = random_wo_rescaled
|
39 |
+
else:
|
40 |
+
raise NotImplementedError
|
41 |
+
for i, delta in enumerate(deltas):
|
42 |
+
deltas[i] = sparsify(
|
43 |
+
delta,
|
44 |
+
density=densities[i]
|
45 |
+
)
|
46 |
+
|
47 |
+
deltas = torch.stack(deltas, dim=0)
|
48 |
+
weights = torch.tensor(
|
49 |
+
[_ for _ in weights], dtype=deltas.dtype, device=deltas.device
|
50 |
+
)
|
51 |
+
while len(deltas.shape) > len(weights.shape):
|
52 |
+
weights.unsqueeze_(-1)
|
53 |
+
|
54 |
+
weighted_deltas = deltas * weights
|
55 |
+
|
56 |
+
# get sign consensus and mix deltas
|
57 |
+
if self.consensus_method:
|
58 |
+
mask_dtype = base.dtype
|
59 |
+
mask = get_mask(
|
60 |
+
weighted_deltas,
|
61 |
+
method=self.consensus_method,
|
62 |
+
mask_dtype=mask_dtype,
|
63 |
+
)
|
64 |
+
mixed_delta = (weighted_deltas * mask).sum(dim=0)
|
65 |
+
divisor = (weights * mask).sum(dim=0)
|
66 |
+
divisor[divisor == 0] = 1
|
67 |
+
else:
|
68 |
+
mixed_delta = weighted_deltas.sum(dim=0)
|
69 |
+
divisor = weights.sum(dim=0)
|
70 |
+
divisor[divisor.abs() < 1e-8] = 1
|
71 |
+
|
72 |
+
if self.normalize:
|
73 |
+
mixed_delta /= divisor
|
74 |
+
|
75 |
+
return (base + mixed_delta).to(base.dtype)
|
76 |
+
|
77 |
+
def get_task_vectors(
|
78 |
+
base: Union[np.ndarray, torch.Tensor],
|
79 |
+
tensors: Union[List[np.ndarray], List[torch.Tensor]],
|
80 |
+
) -> Tuple[List[Dict[str, Any]], torch.Tensor]:
|
81 |
+
|
82 |
+
res = []
|
83 |
+
for x in tensors:
|
84 |
+
delta = x - base
|
85 |
+
del x
|
86 |
+
res.append(delta)
|
87 |
+
return res, base
|
88 |
+
|
89 |
+
def get_mask(
|
90 |
+
delta: torch.Tensor,
|
91 |
+
method: Literal["sum", "count"] = "sum",
|
92 |
+
mask_dtype: Optional[torch.dtype] = None,
|
93 |
+
):
|
94 |
+
"""Returns a mask determining which delta vectors should be merged
|
95 |
+
into the final model.
|
96 |
+
|
97 |
+
For the methodology described in the TIES paper use 'sum'. For a
|
98 |
+
simpler naive count of signs, use 'count'."""
|
99 |
+
if mask_dtype is None:
|
100 |
+
mask_dtype = delta.dtype
|
101 |
+
|
102 |
+
sign = delta.sign().to(mask_dtype)
|
103 |
+
|
104 |
+
if method == "sum":
|
105 |
+
sign_weight = delta.sum(dim=0)
|
106 |
+
majority_sign = (sign_weight >= 0).to(mask_dtype) * 2 - 1
|
107 |
+
del sign_weight
|
108 |
+
elif method == "count":
|
109 |
+
majority_sign = (sign.sum(dim=0) >= 0).to(mask_dtype) * 2 - 1
|
110 |
+
else:
|
111 |
+
raise RuntimeError(f'Unimplemented mask method "{method}"')
|
112 |
+
|
113 |
+
return sign == majority_sign
|
easyeditor/models/wise/merge/linear.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from typing import Union, List
|
4 |
+
|
5 |
+
class linear:
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
def execute(
|
9 |
+
self,
|
10 |
+
t: Union[float, List[float]],
|
11 |
+
v0: Union[List[torch.Tensor], torch.Tensor],
|
12 |
+
v1: Union[List[torch.Tensor], torch.Tensor],
|
13 |
+
DOT_THRESHOLD: float = 0.9995,
|
14 |
+
eps: float = 1e-8,
|
15 |
+
densities = None,
|
16 |
+
):
|
17 |
+
if type(v0) is list:
|
18 |
+
v0 = v0[0]
|
19 |
+
if type(t) is list:
|
20 |
+
t = t[0]
|
21 |
+
if type(v1) is list:
|
22 |
+
v1 = v1[0]
|
23 |
+
|
24 |
+
return t * v1 + (1.0 - t) * v0
|
easyeditor/models/wise/merge/slerp.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from typing import Union, List
|
4 |
+
|
5 |
+
def lerp(
|
6 |
+
t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor]
|
7 |
+
) -> Union[np.ndarray, torch.Tensor]:
|
8 |
+
return (1 - t) * v0 + t * v1
|
9 |
+
|
10 |
+
def maybe_torch(v: np.ndarray, is_torch: bool):
|
11 |
+
if is_torch:
|
12 |
+
return torch.from_numpy(v)
|
13 |
+
return v
|
14 |
+
|
15 |
+
|
16 |
+
def normalize(v: np.ndarray, eps: float):
|
17 |
+
norm_v = np.linalg.norm(v)
|
18 |
+
if norm_v > eps:
|
19 |
+
v = v / norm_v
|
20 |
+
return v
|
21 |
+
|
22 |
+
class slerp:
|
23 |
+
def __init__(self):
|
24 |
+
pass
|
25 |
+
def execute(
|
26 |
+
self,
|
27 |
+
t: Union[float, List[float]],
|
28 |
+
v0: Union[List[torch.Tensor], torch.Tensor],
|
29 |
+
v1: Union[List[torch.Tensor], torch.Tensor],
|
30 |
+
DOT_THRESHOLD: float = 0.9995,
|
31 |
+
eps: float = 1e-8,
|
32 |
+
densities = None,
|
33 |
+
):
|
34 |
+
if type(v0) is list:
|
35 |
+
v0 = v0[0]
|
36 |
+
if type(v1) is list:
|
37 |
+
v1 = v1[0]
|
38 |
+
if type(t) is list:
|
39 |
+
t = t[0]
|
40 |
+
"""
|
41 |
+
Spherical linear interpolation
|
42 |
+
|
43 |
+
From: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
|
44 |
+
Args:
|
45 |
+
t (float/np.ndarray): Float value between 0.0 and 1.0
|
46 |
+
v0 (np.ndarray): Starting vector
|
47 |
+
v1 (np.ndarray): Final vector
|
48 |
+
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
49 |
+
colinear. Not recommended to alter this.
|
50 |
+
Returns:
|
51 |
+
v2 (np.ndarray): Interpolation vector between v0 and v1
|
52 |
+
"""
|
53 |
+
is_torch = False
|
54 |
+
if not isinstance(v0, np.ndarray):
|
55 |
+
is_torch = True
|
56 |
+
v0 = v0.detach().cpu().float().numpy()
|
57 |
+
if not isinstance(v1, np.ndarray):
|
58 |
+
is_torch = True
|
59 |
+
v1 = v1.detach().cpu().float().numpy()
|
60 |
+
|
61 |
+
# Copy the vectors to reuse them later
|
62 |
+
v0_copy = np.copy(v0)
|
63 |
+
v1_copy = np.copy(v1)
|
64 |
+
|
65 |
+
# Normalize the vectors to get the directions and angles
|
66 |
+
v0 = normalize(v0, eps)
|
67 |
+
v1 = normalize(v1, eps)
|
68 |
+
|
69 |
+
# Dot product with the normalized vectors (can't use np.dot in W)
|
70 |
+
dot = np.sum(v0 * v1)
|
71 |
+
|
72 |
+
# If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
|
73 |
+
if np.abs(dot) > DOT_THRESHOLD:
|
74 |
+
res = lerp(t, v0_copy, v1_copy)
|
75 |
+
return maybe_torch(res, is_torch)
|
76 |
+
|
77 |
+
# Calculate initial angle between v0 and v1
|
78 |
+
theta_0 = np.arccos(dot)
|
79 |
+
sin_theta_0 = np.sin(theta_0)
|
80 |
+
|
81 |
+
# Angle at timestep t
|
82 |
+
theta_t = theta_0 * t
|
83 |
+
sin_theta_t = np.sin(theta_t)
|
84 |
+
|
85 |
+
# Finish the slerp algorithm
|
86 |
+
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
87 |
+
s1 = sin_theta_t / sin_theta_0
|
88 |
+
res = s0 * v0_copy + s1 * v1_copy
|
89 |
+
|
90 |
+
return maybe_torch(res, is_torch)
|
easyeditor/models/wise/merge/utils.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def magnitude(tensor: torch.Tensor, density: float) -> torch.Tensor:
|
4 |
+
"""Masks out the smallest values, retaining a proportion of `density`."""
|
5 |
+
if density >= 1:
|
6 |
+
return tensor
|
7 |
+
|
8 |
+
k = int(density * tensor.view(-1).shape[0])
|
9 |
+
|
10 |
+
assert k > 0, "not gonna zero out the whole tensor buddy"
|
11 |
+
mask = torch.zeros_like(tensor)
|
12 |
+
w = tensor.abs().view(-1)
|
13 |
+
if w.device.type == "cpu":
|
14 |
+
w = w.float()
|
15 |
+
topk = torch.topk(w, k=k, largest=True)
|
16 |
+
mask.view(-1)[topk.indices] = 1
|
17 |
+
|
18 |
+
return tensor * mask
|
19 |
+
|
20 |
+
|
21 |
+
def bernoulli(
|
22 |
+
tensor: torch.Tensor, density: float, rescale: bool = True
|
23 |
+
) -> torch.Tensor:
|
24 |
+
if density >= 1:
|
25 |
+
return tensor
|
26 |
+
|
27 |
+
if (tensor.device.type != "cpu") or tensor.dtype == torch.bfloat16:
|
28 |
+
work_dtype = tensor.dtype
|
29 |
+
else:
|
30 |
+
# torch.bernoulli not implemented for float16 on CPU, upcast to float32
|
31 |
+
work_dtype = torch.float32
|
32 |
+
|
33 |
+
mask = torch.bernoulli(
|
34 |
+
torch.full_like(input=tensor, fill_value=density, dtype=work_dtype)
|
35 |
+
)
|
36 |
+
res = tensor.to(work_dtype) * mask
|
37 |
+
if rescale:
|
38 |
+
res /= density
|
39 |
+
return res.to(tensor.dtype)
|
40 |
+
|
41 |
+
def rescaled_random(tensor: torch.Tensor, density: float):
|
42 |
+
return bernoulli(tensor, density, rescale=True)
|
43 |
+
|
44 |
+
def random_wo_rescaled(tensor: torch.Tensor, density: float):
|
45 |
+
return bernoulli(tensor, density, rescale=False)
|
easyeditor/models/wise/utils.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import transformers
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import struct
|
5 |
+
import random
|
6 |
+
|
7 |
+
CONTEXT_TEMPLATES_CACHE = None
|
8 |
+
|
9 |
+
def find_sublist_start_index(list1, list2):
|
10 |
+
for i in range(len(list1) - len(list2)+1):
|
11 |
+
if all(a == b for a, b in zip(list1[i:i+len(list2)], list2)):
|
12 |
+
return i
|
13 |
+
return None
|
14 |
+
|
15 |
+
def get_inner_params(named_parameters, inner_names):
|
16 |
+
param_dict = dict(named_parameters)
|
17 |
+
return [(n, param_dict[n]) for n in inner_names]
|
18 |
+
|
19 |
+
def param_subset(named_parameters, inner_names):
|
20 |
+
param_dict = dict(named_parameters)
|
21 |
+
return [param_dict[n] for n in inner_names]
|
22 |
+
|
23 |
+
def print_trainable_parameters(model, new_weight, mask_ratio):
|
24 |
+
original_parameters = 0
|
25 |
+
new_weight_param = 0
|
26 |
+
for _, param in new_weight.named_parameters():
|
27 |
+
new_weight_param += param.numel()
|
28 |
+
for _, param in model.named_parameters():
|
29 |
+
original_parameters += param.numel()
|
30 |
+
print(f"Original Model params: {original_parameters} || New Weight params: {new_weight_param} || trainable%: {100 * new_weight_param * (1-mask_ratio) / original_parameters}")
|
31 |
+
|
32 |
+
|
33 |
+
def parent_module(model, pname):
|
34 |
+
components = pname.split('.')
|
35 |
+
parent = model
|
36 |
+
|
37 |
+
for component in components[:-1]:
|
38 |
+
if hasattr(parent, component):
|
39 |
+
parent = getattr(parent, component)
|
40 |
+
elif component.isdigit():
|
41 |
+
parent = parent[int(component)]
|
42 |
+
else:
|
43 |
+
raise RuntimeError(f"Couldn't find child module {component}")
|
44 |
+
|
45 |
+
if not hasattr(parent, components[-1]):
|
46 |
+
raise RuntimeError(f"Couldn't find child module {components[-1]}")
|
47 |
+
|
48 |
+
return parent
|
49 |
+
|
50 |
+
def uuid(digits=4):
|
51 |
+
if not hasattr(uuid, "uuid_value"):
|
52 |
+
uuid.uuid_value = struct.unpack('I', os.urandom(4))[0] % int(10**digits)
|
53 |
+
|
54 |
+
return uuid.uuid_value
|
55 |
+
|
56 |
+
def ckpt_dir():
|
57 |
+
"""returns the directory in which to store model checkpoints"""
|
58 |
+
path = "./ckpts/"
|
59 |
+
if not os.path.exists(path):
|
60 |
+
os.makedirs(path)
|
61 |
+
return path
|
62 |
+
|
63 |
+
def brackets_to_periods(name):
|
64 |
+
return name.replace("[", ".").replace("]", "")
|
65 |
+
|
66 |
+
def get_params(model):
|
67 |
+
return model.state_dict()
|
68 |
+
|
69 |
+
def get_shape(p, model):
|
70 |
+
# We need to flip the shapes since OpenAI gpt2 uses convs instead of linear
|
71 |
+
return p.shape if isinstance(model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0])
|
72 |
+
|
73 |
+
def get_logits(x):
|
74 |
+
return x.logits if hasattr(x, "logits") else x
|
75 |
+
|
76 |
+
|
77 |
+
LOC_PROMPTS = ['nq question: who played mr grainger in are you being served Arthur Brough',
|
78 |
+
"nq question: who sings the song let's hear it for the boy Deniece Williams",
|
79 |
+
"nq question: who wrote all my ex's live in texas Sanger D. Shafer",
|
80 |
+
"nq question: when is the america's got talent finale 2018 September 19, 2018",
|
81 |
+
"nq question: what is the fifth biggest state in the united states New Mexico",
|
82 |
+
"nq question: who plays john black on days of our lives Drake Hogestyn (/ˈhʌdʒstən/; born Donald Drake Hogestyn",
|
83 |
+
"nq question: what is the name of the new star wars movie The Last Jedi",
|
84 |
+
"nq question: what is the main principle of path-goal theory a leader's behavior is contingent to the satisfaction, motivation and performance of his or her subordinates",
|
85 |
+
"nq question: who plays luna's dad in harry potter Ifans",
|
86 |
+
"nq question: who has the most grammy nominations as an artist Quincy Jones",
|
87 |
+
"nq question: what is the control unit function in the cpu tells the computer's memory, arithmetic/logic unit and input and output devices how to respond to the instructions that have been sent to the processor",
|
88 |
+
"nq question: who was the first indian prime minister to visit palestine Narendra Modi",
|
89 |
+
"nq question: where did the plane carrying the marshall football team crash into a hill just short of the Tri-State Airport",
|
90 |
+
"nq question: what movie is the line lighten up francis from Stripes",
|
91 |
+
"nq question: set of rules for solving a mathematical or computational problem in finite number of steps an algorithm",
|
92 |
+
"nq question: who changed indian capital from calcutta to delhi George V",
|
93 |
+
"nq question: who did bette midler play in the rose Mary Rose Foster (The Rose)",
|
94 |
+
"nq question: how much did it cost to make the new star wars movie $200–217 million"
|
95 |
+
]
|
96 |
+
|
97 |
+
def tokenize(batch, tokenizer, device, context_templates=None, hparams=None):
|
98 |
+
prompt, label = batch["prompt"], batch["target_new"]
|
99 |
+
batch['loc_prompt'] = random.choice(LOC_PROMPTS)
|
100 |
+
if not isinstance(prompt, list):
|
101 |
+
prompt=[prompt]
|
102 |
+
if not isinstance(label, list):
|
103 |
+
label=[label]
|
104 |
+
mask_token = -100 # ignore_index of CrossEntropyLoss
|
105 |
+
|
106 |
+
# input
|
107 |
+
full_prompt = [f"{templ.format(p + ' ' + l)}" for p, l in zip(prompt, label) for templ in context_templates]
|
108 |
+
full_prompt += [batch['loc_prompt']] # add for subject activation
|
109 |
+
|
110 |
+
prompt_ids = tokenizer([f"{templ.format(p)}" for p in prompt for templ in context_templates], return_tensors="pt", padding=True, truncation=True)["input_ids"]
|
111 |
+
|
112 |
+
num_prompt_toks = [len(i) for i in prompt_ids]
|
113 |
+
tokens = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
|
114 |
+
tokens["labels"] = tokens["input_ids"].clone()
|
115 |
+
if hparams.objective_optimization == 'only_label':
|
116 |
+
for i in range(len(num_prompt_toks)):
|
117 |
+
tokens["labels"][i][:num_prompt_toks[i]] = mask_token
|
118 |
+
|
119 |
+
tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token
|
120 |
+
if batch['loc_prompt'] in batch['prompt']: ## subject: Factual Editing
|
121 |
+
subject_token = tokenizer.encode(' ' + batch['loc_prompt'], add_special_tokens=False)
|
122 |
+
subject_token1 = tokenizer.encode(batch['loc_prompt'], add_special_tokens=False)
|
123 |
+
subject_length = len(subject_token)
|
124 |
+
act_mask = torch.zeros_like(tokens['input_ids'][:-1])
|
125 |
+
deact_mask = torch.zeros_like(tokens['input_ids'][:-1])
|
126 |
+
for i, token in enumerate(tokens['input_ids'][:-1]):
|
127 |
+
start_idx = find_sublist_start_index(token.detach().cpu().numpy().tolist(), subject_token)
|
128 |
+
if start_idx is None:
|
129 |
+
start_idx = find_sublist_start_index(token.detach().cpu().numpy().tolist(), subject_token1)
|
130 |
+
subject_length = len(subject_token1)
|
131 |
+
act_mask[i][start_idx: start_idx + subject_length] = 1
|
132 |
+
deact_mask[i][:start_idx] = 1
|
133 |
+
deact_mask[i][start_idx + subject_length:] = 1
|
134 |
+
|
135 |
+
act_mask = act_mask.to(device)
|
136 |
+
deact_mask = deact_mask.to(device)
|
137 |
+
else: # General Editing
|
138 |
+
act_mask = None
|
139 |
+
deact_mask = None
|
140 |
+
|
141 |
+
tokens = {f"{k1}" : v1.to(device) for k1, v1 in tokens.items()}
|
142 |
+
return tokens, act_mask, deact_mask
|
143 |
+
|
144 |
+
class EarlyStopMeter:
|
145 |
+
"""Computes and stores the average and current value"""
|
146 |
+
|
147 |
+
def __init__(self):
|
148 |
+
self.reset()
|
149 |
+
|
150 |
+
def reset(self):
|
151 |
+
self.avg = 0
|
152 |
+
self.pre = 0
|
153 |
+
self.val = 1e9
|
154 |
+
self.sum = 0
|
155 |
+
self.count = 0
|
156 |
+
|
157 |
+
def update(self, val):
|
158 |
+
self.pre = self.val
|
159 |
+
self.val = val
|
160 |
+
self.sum += val
|
161 |
+
self.count += 1
|
162 |
+
self.avg = self.sum / self.count
|
163 |
+
|
164 |
+
def stop(self, ):
|
165 |
+
return abs(self.val - self.pre) <= 1e-4 and self.val <= 0.02
|
166 |
+
|
167 |
+
class EditingMeanAct:
|
168 |
+
"""Computes and stores the average and current value"""
|
169 |
+
|
170 |
+
def __init__(self, min_a=1e9):
|
171 |
+
self.reset(min_a=min_a)
|
172 |
+
|
173 |
+
def reset(self, min_a=1e9):
|
174 |
+
self.avg = 0
|
175 |
+
self.count = 0
|
176 |
+
self.sum = 0
|
177 |
+
self.min_a = min_a
|
178 |
+
|
179 |
+
def update(self, val):
|
180 |
+
self.sum += val
|
181 |
+
self.count += 1
|
182 |
+
self.avg = self.sum / self.count
|
183 |
+
self.min_a = min(self.min_a, val)
|
184 |
+
|
185 |
+
def mean_act(self):
|
186 |
+
return self.avg
|
187 |
+
def min_act(self):
|
188 |
+
return self.min_a
|
189 |
+
|
190 |
+
def get_context_templates(model, tok, length_params, device):
|
191 |
+
global CONTEXT_TEMPLATES_CACHE
|
192 |
+
|
193 |
+
if CONTEXT_TEMPLATES_CACHE is None:
|
194 |
+
CONTEXT_TEMPLATES_CACHE = []
|
195 |
+
prompt_tok = tok(
|
196 |
+
["I", "You", "Because", 'Yes', 'Q: '],
|
197 |
+
padding=True,
|
198 |
+
return_tensors="pt"
|
199 |
+
).to(device)
|
200 |
+
for length, n_gen in length_params:
|
201 |
+
|
202 |
+
gen_token = model.generate(
|
203 |
+
input_ids=prompt_tok['input_ids'],
|
204 |
+
attention_mask=prompt_tok['attention_mask'],
|
205 |
+
max_new_tokens=length,
|
206 |
+
num_beams=n_gen // 5,
|
207 |
+
num_return_sequences=n_gen // 5,
|
208 |
+
pad_token_id=tok.eos_token_id,
|
209 |
+
)
|
210 |
+
CONTEXT_TEMPLATES_CACHE += tok.batch_decode(gen_token, skip_special_tokens=True)
|
211 |
+
CONTEXT_TEMPLATES_CACHE = ['{}'] + [_ + ' {}' for _ in CONTEXT_TEMPLATES_CACHE]
|
212 |
+
return CONTEXT_TEMPLATES_CACHE
|
213 |
+
|
easyeditor/models/wise/wise_hparams.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Union
|
3 |
+
from ...util.hparams import HyperParams
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class WISEHyperParams(HyperParams):
|
9 |
+
# Experiments
|
10 |
+
|
11 |
+
edit_lr: float
|
12 |
+
n_iter: int
|
13 |
+
# Method
|
14 |
+
objective_optimization: str
|
15 |
+
mask_ratio: float
|
16 |
+
alpha: float # act_margin[0]
|
17 |
+
beta: float # act_margin[1]
|
18 |
+
gamma: float # act_margin[2]
|
19 |
+
act_ratio: float
|
20 |
+
merge_freq: int
|
21 |
+
retrieve: bool
|
22 |
+
replay: bool
|
23 |
+
save_freq: Union[int, None]
|
24 |
+
merge_alg: str
|
25 |
+
norm_constraint: float
|
26 |
+
# Module templates
|
27 |
+
inner_params: List[str]
|
28 |
+
weights: Union[float, None]
|
29 |
+
densities: Union[float, None]
|
30 |
+
|
31 |
+
device: int
|
32 |
+
alg_name: str
|
33 |
+
model_name: str
|
34 |
+
|
35 |
+
# Defaults
|
36 |
+
batch_size: int = 1
|
37 |
+
max_length: int = 30
|
38 |
+
model_parallel: bool = False
|
39 |
+
|
40 |
+
@classmethod
|
41 |
+
def from_hparams(cls, hparams_name_or_path: str):
|
42 |
+
if '.yaml' not in hparams_name_or_path:
|
43 |
+
hparams_name_or_path = hparams_name_or_path + '.yaml'
|
44 |
+
|
45 |
+
with open(hparams_name_or_path, "r") as stream:
|
46 |
+
config = yaml.safe_load(stream)
|
47 |
+
config = super().construct_float_from_scientific_notation(config)
|
48 |
+
|
49 |
+
assert config['merge_freq'] % config['save_freq'] == 0, 'merge_freq need to be divisible by save_freq (like 1000 / 500)'
|
50 |
+
assert len(config['act_margin']) == 3
|
51 |
+
config['alpha'], config['beta'], config['gamma'] = config['act_margin'][0], config['act_margin'][1], config['act_margin'][2]
|
52 |
+
config.pop('act_margin')
|
53 |
+
|
54 |
+
assert (config and config['alg_name'] == 'WISE'), \
|
55 |
+
f'WISEHyperParams can not load from {hparams_name_or_path}. alg_name is {config["alg_name"]}'
|
56 |
+
return cls(**config)
|
easyeditor/models/wise/wise_main.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Tuple
|
2 |
+
from copy import deepcopy
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
+
from .WISE import WISE
|
5 |
+
from .utils import tokenize, get_context_templates
|
6 |
+
from .wise_hparams import WISEHyperParams
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
def apply_wise_to_model(
|
10 |
+
model: AutoModelForCausalLM,
|
11 |
+
tok: AutoTokenizer,
|
12 |
+
request: List[Dict],
|
13 |
+
hparams: WISEHyperParams,
|
14 |
+
num_steps: int,
|
15 |
+
edit_lr: float,
|
16 |
+
copy=False,
|
17 |
+
return_orig_weights=False,
|
18 |
+
keep_original_weight=False,
|
19 |
+
**kwargs: Any,
|
20 |
+
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
|
21 |
+
if copy:
|
22 |
+
model = deepcopy(model)
|
23 |
+
weights_copy = {}
|
24 |
+
hparams.n_iter = num_steps
|
25 |
+
hparams.edit_lr = edit_lr
|
26 |
+
context_templates = get_context_templates(model, tok, length_params=[[5,5], [10,5]], device=hparams.device)
|
27 |
+
editor = WISE(model=model, config=hparams, device=hparams.device)
|
28 |
+
print(
|
29 |
+
f"Executing WISE algorithm for the update: "
|
30 |
+
f"[{request['prompt']}] -> [{request['target_new']}]"
|
31 |
+
)
|
32 |
+
tokens, act_mask, deact_mask = tokenize(request, tokenizer=tok, device=hparams.device, context_templates=context_templates, hparams=hparams)
|
33 |
+
editor.edit(config=hparams, tokens=tokens, act_mask=act_mask, deact_mask=deact_mask)
|
34 |
+
|
35 |
+
editor.to('cpu')
|
36 |
+
gr.Info("Completed editing via WISE!")
|
37 |
+
|
38 |
+
return editor
|
easyeditor/util/__pycache__/__init__.cpython-39.pyc
DELETED
Binary file (216 Bytes)
|
|
easyeditor/util/__pycache__/hparams.cpython-39.pyc
DELETED
Binary file (1.21 kB)
|
|
easyeditor/util/__pycache__/logit_lens.cpython-39.pyc
DELETED
Binary file (3.36 kB)
|
|
easyeditor/util/__pycache__/nethook.cpython-39.pyc
DELETED
Binary file (13.2 kB)
|
|
hparams/GRACE/gpt2.yaml
CHANGED
@@ -7,7 +7,7 @@ inner_params:
|
|
7 |
|
8 |
edit_lr: 1.0
|
9 |
n_iter: 30
|
10 |
-
eps:
|
11 |
dist_fn: euc # euc, mmd, cos
|
12 |
val_init: cold # cold, warm
|
13 |
val_train: sgd # sgd, pert
|
|
|
7 |
|
8 |
edit_lr: 1.0
|
9 |
n_iter: 30
|
10 |
+
eps: 500.0
|
11 |
dist_fn: euc # euc, mmd, cos
|
12 |
val_init: cold # cold, warm
|
13 |
val_train: sgd # sgd, pert
|
hparams/ROME/gpt2.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
alg_name: "ROME"
|
2 |
+
model_name: "./hugging_cache/gpt2-xl"
|
3 |
+
stats_dir: "./data/stats"
|
4 |
+
device: cpu
|
5 |
+
layers: [5]
|
6 |
+
fact_token: "subject_last"
|
7 |
+
v_num_grad_steps: 20
|
8 |
+
v_lr: 5e-1
|
9 |
+
v_loss_layer: 11
|
10 |
+
v_weight_decay: 0.5
|
11 |
+
clamp_norm_factor: 4
|
12 |
+
kl_factor: 0.0625
|
13 |
+
mom2_adjustment: false
|
14 |
+
context_template_length_params: [[5, 10], [10, 10]]
|
15 |
+
rewrite_module_tmp: "transformer.h.{}.mlp.c_proj"
|
16 |
+
layer_module_tmp: "transformer.h.{}"
|
17 |
+
mlp_module_tmp: "transformer.h.{}.mlp"
|
18 |
+
attn_module_tmp: "transformer.h.{}.attn"
|
19 |
+
ln_f_module: "transformer.ln_f"
|
20 |
+
lm_head_module: "transformer.wte"
|
21 |
+
mom2_dataset: "wikipedia"
|
22 |
+
mom2_n_samples: 100000
|
23 |
+
mom2_dtype: "float32"
|
24 |
+
model_parallel: false
|
25 |
+
fp16: false
|
26 |
+
|
hparams/WISE/gpt2.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
alg_name: "WISE"
|
2 |
+
model_name: "./hugging_cache/gpt2"
|
3 |
+
device: cpu
|
4 |
+
|
5 |
+
mask_ratio: 0.2
|
6 |
+
edit_lr: 1.0
|
7 |
+
n_iter: 40
|
8 |
+
norm_constraint: 1.0
|
9 |
+
act_margin: [15.0, 40.0, 20.0] # alpha, beta, gamma
|
10 |
+
act_ratio: 0.7
|
11 |
+
save_freq: 1
|
12 |
+
merge_freq: 1
|
13 |
+
merge_alg: 'ties'
|
14 |
+
objective_optimization: 'only_label'
|
15 |
+
inner_params:
|
16 |
+
- transformer.h[8].mlp.c_fc.weight
|
17 |
+
|
18 |
+
|
19 |
+
## alternative: WISE-Merge, WISE-Retrieve
|
20 |
+
|
21 |
+
# for merge (if merge)
|
22 |
+
densities: 0.53
|
23 |
+
weights: 1.0
|
24 |
+
|
25 |
+
# for retrieve (if retrieve, pls set to True)
|
26 |
+
retrieve: True
|
27 |
+
replay: False # True --> will replay the past editing instances: see https://arxiv.org/abs/2405.14768 Appendix B.3
|
utils.py
CHANGED
@@ -1,42 +1,233 @@
|
|
1 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
|
2 |
from transformers import GPT2TokenizerFast, GPT2Tokenizer
|
3 |
-
from easyeditor import apply_grace_to_model, GraceHyperParams,nethook
|
4 |
import torch
|
5 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
def
|
10 |
request={"prompt":prompt,"target_new":target_new}
|
11 |
hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")
|
12 |
|
13 |
-
model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
|
14 |
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
|
15 |
tok.pad_token_id = tok.eos_token_id
|
16 |
global edit_model
|
17 |
-
edit_model = apply_grace_to_model(model,tok,request,hparams, num_steps,
|
18 |
-
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
def generate(input_text, target_new=None):
|
21 |
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
|
22 |
-
hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")
|
23 |
tok.pad_token_id = tok.eos_token_id
|
24 |
-
|
25 |
global edit_model
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
max_new_tokens = len(tok.encode(target_new))
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
|
2 |
from transformers import GPT2TokenizerFast, GPT2Tokenizer
|
3 |
+
from easyeditor import apply_grace_to_model, GraceHyperParams,nethook, apply_wise_to_model, WISEHyperParams, ROMEHyperParams, apply_rome_to_model
|
4 |
import torch
|
5 |
import gradio as gr
|
6 |
+
import json
|
7 |
+
import numpy as np
|
8 |
+
import random
|
9 |
+
seed=0
|
10 |
+
random.seed(seed)
|
11 |
+
torch.manual_seed(seed)
|
12 |
+
np.random.seed(seed)
|
13 |
+
torch.cuda.manual_seed_all(seed)
|
14 |
+
model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
|
15 |
|
16 |
|
17 |
+
def clear():
|
18 |
+
global model
|
19 |
+
model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
|
20 |
+
return '', ''
|
21 |
|
22 |
+
def grace_edit(prompt, target_new, num_steps, edit_lr):
|
23 |
request={"prompt":prompt,"target_new":target_new}
|
24 |
hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")
|
25 |
|
|
|
26 |
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
|
27 |
tok.pad_token_id = tok.eos_token_id
|
28 |
global edit_model
|
29 |
+
edit_model = apply_grace_to_model(model,tok,request,hparams, num_steps, edit_lr)
|
30 |
+
return prompt, target_new
|
31 |
+
|
32 |
+
def wise_edit(prompt, target_new, num_steps, edit_lr):
|
33 |
+
request={"prompt":prompt,"target_new":target_new}
|
34 |
+
hparams = WISEHyperParams.from_hparams("./hparams/WISE/gpt2.yaml")
|
35 |
+
|
36 |
+
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
|
37 |
+
tok.pad_token_id = tok.eos_token_id
|
38 |
+
global edit_model
|
39 |
+
edit_model = apply_wise_to_model(model,tok,request,hparams, num_steps, edit_lr)
|
40 |
+
return prompt, target_new
|
41 |
+
|
42 |
+
def rome_edit(prompt, target_new, num_steps, edit_lr):
|
43 |
+
request={"prompt":prompt,"target_new":target_new}
|
44 |
+
hparams = ROMEHyperParams.from_hparams("./hparams/ROME/gpt2.yaml")
|
45 |
|
|
|
46 |
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
|
|
|
47 |
tok.pad_token_id = tok.eos_token_id
|
|
|
48 |
global edit_model
|
49 |
+
edit_model = apply_rome_to_model(model,tok,request,hparams, num_steps, edit_lr)
|
50 |
+
return prompt, target_new
|
51 |
+
|
52 |
+
def edit(edit_alg, prompt, target_new, num_steps, edit_lr):
|
53 |
+
if edit_alg == 'GRACE':
|
54 |
+
return grace_edit(prompt, target_new, num_steps, edit_lr)
|
55 |
+
elif edit_alg == 'WISE':
|
56 |
+
return wise_edit(prompt, target_new, num_steps, edit_lr)
|
57 |
+
elif edit_alg == 'ROME':
|
58 |
+
return rome_edit(prompt, target_new, num_steps, edit_lr)
|
59 |
+
else:
|
60 |
+
raise NotImplementedError
|
61 |
+
|
62 |
+
def generate(input_text, target_new=None, edit_alg=None):
|
63 |
+
loc_output = {
|
64 |
+
"nq question: where does the phrase good bye felicia come from": "intended as a dismissive kiss-off",
|
65 |
+
"nq question: which best describes timbuktu under the mali empire": "a place of trade, entertainment, and education",
|
66 |
+
"nq question: where do the question marks go in spanish": "before the first letter of an interrogative sentence",
|
67 |
+
"nq question: who replaces the vice president in the senate": "Speaker of the House of Representatives",
|
68 |
+
"nq question: active transport performs which function in a cell": "uses cellular energy to move them against a gradient, polar repulsion, or other resistance"
|
69 |
+
}
|
70 |
+
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
|
71 |
+
tok.pad_token_id = tok.eos_token_id
|
72 |
+
global edit_model
|
73 |
+
|
74 |
+
if edit_alg == 'GRACE' and target_new is not None:
|
75 |
+
max_new_tokens = len(tok.encode(' ' + target_new))
|
76 |
+
prompt_len = len(input_text)
|
77 |
+
input_ids = tok.encode(input_text, return_tensors='pt').to('cpu')
|
78 |
+
edit_output = edit_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, use_cache=False)
|
79 |
+
edit_reply = tok.decode(edit_output[0], skip_special_tokens=False)
|
80 |
+
torch.cuda.empty_cache()
|
81 |
+
|
82 |
+
ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
|
83 |
+
ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
|
84 |
+
ori_reply = tok.decode(ori_output[0], skip_special_tokens=False)
|
85 |
+
ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)]
|
86 |
+
edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)]
|
87 |
+
return ori_reply, edit_reply
|
88 |
+
else:
|
89 |
+
if target_new is None:
|
90 |
+
target_new = loc_output[input_text]
|
91 |
+
max_new_tokens = len(tok.encode(target_new))
|
92 |
+
input_ids = tok.encode(input_text + ' ' + target_new, return_tensors='pt').to('cpu')
|
93 |
+
prompt_len = len(tok.encode(input_text))
|
94 |
+
edit_output = edit_model(input_ids=input_ids).logits
|
95 |
+
edit_output = torch.argmax(edit_output, dim=-1)
|
96 |
+
|
97 |
+
edit_reply = input_text + ' ' + tok.decode(edit_output[0][prompt_len-1:-1], skip_special_tokens=True)
|
98 |
+
torch.cuda.empty_cache()
|
99 |
+
|
100 |
+
|
101 |
+
ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
|
102 |
+
# ori_output = ori_model.generate(tok.encode(input_text, return_tensors='pt').to('cpu'), max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
|
103 |
+
# ori_reply = tok.decode(ori_output[0], skip_special_tokens=True)
|
104 |
+
ori_output = ori_model(input_ids=input_ids).logits
|
105 |
+
ori_output = torch.argmax(ori_output, dim=-1)
|
106 |
+
|
107 |
+
ori_reply = input_text + ' ' + tok.decode(ori_output[0][prompt_len-1:-1], skip_special_tokens=True)
|
108 |
+
torch.cuda.empty_cache()
|
109 |
+
ori_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(ori_reply)]
|
110 |
+
edit_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(edit_reply)]
|
111 |
+
return ori_reply, edit_reply
|
112 |
+
|
113 |
+
def union_generate(input_text, para_input_text, target_new=None, edit_alg=None):
|
114 |
+
res1, res2 = generate(input_text, target_new=target_new, edit_alg=edit_alg)
|
115 |
+
res3, res4 = generate(para_input_text, target_new=target_new, edit_alg=edit_alg)
|
116 |
+
return res1, res2, res3, res4
|
117 |
+
|
118 |
+
# continuous_examples=[
|
119 |
+
# ["Who is the architect for Toodyay Fire Station?","Wong Tung & Sons"]
|
120 |
+
# ]
|
121 |
+
|
122 |
+
continuous_examples=[
|
123 |
+
["Who is the architect for Toodyay Fire Station?", "Wong Tung & Sons"],
|
124 |
+
["What company makes Springfield Armory XDM?", "Messerschmitt"],
|
125 |
+
["Which fictional universe is Chlorophyll Kid part of?", "Image Universe"],
|
126 |
+
["What year did Sunnyside Hospital cease to exist?", "1962"],
|
127 |
+
["Which designer was responsible for Holmenkollen Chapel?", "Inigo Jones"],
|
128 |
+
["What piece of fiction does Jack Harkness appear in?", "Lost"]
|
129 |
+
]
|
130 |
+
|
131 |
+
|
132 |
+
global grace_hparams
|
133 |
+
grace_hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")
|
134 |
+
global wise_hparams
|
135 |
+
wise_hparams = WISEHyperParams.from_hparams("./hparams/WISE/gpt2.yaml")
|
136 |
+
global tokenizer
|
137 |
+
tokenizer = GPT2Tokenizer.from_pretrained("./models/gpt2")
|
138 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
139 |
+
global grace_continuous_model
|
140 |
+
global wise_continuous_model
|
141 |
+
grace_continuous_model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
|
142 |
+
wise_continuous_model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
|
143 |
+
|
144 |
+
|
145 |
+
for prompt, target_new in continuous_examples:
|
146 |
+
request={"prompt":prompt,"target_new":target_new}
|
147 |
+
apply_grace_to_model(grace_continuous_model,tokenizer,request,grace_hparams, 40, 1.0)
|
148 |
+
|
149 |
+
for prompt, target_new in continuous_examples:
|
150 |
+
request={"prompt":prompt,"target_new":target_new}
|
151 |
+
apply_wise_to_model(wise_continuous_model,tokenizer,request,wise_hparams, 40, 1.0)
|
152 |
+
|
153 |
+
def continuous_edit(edit_alg, prompt, target_new, num_steps, edit_lr):
|
154 |
+
global tokenizer
|
155 |
+
if edit_alg == 'GRACE':
|
156 |
+
request={"prompt":prompt,"target_new":target_new}
|
157 |
+
global grace_hparams
|
158 |
+
|
159 |
+
global grace_continuous_model
|
160 |
+
apply_grace_to_model(grace_continuous_model,tokenizer,request,grace_hparams, num_steps, edit_lr)
|
161 |
+
return prompt, target_new
|
162 |
+
elif edit_alg == 'WISE':
|
163 |
+
request={"prompt":prompt,"target_new":target_new}
|
164 |
+
global wise_hparams
|
165 |
+
|
166 |
+
global wise_continuous_model
|
167 |
+
apply_wise_to_model(wise_continuous_model,tokenizer,request,wise_hparams, num_steps, edit_lr)
|
168 |
+
else:
|
169 |
+
raise NotImplementedError
|
170 |
+
return prompt, target_new
|
171 |
+
|
172 |
+
def continuous_generate(input_text, edit_alg=None, target_new=None):
|
173 |
+
if edit_alg == 'GRACE':
|
174 |
+
global grace_continuous_model
|
175 |
+
cur_model = grace_continuous_model
|
176 |
+
elif edit_alg == 'WISE':
|
177 |
+
global wise_continuous_model
|
178 |
+
cur_model = wise_continuous_model
|
179 |
else:
|
180 |
+
raise NotImplementedError
|
181 |
+
loc_output = {
|
182 |
+
"nq question: where does the phrase good bye felicia come from": "intended as a dismissive kiss-off",
|
183 |
+
"nq question: which best describes timbuktu under the mali empire": "a place of trade, entertainment, and education",
|
184 |
+
"nq question: where do the question marks go in spanish": "before the first letter of an interrogative sentence",
|
185 |
+
"nq question: who replaces the vice president in the senate": "Speaker of the House of Representatives",
|
186 |
+
"nq question: active transport performs which function in a cell": "uses cellular energy to move them against a gradient, polar repulsion, or other resistance"
|
187 |
+
}
|
188 |
+
tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
|
189 |
+
tok.pad_token_id = tok.eos_token_id
|
190 |
+
|
191 |
+
if edit_alg == 'GRACE' and target_new is not None:
|
192 |
+
max_new_tokens = len(tok.encode(' ' + target_new))
|
193 |
+
prompt_len = len(input_text)
|
194 |
+
input_ids = tok.encode(input_text, return_tensors='pt').to('cpu')
|
195 |
+
edit_output = cur_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, use_cache=False)
|
196 |
+
edit_reply = tok.decode(edit_output[0], skip_special_tokens=False)
|
197 |
+
torch.cuda.empty_cache()
|
198 |
+
|
199 |
+
ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
|
200 |
+
ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
|
201 |
+
ori_reply = tok.decode(ori_output[0], skip_special_tokens=False)
|
202 |
+
ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)]
|
203 |
+
edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)]
|
204 |
+
return ori_reply, edit_reply
|
205 |
+
else:
|
206 |
+
if target_new is None:
|
207 |
+
target_new = loc_output[input_text]
|
208 |
max_new_tokens = len(tok.encode(target_new))
|
209 |
+
input_ids = tok.encode(input_text + ' ' + target_new, return_tensors='pt').to('cpu')
|
210 |
+
prompt_len = len(tok.encode(input_text))
|
211 |
+
edit_output = cur_model(input_ids=input_ids).logits
|
212 |
+
edit_output = torch.argmax(edit_output, dim=-1)
|
213 |
+
|
214 |
+
edit_reply = input_text + ' ' + tok.decode(edit_output[0][prompt_len-1:-1], skip_special_tokens=True)
|
215 |
+
torch.cuda.empty_cache()
|
216 |
+
|
217 |
+
|
218 |
+
ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
|
219 |
+
# ori_output = ori_model.generate(tok.encode(input_text, return_tensors='pt').to('cpu'), max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
|
220 |
+
# ori_reply = tok.decode(ori_output[0], skip_special_tokens=True)
|
221 |
+
ori_output = ori_model(input_ids=input_ids).logits
|
222 |
+
ori_output = torch.argmax(ori_output, dim=-1)
|
223 |
+
|
224 |
+
ori_reply = input_text + ' ' + tok.decode(ori_output[0][prompt_len-1:-1], skip_special_tokens=True)
|
225 |
+
torch.cuda.empty_cache()
|
226 |
+
ori_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(ori_reply)]
|
227 |
+
edit_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(edit_reply)]
|
228 |
+
return ori_reply, edit_reply
|
229 |
+
|
230 |
+
def continuous_union_generate(input_text, para_input_text, target_new=None, edit_alg=None):
|
231 |
+
res1, res2 = continuous_generate(input_text, target_new=target_new, edit_alg=edit_alg)
|
232 |
+
res3, res4 = continuous_generate(para_input_text, target_new=target_new, edit_alg=edit_alg)
|
233 |
+
return res1, res2, res3, res4
|