"""Minimal modeling.py file for HF compatibility and funny zero-shot experiments. Use only for inference.""" import torch import math from torch import Tensor from dataclasses import dataclass from typing import Optional, Union from .raven_config_minimal import RavenConfig from transformers.cache_utils import Cache, DynamicCache ###################### Huggingface Glue code I ################################################################## from transformers import PreTrainedModel from transformers.utils import ModelOutput class RavenPreTrainedModel(PreTrainedModel): config_class = RavenConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["SandwichBlock"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = False _supports_static_cache = False def _init_weights(self, module): print("Random Initialization not implemented.") @dataclass class CausalLMOutputRecurrentLatents(ModelOutput): loss: Optional[torch.Tensor] = None log_ppl: Optional[torch.Tensor] = None logits: Optional[torch.Tensor] = None past_key_values: Optional[Cache] = None latent_states: Optional[torch.Tensor] = None hidden_states: Optional[torch.Tensor] = None attention_maps: Optional[tuple[torch.Tensor, ...]] = None stats: Optional[dict] = None ###################### Minimal implementation from here ############################################################ class RMSNorm(torch.nn.Module): """Saner dtype handling and slightly better for fusion""" def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): with torch.autocast(enabled=False, device_type=x.device.type): return self._norm(x.float()).type_as(x) * self.weight def reset_parameters(self) -> None: torch.nn.init.ones_(self.weight) class HuginnDynamicCache(DynamicCache): def __init__(self) -> None: super().__init__() self._seen_tokens = 0 self.key_cache: dict[int, dict[int, torch.Tensor]] = {} self.value_cache: dict[int, dict[int, torch.Tensor]] = {} # structure: cache[index_of_layer_or_recurrent_step][index_in_sequence] # the cache is held uncoalesced because certain recurrent steps may be missing for some sequence ids if using # per-token adaptive compute. In those cases, the "lookup_strategy" determines how to proceed # Also, It is critical that the head indices do not overlap with the recurrent iteration indices def update( self, key_states: torch.Tensor, value_states: torch.Tensor, step_idx: int, lookup_strategy: str = "latest", ) -> tuple[torch.Tensor, torch.Tensor]: # Init if step_idx not in self.key_cache: self.key_cache[step_idx] = {} self.value_cache[step_idx] = {} # Update the number of seen tokens, we assume that step_idx=0 (first prelude) is always hit if step_idx == 0: self._seen_tokens += key_states.shape[-2] # Add entries to cache for idx, entry in enumerate(key_states.unbind(dim=-2)): assert self._seen_tokens - key_states.shape[-2] + idx not in self.key_cache[step_idx] self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry for idx, entry in enumerate(value_states.unbind(dim=-2)): self.value_cache[step_idx][self._seen_tokens - value_states.shape[-2] + idx] = entry # Materialize past state based on lookup strategy: if len(self.key_cache[step_idx]) == self._seen_tokens: # All entries are present, materialize cache as normal return ( torch.stack(list(self.key_cache[step_idx].values()), dim=-2), torch.stack(list(self.value_cache[step_idx].values()), dim=-2), ) else: # some entries where not previously computed if lookup_strategy == "latest": latest_keys = [] latest_values = [] for token_pos in range(self._seen_tokens): # Find the latest step that has this token position max_step = max((s for s in range(step_idx + 1) if token_pos in self.key_cache[s]), default=None) if max_step is None: raise ValueError(f"No cache entry found for token position {token_pos}") latest_keys.append(self.key_cache[max_step][token_pos]) latest_values.append(self.value_cache[max_step][token_pos]) return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2) elif lookup_strategy == "skip": existing_keys = [] existing_values = [] for token_pos in range(self._seen_tokens): if token_pos in self.key_cache[step_idx]: existing_keys.append(self.key_cache[step_idx][token_pos]) existing_values.append(self.value_cache[step_idx][token_pos]) return torch.stack(existing_keys, dim=-2), torch.stack(existing_values, dim=-2) elif lookup_strategy == "randomized": # sanity check rand_keys = [] rand_values = [] for token_pos in range(self._seen_tokens): # Find steps that have this token position steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]] rand_step = steps[torch.randint(len(steps), (1,))] rand_keys.append(self.key_cache[rand_step][token_pos]) rand_values.append(self.value_cache[rand_step][token_pos]) return torch.stack(rand_keys, dim=-2), torch.stack(rand_values, dim=-2) else: raise ValueError(f"Unknown lookup strategy: {lookup_strategy}") def reset(self) -> None: """Reset the cache state.""" self._seen_tokens = 0 self.key_cache.clear() self.value_cache.clear() def get_seq_length(self, step_idx: int = 0) -> int: return self._seen_tokens class CausalSelfAttention(torch.nn.Module): def __init__(self, config: RavenConfig) -> None: super().__init__() self.config = config self.n_head = config.num_attention_heads self.n_kv_heads = config.num_key_value_heads self.head_dim = config.n_embd // self.n_head shape = (self.n_head + 2 * self.n_kv_heads) * self.head_dim self.chunks = [config.n_embd, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim] self.Wqkv = torch.nn.Linear(config.n_embd, shape, bias=False) if config.qk_bias: self.qk_bias = torch.nn.Parameter(torch.zeros(2, 1, self.n_head, self.head_dim)) self.proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=False) def forward( self, x: Tensor, freqs_cis: Tensor, step_idx: int, mask: Optional[Tensor] = None, past_key_values: Optional[Cache] = None, ) -> Tensor: B, S, E = x.shape # batch size, sequence length, embedding dimensionality (n_embd) q, k, v = self.Wqkv(x).split(self.chunks, dim=2) q = q.view(B, S, self.n_head, self.head_dim) k = k.view(B, S, self.n_kv_heads, self.head_dim) v = v.view(B, S, self.n_kv_heads, self.head_dim) # bias? if self.config.qk_bias: q_bias, k_bias = self.qk_bias.split(1, dim=0) q, k = (q + q_bias).to(q.dtype), (k + k_bias).to(q.dtype) # apply rotary q, k = apply_rotary_emb_complex_like(q, k, freqs_cis=freqs_cis) q = q.transpose(1, 2) # (B, nh, S, hs) k = k.transpose(1, 2) v = v.transpose(1, 2) if past_key_values is not None: k, v = past_key_values.update(k, v, step_idx) y = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=q.shape[2] > 1 ) y = y.transpose(1, 2).reshape(B, S, E).contiguous() # reshape is a view if possible (it mostly is) return self.proj(y) class GatedMLP(torch.nn.Module): def __init__(self, config: RavenConfig, in_features: int = 0) -> None: super().__init__() in_features = config.n_embd if in_features == 0 else in_features self.fc = torch.nn.Linear(in_features, config.intermediate_size * 2, bias=False) self.proj = torch.nn.Linear(config.intermediate_size, config.n_embd, bias=False) self.nonlin = torch.nn.SiLU() def forward(self, x: Tensor) -> Tensor: # modified to single FC layer to improve parallelism x_fc_1, x_fc_2 = self.fc(x).chunk(2, dim=-1) x = self.nonlin(x_fc_1) * x_fc_2 return self.proj(x) class SandwichBlock(torch.nn.Module): expanded = False def __init__(self, config: RavenConfig, layer_id: int) -> None: super().__init__() self.norm_1 = RMSNorm(config.n_embd, eps=config.norm_eps) self.attn = CausalSelfAttention(config) self.norm_2 = RMSNorm(config.n_embd, eps=config.norm_eps) self.mlp = GatedMLP(config) self.norm_3 = RMSNorm(config.n_embd, eps=config.norm_eps) self.norm_4 = RMSNorm(config.n_embd, eps=config.norm_eps) self.layer_id = layer_id def forward( self, x: Tensor, freqs_cis: Tensor, step_idx: int, mask: Optional[Tensor] = None, past_key_values: Optional[Cache] = None, ) -> Tensor: x = self.norm_2(self.attn(self.norm_1(x), freqs_cis, step_idx, mask, past_key_values) + x) x = self.norm_4(self.mlp(self.norm_3(x)) + x) return x class RavenForCausalLM(RavenPreTrainedModel): def __init__( self, config: RavenConfig, ) -> None: super().__init__(config) self.config = config # Transformer layers prelude = torch.nn.ModuleList(SandwichBlock(config, layer_id=i) for i in range(config.n_layers_in_prelude)) adapter = torch.nn.Linear(config.n_embd * 2, config.n_embd, bias=config.bias) core_block = torch.nn.ModuleList( SandwichBlock(config, layer_id=i + config.n_layers_in_prelude) for i in range(config.n_layers_in_recurrent_block) ) o = config.n_layers_in_prelude + config.n_layers_in_recurrent_block * config.mean_recurrence coda = torch.nn.ModuleList(SandwichBlock(config, layer_id=i + o) for i in range(config.n_layers_in_coda)) self.transformer = torch.nn.ModuleDict( dict( wte=torch.nn.Embedding(config.padded_vocab_size, config.n_embd), prelude=prelude, adapter=adapter, core_block=core_block, coda=coda, ln_f=RMSNorm(config.n_embd, eps=config.norm_eps), # used twice :> ) ) self.emb_scale = config.init_values["embed_scale"] # Head self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) if self.config.tie_embeddings: self.lm_head.weight = self.transformer.wte.weight # rope self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) def _precompute_freqs_cis(self): # can actually be a buffer now, and remains in fp32! (at least in the settings I tested) freqs_cis = precompute_freqs_cis( self.config.n_embd // self.config.num_attention_heads, self.config.block_size, self.config.rope_base, 1 ) return freqs_cis def forward( self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None, input_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, num_steps_pair: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, output_details: dict = { "return_logits": True, "return_latents": True, "return_attention": False, "return_head": False, "return_stats": True, }, use_cache: bool = False, cache_position: Optional[torch.Tensor] = None, ) -> dict[str, Optional[torch.Tensor]]: if position_ids is None and cache_position is None: freqs_cis = self.freqs_cis[:, : input_ids.shape[1]] elif position_ids is not None: freqs_cis = self.freqs_cis.index_select(1, position_ids) elif cache_position is not None: # support HF format freqs_cis = self.freqs_cis[:, cache_position : cache_position + 1] if input_embeds is None: input_embeds = self.transformer.wte(input_ids) if self.emb_scale != 1: input_embeds = input_embeds * self.emb_scale # type: ignore if use_cache and past_key_values is None: past_key_values = HuginnDynamicCache() # Non-recurrent prelude for block_idx, block in enumerate(self.transformer.prelude): input_embeds = block(input_embeds, freqs_cis, block_idx, attention_mask, past_key_values) # Main recurrence x, num_steps_no_grad, num_steps_with_grad, xk = self.iterate_forward( input_embeds, # type: ignore input_states, freqs_cis, block_idx, attention_mask, past_key_values, num_steps_pair, ) latent_states = x.clone().detach() # Coda layers for block_idx, block in enumerate(self.transformer.coda, start=1): x = block(x, freqs_cis, -block_idx, attention_mask, past_key_values) x = self.transformer.ln_f(x) # Prediction head, assuming labels really are labels and not equal to input_ids if labels is not None: logits = self.lm_head(x).float() loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1)) log_ppl = loss.clone().detach() else: logits = self.lm_head(x).float() loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0) return CausalLMOutputRecurrentLatents( loss=loss, log_ppl=log_ppl, logits=logits if output_details["return_logits"] else None, past_key_values=past_key_values, hidden_states=x if output_details["return_head"] else None, latent_states=latent_states if output_details["return_latents"] else None, attention_maps=ValueError() if output_details["return_attention"] else None, # type: ignore stats=self.get_stats(logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad) if output_details["return_stats"] else None, ) @torch._dynamo.disable(recursive=False) # type: ignore def iterate_forward( self, input_embeds, input_states, freqs_cis, block_idx, mask, past_key_values: Optional[Cache] = None, num_steps_pair: Optional[torch.Tensor] = None, ): x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone() if num_steps_pair is None: num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore elif len(num_steps_pair) > 1: num_steps_no_grad, num_steps_with_grad = num_steps_pair else: num_steps_no_grad, num_steps_with_grad = num_steps_pair, torch.tensor(0) with torch.no_grad(): # ultra annoying in ddp due to # https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594 # for now running with find_unused_params=True enabled even though the graph structure is (technically) clear # and all parameters are always used for step in range(num_steps_no_grad): xk = x x, block_idx = self.core_block_forward(xk, input_embeds, freqs_cis, mask, past_key_values, block_idx) for step in range(num_steps_with_grad): xk = x x, block_idx = self.core_block_forward(xk, input_embeds, freqs_cis, mask, past_key_values, block_idx) return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach() def core_block_forward( self, x, input_embeds, freqs_cis, mask, past_key_values, block_idx: Union[torch.Tensor, int] ): x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1)) for idx, block in enumerate(self.transformer.core_block, start=1): x = block(x, freqs_cis, block_idx + idx, mask, past_key_values) return x, block_idx + idx @torch._dynamo.disable(recursive=False) # type: ignore def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]: """Outputs are long tensors so that they can be passed through compiled functions""" t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0) s = self.config.mean_backprop_depth if self.training: sigma = 0.5 mu = math.log(t + s) - (sigma**2 / 2) rate = torch.zeros((1,)).log_normal_(mean=mu, std=sigma) p = torch.poisson(torch.tensor([rate], dtype=torch.float)) + 1 n = torch.clamp(p - s, min=0) k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p)) else: n, k = torch.as_tensor(self.config.mean_recurrence), torch.as_tensor(0) return n.to(dtype=torch.long), k.to(dtype=torch.long) def initialize_state(self, input_embeds): x = torch.randn_like(input_embeds) std = self.config.init_values["std"] torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std) if self.emb_scale != 1: x = x * self.emb_scale return x def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ): model_inputs = {} model_inputs["cache_position"] = cache_position current_input_length = model_inputs["input_ids"].shape[1] if past_key_values is not None: model_inputs["past_key_values"] = past_key_values input_ids = input_ids[:, cache_position] # type: ignore model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format) position_ids = torch.arange(current_input_length)[None, :] model_inputs["positions_ids"] = position_ids[:, -current_input_length:].clone( memory_format=torch.contiguous_format ) # positions_ids is a critical argument for the model to correctly apply rope! # forward all other entries for key, value in kwargs.items(): if key not in model_inputs: model_inputs[key] = value return model_inputs def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad): probs = torch.softmax(logits.float(), dim=-1) prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1) residual_diff = (x - latent_states).norm(dim=-1) rel_residual = residual_diff / latent_states.norm(dim=-1) stats = { "entropy": prob_entropy, "residual_diff": residual_diff, "rel_residual": rel_residual, "num_steps_no_grad": num_steps_no_grad, "num_steps_with_grad": num_steps_with_grad, } return stats #################################### Utils ####################################################################### def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, condense_ratio: int = 1): with torch.autocast("cuda", enabled=False): inv_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) t = torch.arange(end, dtype=torch.float32, device=inv_freqs.device) / condense_ratio freqs = torch.outer(t, inv_freqs).float() return torch.stack([torch.cos(freqs)[None, :, None, :], torch.sin(freqs)[None, :, None, :]], dim=4) # equivalent to # freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) def apply_rotary_emb_complex_like(q: Tensor, k: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: with torch.autocast("cuda", enabled=False): qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() # cast to float32 for smooth skin rotated_qk_r2 = torch.stack( [ qk_r2[..., 0] * freqs_cis[..., 0] - qk_r2[..., 1] * freqs_cis[..., 1], qk_r2[..., 1] * freqs_cis[..., 0] + qk_r2[..., 0] * freqs_cis[..., 1], ], -1, ).flatten(3) rotated_qk = rotated_qk_r2 return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) # type: ignore #################################### HF registration ############################################################ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM # New RavenConfig.register_for_auto_class() RavenForCausalLM.register_for_auto_class("AutoModel") RavenForCausalLM.register_for_auto_class("AutoModelForCausalLM") # Old? AutoConfig.register("huginn_raven", RavenConfig) AutoModel.register(RavenConfig, RavenForCausalLM) AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM)