Skip to content

Instantly share code, notes, and snippets.

@alexirpan
Last active April 1, 2021 07:49
Show Gist options
  • Save alexirpan/54ac855db7e0d017656645ef1475ac08 to your computer and use it in GitHub Desktop.
Save alexirpan/54ac855db7e0d017656645ef1475ac08 to your computer and use it in GitHub Desktop.
Code for binary tree environment
r"""Define a tree env like so.
1
/ \
2 3
/ \ / \
4 5 6 7
where leaves are terminal and the binary tree is full.
Starting from 1 lets us do 2*x, 2*x+1. From 0 is annoying.
A tree is represented by a numpy array, where the index is the state ID, and the
value in the array is the reward (which is always 0 or 1).
Initial state distribution = uniform over any non-terminal state.
"""
# pylint: disable=redefined-outer-name
# pylint: disable=invalid-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
NUM_ACTIONS = 2
# Controls how stochastic the tree is.
# With this probability, the action taken is flipped.
# 0.5 for full randomness.
# NOTE: in the paper, epsilon = 2 * STOCHASTICITY, because the paper defines
# it as probability of random action. Here, it's the probability of flipping
# the action. Defining it this way makes the coding easier.
STOCHASTICITY = 0.2
def next_state(s, a, tree):
next_is_terminal = terminal(2 * s, tree)
if np.random.rand() < STOCHASTICITY:
# Flip action.
a = 1 - a
next_ind = 2*s + a
return next_ind
def terminal(s, tree):
next_ind = 2 * s
# return iff out of index
return next_ind >= len(tree)
def eval_policy(q_function, s, tree):
while not terminal(s, tree):
a = np.argmax(q_function[s])
s = next_state(s, a, tree)
return tree[s]
def make_tree(levels, num_bad=1):
r"""Makes a tree, num_bad leaves have reward 0 and the rest have reward 1."""
# 3 levels = (0) 1 2 3 4 5 6 7 = length 8 array, corresponding to tree
# 1
# / \
# 2 3
# / \ / \
# 4 5 6 7
#
# Don't pile all bad into one subtree, that's not interesting. Distribute
# randomly.
tree = np.zeros(2 ** levels)
leaves = range(2 ** (levels - 1), 2 ** levels)
bad_leaves = set(np.random.choice(leaves, size=num_bad, replace=False))
for i in leaves:
if i not in bad_leaves:
tree[i] = 1.0
else:
tree[i] = 0.0
return tree
def get_return(q_function, tree):
"""Gets return of the argmax policy of the Q-functions."""
# Computes expected return of argmax policy for Q-function.
# Only does 1 rollout, not good for stochastic envs. Use get_exact_return.
total = 0
count = 0
for s in range(1, len(tree)):
if terminal(s, tree):
break
total += 1
count += eval_policy(q_function, s, tree)
return float(count) / total
def get_exact_return(q_function, tree):
# Compute true return using DP.
# initialized uniformly at random from any non-terminal state
returns = [0] * len(tree) # return if initialized from this state.
# init DP
for s in range(1, len(tree)):
if terminal(s, tree):
returns[s] = tree[s]
# iterate backwards
for s in range(len(tree), 0, -1):
if terminal(s, tree):
continue
a = np.argmax(q_function[s])
p_a = 1 - STOCHASTICITY
p_nota = STOCHASTICITY
returns[s] = p_a * returns[2*s + a] + p_nota * returns[2*s + (1-a)]
# average over all non-terminals
values = []
for s in range(1, len(tree)):
if terminal(s, tree):
break
values.append(returns[s])
return sum(values) / len(values)
def test(q_function, tree):
exact = get_exact_return(q_function, tree)
returns = [get_return(q_function, tree) for _ in range(1000)]
print(exact, sum(returns[:10]) / 10.0, sum(returns[:100]) / 100.0, sum(returns) / 1000.0)
def get_dataset(tree, NUM_EPS=1000):
"""Generate dataset of NUM_EPS episodes from uniform behavior policy."""
sequences = []
positive_sequences = []
rewards = []
# Loop over initial state exactly to reduce variance in generated epsodes.
nonterminal = [s for s in range(1, len(tree)) if not terminal(s, tree)]
for i in xrange(NUM_EPS):
s = nonterminal[i % len(nonterminal)]
seq = []
while not terminal(s, tree):
a = np.random.randint(NUM_ACTIONS) # 0 or 1
seq.append((s, a))
s = next_state(s, a, tree)
sequences.append(seq)
rewards.append(tree[s])
if tree[s] == 1.0:
positive_sequences.append(seq)
return sequences, positive_sequences, rewards
def opc(positives, all_values, prior=1.0):
"""Computes OPC score given list of positive Q(s,a) and all Q(s,a)."""
# OPC score:
# E_all[Q > b] - prior * E_pos[Q > b]
# Find the minimum.
#
# |------b-------|
#
# score: for points ahead of b, positives give +1/all - prior/pos
# for points ahead of b, all gives +1/all
#
# initial value: 1 - prior.
#
# This is not normalized to be between 0 and 1.
num_pos = len(positives)
num_all = len(all_values)
# score = contribution if x > b.
# 1st term, 2nd term.
# positive values will be inside all_values.
pos_ = [(x, -prior / num_pos) for x in positives]
all_ = [(x, 1.0 / num_all) for x in all_values]
# Make sure each x is unique.
# Collapse (x,1), (x,2), (x,3) into (x,6) to make sure total score is right.
# Important for tree env since there are finitely many (s,a).
unique_points = collections.defaultdict(list)
for k, v in pos_ + all_:
unique_points[k].append(v)
# Aggregate + sort by x.
reduced_scores = sorted([(k, sum(v)) for k, v in unique_points.items()],
key=lambda pair: pair[0])
# curr_score = if b smaller than all points.
# subtract contribution of next value.
curr_score = 1.0 - prior
best = curr_score
for qval, score in reduced_scores: # pylint: disable=unused-variable
curr_score -= score
best = min(curr_score, best)
return best
def opc_score(dataset, q_function, prior=1.0):
"""Computes OPC score of Q-function over dataset."""
all_seq, pos_seq, _ = dataset
all_transitions = []
pos_transitions = []
for seq in all_seq:
for s, a in seq:
all_transitions.append(q_function[s, a])
for seq in pos_seq:
for s, a in seq:
pos_transitions.append(q_function[s, a])
return opc(pos_transitions, all_transitions, prior)
def sample_qfunction(levels):
"""Generates random Q-function from U[0,1]."""
len_ = 2 ** levels
return np.random.uniform(size=(len_, 2))
def soft_opc(dataset, q_function, prior=1.0):
"""Computes SoftOPC score."""
all_qmeans = []
pos_qmeans = []
all_seq, pos_seq, _ = dataset
for seq in all_seq:
q_total = 0
for s, a in seq:
q_total += q_function[s, a]
all_qmeans.append(q_total / float(len(seq)))
for seq in pos_seq:
q_total = 0
for s, a in seq:
q_total += q_function[s, a]
pos_qmeans.append(q_total / float(len(seq)))
pos_mean = np.mean(pos_qmeans)
all_mean = np.mean(all_qmeans)
return prior * pos_mean - all_mean
# This implementation is just for undiscounted, 0/1 reward case.
# Bellman, disc sum adv, MCC can be implemented in other cases with other code.
def bellman(dataset, q_function):
"""Average TD Error."""
err = []
all_seq, _, rewards = dataset
for seq, rew in zip(all_seq, rewards):
# Q(s0,a0) Q(s1,a1) Q(s2,a2)
q_values = [q_function[s, a] for s, a in seq]
# Intermediate TD error (reward always 0).
errors = [(q1 - q2) ** 2 for q1, q2 in zip(q_values[:-1], q_values[1:])]
# Final TD error (target = 0 or 1).
errors.append((q_values[-1] - rew) ** 2)
# Average over the episode.
err.append(np.mean(errors))
return np.mean(err)
def advantages(seq, q_function):
"""Helper function for computing advantage."""
q_values = [q_function[s, a] for s, a in seq]
# V(s0) V(s1) V(s2)
values = [q_function[s].max() for s, _ in seq]
return np.array(q_values) - np.array(values)
def disc_sum_adv(dataset, q_function):
"""Discounted sum of advantages, discount=1."""
total = []
all_seq, _, _ = dataset
for seq in all_seq:
advant = advantages(seq, q_function)
# To get better estimate, take sum over any start time (equivalent to
# init in random start state).
for t in range(len(advant)):
total.append(advant[t:].sum())
return np.mean(total)
def mcc(dataset, q_function):
"""Monte Carlo corrected error."""
errs = []
all_seq, _, rewards = dataset
for seq, rew in zip(all_seq, rewards):
advant = advantages(seq, q_function)
q_values = np.array([q_function[s, a] for s, a in seq])
# error between Q_t and [(0/1) - sum advantages from t+1].
for t in range(len(advant)):
target = rew - advant[t+1:].sum()
q_val = q_values[t]
errs.append((target - q_val) ** 2)
return np.mean(errs)
def correlations(returns, metrics):
pearson = scipy.stats.pearsonr(returns, metrics)[0]
spearman = scipy.stats.spearmanr(returns, metrics).correlation
# Report R^2 not r.
return pearson ** 2, spearman
def ranks(returns, metrics, NUM_FUNCS):
"""Reports fraction of pairs with incorrect rank, unused."""
tot = 0
errors = 0
for i in range(NUM_FUNCS - 1):
for j in range(i + 1, NUM_FUNCS):
# if equal, don't care
if returns[i] == returns[j]:
continue
tot += 1
if metrics[i] < metrics[j]:
errors += ((metrics[i] < metrics[j]) != (returns[i] < returns[j]))
return float(errors) / tot
if __name__ == '__main__':
# Ensure repeatable results.
np.random.seed(123)
LEVELS = 6
# 1
# / \
# 2 3
# / \ / \
# 4 5 6 7
# /\ /\ /\ /\
# 89101112131415
#
## 1 success leaf.
NUM_BAD = 2 ** (LEVELS - 1) - 1
## 1 failure leaf.
# NUM_BAD = 1
NUM_FUNCS = 1000
tree = make_tree(levels=LEVELS, num_bad=NUM_BAD)
## Default plot.
q_functions = [sample_qfunction(LEVELS) for _ in xrange(NUM_FUNCS)]
## Q-functions of different magnitude. Use with 1 success leaf.
# q_functions = []
# for i in range(1, NUM_FUNCS+1):
# q_functions.append(i * sample_qfunction(LEVELS))
## Big Q-functions. Use with 1 success leaf.
# q_functions = [1000 * sample_qfunction(LEVELS) for _ in xrange(NUM_FUNCS)]
for attempt in [0, 0.1, 0.2, 0.3, 0.4]:
STOCHASTICITY = attempt
dataset = get_dataset(tree, NUM_EPS=1000)
returns = [get_exact_return(q_func, tree) for q_func in q_functions]
print('Return of 1st Q-func for debugging', returns[0])
# Baselines
bellman_score = [bellman(dataset, q_func) for q_func in q_functions]
disc_sum_adv_score = [disc_sum_adv(dataset, q_func) for q_func in q_functions]
mcc_score = [mcc(dataset, q_func) for q_func in q_functions]
print('Flip probability', STOCHASTICITY)
print('Bad states', NUM_BAD)
print('Bellman', correlations(returns, bellman_score))
print('disc_sum_adv', correlations(returns, disc_sum_adv_score))
print('mcc', correlations(returns, mcc_score))
# New ones
#priors = np.arange(1+20) / 20.0
priors = [1.0]
soft_opc_rank_scores = []
opc_rank_scores = []
for prior in priors:
soft_opcs = [soft_opc(dataset, q_func, prior) for q_func in q_functions]
# min opc -> max -opc, easier for ranking
opcs = [-opc_score(dataset, q_func, prior) for q_func in q_functions]
# Pearson, Spearman
soft_opc_r2, soft_opc_corr = correlations(returns, soft_opcs)
opc_r2, opc_corr = correlations(returns, opcs)
print('Prior %f, SoftOPC correlations' % prior, soft_opc_r2, soft_opc_corr)
print('Prior %f, OPC correlations' % prior, opc_r2, opc_corr)
soft_opc_rank_scores.append(soft_opc_corr)
opc_rank_scores.append(opc_corr)
#plt.show()
"""
plt.plot(priors, soft_opc_rank_scores)
plt.plot(priors, opc_rank_scores)
plt.rc('text', usetex=True)
# If U[0,k], best possible total max is 1000 * 1001 / 2 = about 500 * 1000.
# If U[0,1000], best possible total max is 1000 * 1000.
# We expect max_{s,a} Q(s,a) > 500 so this threshold is enough to figure out
# correct plot title.
if sum(q_func.max() for q_func in q_functions) > (1000 * 1001 / 2):
plt.title('Binary Tree, %d Levels, %d Success State\nQ(s,a) ~ U[0,1000]' % (
LEVELS, 2 ** (LEVELS - 1) - NUM_BAD))
elif max(q_func.max() for q_func in q_functions) > 1:
plt.title('Binary Tree, %d Levels, %d Success State\nQ(s,a) ~ U[0,k]' % (
LEVELS, 2 ** (LEVELS - 1) - NUM_BAD))
elif NUM_BAD == 1:
plt.title('Binary Tree, Flip Prob %f, %d Levels, %d Fail State' % (STOCHASTICITY, LEVELS, NUM_BAD))
elif NUM_BAD == 2 ** (LEVELS - 1) - 1:
plt.title('Binary Tree, Flip Prob %f, %d Levels, %d Success State' % (STOCHASTICITY, LEVELS, 2 ** (LEVELS - 1) - NUM_BAD))
else:
plt.title('Binary Tree, Flip Prob %f, %d Levels, %d Fail States' % (STOCHASTICITY, LEVELS, NUM_BAD))
plt.xlabel('Prior $p(y=1)$')
plt.ylabel('Spearman Correlation To Episode Return')
plt.legend(['SoftOPC', 'OPC'])
bs = correlations(returns, bellman_score)[1]
plt.axhline(y=bs, linestyle='--', color='gray')
plt.text(x=0.75, y=bs + 0.007, s='TD Error')
ds = correlations(returns, disc_sum_adv_score)[1]
plt.axhline(y=ds, linestyle='--', color='gray')
plt.text(x=0.75, y=ds + 0.007, s='Sum Advantages')
ms = correlations(returns, mcc_score)[1]
plt.axhline(y=ms, linestyle='--', color='gray')
plt.text(x=0.75, y=ms + 0.007, s='MCC Error')
plt.show()
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment