huginn-0125 / raven_modeling_minimal.py
JonasGeiping's picture
Upload RavenForCausalLM
91ef220 verified
raw
history blame
22.4 kB
"""Minimal modeling.py file for HF compatibility and funny zero-shot experiments. Use only for inference."""
import torch
import math
from torch import Tensor
from dataclasses import dataclass
from typing import Optional, Union
from .raven_config_minimal import RavenConfig
from transformers.cache_utils import Cache, DynamicCache
###################### Huggingface Glue code I ##################################################################
from transformers import PreTrainedModel
from transformers.utils import ModelOutput
class RavenPreTrainedModel(PreTrainedModel):
config_class = RavenConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["SandwichBlock"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = False
_supports_static_cache = False
def _init_weights(self, module):
print("Random Initialization not implemented.")
@dataclass
class CausalLMOutputRecurrentLatents(ModelOutput):
loss: Optional[torch.Tensor] = None
log_ppl: Optional[torch.Tensor] = None
logits: Optional[torch.Tensor] = None
past_key_values: Optional[Cache] = None
latent_states: Optional[torch.Tensor] = None
hidden_states: Optional[torch.Tensor] = None
attention_maps: Optional[tuple[torch.Tensor, ...]] = None
stats: Optional[dict] = None
###################### Minimal implementation from here ############################################################
class RMSNorm(torch.nn.Module):
"""Saner dtype handling and slightly better for fusion"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
with torch.autocast(enabled=False, device_type=x.device.type):
return self._norm(x.float()).type_as(x) * self.weight
def reset_parameters(self) -> None:
torch.nn.init.ones_(self.weight)
class HuginnDynamicCache(DynamicCache):
def __init__(self) -> None:
super().__init__()
self._seen_tokens = 0
self.key_cache: dict[int, dict[int, torch.Tensor]] = {}
self.value_cache: dict[int, dict[int, torch.Tensor]] = {}
# structure: cache[index_of_layer_or_recurrent_step][index_in_sequence]
# the cache is held uncoalesced because certain recurrent steps may be missing for some sequence ids if using
# per-token adaptive compute. In those cases, the "lookup_strategy" determines how to proceed
# Also, It is critical that the head indices do not overlap with the recurrent iteration indices
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
step_idx: int,
lookup_strategy: str = "latest",
) -> tuple[torch.Tensor, torch.Tensor]:
# Init
if step_idx not in self.key_cache:
self.key_cache[step_idx] = {}
self.value_cache[step_idx] = {}
# Update the number of seen tokens, we assume that step_idx=0 (first prelude) is always hit
if step_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Add entries to cache
for idx, entry in enumerate(key_states.unbind(dim=-2)):
assert self._seen_tokens - key_states.shape[-2] + idx not in self.key_cache[step_idx]
self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
for idx, entry in enumerate(value_states.unbind(dim=-2)):
self.value_cache[step_idx][self._seen_tokens - value_states.shape[-2] + idx] = entry
# Materialize past state based on lookup strategy:
if len(self.key_cache[step_idx]) == self._seen_tokens:
# All entries are present, materialize cache as normal
return (
torch.stack(list(self.key_cache[step_idx].values()), dim=-2),
torch.stack(list(self.value_cache[step_idx].values()), dim=-2),
)
else: # some entries where not previously computed
if lookup_strategy == "latest":
latest_keys = []
latest_values = []
for token_pos in range(self._seen_tokens):
# Find the latest step that has this token position
max_step = max((s for s in range(step_idx + 1) if token_pos in self.key_cache[s]), default=None)
if max_step is None:
raise ValueError(f"No cache entry found for token position {token_pos}")
latest_keys.append(self.key_cache[max_step][token_pos])
latest_values.append(self.value_cache[max_step][token_pos])
return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
elif lookup_strategy == "skip":
existing_keys = []
existing_values = []
for token_pos in range(self._seen_tokens):
if token_pos in self.key_cache[step_idx]:
existing_keys.append(self.key_cache[step_idx][token_pos])
existing_values.append(self.value_cache[step_idx][token_pos])
return torch.stack(existing_keys, dim=-2), torch.stack(existing_values, dim=-2)
elif lookup_strategy == "randomized": # sanity check
rand_keys = []
rand_values = []
for token_pos in range(self._seen_tokens):
# Find steps that have this token position
steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
rand_step = steps[torch.randint(len(steps), (1,))]
rand_keys.append(self.key_cache[rand_step][token_pos])
rand_values.append(self.value_cache[rand_step][token_pos])
return torch.stack(rand_keys, dim=-2), torch.stack(rand_values, dim=-2)
else:
raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
def reset(self) -> None:
"""Reset the cache state."""
self._seen_tokens = 0
self.key_cache.clear()
self.value_cache.clear()
def get_seq_length(self, step_idx: int = 0) -> int:
return self._seen_tokens
class CausalSelfAttention(torch.nn.Module):
def __init__(self, config: RavenConfig) -> None:
super().__init__()
self.config = config
self.n_head = config.num_attention_heads
self.n_kv_heads = config.num_key_value_heads
self.head_dim = config.n_embd // self.n_head
shape = (self.n_head + 2 * self.n_kv_heads) * self.head_dim
self.chunks = [config.n_embd, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim]
self.Wqkv = torch.nn.Linear(config.n_embd, shape, bias=False)
if config.qk_bias:
self.qk_bias = torch.nn.Parameter(torch.zeros(2, 1, self.n_head, self.head_dim))
self.proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=False)
def forward(
self,
x: Tensor,
freqs_cis: Tensor,
step_idx: int,
mask: Optional[Tensor] = None,
past_key_values: Optional[Cache] = None,
) -> Tensor:
B, S, E = x.shape # batch size, sequence length, embedding dimensionality (n_embd)
q, k, v = self.Wqkv(x).split(self.chunks, dim=2)
q = q.view(B, S, self.n_head, self.head_dim)
k = k.view(B, S, self.n_kv_heads, self.head_dim)
v = v.view(B, S, self.n_kv_heads, self.head_dim)
# bias?
if self.config.qk_bias:
q_bias, k_bias = self.qk_bias.split(1, dim=0)
q, k = (q + q_bias).to(q.dtype), (k + k_bias).to(q.dtype)
# apply rotary
q, k = apply_rotary_emb_complex_like(q, k, freqs_cis=freqs_cis)
q = q.transpose(1, 2) # (B, nh, S, hs)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if past_key_values is not None:
k, v = past_key_values.update(k, v, step_idx)
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=q.shape[2] > 1
)
y = y.transpose(1, 2).reshape(B, S, E).contiguous() # reshape is a view if possible (it mostly is)
return self.proj(y)
class GatedMLP(torch.nn.Module):
def __init__(self, config: RavenConfig, in_features: int = 0) -> None:
super().__init__()
in_features = config.n_embd if in_features == 0 else in_features
self.fc = torch.nn.Linear(in_features, config.intermediate_size * 2, bias=False)
self.proj = torch.nn.Linear(config.intermediate_size, config.n_embd, bias=False)
self.nonlin = torch.nn.SiLU()
def forward(self, x: Tensor) -> Tensor:
# modified to single FC layer to improve parallelism
x_fc_1, x_fc_2 = self.fc(x).chunk(2, dim=-1)
x = self.nonlin(x_fc_1) * x_fc_2
return self.proj(x)
class SandwichBlock(torch.nn.Module):
expanded = False
def __init__(self, config: RavenConfig, layer_id: int) -> None:
super().__init__()
self.norm_1 = RMSNorm(config.n_embd, eps=config.norm_eps)
self.attn = CausalSelfAttention(config)
self.norm_2 = RMSNorm(config.n_embd, eps=config.norm_eps)
self.mlp = GatedMLP(config)
self.norm_3 = RMSNorm(config.n_embd, eps=config.norm_eps)
self.norm_4 = RMSNorm(config.n_embd, eps=config.norm_eps)
self.layer_id = layer_id
def forward(
self,
x: Tensor,
freqs_cis: Tensor,
step_idx: int,
mask: Optional[Tensor] = None,
past_key_values: Optional[Cache] = None,
) -> Tensor:
x = self.norm_2(self.attn(self.norm_1(x), freqs_cis, step_idx, mask, past_key_values) + x)
x = self.norm_4(self.mlp(self.norm_3(x)) + x)
return x
class RavenForCausalLM(RavenPreTrainedModel):
def __init__(
self,
config: RavenConfig,
) -> None:
super().__init__(config)
self.config = config
# Transformer layers
prelude = torch.nn.ModuleList(SandwichBlock(config, layer_id=i) for i in range(config.n_layers_in_prelude))
adapter = torch.nn.Linear(config.n_embd * 2, config.n_embd, bias=config.bias)
core_block = torch.nn.ModuleList(
SandwichBlock(config, layer_id=i + config.n_layers_in_prelude)
for i in range(config.n_layers_in_recurrent_block)
)
o = config.n_layers_in_prelude + config.n_layers_in_recurrent_block * config.mean_recurrence
coda = torch.nn.ModuleList(SandwichBlock(config, layer_id=i + o) for i in range(config.n_layers_in_coda))
self.transformer = torch.nn.ModuleDict(
dict(
wte=torch.nn.Embedding(config.padded_vocab_size, config.n_embd),
prelude=prelude,
adapter=adapter,
core_block=core_block,
coda=coda,
ln_f=RMSNorm(config.n_embd, eps=config.norm_eps), # used twice :>
)
)
self.emb_scale = config.init_values["embed_scale"]
# Head
self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
if self.config.tie_embeddings:
self.lm_head.weight = self.transformer.wte.weight
# rope
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
def _precompute_freqs_cis(self):
# can actually be a buffer now, and remains in fp32! (at least in the settings I tested)
freqs_cis = precompute_freqs_cis(
self.config.n_embd // self.config.num_attention_heads, self.config.block_size, self.config.rope_base, 1
)
return freqs_cis
def forward(
self,
input_ids: torch.Tensor,
input_embeds: Optional[torch.Tensor] = None,
input_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
num_steps_pair: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
output_details: dict = {
"return_logits": True,
"return_latents": True,
"return_attention": False,
"return_head": False,
"return_stats": True,
},
use_cache: bool = False,
cache_position: Optional[torch.Tensor] = None,
) -> dict[str, Optional[torch.Tensor]]:
if position_ids is None and cache_position is None:
freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
elif position_ids is not None:
freqs_cis = self.freqs_cis.index_select(1, position_ids)
elif cache_position is not None: # support HF format
freqs_cis = self.freqs_cis[:, cache_position : cache_position + 1]
if input_embeds is None:
input_embeds = self.transformer.wte(input_ids)
if self.emb_scale != 1:
input_embeds = input_embeds * self.emb_scale # type: ignore
if use_cache and past_key_values is None:
past_key_values = HuginnDynamicCache()
# Non-recurrent prelude
for block_idx, block in enumerate(self.transformer.prelude):
input_embeds = block(input_embeds, freqs_cis, block_idx, attention_mask, past_key_values)
# Main recurrence
x, num_steps_no_grad, num_steps_with_grad, xk = self.iterate_forward(
input_embeds, # type: ignore
input_states,
freqs_cis,
block_idx,
attention_mask,
past_key_values,
num_steps_pair,
)
latent_states = x.clone().detach()
# Coda layers
for block_idx, block in enumerate(self.transformer.coda, start=1):
x = block(x, freqs_cis, -block_idx, attention_mask, past_key_values)
x = self.transformer.ln_f(x)
# Prediction head, assuming labels really are labels and not equal to input_ids
if labels is not None:
logits = self.lm_head(x).float()
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1))
log_ppl = loss.clone().detach()
else:
logits = self.lm_head(x).float()
loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0)
return CausalLMOutputRecurrentLatents(
loss=loss,
log_ppl=log_ppl,
logits=logits if output_details["return_logits"] else None,
past_key_values=past_key_values,
hidden_states=x if output_details["return_head"] else None,
latent_states=latent_states if output_details["return_latents"] else None,
attention_maps=ValueError() if output_details["return_attention"] else None, # type: ignore
stats=self.get_stats(logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad)
if output_details["return_stats"]
else None,
)
@torch._dynamo.disable(recursive=False) # type: ignore
def iterate_forward(
self,
input_embeds,
input_states,
freqs_cis,
block_idx,
mask,
past_key_values: Optional[Cache] = None,
num_steps_pair: Optional[torch.Tensor] = None,
):
x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
if num_steps_pair is None:
num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
elif len(num_steps_pair) > 1:
num_steps_no_grad, num_steps_with_grad = num_steps_pair
else:
num_steps_no_grad, num_steps_with_grad = num_steps_pair, torch.tensor(0)
with torch.no_grad():
# ultra annoying in ddp due to
# https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594
# for now running with find_unused_params=True enabled even though the graph structure is (technically) clear
# and all parameters are always used
for step in range(num_steps_no_grad):
xk = x
x, block_idx = self.core_block_forward(xk, input_embeds, freqs_cis, mask, past_key_values, block_idx)
for step in range(num_steps_with_grad):
xk = x
x, block_idx = self.core_block_forward(xk, input_embeds, freqs_cis, mask, past_key_values, block_idx)
return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach()
def core_block_forward(
self, x, input_embeds, freqs_cis, mask, past_key_values, block_idx: Union[torch.Tensor, int]
):
x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
for idx, block in enumerate(self.transformer.core_block, start=1):
x = block(x, freqs_cis, block_idx + idx, mask, past_key_values)
return x, block_idx + idx
@torch._dynamo.disable(recursive=False) # type: ignore
def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Outputs are long tensors so that they can be passed through compiled functions"""
t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0)
s = self.config.mean_backprop_depth
if self.training:
sigma = 0.5
mu = math.log(t + s) - (sigma**2 / 2)
rate = torch.zeros((1,)).log_normal_(mean=mu, std=sigma)
p = torch.poisson(torch.tensor([rate], dtype=torch.float)) + 1
n = torch.clamp(p - s, min=0)
k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p))
else:
n, k = torch.as_tensor(self.config.mean_recurrence), torch.as_tensor(0)
return n.to(dtype=torch.long), k.to(dtype=torch.long)
def initialize_state(self, input_embeds):
x = torch.randn_like(input_embeds)
std = self.config.init_values["std"]
torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
if self.emb_scale != 1:
x = x * self.emb_scale
return x
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
model_inputs = {}
model_inputs["cache_position"] = cache_position
current_input_length = model_inputs["input_ids"].shape[1]
if past_key_values is not None:
model_inputs["past_key_values"] = past_key_values
input_ids = input_ids[:, cache_position] # type: ignore
model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
position_ids = torch.arange(current_input_length)[None, :]
model_inputs["positions_ids"] = position_ids[:, -current_input_length:].clone(
memory_format=torch.contiguous_format
) # positions_ids is a critical argument for the model to correctly apply rope!
# forward all other entries
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value
return model_inputs
def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
probs = torch.softmax(logits.float(), dim=-1)
prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1)
residual_diff = (x - latent_states).norm(dim=-1)
rel_residual = residual_diff / latent_states.norm(dim=-1)
stats = {
"entropy": prob_entropy,
"residual_diff": residual_diff,
"rel_residual": rel_residual,
"num_steps_no_grad": num_steps_no_grad,
"num_steps_with_grad": num_steps_with_grad,
}
return stats
#################################### Utils #######################################################################
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, condense_ratio: int = 1):
with torch.autocast("cuda", enabled=False):
inv_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
t = torch.arange(end, dtype=torch.float32, device=inv_freqs.device) / condense_ratio
freqs = torch.outer(t, inv_freqs).float()
return torch.stack([torch.cos(freqs)[None, :, None, :], torch.sin(freqs)[None, :, None, :]], dim=4)
# equivalent to
# freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
# cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
def apply_rotary_emb_complex_like(q: Tensor, k: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
with torch.autocast("cuda", enabled=False):
qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() # cast to float32 for smooth skin
rotated_qk_r2 = torch.stack(
[
qk_r2[..., 0] * freqs_cis[..., 0] - qk_r2[..., 1] * freqs_cis[..., 1],
qk_r2[..., 1] * freqs_cis[..., 0] + qk_r2[..., 0] * freqs_cis[..., 1],
],
-1,
).flatten(3)
rotated_qk = rotated_qk_r2
return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) # type: ignore
#################################### HF registration ############################################################
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
# New
RavenConfig.register_for_auto_class()
RavenForCausalLM.register_for_auto_class("AutoModel")
RavenForCausalLM.register_for_auto_class("AutoModelForCausalLM")
# Old?
AutoConfig.register("huginn_raven", RavenConfig)
AutoModel.register(RavenConfig, RavenForCausalLM)
AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM)