Skip to content

Instantly share code, notes, and snippets.

@patrickvonplaten
Created October 3, 2023 16:25
Show Gist options
  • Save patrickvonplaten/7411f84b8a2cca3bc8e63df315d7d618 to your computer and use it in GitHub Desktop.
Save patrickvonplaten/7411f84b8a2cca3bc8e63df315d7d618 to your computer and use it in GitHub Desktop.
Transformers new cache design
class DynamicCache: # <- this is what we currently have
def __init__(self):
self.cache = {}
def update(self, key_states, value_states, layer_idx):
kv_states = torch.cat([key_states[None, :], value_states[None, :], dim=0)
if layer_idx not in self.cache:
self.cache[layer_idx] = kv_states
else:
self.cache[layer_idx] = torch.cat([self.cache[layer_idx], kv_states], dim=0)
return self.cache[layer_idx].split(0) # return key, value states
class SinkCache:
def __init__(self, window_length, num_sink_tokens):
self.is_prefill = False
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.index = torch.arange(num_sink_tokens, window_length)
self.cache = {}
def update(self, key_states, value_states, layer_idx):
if layer_idx not in self.cache:
# first in
sink_keys = key_states[:self.num_sink_tokens]
sink_values = value_states[:self.num_sink_tokens]
cached_keys = torch.cat([sink_keys, key_states[:, -self.window_length:]], dim=-1])
cached_values = torch.cat([sink_values, values_states[:, -self.window_length:]], dim=-1])
self.cache[layer_idx] = torch.cat([cached_keys[None, :], cached_values[None, :], dim=0])
elif key_states.shape[1] < index.shape[-1] + self.num_sink_tokens:
# auto-regressive
key_len = key_states.shape[1]
# roll cache to the left
self.cache[layer_idx]._index_copy(0, self.index[:key_len], self.cache[layer_idx][0][self.num_sink_tokens + key_len:])
self.cache[layer_idx]._index_copy(1, self.index[:key_len], self.cache[layer_idx][1][self.num_sink_tokens + key_len:])
# add new tokens
self.cache[layer_idx]._index_copy(0, self.index[-key_len:], key_states)
self.cache[layer_idx]._index_copy(1, self.index[-key_len:], value_states)
else:
self.cache[layer_idx]._index_copy(0, self.index, key_states[:, :self.window_length - self.num_sink_tokens])
self.cache[layer_idx]._index_copy(1, self.index, value_states[:, :self.window_length - self.num_sink_tokens])
return self.cache[layer_idx].split(0) # return key, value states
cache = SinkCache(window_length=256, num_sink_tokens=3)
llama(input_ids, past_key_values=cache) # llama will internally call cache.update(...) at every layer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment