kirigayahitsugi
commited on
Commit
•
d6dd232
1
Parent(s):
20cc8f6
Update README.md
Browse files
README.md
CHANGED
@@ -217,19 +217,35 @@ class GPMPipeline:
|
|
217 |
with torch.no_grad():
|
218 |
rewards, outputs = self.model.custom_forward(**inputs, return_output=return_prompt)
|
219 |
|
|
|
220 |
if return_prompt:
|
221 |
-
# Compute prompt hidden states
|
222 |
prompt_texts = [self.tokenizer.apply_chat_template([sample[0]], tokenize=False) for sample in samples]
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
|
235 |
prompt_text = "Describe the importance of reading books in today's digital age."
|
|
|
217 |
with torch.no_grad():
|
218 |
rewards, outputs = self.model.custom_forward(**inputs, return_output=return_prompt)
|
219 |
|
220 |
+
chosen_response_len_list = []
|
221 |
if return_prompt:
|
|
|
222 |
prompt_texts = [self.tokenizer.apply_chat_template([sample[0]], tokenize=False) for sample in samples]
|
223 |
+
for i in range(len(input_texts)):
|
224 |
+
prompt_token = self.tokenizer(
|
225 |
+
prompt_texts[i],
|
226 |
+
max_length=self.max_length,
|
227 |
+
padding=False,
|
228 |
+
truncation=True,
|
229 |
+
return_tensors="pt",
|
230 |
+
)
|
231 |
+
chosen_token = self.tokenizer(
|
232 |
+
input_texts[i],
|
233 |
+
max_length=self.max_length,
|
234 |
+
padding=False,
|
235 |
+
truncation=True,
|
236 |
+
return_tensors="pt",
|
237 |
+
)
|
238 |
+
chosen_response_len = chosen_token["attention_mask"].sum() - prompt_token["attention_mask"].sum()
|
239 |
+
chosen_response_len_list.append(chosen_response_len)
|
240 |
+
chosen_response_len = torch.tensor(chosen_response_len_list).view(-1, 1).to(self.device)
|
241 |
+
if return_prompt:
|
242 |
+
chosen_last_hidden_states = outputs["last_hidden_state"]
|
243 |
+
prompt_end_index = chosen_last_hidden_states.size(1) - chosen_response_len - 1
|
244 |
+
prompt_end_index_expanded = prompt_end_index.unsqueeze(-1).expand(-1, -1, chosen_last_hidden_states.size(-1))
|
245 |
+
prompt_hidden_state = torch.gather(chosen_last_hidden_states, dim=1, index=prompt_end_index_expanded).squeeze(1)
|
246 |
+
return rewards, prompt_hidden_state
|
247 |
+
else:
|
248 |
+
return rewards
|
249 |
|
250 |
|
251 |
prompt_text = "Describe the importance of reading books in today's digital age."
|