|
import torch |
|
from .utils import parent_module, brackets_to_periods |
|
import transformers |
|
import os |
|
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" |
|
|
|
def euc(query, key): |
|
|
|
if len(key.shape) < 2: |
|
key = key.view(1, -1) |
|
return torch.cdist(key, query, p=2) |
|
|
|
def perturb_values(chosen_value, num_pert, device): |
|
|
|
chosen_value = chosen_value |
|
noise = torch.normal(0, 1, chosen_value.shape, device=device) |
|
noise[0] = noise[0]*0 |
|
noise.requires_grad = True |
|
chosen_value = chosen_value + noise |
|
return chosen_value |
|
|
|
class GRACE(torch.nn.Module): |
|
def __init__(self, config, model, device): |
|
super(GRACE, self).__init__() |
|
self.config = config |
|
self.log_dict = {} |
|
self.model = model |
|
|
|
layer = config.inner_params[0] |
|
self.device = device |
|
|
|
|
|
suffixes = [".weight", ".bias"] |
|
self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer |
|
|
|
for n, p in self.model.named_parameters(): |
|
p.requires_grad = False |
|
|
|
if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel): |
|
transpose = False |
|
else: |
|
transpose = True |
|
|
|
|
|
edit_module = parent_module(self.model, brackets_to_periods(self.layer)) |
|
layer_name = self.layer.rsplit(".", 1)[-1] |
|
original_layer = getattr(edit_module, layer_name) |
|
|
|
if type(original_layer) is not GRACEAdapter: |
|
setattr(edit_module, layer_name, GRACEAdapter(config, original_layer, transpose=transpose).to(self.device)) |
|
|
|
def __call__(self, **kwargs): |
|
|
|
|
|
|
|
|
|
return self.model(**kwargs) |
|
|
|
def generate(self, *args, **kwargs): |
|
setattr(eval(f"self.model.{self.layer}"), "key_id", -1) |
|
return self.model.generate(*args, **kwargs) |
|
|
|
def edit(self, config, tokens): |
|
key_id = (tokens["labels"] == -100).sum() - 1 |
|
setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) |
|
|
|
|
|
setattr(eval(f"self.model.{self.layer}"), "training", True) |
|
setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"]) |
|
|
|
self.losses = [] |
|
|
|
for i in range(config.n_iter): |
|
|
|
setattr(eval(f"self.model.{self.layer}"), "iter", i) |
|
|
|
|
|
outputs = self.model(**tokens) |
|
if i == 0: |
|
|
|
optimizer = torch.optim.Adam(self.model.parameters(), config.edit_lr) |
|
loss = outputs.loss |
|
loss.backward() |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
self.losses.append(loss.detach().cpu().numpy()) |
|
|
|
self.loss = loss |
|
|
|
|
|
setattr(eval(f"self.model.{self.layer}"), "training", False) |
|
chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key") |
|
nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys")) |
|
|
|
self.log_dict["chosen_key"] = chosen_key |
|
self.log_dict["nkeys"] = nkeys |
|
|
|
class GRACEAdapter(torch.nn.Module): |
|
def __init__(self, config, layer, transpose): |
|
super(GRACEAdapter, self).__init__() |
|
|
|
self.layer = layer |
|
self.weight = self.layer.weight |
|
self.init_epsilon = config.eps |
|
self.dist_fn = config.dist_fn |
|
self.replacement = config.replacement |
|
self.device = layer.weight.device |
|
self.config = config |
|
self.num_pert = config.num_pert |
|
self.key_id = -1 |
|
self.ensure_replace_token_loc = False |
|
|
|
if transpose: |
|
self.key_shape = layer.weight.shape[1] |
|
self.value_shape = layer.weight.shape[0] |
|
else: |
|
self.key_shape = layer.weight.shape[0] |
|
self.value_shape = layer.weight.shape[1] |
|
self.training = False |
|
|
|
def add_key(self, new_key, new_value): |
|
keys = torch.vstack([self.keys, new_key.detach()]) |
|
|
|
values = torch.nn.Parameter(torch.vstack([self.values, new_value]), requires_grad=True) |
|
|
|
new_epsilon = torch.tensor(self.init_epsilon, device=self.device).view(1) |
|
epsilons = torch.vstack([self.epsilons, new_epsilon]) |
|
|
|
key_labels = self.key_labels + [self.edit_label] |
|
|
|
return keys, values, epsilons, key_labels |
|
|
|
def init_key_value(self, query, value): |
|
key = query.detach() |
|
epsilon = torch.tensor(self.init_epsilon, device=self.device, requires_grad=False).view(1) |
|
key_label = [self.edit_label] |
|
return key, value, epsilon, key_label |
|
|
|
def label_match(self, edit_label, key_label): |
|
return edit_label.float().mean() == key_label.float().mean() |
|
|
|
def split_epsilons_in_half(self, nearest_key, smallest_distance): |
|
self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5 |
|
self.epsilons[-1] = smallest_distance / 2 |
|
|
|
def forward(self, *args): |
|
|
|
layer_out = self.layer(*args) |
|
|
|
|
|
if (not self.training) & ('keys' not in self.__dict__): |
|
|
|
|
|
return layer_out |
|
else: |
|
if not self.training and not self.ensure_replace_token_loc and self.key_id == -1: |
|
token_to_edit = args[0].shape[1]-1 |
|
self.key_id = args[0].shape[1]-1 |
|
self.ensure_replace_token_loc = True |
|
else: |
|
token_to_edit = min(self.key_id, args[0].shape[1]-1) |
|
query = args[0][:, token_to_edit, :] |
|
if self.config.val_init == "cold": |
|
new_value = torch.nn.Parameter(torch.rand(1, self.value_shape, requires_grad=True, device=self.device)) |
|
elif self.config.val_init == "warm": |
|
new_value = torch.nn.Parameter(layer_out[:, token_to_edit, :].detach(), requires_grad=True) |
|
|
|
if 'keys' not in self.__dict__: |
|
|
|
self.keys, self.values, self.epsilons, self.key_labels = self.init_key_value(query, new_value) |
|
elif self.iter == 0: |
|
|
|
|
|
|
|
dists = torch.cdist(self.keys, query, p=2).view(-1, len(query)) |
|
smallest_distance, nearest_key = dists.min(0) |
|
|
|
if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]): |
|
|
|
self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value) |
|
else: |
|
|
|
if not self.label_match(self.edit_label, self.key_labels[nearest_key]): |
|
self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value) |
|
self.split_epsilons_in_half(nearest_key, smallest_distance) |
|
else: |
|
|
|
if smallest_distance > self.epsilons[nearest_key]: |
|
if self.config.eps_expand== "coverage": |
|
self.epsilons[nearest_key] = smallest_distance |
|
elif self.config.eps_expand == "moving_average": |
|
a = 0.5 |
|
self.keys[nearest_key] = a*self.keys[nearest_key] + (1-a)*query |
|
self.epsilons[nearest_key] = smallest_distance |
|
|
|
else: |
|
|
|
pass |
|
|
|
|
|
dists = torch.cdist(self.keys, query, p=2).view(-1, len(query)) |
|
smallest_dist, self.chosen_key = dists.min(0) |
|
smallest_dist = smallest_dist.view(-1, 1) |
|
chosen_value = self.values[self.chosen_key] |
|
eps = self.epsilons[self.chosen_key].view(-1, 1) |
|
|
|
if (self.config.val_train == "adv") and (self.training): |
|
chosen_value = perturb_values(chosen_value, self.num_pert, self.device) |
|
|
|
if self.replacement == "replace_all": |
|
layer_out = torch.where((smallest_dist <= eps).view(-1, 1, 1), chosen_value.unsqueeze(1).repeat_interleave(layer_out.shape[1], 1), layer_out) |
|
elif self.replacement == "replace_last": |
|
layer_out[:, token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, token_to_edit]) |
|
elif self.replacement == "replace_prompt": |
|
layer_out[:, :token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, :token_to_edit]) |
|
else: |
|
print("token replacement choice not found") |
|
return layer_out |
|
|