|
import torch |
|
from typing import Any, Dict, Optional, Union |
|
from transformers.cache_utils import DynamicCache |
|
|
|
|
|
class TimeMixState: |
|
def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor): |
|
self.shift_state = shift_state |
|
self.wkv_state = wkv_state |
|
|
|
|
|
class ChannelMixState: |
|
def __init__(self, shift_state: torch.Tensor): |
|
self.shift_state = shift_state |
|
|
|
|
|
class BlockState: |
|
def __init__(self, time_mix_state: TimeMixState, |
|
channel_mix_state: ChannelMixState): |
|
self.time_mix_state = time_mix_state |
|
self.channel_mix_state = channel_mix_state |
|
|
|
|
|
class BlockStateList: |
|
def __init__(self, shift_states, wkv_states): |
|
self.wkv_states = wkv_states |
|
self.shift_states = shift_states |
|
|
|
@staticmethod |
|
def create(N, B, C, H, device, dtype): |
|
result = BlockStateList.empty(N, B, C, H, device, dtype) |
|
result.wkv_states[:] = 0 |
|
result.wkv_states[:] = 0 |
|
result.shift_states[:] = 0 |
|
return result |
|
|
|
@staticmethod |
|
def empty(N, B, C, H, device, dtype): |
|
wkv_states = torch.empty((N, B, H, C//H, C//H), |
|
device=device, |
|
dtype=torch.bfloat16) |
|
shift_states = torch.empty((N, 2, B, C), device=device, dtype=dtype) |
|
return BlockStateList(shift_states, wkv_states) |
|
|
|
def __getitem__(self, layer: int): |
|
return BlockState( |
|
TimeMixState(self.shift_states[layer, 0], self.wkv_states[layer]), |
|
ChannelMixState(self.shift_states[layer, 1])) |
|
|
|
def __setitem__(self, layer: int, state: BlockState): |
|
self.shift_states[layer, 0] = state.time_mix_state.shift_state |
|
self.wkv_states[layer] = state.time_mix_state.wkv_state |
|
self.shift_states[layer, 1] = state.channel_mix_state.shift_state |
|
|
|
|
|
class HybridCache(DynamicCache): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.rwkv_layers = set() |
|
|
|
def __repr__(self) -> str: |
|
rwkv_layers = f"HybridCache(rwkv_layers={self.rwkv_layers})" |
|
|
|
key_cache_count = sum(len(cache) for cache in self.key_cache) |
|
value_cache_count = sum(len(cache) for cache in self.value_cache) |
|
count_info = rwkv_layers + \ |
|
f", key_cache_count={key_cache_count}, value_cache_count={value_cache_count}" |
|
memories = 0 |
|
seq_length = self.get_seq_length() |
|
for cache in self.value_cache: |
|
for data in cache: |
|
if not isinstance(data, torch.Tensor): |
|
memories += data.time_mix_state.wkv_state.numel() |
|
else: |
|
memories += data.numel() |
|
count_info += f", memories={memories / 1024/1024}MB, seq_length={seq_length}" |
|
return count_info |
|
|
|
def update(self, |
|
key_states: Union[int, torch.Tensor], |
|
value_states: Union[torch.Tensor, BlockState], |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None): |
|
if isinstance(key_states, int) and not isinstance(value_states, torch.Tensor): |
|
self.rwkv_layers.add(layer_idx) |
|
if layer_idx >= len(self.key_cache): |
|
self.key_cache.append([]) |
|
self.value_cache.append([]) |
|
|
|
if len(self.key_cache[layer_idx]) == 0: |
|
self.key_cache[layer_idx].append(key_states) |
|
self.value_cache[layer_idx].append(value_states) |
|
else: |
|
self.key_cache[layer_idx][0] = self.key_cache[layer_idx][0]+key_states |
|
self.value_cache[layer_idx][0] = value_states |
|
|
|
return key_states, value_states |
|
|
|
return super().update(key_states, value_states, layer_idx, cache_kwargs) |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0): |
|
if layer_idx in self.rwkv_layers: |
|
return self.key_cache[layer_idx][0] |
|
return super().get_seq_length(layer_idx) |
|
|
|
def get_max_length(self): |
|
return super().get_max_length() |
|
|
|
def reorder_cache(self, beam_idx): |
|
return super().reorder_cache(beam_idx) |
|
|
|
def __getitem__(self, item): |
|
if item in self.rwkv_layers: |
|
return self.value_cache[item] |
|
return super().__getitem__(item) |
|
|
|
def offload_to_cpu(self): |
|
for cache in self.value_cache: |
|
for data in cache: |
|
if isinstance(data, torch.Tensor): |
|
data.cpu() |
|
else: |
|
data.time_mix_state.wkv_state.cpu() |
|
data.time_mix_state.shift_state.cpu() |
|
|
|
def offload_to_cuda(self, device: str): |
|
for cache in self.value_cache: |
|
for data in cache: |
|
if isinstance(data, torch.Tensor): |
|
data.cuda(device) |
|
else: |
|
data.time_mix_state.wkv_state.cuda(device) |
|
data.time_mix_state.shift_state.cuda(device) |
|
|
|
def offload_to_device(self, device_type: str, device_id: int = 0): |
|
for cache in self.value_cache: |
|
for data in cache: |
|
if isinstance(data, torch.Tensor): |
|
method = getattr(data, device_type) |
|
if device_type == 'cpu': |
|
method() |
|
else: |
|
method(device_id) |
|
else: |
|
wkv_state_method = getattr( |
|
data.time_mix_state.wkv_state, device_type) |
|
shift_state_method = getattr( |
|
data.time_mix_state.shift_state, device_type) |
|
if device_type == 'cpu': |
|
wkv_state_method() |
|
shift_state_method() |
|
else: |
|
wkv_state_method(device_id) |
|
shift_state_method(device_id) |
|
|