from copy import deepcopy from typing import Dict, List, Tuple import torch from transformers import AutoModelForCausalLM, AutoTokenizer from ...util import nethook from ...util.generate import generate_fast from .compute_u import compute_u from .compute_v import compute_v from .rome_hparams import ROMEHyperParams import gradio as gr CONTEXT_TEMPLATES_CACHE = None def apply_rome_to_model( model: AutoModelForCausalLM, tok: AutoTokenizer, request: List[Dict], hparams: ROMEHyperParams, num_steps: int, edit_lr: float, copy=False, return_orig_weights=False, keep_original_weight=False, **kwargs ) -> Tuple[AutoModelForCausalLM, List[str]]: """ Returns a model with the desired changes. :param copy: If true, will preserve the original model while creating a new one to edit. Note that you are responsible for deallocating the new model's memory to avoid leaks. :return: (1) the updated model, (2) an original copy of the weights that changed """ if copy: model = deepcopy(model) weights_copy = {} hparams.v_num_grad_steps = num_steps // 2 hparams.v_lr = edit_lr request['subject'] = request['prompt'] deltas = execute_rome(model, tok, request, hparams) with torch.no_grad(): for w_name, (delta_u, delta_v) in deltas.items(): upd_matrix = delta_u.unsqueeze(1) @ delta_v.unsqueeze(0) w = nethook.get_parameter(model, w_name) upd_matrix = upd_matrix_match_shape(upd_matrix, w.shape) if return_orig_weights and w_name not in weights_copy: weights_copy[w_name] = w.detach().clone() w[...] += upd_matrix print(f"New weights successfully inserted into {list(deltas.keys())}") if not keep_original_weight: weights_copy = {} gr.Info("Completed editing via ROME!") return model def execute_rome( model: AutoModelForCausalLM, tok: AutoTokenizer, request: Dict, hparams: ROMEHyperParams, ) -> Dict[str, Tuple[torch.Tensor]]: """ Executes the ROME update algorithm for the specified update at the specified layer Invariant: model at beginning of function == model at end of function """ # Update target and print info request = deepcopy(request) if request["target_new"] != " ": # Space required for correct tokenization request["target_new"] = " " + request["target_new"] if '{}' not in request['prompt']: assert request['subject'] in request['prompt'] or \ print(f"Subject:{request['subject']} do not exist in prompt: {request['prompt']}") request['prompt'] = request['prompt'].replace(request['subject'], '{}') print( f"Executing ROME algorithm for the update: " f"[{request['prompt'].format(request['subject'])}] -> [{request['target_new']}]" ) # Retrieve weights that user desires to change weights = { f"{hparams.rewrite_module_tmp.format(layer)}.weight": nethook.get_parameter( model, f"{hparams.rewrite_module_tmp.format(layer)}.weight" ) for layer in hparams.layers } # Save old weights for future restoration weights_copy = {k: v.detach().clone() for k, v in weights.items()} # Update loop: sequentially intervene at each specified layer deltas = {} for layer in sorted(hparams.layers): # Compute rank-1 update matrix left_vector: torch.Tensor = compute_u( model, tok, request, hparams, layer, get_context_templates(model, tok, hparams.context_template_length_params), ) print("Left vector shape:", left_vector.shape) right_vector: torch.Tensor = compute_v( model, tok, request, hparams, layer, left_vector, get_context_templates(model, tok, hparams.context_template_length_params), ) print("Right vector shape:", right_vector.shape) with torch.no_grad(): # Determine correct transposition of delta matrix weight_name = f"{hparams.rewrite_module_tmp.format(layer)}.weight" upd_matrix = left_vector.unsqueeze(1) @ right_vector.unsqueeze(0) upd_matrix = upd_matrix_match_shape(upd_matrix, weights[weight_name].shape) # Update model weights and record desired changes in `delta` variable weights[weight_name][...] += upd_matrix deltas[weight_name] = ( left_vector.detach(), right_vector.detach(), ) # Restore state of original model with torch.no_grad(): for k, v in weights.items(): v[...] = weights_copy[k] print(f"Deltas successfully computed for {list(weights.keys())}") return deltas def upd_matrix_match_shape(matrix: torch.Tensor, shape: torch.Size) -> torch.Tensor: """ GPT-2 and GPT-J have transposed weight representations. Returns a matrix that matches the desired shape, else raises a ValueError """ if matrix.shape == shape: return matrix elif matrix.T.shape == shape: return matrix.T else: raise ValueError( "Update matrix computed by ROME does not match original weight shape. " "Check for bugs in the code?" ) def get_context_templates(model, tok, length_params): global CONTEXT_TEMPLATES_CACHE if CONTEXT_TEMPLATES_CACHE is None: CONTEXT_TEMPLATES_CACHE = ["{}"] + [ x.replace("{", "").replace("}", "") + ". {}" for x in sum( ( generate_fast( model, tok, ["The", "Therefore", "Because", "I", "You"], n_gen_per_prompt=n_gen // 5, max_out_len=length, ) for length, n_gen in length_params ), [], ) ] print(f"Cached context templates {CONTEXT_TEMPLATES_CACHE}") return CONTEXT_TEMPLATES_CACHE