Skip to content

Instantly share code, notes, and snippets.

import json
import time
import numpy as np
import torch
from typing import List, Tuple, Union
from torch.nn import functional as F
from numpy.linalg import norm
from RWKV_model import tokenizer, RWKV_RNN, args_430M
from util.sampling import tail_free_sampling
@saharNooby
saharNooby / simple_bpe_tokenizer.py
Last active May 5, 2023 18:57
Probably the dumbest, no-dependencies, pure Python implementation of 20B_tokenizer.json (a BPE tokenizer for GPT-NeoX model)
import regex
import json
import unicodedata
from typing import Tuple, Callable, Union
# Parses the tokenizer config and returns encode and decode functions.
def load_tokenizer(config_path: str) -> Tuple[Callable[[str], list[int]], Callable[[list[int]], str]]:
# Maps any byte 0..255 to a printable Unicode character.
byte_to_unicode: dict[int, str] = {
33: "!",
# USAGE EXAMPLE
logits = llm(...) # Get raw logits from an LLM
logits = tail_free_sampling(z=0.95) # Cut off logits in the tail
token = sample(logits, temperature=1.0) # Do your usual sampling with temp/top-p
def tail_free_sampling(logits: torch.Tensor, z: float = 0.95, mask_value: float = -float('inf')) -> torch.Tensor:
"""
See https://www.trentonbricken.com/Tail-Free-Sampling/
Code copied from https://github.com/finetunej/transformers/blob/c83109932f4592b871ec4c60326df3b4173b021a/src/transformers/generation_logits_process.py#L243-L284

Setup for measuring perplexity and latency of rwkv.cpp implementation of RWKV:

Caveat: "perplexity" here is defined simply as "exp of average per-token cross-entropy loss". This may or may not be correct.

@saharNooby
saharNooby / RWKV_cache.py
Last active May 3, 2023 20:58
State cache for RWKV language model
# USAGE EXAMPLE
cache = RWKV_Cache()
init_out, init_state = cache.preprocess_prompt(model, prompt_tokens)
for GENERATION in range(NUM_GENERATIONS):
out, state = init_out.clone(), init_state.clone()
cache_key = [*prompt_tokens]