Skip to content

Instantly share code, notes, and snippets.

@saharNooby
Created May 12, 2023 15:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saharNooby/c4f341cb14d3f9f6c826fa25dd3484d4 to your computer and use it in GitHub Desktop.
Save saharNooby/c4f341cb14d3f9f6c826fa25dd3484d4 to your computer and use it in GitHub Desktop.
import json
import time
import numpy as np
import torch
from typing import List, Tuple, Union
from torch.nn import functional as F
from numpy.linalg import norm
from RWKV_model import tokenizer, RWKV_RNN, args_430M
from util.sampling import tail_free_sampling
########################################################################################################
context = "..."
args = args_430M
NUM_GENERATIONS = 1
TOKENS_PER_GENERATION = 200
# Generation will be SEARCH_K times slower
SEARCH_K = 5
ALPHA = 0.5
USE_TAIL_FREE_SAMPLING = False
DEBUG = False
########################################################################################################
def debug(*args):
if DEBUG:
print(*args)
model = RWKV_RNN(args)
model.warm_up()
representations = []
print('Preprocessing context')
start = time.time()
prompt_tokens = tokenizer.encode(context).ids
prompt_token_count = len(prompt_tokens)
init_out, init_state = None, None
for i in range(prompt_token_count):
init_out, init_state = model.forward(prompt_tokens[i], init_state, save_representation=True)
representations.append(model.representation)
if prompt_token_count < 5 or i % (prompt_token_count // 5) == 0:
print(f'{i}/{prompt_token_count}')
delay = time.time() - start
print('Took %.3f sec, %d tokens in context, %d ms per token' % (delay, prompt_token_count, delay / prompt_token_count * 1000))
def cosine_similarity(x: np.ndarray, y: np.ndarray):
return np.dot(x, y) / (norm(x) * norm(y))
for GENERATION in range(NUM_GENERATIONS):
print(f'\n--- Generation {GENERATION} ---\n')
print(context, end='[')
start = time.time()
out, state = init_out.clone(), init_state.clone()
for TOKEN in range(TOKENS_PER_GENERATION):
debug('')
debug('out', out)
if USE_TAIL_FREE_SAMPLING:
out = tail_free_sampling(out)
debug('out', out)
probs = F.softmax(out, dim=-1).cpu().numpy()
debug('probs', probs)
top_tokens = (-probs).argsort()
debug('top_tokens', top_tokens)
first_zero_index = np.where(probs[top_tokens] == 0.0)[0][0]
debug('first_zero_index', first_zero_index)
top_tokens = top_tokens[:min(max(1, first_zero_index), SEARCH_K)]
else:
probs = F.softmax(out, dim=-1).cpu().numpy()
top_tokens = (-probs).argsort()[:SEARCH_K]
debug('top_tokens', top_tokens)
top_tokens_probs = probs[top_tokens]
debug('top_tokens_probs', top_tokens_probs)
top_tokens_max_similarities = np.zeros_like(top_tokens, dtype=float)
next_states_and_representations: List[Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], None]] = [None] * len(top_tokens)
for j in range(len(top_tokens)):
token = top_tokens[j]
candidate_out, candidate_state = model.forward(token, state.clone(), save_representation=True)
candidate_representation = model.representation
max_similarity = 0
for representation in representations:
max_similarity = max(max_similarity, cosine_similarity(representation.cpu().numpy(), candidate_representation.cpu().numpy()))
top_tokens_max_similarities[j] = max_similarity
next_states_and_representations[j] = (candidate_out, candidate_state, candidate_representation)
debug('top_tokens_max_similarities', top_tokens_max_similarities)
top_tokens_scores = (1 - ALPHA) * top_tokens_probs - ALPHA * top_tokens_max_similarities
debug('top_tokens_scores', top_tokens_scores)
selected_token_index = np.argmax(top_tokens_scores)
debug('selected_token_index', selected_token_index)
token = top_tokens[selected_token_index]
debug('token', token)
out, state, representation = next_states_and_representations[selected_token_index]
representations.append(representation)
if DEBUG:
print(json.dumps(tokenizer.decode([token])))
else:
print(tokenizer.decode([token]), end='')
delay = time.time() - start
print(']\n\nTook %.3f sec, %d ms per token' % (delay, delay / TOKENS_PER_GENERATION * 1000))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment