-
-
Save leogao2/ef6afb5530eaf948b8602faae961fda0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import lru_cache | |
import os | |
import time | |
from typing import Callable, List | |
import numpy as np | |
import scipy.special | |
from tqdm import tqdm | |
import tiktoken | |
def compute_shapley_2d( | |
engine, tokenizer, toks, mask_positions, iters_each: int = 1000, api_base=None | |
): | |
""" | |
Compute Shapley values of logprob contribution of each token on each future token. | |
:param engine: api engine | |
:param tokenizer: tokenizer | |
:param toks: list of tokens | |
:param mask_positions: list of positions to mask (useful for only masking some tokens, like only numbers) | |
:param iters_each: number of iterations to run for each token | |
:param api_base: api_base for openai api | |
:return: 2d array of shapley values | |
""" | |
res = np.zeros((len(mask_positions), len(mask_positions))) | |
res.fill(np.nan) | |
for i in range(len(mask_positions) - 1, 0, -1): | |
inputs_pos, output_pos = mask_positions[:i], mask_positions[i] | |
shap = compute_shapley( | |
xs=inputs_pos, | |
value_fn=lambda x: oa_value( | |
engine, | |
toks[: output_pos + 1], | |
[y for y in inputs_pos if y not in x], | |
underscore_ablate, | |
tokenizer, | |
api_base=api_base, | |
), | |
iters_each=iters_each, | |
) | |
# pad shap to the right length | |
res[:, i] = np.pad(shap, (0, len(mask_positions) - len(shap)), "constant") | |
return res | |
def underscore_ablate(tok, tokenizer): | |
# NOTE: for GPT4, spaces are not merged into digits! so this results in out of distribution [" ", "_"] | |
return tokenizer.encode(" _" if tokenizer.decode([tok])[0] == " " else "_")[0] | |
def compute_shapley( | |
xs: list, value_fn: Callable[[list], float], iters_each: int = 1000 | |
) -> List[float]: | |
""" | |
Compute Shapley values for a given list of objects. | |
:param xs: list of objects | |
:param value_fn: function that takes a list of objects and returns a float | |
""" | |
# todo: calculate iters_each with proper bounds | |
import random | |
shap = [[] for _ in range(len(xs))] | |
pbar = tqdm(total=iters_each * len(xs)) | |
for i in range(len(xs)): | |
for j in range(iters_each): | |
pbar.update(1) | |
random.seed(f"{i}_{j}") | |
# masked -> not shown to model | |
keep = [random.random() < 0.5 for _ in range(len(xs))] | |
keep_with = keep.copy() | |
keep_with[i] = True | |
keep_without = keep.copy() | |
keep_without[i] = False | |
S = sum(keep_without) | |
binomial = np.exp( | |
logfactorial(S) | |
+ logfactorial(len(xs) - S - 1) | |
- logfactorial(len(xs)) | |
+ (len(xs) - 1) * np.log(2) | |
) | |
# todo: more principled way to do this | |
reject_prob = 0.9 if binomial < 1 else 0.75 if binomial < 10 else 0 | |
random.seed(f"{i}_{j}_reject") | |
if random.random() < reject_prob: | |
continue | |
weight = 1 / (1 - reject_prob) | |
delta_value = value_fn([x for x, m in zip(xs, keep_with) if m]) - value_fn( | |
[x for x, m in zip(xs, keep_without) if m] | |
) | |
shap[i].append((delta_value * binomial, weight)) | |
# weighted mean of shap | |
return [ | |
sum(delta_value * weight for delta_value, weight in shap[i]) | |
/ sum(weight for _, weight in shap[i]) | |
if len(shap[i]) | |
else 0 | |
for i in range(len(xs)) | |
] | |
oa_value_cache = {} | |
def oa_value( | |
engine: str, | |
tokens: List[int], | |
masked_positions: List[int], | |
mask_fn: Callable[[int, tiktoken.Encoding], int], | |
tokenizer, | |
api_base=None, | |
) -> float: | |
""" | |
Compute the logprob on the final token when the given positions are masked. Caches results for all | |
intermediate tokens, which lets us get shapley values at all intermediate tokens basically for free. | |
:param engine: api engine | |
:param tokens: list of tokens | |
:param masked_positions: list of positions to mask | |
:param mask_fn: function that takes a token and returns a masked token (useful for custom masking) | |
:param tokenizer: tokenizer | |
:param api_base: api_base for openai api | |
""" | |
orig_tokens = tokens | |
tokens = tokens.copy() | |
for pos in masked_positions: | |
tokens[pos] = mask_fn(tokens[pos], tokenizer) | |
cache_key = (engine, api_base) | |
if cache_key + tuple(tokens) in oa_value_cache: | |
return oa_value_cache[cache_key + tuple(tokens)] | |
res = oa_completion( | |
engine=engine, api_base=api_base, prompt=tuple(tokens), max_tokens=0, echo=True, logprobs=10 | |
) | |
logprobs = res["choices"][0]["logprobs"]["token_logprobs"] | |
assert len(logprobs) == len(tokens) | |
assert logprobs[0] is None | |
# todo: implement with something efficient like a trie | |
for i in range(1, len(tokens)): | |
oa_value_cache[cache_key + tuple(tokens[: i + 1])] = logprobs[i] | |
# if the unmasked version is in logprobs in res, then cache that as well | |
# this is necessary for the efficient caching trick because we always | |
# want to record the unmasked logprob for the final token | |
if tokenizer.decode([orig_tokens[i]]) in res["choices"][0]["logprobs"]["top_logprobs"][i]: | |
oa_value_cache[cache_key + tuple(tokens[:i]) + (orig_tokens[i],)] = res["choices"][0][ | |
"logprobs" | |
]["top_logprobs"][i][tokenizer.decode([orig_tokens[i]])] | |
return oa_value_cache[cache_key + tuple(tokens)] | |
@lru_cache(maxsize=None) | |
def logfactorial(x): | |
return scipy.special.gammaln(x + 1) | |
@lru_cache(maxsize=None) | |
def oa_completion(**kwargs): | |
"""Query OpenAI API for completion. | |
Retry with back-off until they respond | |
""" | |
import openai | |
import openai.error | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
default_api_base = "https://api.openai.com/v1" | |
openai.api_base = kwargs.pop("api_base", default_api_base) or default_api_base | |
backoff_time = 3 | |
while True: | |
try: | |
return openai.Completion.create(**kwargs) | |
except openai.error.OpenAIError: | |
import traceback | |
traceback.print_exc() | |
time.sleep(backoff_time) | |
backoff_time *= 1.5 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment