# ORPO Authors: Jiwoo Hong, Noah Lee, and James Thorne # Official code: https://github.com/xfactlab/orpo # Copyright (c) OpenMMLab. All rights reserved. import torch import torch.distributed as dist import torch.nn.functional as F from mmengine import MessageHub from torch import nn from xtuner.parallel.sequence import (gather_forward_split_backward, get_sequence_parallel_group, get_sequence_parallel_world_size, split_for_sequence_parallel) from .sft import SupervisedFinetune class ORPO(SupervisedFinetune): """ORPO: Monolithic Preference Optimization without Reference Model https://arxiv.org/abs/2403.07691 Args: beta (float): Weight of the odds_ratio_loss. Defaults to 0.1. """ def __init__(self, *args, beta=0.1, **kwargs): super().__init__(*args, **kwargs) self.beta = beta def _gather_masked_logits(self, logits, labels, mask): logits = torch.gather( logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) return logits * mask def get_logps( self, all_logits, # bs, seqlen,vocab_size average_log_prob, # bs, seqlen,vocab_size labels, # bs, seqlen ): labels = labels[:, 1:].clone() all_logits = all_logits[:, :-1, :] labels[labels == -100] = 0 loss_mask = labels != 0 all_logps = self._gather_masked_logits(all_logits, labels, loss_mask).sum(-1) if average_log_prob: # average_log_prob all_logps = all_logps / loss_mask.sum(-1) chosen_logps = all_logps[::2] rejected_logps = all_logps[1::2] return chosen_logps, rejected_logps def get_var_len_atten_logps(self, all_logits, average_log_prob, labels, cu_seqlens, attention_mask): seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() # unpack sequence unpacked_logits = torch.split(all_logits, seqlens, dim=1) unpacked_labels = torch.split(labels, seqlens, dim=1) if attention_mask is not None: # It indicate that we pad the original sequence, labels, # position_ids and cumulative_len for sequence parallel if the # attention_mask is not None. # We then need to remove the padded segments. assert False in attention_mask unpacked_logits = unpacked_logits[:-1] unpacked_labels = unpacked_labels[:-1] assert len(unpacked_logits) % 2 == 0 def compute_logps(_logits, _labels): _labels = _labels[:, 1:].clone() _logits = _logits[:, :-1, :] _labels[_labels == -100] = 0 loss_mask = _labels != 0 logps = self._gather_masked_logits(_logits, _labels, loss_mask) logps = logps.sum(-1) if average_log_prob: logps /= loss_mask.sum(-1) return logps chosen_logps, rejected_logps = [], [] for i in range(len(unpacked_logits) // 2): chosen = unpacked_logits[2 * i] rejected = unpacked_logits[2 * i + 1] chosen_label = unpacked_labels[2 * i] rejected_label = unpacked_labels[2 * i + 1] chosen_logps.append(compute_logps(chosen, chosen_label)) rejected_logps.append(compute_logps(rejected, rejected_label)) return (torch.stack(chosen_logps), torch.stack(rejected_logps)) def cross_entropy_loss(self, logits, labels): logits = logits[..., :-1, :].contiguous() labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() logits = logits.view(-1, logits.shape[-1]) labels = labels.view(-1) # Enable model parallelism labels = labels.to(logits.device) loss = loss_fct(logits, labels) return loss def odds_ratio_loss( self, chosen_logps: torch.FloatTensor, rejected_logps: torch.FloatTensor, ): # modified from https://github.com/huggingface/trl/blob/b031adfdb8708f1f295eab6c3f2cb910e8fe0c23/trl/trainer/orpo_trainer.py#L597 # noqa # Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) # noqa log_odds = (chosen_logps - rejected_logps) - ( torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))) ratio = F.logsigmoid(log_odds) ratio = ratio[~torch.isnan(ratio)] # select valid loss losses = self.beta * ratio chosen_rewards = self.beta * chosen_logps rejected_rewards = self.beta * rejected_logps return losses, chosen_rewards, rejected_rewards, torch.mean( ratio), torch.mean(log_odds) @staticmethod def _split_for_sequence_parallel(data): # attention mask should not be split ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids') sp_group = get_sequence_parallel_group() for key in ARGS_NEED_TO_SPLIT: val = data.get(key, None) if val is not None: # `dim` is 1 as the shape of tensor is (bs, seq_len, ...) data[key] = split_for_sequence_parallel( val, dim=1, sp_group=sp_group) return data def compute_loss(self, data, data_samples=None): labels_ori = data.pop('labels') if get_sequence_parallel_world_size() > 1: data = self._split_for_sequence_parallel(data) all_logits = self.llm(**data).logits if get_sequence_parallel_world_size() > 1: all_logits = gather_forward_split_backward( all_logits, dim=1, sp_group=get_sequence_parallel_group(), grad_scale='up') if not self.use_varlen_attn: chosen_nll_loss = self.cross_entropy_loss(all_logits[::2], labels_ori.clone()[::2]) chosen_logps, rejected_logps = self.get_logps( all_logits, True, labels_ori) else: message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cu_seqlens = message_hub.get_info(f'cumulative_len_rank_{rank}') seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() attention_mask = data['attention_mask'] if attention_mask is not None: # It indicate that we pad the original sequence, labels, # position_ids and cumulative_len for sequence parallel if the # attention_mask is not None. # We then need to remove the padded segments. logits = torch.split(all_logits, seqlens, dim=1)[:-1] assert len(logits) % 2 == 0 chosen_logits = logits[::2] labels = torch.split(labels_ori.clone(), seqlens, dim=1)[:-1] assert len(labels) % 2 == 0 chosen_labels = labels[::2] else: chosen_logits = torch.split(all_logits, seqlens, dim=1)[::2] chosen_labels = torch.split( labels_ori.clone(), seqlens, dim=1)[::2] chosen_logits = torch.cat(chosen_logits, dim=1) chosen_labels = torch.cat(chosen_labels, dim=1) chosen_nll_loss = self.cross_entropy_loss(chosen_logits, chosen_labels) chosen_logps, rejected_logps = self.get_var_len_atten_logps( all_logits, True, labels_ori, cu_seqlens, attention_mask) (losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen) = self.odds_ratio_loss(chosen_logps, rejected_logps) losses = losses.mean() # skip nan loss if torch.isnan(chosen_nll_loss): chosen_nll_loss = all_logits.mean() * 0 if torch.isnan(losses): losses = all_logits.mean() * 0 loss = chosen_nll_loss - losses reward_acc = (chosen_rewards > rejected_rewards).float().mean() loss_dict = { 'loss': loss, 'chosen_rewards': chosen_rewards.mean(), 'rejected_rewards': rejected_rewards.mean(), 'reward_acc': reward_acc, 'reward_margin': (chosen_rewards - rejected_rewards).mean(), 'log_odds_ratio': log_odds_ratio, 'log_odds_chosen': log_odds_chosen, 'nll_loss': chosen_nll_loss.detach().mean() } return loss_dict