ARWKV-7B-Preview-0.1 / hybrid_cache.py
zhiyuan8's picture
Add files using upload-large-folder tool
c1a12af verified
import torch
from typing import Any, Dict, Optional, Union
from transformers.cache_utils import DynamicCache
class TimeMixState:
def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor):
self.shift_state = shift_state
self.wkv_state = wkv_state
class ChannelMixState:
def __init__(self, shift_state: torch.Tensor):
self.shift_state = shift_state
class BlockState:
def __init__(self, time_mix_state: TimeMixState,
channel_mix_state: ChannelMixState):
self.time_mix_state = time_mix_state
self.channel_mix_state = channel_mix_state
class BlockStateList:
def __init__(self, shift_states, wkv_states):
self.wkv_states = wkv_states
self.shift_states = shift_states
@staticmethod
def create(N, B, C, H, device, dtype):
result = BlockStateList.empty(N, B, C, H, device, dtype)
result.wkv_states[:] = 0
result.wkv_states[:] = 0
result.shift_states[:] = 0
return result
@staticmethod
def empty(N, B, C, H, device, dtype):
wkv_states = torch.empty((N, B, H, C//H, C//H),
device=device,
dtype=torch.bfloat16)
shift_states = torch.empty((N, 2, B, C), device=device, dtype=dtype)
return BlockStateList(shift_states, wkv_states)
def __getitem__(self, layer: int):
return BlockState(
TimeMixState(self.shift_states[layer, 0], self.wkv_states[layer]),
ChannelMixState(self.shift_states[layer, 1]))
def __setitem__(self, layer: int, state: BlockState):
self.shift_states[layer, 0] = state.time_mix_state.shift_state
self.wkv_states[layer] = state.time_mix_state.wkv_state
self.shift_states[layer, 1] = state.channel_mix_state.shift_state
class HybridCache(DynamicCache):
def __init__(self) -> None:
super().__init__()
self.rwkv_layers = set()
def __repr__(self) -> str:
rwkv_layers = f"HybridCache(rwkv_layers={self.rwkv_layers})"
# count the number of key_cache and value_cache
key_cache_count = sum(len(cache) for cache in self.key_cache)
value_cache_count = sum(len(cache) for cache in self.value_cache)
count_info = rwkv_layers + \
f", key_cache_count={key_cache_count}, value_cache_count={value_cache_count}"
memories = 0
seq_length = self.get_seq_length()
for cache in self.value_cache:
for data in cache:
if not isinstance(data, torch.Tensor):
memories += data.time_mix_state.wkv_state.numel()
else:
memories += data.numel()
count_info += f", memories={memories / 1024/1024}MB, seq_length={seq_length}"
return count_info
def update(self,
key_states: Union[int, torch.Tensor],
value_states: Union[torch.Tensor, BlockState],
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None):
if isinstance(key_states, int) and not isinstance(value_states, torch.Tensor):
self.rwkv_layers.add(layer_idx)
if layer_idx >= len(self.key_cache):
self.key_cache.append([])
self.value_cache.append([])
if len(self.key_cache[layer_idx]) == 0:
self.key_cache[layer_idx].append(key_states)
self.value_cache[layer_idx].append(value_states)
else:
self.key_cache[layer_idx][0] = self.key_cache[layer_idx][0]+key_states
self.value_cache[layer_idx][0] = value_states
return key_states, value_states
return super().update(key_states, value_states, layer_idx, cache_kwargs)
def get_seq_length(self, layer_idx: Optional[int] = 0):
if layer_idx in self.rwkv_layers:
return self.key_cache[layer_idx][0]
return super().get_seq_length(layer_idx)
def get_max_length(self):
return super().get_max_length()
def reorder_cache(self, beam_idx):
return super().reorder_cache(beam_idx)
def __getitem__(self, item):
if item in self.rwkv_layers:
return self.value_cache[item]
return super().__getitem__(item)
def offload_to_cpu(self):
for cache in self.value_cache:
for data in cache:
if isinstance(data, torch.Tensor):
data.cpu()
else:
data.time_mix_state.wkv_state.cpu()
data.time_mix_state.shift_state.cpu()
def offload_to_cuda(self, device: str):
for cache in self.value_cache:
for data in cache:
if isinstance(data, torch.Tensor):
data.cuda(device)
else:
data.time_mix_state.wkv_state.cuda(device)
data.time_mix_state.shift_state.cuda(device)
def offload_to_device(self, device_type: str, device_id: int = 0):
for cache in self.value_cache:
for data in cache:
if isinstance(data, torch.Tensor):
method = getattr(data, device_type)
if device_type == 'cpu':
method()
else:
method(device_id)
else:
wkv_state_method = getattr(
data.time_mix_state.wkv_state, device_type)
shift_state_method = getattr(
data.time_mix_state.shift_state, device_type)
if device_type == 'cpu':
wkv_state_method()
shift_state_method()
else:
wkv_state_method(device_id)
shift_state_method(device_id)