File size: 5,912 Bytes
c1a12af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch
from typing import Any, Dict, Optional, Union
from transformers.cache_utils import DynamicCache
class TimeMixState:
def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor):
self.shift_state = shift_state
self.wkv_state = wkv_state
class ChannelMixState:
def __init__(self, shift_state: torch.Tensor):
self.shift_state = shift_state
class BlockState:
def __init__(self, time_mix_state: TimeMixState,
channel_mix_state: ChannelMixState):
self.time_mix_state = time_mix_state
self.channel_mix_state = channel_mix_state
class BlockStateList:
def __init__(self, shift_states, wkv_states):
self.wkv_states = wkv_states
self.shift_states = shift_states
@staticmethod
def create(N, B, C, H, device, dtype):
result = BlockStateList.empty(N, B, C, H, device, dtype)
result.wkv_states[:] = 0
result.wkv_states[:] = 0
result.shift_states[:] = 0
return result
@staticmethod
def empty(N, B, C, H, device, dtype):
wkv_states = torch.empty((N, B, H, C//H, C//H),
device=device,
dtype=torch.bfloat16)
shift_states = torch.empty((N, 2, B, C), device=device, dtype=dtype)
return BlockStateList(shift_states, wkv_states)
def __getitem__(self, layer: int):
return BlockState(
TimeMixState(self.shift_states[layer, 0], self.wkv_states[layer]),
ChannelMixState(self.shift_states[layer, 1]))
def __setitem__(self, layer: int, state: BlockState):
self.shift_states[layer, 0] = state.time_mix_state.shift_state
self.wkv_states[layer] = state.time_mix_state.wkv_state
self.shift_states[layer, 1] = state.channel_mix_state.shift_state
class HybridCache(DynamicCache):
def __init__(self) -> None:
super().__init__()
self.rwkv_layers = set()
def __repr__(self) -> str:
rwkv_layers = f"HybridCache(rwkv_layers={self.rwkv_layers})"
# count the number of key_cache and value_cache
key_cache_count = sum(len(cache) for cache in self.key_cache)
value_cache_count = sum(len(cache) for cache in self.value_cache)
count_info = rwkv_layers + \
f", key_cache_count={key_cache_count}, value_cache_count={value_cache_count}"
memories = 0
seq_length = self.get_seq_length()
for cache in self.value_cache:
for data in cache:
if not isinstance(data, torch.Tensor):
memories += data.time_mix_state.wkv_state.numel()
else:
memories += data.numel()
count_info += f", memories={memories / 1024/1024}MB, seq_length={seq_length}"
return count_info
def update(self,
key_states: Union[int, torch.Tensor],
value_states: Union[torch.Tensor, BlockState],
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None):
if isinstance(key_states, int) and not isinstance(value_states, torch.Tensor):
self.rwkv_layers.add(layer_idx)
if layer_idx >= len(self.key_cache):
self.key_cache.append([])
self.value_cache.append([])
if len(self.key_cache[layer_idx]) == 0:
self.key_cache[layer_idx].append(key_states)
self.value_cache[layer_idx].append(value_states)
else:
self.key_cache[layer_idx][0] = self.key_cache[layer_idx][0]+key_states
self.value_cache[layer_idx][0] = value_states
return key_states, value_states
return super().update(key_states, value_states, layer_idx, cache_kwargs)
def get_seq_length(self, layer_idx: Optional[int] = 0):
if layer_idx in self.rwkv_layers:
return self.key_cache[layer_idx][0]
return super().get_seq_length(layer_idx)
def get_max_length(self):
return super().get_max_length()
def reorder_cache(self, beam_idx):
return super().reorder_cache(beam_idx)
def __getitem__(self, item):
if item in self.rwkv_layers:
return self.value_cache[item]
return super().__getitem__(item)
def offload_to_cpu(self):
for cache in self.value_cache:
for data in cache:
if isinstance(data, torch.Tensor):
data.cpu()
else:
data.time_mix_state.wkv_state.cpu()
data.time_mix_state.shift_state.cpu()
def offload_to_cuda(self, device: str):
for cache in self.value_cache:
for data in cache:
if isinstance(data, torch.Tensor):
data.cuda(device)
else:
data.time_mix_state.wkv_state.cuda(device)
data.time_mix_state.shift_state.cuda(device)
def offload_to_device(self, device_type: str, device_id: int = 0):
for cache in self.value_cache:
for data in cache:
if isinstance(data, torch.Tensor):
method = getattr(data, device_type)
if device_type == 'cpu':
method()
else:
method(device_id)
else:
wkv_state_method = getattr(
data.time_mix_state.wkv_state, device_type)
shift_state_method = getattr(
data.time_mix_state.shift_state, device_type)
if device_type == 'cpu':
wkv_state_method()
shift_state_method()
else:
wkv_state_method(device_id)
shift_state_method(device_id)
|