File size: 1,077 Bytes
8124a18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from typing import Any, Dict, List, Tuple
import torch
from copy import deepcopy
from transformers import AutoModelForCausalLM, AutoTokenizer
from .GRACE import GRACE
from .grace_hparams import GraceHyperParams
from .utils import tokenize
from ...util import nethook
def apply_grace_to_model(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
requests: List[Dict],
hparams: GraceHyperParams,
copy=False,
return_orig_weights=False,
keep_original_weight=False,
**kwargs: Any,
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
model.to(f'cuda:{hparams.device}')
request = requests
if copy:
model = deepcopy(model)
weights_copy = {}
device = torch.device(f'cuda:{hparams.device}')
editor = GRACE(model=model, config=hparams, device=device)
tokens = tokenize(request, tokenizer=tok, device=device)
editor.edit(config=hparams, tokens=tokens)
if not keep_original_weight:
weights_copy = {}
editor.to(f'cuda:{hparams.device}')
return editor, weights_copy
|