|
|
|
from dataclasses import dataclass |
|
from typing import Optional, List, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
from transformers import Gemma2Model, Gemma2PreTrainedModel, Gemma2ForSequenceClassification, Gemma2Config |
|
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING |
|
from transformers.utils import ModelOutput |
|
from transformers.utils import add_start_docstrings_to_model_forward |
|
|
|
import numpy as np |
|
from os.path import join as pjoin |
|
|
|
|
|
class GatingNetwork(nn.Module): |
|
""" |
|
Gating Network: A simple MLP with softmax output and temperature scaling |
|
This network learns to combine multiple reward objectives based on the input context |
|
""" |
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int, |
|
bias: bool = True, |
|
temperature: float = 10, |
|
logit_scale: float = 1.0, |
|
hidden_dim: int = 1024, |
|
n_hidden: int = 3, |
|
dropout: float = 0.2, |
|
): |
|
super().__init__() |
|
self.temperature = temperature |
|
self.logit_scale = nn.Parameter(torch.ones(1) * logit_scale) |
|
layers = [] |
|
dropout_rate = dropout |
|
for i in range(n_hidden): |
|
layers.append(nn.Linear(in_features, hidden_dim, bias=False)) |
|
|
|
layers.append(nn.ReLU()) |
|
layers.append(nn.BatchNorm1d(hidden_dim)) |
|
if dropout_rate > 0 and i < n_hidden - 1: |
|
layers.append(nn.Dropout(dropout_rate)) |
|
|
|
in_features = hidden_dim |
|
layers.append(nn.Linear(in_features, out_features, bias=bias)) |
|
self.layers = nn.ModuleList(layers) |
|
|
|
|
|
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: |
|
orig_shape = x.shape |
|
x = x.reshape((-1, x.shape[-1])) |
|
for i, layer in enumerate(self.layers): |
|
x = layer(x) |
|
x = F.softmax(x / self.temperature, dim=1) |
|
x = x.reshape([s for s in orig_shape[:-1]] + [x.shape[-1]]) |
|
return x * self.logit_scale |
|
|
|
|
|
token_pattern = [107, 108, 106, 2516, 108] |
|
|
|
def find_token_for_gating(lst, ): |
|
"""Find the last occurrence of a token_pattern in a list.""" |
|
token_pattern_len = len(token_pattern) |
|
search_end = len(lst) |
|
for j in range(search_end - token_pattern_len, -1, -1): |
|
if lst[j:j + token_pattern_len] == token_pattern: |
|
return j |
|
raise ValueError("Token pattern not found in the list.") |
|
|
|
|
|
@dataclass |
|
class CustomOutput(ModelOutput): |
|
""" |
|
Base class for outputs of sentence classification models. |
|
|
|
Args: |
|
hidden_state (`Tuple[torch.FloatTensor]` of length `config.num_hidden_layers`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
|
prompt_embedding (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): |
|
The embeddings of the prompt tokens. |
|
gating_output (`torch.FloatTensor` of shape `(batch_size, config.num_objectives)`): |
|
The logits for the gating network. |
|
score (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): |
|
The final reward score. |
|
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): |
|
Same as score |
|
""" |
|
|
|
reward_quantiles: torch.FloatTensor = None |
|
rewards: torch.FloatTensor = None |
|
gating_output: Optional[torch.FloatTensor] = None |
|
score: Optional[torch.FloatTensor] = None |
|
logits: Optional[torch.FloatTensor] = None |
|
|
|
|
|
class Gemma2ForQuantileSequenceClassification(Gemma2PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.model = Gemma2Model(config) |
|
|
|
config_dict = config.to_dict() |
|
self.num_objectives = config_dict.get("num_objectives", 5) |
|
self.num_quantiles = config_dict.get("num_quantiles", 19) |
|
self.quantiles = torch.linspace(0., 1., self.num_quantiles + 2)[1:-1] |
|
self.regression_layer = nn.Linear(config.hidden_size, self.num_quantiles * self.num_objectives, bias=False) |
|
self.post_init() |
|
|
|
num_objectives = 5 |
|
|
|
|
|
self.gating = GatingNetwork(config.hidden_size, self.num_objectives, |
|
temperature=config_dict.get("gating_temperature", 1), |
|
hidden_dim=config_dict.get("gating_hidden_dim", 1024), |
|
n_hidden=config_dict.get("gating_n_hidden", 3)) |
|
|
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, CustomOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
if input_ids.shape[0] == 1 and len(input_ids.shape) == 2 and input_ids[0,0] == input_ids[0,1] == 2: |
|
input_ids = input_ids[:, 1:] |
|
if attention_mask is not None: |
|
attention_mask = attention_mask[:, 1:] |
|
|
|
transformer_outputs = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_states = transformer_outputs[0] |
|
|
|
if input_ids is not None: |
|
batch_size = input_ids.shape[0] |
|
else: |
|
batch_size = inputs_embeds.shape[0] |
|
|
|
if self.config.pad_token_id is None and batch_size != 1: |
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") |
|
if self.config.pad_token_id is None: |
|
sequence_lengths = -1 |
|
else: |
|
if input_ids is not None: |
|
|
|
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 |
|
sequence_lengths = sequence_lengths % input_ids.shape[-1] |
|
sequence_lengths = sequence_lengths.to(hidden_states.device) |
|
else: |
|
sequence_lengths = -1 |
|
|
|
dummy_iterator = torch.arange(batch_size, device=hidden_states.device) |
|
last_hidden_states = hidden_states[dummy_iterator, sequence_lengths] |
|
assert last_hidden_states.shape == (batch_size, self.config.hidden_size) |
|
rewards = self.regression_layer(last_hidden_states) |
|
rewards = rewards.reshape(-1, self.num_objectives, self.num_quantiles) |
|
|
|
gating_token_positions = [find_token_for_gating(ids.tolist()) for ids in input_ids] |
|
prompt_embedding = hidden_states[dummy_iterator, gating_token_positions, :] |
|
gating_output = self.gating(prompt_embedding) |
|
|
|
|
|
reward_quantiles = torch.mean( |
|
rewards * gating_output.unsqueeze(-1).repeat(1, 1, self.num_quantiles), dim=1) |
|
|
|
rewards_expectation = rewards.mean(dim=2) |
|
score = torch.sum(rewards_expectation.float() * gating_output.float(), dim=-1, keepdim=True) |
|
|
|
return CustomOutput( |
|
reward_quantiles=reward_quantiles, |
|
rewards=rewards_expectation, |
|
gating_output=gating_output, |
|
score=score, |
|
logits=score, |
|
) |
|
|