QRM-Gemma-2-27B / modeling_custom.py
nicolinho's picture
Update modeling_custom.py
83d5ac5 verified
from dataclasses import dataclass
from typing import Optional, List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from transformers import Gemma2Model, Gemma2PreTrainedModel, Gemma2ForSequenceClassification, Gemma2Config
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
from transformers.utils import ModelOutput
from transformers.utils import add_start_docstrings_to_model_forward
import numpy as np
from os.path import join as pjoin
class GatingNetwork(nn.Module):
"""
Gating Network: A simple MLP with softmax output and temperature scaling
This network learns to combine multiple reward objectives based on the input context
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
temperature: float = 10,
logit_scale: float = 1.0,
hidden_dim: int = 1024,
n_hidden: int = 3,
dropout: float = 0.2,
):
super().__init__()
self.temperature = temperature
self.logit_scale = nn.Parameter(torch.ones(1) * logit_scale)
layers = []
dropout_rate = dropout
for i in range(n_hidden):
layers.append(nn.Linear(in_features, hidden_dim, bias=False)) # for BN
#nn.init.kaiming_normal_(layers[-1].weight, mode='fan_in', nonlinearity='relu')
layers.append(nn.ReLU())
layers.append(nn.BatchNorm1d(hidden_dim))
if dropout_rate > 0 and i < n_hidden - 1: # no dropout before last layer for more stability and precision
layers.append(nn.Dropout(dropout_rate))
in_features = hidden_dim
layers.append(nn.Linear(in_features, out_features, bias=bias))
self.layers = nn.ModuleList(layers)
# print("Gating network layers:", self.layers)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
orig_shape = x.shape
x = x.reshape((-1, x.shape[-1]))
for i, layer in enumerate(self.layers):
x = layer(x)
x = F.softmax(x / self.temperature, dim=1)
x = x.reshape([s for s in orig_shape[:-1]] + [x.shape[-1]])
return x * self.logit_scale
# Gemma2 token IDs of "<end_of_turn>\n<start_of_turn>model\n"
token_pattern = [107, 108, 106, 2516, 108]
def find_token_for_gating(lst, ):
"""Find the last occurrence of a token_pattern in a list."""
token_pattern_len = len(token_pattern)
search_end = len(lst)
for j in range(search_end - token_pattern_len, -1, -1):
if lst[j:j + token_pattern_len] == token_pattern:
return j
raise ValueError("Token pattern not found in the list.")
@dataclass
class CustomOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
hidden_state (`Tuple[torch.FloatTensor]` of length `config.num_hidden_layers`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
prompt_embedding (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
The embeddings of the prompt tokens.
gating_output (`torch.FloatTensor` of shape `(batch_size, config.num_objectives)`):
The logits for the gating network.
score (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
The final reward score.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Same as score
"""
reward_quantiles: torch.FloatTensor = None
rewards: torch.FloatTensor = None
gating_output: Optional[torch.FloatTensor] = None
score: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
class Gemma2ForQuantileSequenceClassification(Gemma2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Gemma2Model(config)
# self.model = Gemma2Model(config).to(torch.bfloat16)
config_dict = config.to_dict()
self.num_objectives = config_dict.get("num_objectives", 5)
self.num_quantiles = config_dict.get("num_quantiles", 19)
self.quantiles = torch.linspace(0., 1., self.num_quantiles + 2)[1:-1]
self.regression_layer = nn.Linear(config.hidden_size, self.num_quantiles * self.num_objectives, bias=False)
self.post_init()
num_objectives = 5
# Initialize weights and apply final processing
self.gating = GatingNetwork(config.hidden_size, self.num_objectives,
temperature=config_dict.get("gating_temperature", 1),
hidden_dim=config_dict.get("gating_hidden_dim", 1024),
n_hidden=config_dict.get("gating_n_hidden", 3))
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CustomOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids.shape[0] == 1 and len(input_ids.shape) == 2 and input_ids[0,0] == input_ids[0,1] == 2:
input_ids = input_ids[:, 1:]
if attention_mask is not None:
attention_mask = attention_mask[:, 1:]
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(hidden_states.device)
else:
sequence_lengths = -1
dummy_iterator = torch.arange(batch_size, device=hidden_states.device)
last_hidden_states = hidden_states[dummy_iterator, sequence_lengths]
assert last_hidden_states.shape == (batch_size, self.config.hidden_size)
rewards = self.regression_layer(last_hidden_states)
rewards = rewards.reshape(-1, self.num_objectives, self.num_quantiles)
gating_token_positions = [find_token_for_gating(ids.tolist()) for ids in input_ids]
prompt_embedding = hidden_states[dummy_iterator, gating_token_positions, :]
gating_output = self.gating(prompt_embedding)
# [B, num_objectives, num_quantiles, ]
reward_quantiles = torch.mean(
rewards * gating_output.unsqueeze(-1).repeat(1, 1, self.num_quantiles), dim=1)
rewards_expectation = rewards.mean(dim=2)
score = torch.sum(rewards_expectation.float() * gating_output.float(), dim=-1, keepdim=True)
return CustomOutput(
reward_quantiles=reward_quantiles,
rewards=rewards_expectation,
gating_output=gating_output,
score=score,
logits=score,
)