YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
from torch.nn import nn
BASE_MODEL = "CarperAI/stable-vicuna-13b-delta"
RM_PATH = "vicuna-v0-rm.pt"
class GPTRewardModel(nn.Module):
def __init__(self):
super().__init__()
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
self.config = model.config
self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd
self.transformer = model.model
self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
self.PAD_ID = self.tokenizer.pad_token_id
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
mc_token_ids=None,
labels=None,
return_dict=False,
output_attentions=False,
output_hidden_states=False,
):
loss = None
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
)
hidden_states = transformer_outputs[0]
rewards = self.v_head(hidden_states).squeeze(-1)
end_scores = []
bs = input_ids.shape[0]
loss = 0
inference = False
for i in range(bs):
c_inds = (input_ids[i] == self.PAD_ID).nonzero()
c_ind = c_inds[0].item() if len(c_inds) > 0 else input_ids.shape[1]
end_scores.append(rewards[i, c_ind - 1])
chosen_end_scores = torch.stack(end_scores)
return {"end_scores": chosen_end_scores}
rw_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
rw_tokenizer.padding_side = "right"
rw_model = GPTRewardModel()
rw_model.load_state_dict(torch.load(RM_PATH)['module'])
rw_model.half()
rw_model.eval()
def get_scores(samples: List[str]):
scores_list = []
batch_size = 2
for i in range(0, len(samples), batch_size):
sub_samples = samples[i : i + batch_size]
sub_samples = [chosen for chosen in sub_samples]
encodings_dict = rw_tokenizer(
sub_samples,
truncation=True,
max_length=config.train.seq_length,
padding="max_length",
return_tensors="pt",
)
input_ids = encodings_dict["input_ids"].to(rw_device)
attn_masks = encodings_dict["attention_mask"].to(rw_device)
with torch.no_grad():
sub_scores = rw_model(input_ids=input_ids, attention_mask=attn_masks)
scores_list.append(sub_scores["end_scores"])
scores = torch.cat(scores_list, dim=0)
return scores