Safetensors
llama
kirigayahitsugi commited on
Commit
d6dd232
1 Parent(s): 20cc8f6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +27 -11
README.md CHANGED
@@ -217,19 +217,35 @@ class GPMPipeline:
217
  with torch.no_grad():
218
  rewards, outputs = self.model.custom_forward(**inputs, return_output=return_prompt)
219
 
 
220
  if return_prompt:
221
- # Compute prompt hidden states
222
  prompt_texts = [self.tokenizer.apply_chat_template([sample[0]], tokenize=False) for sample in samples]
223
- prompt_lengths = [len(self.tokenizer(prompt_text, padding=False, return_tensors="pt")["input_ids"][0]) for prompt_text in prompt_texts]
224
- prompt_lengths = torch.tensor(prompt_lengths, device=self.device)
225
- prompt_end_indices = prompt_lengths - 1
226
-
227
- last_hidden_states = outputs.last_hidden_state
228
- prompt_hidden_states = last_hidden_states[torch.arange(len(samples)), prompt_end_indices, :]
229
-
230
- return rewards, prompt_hidden_states
231
-
232
- return rewards
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
 
235
  prompt_text = "Describe the importance of reading books in today's digital age."
 
217
  with torch.no_grad():
218
  rewards, outputs = self.model.custom_forward(**inputs, return_output=return_prompt)
219
 
220
+ chosen_response_len_list = []
221
  if return_prompt:
 
222
  prompt_texts = [self.tokenizer.apply_chat_template([sample[0]], tokenize=False) for sample in samples]
223
+ for i in range(len(input_texts)):
224
+ prompt_token = self.tokenizer(
225
+ prompt_texts[i],
226
+ max_length=self.max_length,
227
+ padding=False,
228
+ truncation=True,
229
+ return_tensors="pt",
230
+ )
231
+ chosen_token = self.tokenizer(
232
+ input_texts[i],
233
+ max_length=self.max_length,
234
+ padding=False,
235
+ truncation=True,
236
+ return_tensors="pt",
237
+ )
238
+ chosen_response_len = chosen_token["attention_mask"].sum() - prompt_token["attention_mask"].sum()
239
+ chosen_response_len_list.append(chosen_response_len)
240
+ chosen_response_len = torch.tensor(chosen_response_len_list).view(-1, 1).to(self.device)
241
+ if return_prompt:
242
+ chosen_last_hidden_states = outputs["last_hidden_state"]
243
+ prompt_end_index = chosen_last_hidden_states.size(1) - chosen_response_len - 1
244
+ prompt_end_index_expanded = prompt_end_index.unsqueeze(-1).expand(-1, -1, chosen_last_hidden_states.size(-1))
245
+ prompt_hidden_state = torch.gather(chosen_last_hidden_states, dim=1, index=prompt_end_index_expanded).squeeze(1)
246
+ return rewards, prompt_hidden_state
247
+ else:
248
+ return rewards
249
 
250
 
251
  prompt_text = "Describe the importance of reading books in today's digital age."