Skip to content

Instantly share code, notes, and snippets.

@leogao2
Last active April 14, 2023 17:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save leogao2/ef6afb5530eaf948b8602faae961fda0 to your computer and use it in GitHub Desktop.
Save leogao2/ef6afb5530eaf948b8602faae961fda0 to your computer and use it in GitHub Desktop.
from functools import lru_cache
import os
import time
from typing import Callable, List
import numpy as np
import scipy.special
from tqdm import tqdm
import tiktoken
def compute_shapley_2d(
engine, tokenizer, toks, mask_positions, iters_each: int = 1000, api_base=None
):
"""
Compute Shapley values of logprob contribution of each token on each future token.
:param engine: api engine
:param tokenizer: tokenizer
:param toks: list of tokens
:param mask_positions: list of positions to mask (useful for only masking some tokens, like only numbers)
:param iters_each: number of iterations to run for each token
:param api_base: api_base for openai api
:return: 2d array of shapley values
"""
res = np.zeros((len(mask_positions), len(mask_positions)))
res.fill(np.nan)
for i in range(len(mask_positions) - 1, 0, -1):
inputs_pos, output_pos = mask_positions[:i], mask_positions[i]
shap = compute_shapley(
xs=inputs_pos,
value_fn=lambda x: oa_value(
engine,
toks[: output_pos + 1],
[y for y in inputs_pos if y not in x],
underscore_ablate,
tokenizer,
api_base=api_base,
),
iters_each=iters_each,
)
# pad shap to the right length
res[:, i] = np.pad(shap, (0, len(mask_positions) - len(shap)), "constant")
return res
def underscore_ablate(tok, tokenizer):
# NOTE: for GPT4, spaces are not merged into digits! so this results in out of distribution [" ", "_"]
return tokenizer.encode(" _" if tokenizer.decode([tok])[0] == " " else "_")[0]
def compute_shapley(
xs: list, value_fn: Callable[[list], float], iters_each: int = 1000
) -> List[float]:
"""
Compute Shapley values for a given list of objects.
:param xs: list of objects
:param value_fn: function that takes a list of objects and returns a float
"""
# todo: calculate iters_each with proper bounds
import random
shap = [[] for _ in range(len(xs))]
pbar = tqdm(total=iters_each * len(xs))
for i in range(len(xs)):
for j in range(iters_each):
pbar.update(1)
random.seed(f"{i}_{j}")
# masked -> not shown to model
keep = [random.random() < 0.5 for _ in range(len(xs))]
keep_with = keep.copy()
keep_with[i] = True
keep_without = keep.copy()
keep_without[i] = False
S = sum(keep_without)
binomial = np.exp(
logfactorial(S)
+ logfactorial(len(xs) - S - 1)
- logfactorial(len(xs))
+ (len(xs) - 1) * np.log(2)
)
# todo: more principled way to do this
reject_prob = 0.9 if binomial < 1 else 0.75 if binomial < 10 else 0
random.seed(f"{i}_{j}_reject")
if random.random() < reject_prob:
continue
weight = 1 / (1 - reject_prob)
delta_value = value_fn([x for x, m in zip(xs, keep_with) if m]) - value_fn(
[x for x, m in zip(xs, keep_without) if m]
)
shap[i].append((delta_value * binomial, weight))
# weighted mean of shap
return [
sum(delta_value * weight for delta_value, weight in shap[i])
/ sum(weight for _, weight in shap[i])
if len(shap[i])
else 0
for i in range(len(xs))
]
oa_value_cache = {}
def oa_value(
engine: str,
tokens: List[int],
masked_positions: List[int],
mask_fn: Callable[[int, tiktoken.Encoding], int],
tokenizer,
api_base=None,
) -> float:
"""
Compute the logprob on the final token when the given positions are masked. Caches results for all
intermediate tokens, which lets us get shapley values at all intermediate tokens basically for free.
:param engine: api engine
:param tokens: list of tokens
:param masked_positions: list of positions to mask
:param mask_fn: function that takes a token and returns a masked token (useful for custom masking)
:param tokenizer: tokenizer
:param api_base: api_base for openai api
"""
orig_tokens = tokens
tokens = tokens.copy()
for pos in masked_positions:
tokens[pos] = mask_fn(tokens[pos], tokenizer)
cache_key = (engine, api_base)
if cache_key + tuple(tokens) in oa_value_cache:
return oa_value_cache[cache_key + tuple(tokens)]
res = oa_completion(
engine=engine, api_base=api_base, prompt=tuple(tokens), max_tokens=0, echo=True, logprobs=10
)
logprobs = res["choices"][0]["logprobs"]["token_logprobs"]
assert len(logprobs) == len(tokens)
assert logprobs[0] is None
# todo: implement with something efficient like a trie
for i in range(1, len(tokens)):
oa_value_cache[cache_key + tuple(tokens[: i + 1])] = logprobs[i]
# if the unmasked version is in logprobs in res, then cache that as well
# this is necessary for the efficient caching trick because we always
# want to record the unmasked logprob for the final token
if tokenizer.decode([orig_tokens[i]]) in res["choices"][0]["logprobs"]["top_logprobs"][i]:
oa_value_cache[cache_key + tuple(tokens[:i]) + (orig_tokens[i],)] = res["choices"][0][
"logprobs"
]["top_logprobs"][i][tokenizer.decode([orig_tokens[i]])]
return oa_value_cache[cache_key + tuple(tokens)]
@lru_cache(maxsize=None)
def logfactorial(x):
return scipy.special.gammaln(x + 1)
@lru_cache(maxsize=None)
def oa_completion(**kwargs):
"""Query OpenAI API for completion.
Retry with back-off until they respond
"""
import openai
import openai.error
openai.api_key = os.getenv("OPENAI_API_KEY")
default_api_base = "https://api.openai.com/v1"
openai.api_base = kwargs.pop("api_base", default_api_base) or default_api_base
backoff_time = 3
while True:
try:
return openai.Completion.create(**kwargs)
except openai.error.OpenAIError:
import traceback
traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment