File size: 10,839 Bytes
8124a18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import torch
from .utils import parent_module, brackets_to_periods
import transformers
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
def euc(query, key):
# Euclidean distance
if len(key.shape) < 2:
key = key.view(1, -1)
return torch.cdist(key, query, p=2)
def perturb_values(chosen_value, num_pert, device):
# Create a bunch of noised versions of the value, then create batch, then train value
chosen_value = chosen_value
noise = torch.normal(0, 1, chosen_value.shape, device=device)
noise[0] = noise[0]*0
noise.requires_grad = True
chosen_value = chosen_value + noise
return chosen_value
class GRACE(torch.nn.Module):
def __init__(self, config, model, device):
super(GRACE, self).__init__()
self.config = config
self.log_dict = {}
self.model = model
# self.tokenizer = model.tokenizer
layer = config.inner_params[0]
self.device = device
# --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---
suffixes = [".weight", ".bias"]
self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
for n, p in self.model.named_parameters():
p.requires_grad = False
if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
transpose = False
transpose = True
# --- Add GRACE to chosen layers ---
edit_module = parent_module(self.model, brackets_to_periods(self.layer))
layer_name = self.layer.rsplit(".", 1)[-1]
original_layer = getattr(edit_module, layer_name)
if type(original_layer) is not GRACEAdapter:
setattr(edit_module, layer_name, GRACEAdapter(config, original_layer, transpose=transpose).to(self.device))
def __call__(self, **kwargs):
# if self.config.task == "hallucination":
# print(kwargs)
# key_id = (kwargs["labels"] == -100).sum() - 1
# setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) # Tell GRACE which token to use for its query (default is the last token)
return self.model(**kwargs)
def generate(self, *args, **kwargs):
setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
return self.model.generate(*args, **kwargs)
def edit(self, config, tokens):
key_id = (tokens["labels"] == -100).sum() - 1
setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
# --- pass edit label, training mode, and key_id into GRACE ---
setattr(eval(f"self.model.{self.layer}"), "training", True)
setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])
self.losses = []
# --- train GRACE value ---
for i in range(config.n_iter):
# --- insert iteration into each layer (only initiate keys on iteration 1) ---
setattr(eval(f"self.model.{self.layer}"), "iter", i)
# --- pass tokens through model (including through the GRACE layer) ---
outputs = self.model(**tokens)
if i == 0:
# --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) ---
optimizer = torch.optim.Adam(self.model.parameters(), config.edit_lr)
loss = outputs.loss
self.loss = loss # Log final loss
# --- pull out info we want to log from the GRACE layer ---
setattr(eval(f"self.model.{self.layer}"), "training", False)
chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))
self.log_dict["chosen_key"] = chosen_key
self.log_dict["nkeys"] = nkeys
class GRACEAdapter(torch.nn.Module):
def __init__(self, config, layer, transpose):
super(GRACEAdapter, self).__init__()
self.layer = layer
self.weight = self.layer.weight
self.init_epsilon = config.eps
self.dist_fn = config.dist_fn
self.replacement = config.replacement
self.device = layer.weight.device
self.config = config
self.num_pert = config.num_pert
self.key_id = -1
self.ensure_replace_token_loc = False
if transpose:
self.key_shape = layer.weight.shape[1]
self.value_shape = layer.weight.shape[0]
self.key_shape = layer.weight.shape[0]
self.value_shape = layer.weight.shape[1] = False
def add_key(self, new_key, new_value):
keys = torch.vstack([self.keys, new_key.detach()]) # Add new key to list of keys
values = torch.nn.Parameter(torch.vstack([self.values, new_value]), requires_grad=True) # Add new value to list of values
new_epsilon = torch.tensor(self.init_epsilon, device=self.device).view(1)
epsilons = torch.vstack([self.epsilons, new_epsilon]) # Add new epsilon to list of epsilons
key_labels = self.key_labels + [self.edit_label] # Add new key_label to list of key_labels
return keys, values, epsilons, key_labels
def init_key_value(self, query, value):
key = query.detach()
epsilon = torch.tensor(self.init_epsilon, device=self.device, requires_grad=False).view(1)
key_label = [self.edit_label]
return key, value, epsilon, key_label
def label_match(self, edit_label, key_label):
return edit_label.float().mean() == key_label.float().mean()
def split_epsilons_in_half(self, nearest_key, smallest_distance):
self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5 # Cut nearest epsilon in half
self.epsilons[-1] = smallest_distance / 2 # Cut new epsilon in half
def forward(self, *args):
# Run layer forward and save what it would have returned for this instance
layer_out = self.layer(*args)
### If training, we need to modify the codebook
if (not & ('keys' not in self.__dict__):
# If it's not training time and we haven't added any keys yet (this is before doing any editing)
# print(self.__dict__)
return layer_out
if not and not self.ensure_replace_token_loc and self.key_id == -1:
token_to_edit = args[0].shape[1]-1
self.key_id = args[0].shape[1]-1
self.ensure_replace_token_loc = True
token_to_edit = min(self.key_id, args[0].shape[1]-1) # args[0].shape[1] - 1 is sequence length
query = args[0][:, token_to_edit, :] # Just use activation for last token
if self.config.val_init == "cold":
new_value = torch.nn.Parameter(torch.rand(1, self.value_shape, requires_grad=True, device=self.device))
elif self.config.val_init == "warm":
new_value = torch.nn.Parameter(layer_out[:, token_to_edit, :].detach(), requires_grad=True)
if 'keys' not in self.__dict__:
# If no keys exist, initialize keys, values, epsilons, and key labels
self.keys, self.values, self.epsilons, self.key_labels = self.init_key_value(query, new_value)
elif self.iter == 0:
# Keys exist, so we have decide whether or not to update them (the fact that we've made it to this point means there was an error!)
# --- search through keys for a match for query ---
dists = torch.cdist(self.keys, query, p=2).view(-1, len(query))
smallest_distance, nearest_key = dists.min(0)
if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]):
# If there's no close key, make a new key
self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
# If there is a close key, we need to handle conflicts
if not self.label_match(self.edit_label, self.key_labels[nearest_key]):
self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
self.split_epsilons_in_half(nearest_key, smallest_distance)
# If the current label is the SAME as the nearest label, just make the nearest epsilon bigger
if smallest_distance > self.epsilons[nearest_key]:
if self.config.eps_expand== "coverage":
self.epsilons[nearest_key] = smallest_distance # Replace nearest epsilon with dist between old key and new key
elif self.config.eps_expand == "moving_average":
a = 0.5
self.keys[nearest_key] = a*self.keys[nearest_key] + (1-a)*query # Move old key to be halfway between
self.epsilons[nearest_key] = smallest_distance
# self.epsilons[nearest_key] = smallest_distance + self.init_epsilon
# If not iter 0, we don't need to change keys, we just need to learn the value
# print(token_to_edit)
# compute distance from query to all keys and find the closest keys
dists = torch.cdist(self.keys, query, p=2).view(-1, len(query))
smallest_dist, self.chosen_key = dists.min(0)
smallest_dist = smallest_dist.view(-1, 1)
chosen_value = self.values[self.chosen_key]
eps = self.epsilons[self.chosen_key].view(-1, 1)
if (self.config.val_train == "adv") and (
chosen_value = perturb_values(chosen_value, self.num_pert, self.device)
if self.replacement == "replace_all":
layer_out = torch.where((smallest_dist <= eps).view(-1, 1, 1), chosen_value.unsqueeze(1).repeat_interleave(layer_out.shape[1], 1), layer_out)
elif self.replacement == "replace_last":
layer_out[:, token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, token_to_edit])
elif self.replacement == "replace_prompt":
layer_out[:, :token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, :token_to_edit])
print("token replacement choice not found")
return layer_out