-
-
Save alexirpan/54ac855db7e0d017656645ef1475ac08 to your computer and use it in GitHub Desktop.
Code for binary tree environment
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
r"""Define a tree env like so. | |
1 | |
/ \ | |
2 3 | |
/ \ / \ | |
4 5 6 7 | |
where leaves are terminal and the binary tree is full. | |
Starting from 1 lets us do 2*x, 2*x+1. From 0 is annoying. | |
A tree is represented by a numpy array, where the index is the state ID, and the | |
value in the array is the reward (which is always 0 or 1). | |
Initial state distribution = uniform over any non-terminal state. | |
""" | |
# pylint: disable=redefined-outer-name | |
# pylint: disable=invalid-name | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import collections | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import scipy.stats | |
NUM_ACTIONS = 2 | |
# Controls how stochastic the tree is. | |
# With this probability, the action taken is flipped. | |
# 0.5 for full randomness. | |
# NOTE: in the paper, epsilon = 2 * STOCHASTICITY, because the paper defines | |
# it as probability of random action. Here, it's the probability of flipping | |
# the action. Defining it this way makes the coding easier. | |
STOCHASTICITY = 0.2 | |
def next_state(s, a, tree): | |
next_is_terminal = terminal(2 * s, tree) | |
if np.random.rand() < STOCHASTICITY: | |
# Flip action. | |
a = 1 - a | |
next_ind = 2*s + a | |
return next_ind | |
def terminal(s, tree): | |
next_ind = 2 * s | |
# return iff out of index | |
return next_ind >= len(tree) | |
def eval_policy(q_function, s, tree): | |
while not terminal(s, tree): | |
a = np.argmax(q_function[s]) | |
s = next_state(s, a, tree) | |
return tree[s] | |
def make_tree(levels, num_bad=1): | |
r"""Makes a tree, num_bad leaves have reward 0 and the rest have reward 1.""" | |
# 3 levels = (0) 1 2 3 4 5 6 7 = length 8 array, corresponding to tree | |
# 1 | |
# / \ | |
# 2 3 | |
# / \ / \ | |
# 4 5 6 7 | |
# | |
# Don't pile all bad into one subtree, that's not interesting. Distribute | |
# randomly. | |
tree = np.zeros(2 ** levels) | |
leaves = range(2 ** (levels - 1), 2 ** levels) | |
bad_leaves = set(np.random.choice(leaves, size=num_bad, replace=False)) | |
for i in leaves: | |
if i not in bad_leaves: | |
tree[i] = 1.0 | |
else: | |
tree[i] = 0.0 | |
return tree | |
def get_return(q_function, tree): | |
"""Gets return of the argmax policy of the Q-functions.""" | |
# Computes expected return of argmax policy for Q-function. | |
# Only does 1 rollout, not good for stochastic envs. Use get_exact_return. | |
total = 0 | |
count = 0 | |
for s in range(1, len(tree)): | |
if terminal(s, tree): | |
break | |
total += 1 | |
count += eval_policy(q_function, s, tree) | |
return float(count) / total | |
def get_exact_return(q_function, tree): | |
# Compute true return using DP. | |
# initialized uniformly at random from any non-terminal state | |
returns = [0] * len(tree) # return if initialized from this state. | |
# init DP | |
for s in range(1, len(tree)): | |
if terminal(s, tree): | |
returns[s] = tree[s] | |
# iterate backwards | |
for s in range(len(tree), 0, -1): | |
if terminal(s, tree): | |
continue | |
a = np.argmax(q_function[s]) | |
p_a = 1 - STOCHASTICITY | |
p_nota = STOCHASTICITY | |
returns[s] = p_a * returns[2*s + a] + p_nota * returns[2*s + (1-a)] | |
# average over all non-terminals | |
values = [] | |
for s in range(1, len(tree)): | |
if terminal(s, tree): | |
break | |
values.append(returns[s]) | |
return sum(values) / len(values) | |
def test(q_function, tree): | |
exact = get_exact_return(q_function, tree) | |
returns = [get_return(q_function, tree) for _ in range(1000)] | |
print(exact, sum(returns[:10]) / 10.0, sum(returns[:100]) / 100.0, sum(returns) / 1000.0) | |
def get_dataset(tree, NUM_EPS=1000): | |
"""Generate dataset of NUM_EPS episodes from uniform behavior policy.""" | |
sequences = [] | |
positive_sequences = [] | |
rewards = [] | |
# Loop over initial state exactly to reduce variance in generated epsodes. | |
nonterminal = [s for s in range(1, len(tree)) if not terminal(s, tree)] | |
for i in xrange(NUM_EPS): | |
s = nonterminal[i % len(nonterminal)] | |
seq = [] | |
while not terminal(s, tree): | |
a = np.random.randint(NUM_ACTIONS) # 0 or 1 | |
seq.append((s, a)) | |
s = next_state(s, a, tree) | |
sequences.append(seq) | |
rewards.append(tree[s]) | |
if tree[s] == 1.0: | |
positive_sequences.append(seq) | |
return sequences, positive_sequences, rewards | |
def opc(positives, all_values, prior=1.0): | |
"""Computes OPC score given list of positive Q(s,a) and all Q(s,a).""" | |
# OPC score: | |
# E_all[Q > b] - prior * E_pos[Q > b] | |
# Find the minimum. | |
# | |
# |------b-------| | |
# | |
# score: for points ahead of b, positives give +1/all - prior/pos | |
# for points ahead of b, all gives +1/all | |
# | |
# initial value: 1 - prior. | |
# | |
# This is not normalized to be between 0 and 1. | |
num_pos = len(positives) | |
num_all = len(all_values) | |
# score = contribution if x > b. | |
# 1st term, 2nd term. | |
# positive values will be inside all_values. | |
pos_ = [(x, -prior / num_pos) for x in positives] | |
all_ = [(x, 1.0 / num_all) for x in all_values] | |
# Make sure each x is unique. | |
# Collapse (x,1), (x,2), (x,3) into (x,6) to make sure total score is right. | |
# Important for tree env since there are finitely many (s,a). | |
unique_points = collections.defaultdict(list) | |
for k, v in pos_ + all_: | |
unique_points[k].append(v) | |
# Aggregate + sort by x. | |
reduced_scores = sorted([(k, sum(v)) for k, v in unique_points.items()], | |
key=lambda pair: pair[0]) | |
# curr_score = if b smaller than all points. | |
# subtract contribution of next value. | |
curr_score = 1.0 - prior | |
best = curr_score | |
for qval, score in reduced_scores: # pylint: disable=unused-variable | |
curr_score -= score | |
best = min(curr_score, best) | |
return best | |
def opc_score(dataset, q_function, prior=1.0): | |
"""Computes OPC score of Q-function over dataset.""" | |
all_seq, pos_seq, _ = dataset | |
all_transitions = [] | |
pos_transitions = [] | |
for seq in all_seq: | |
for s, a in seq: | |
all_transitions.append(q_function[s, a]) | |
for seq in pos_seq: | |
for s, a in seq: | |
pos_transitions.append(q_function[s, a]) | |
return opc(pos_transitions, all_transitions, prior) | |
def sample_qfunction(levels): | |
"""Generates random Q-function from U[0,1].""" | |
len_ = 2 ** levels | |
return np.random.uniform(size=(len_, 2)) | |
def soft_opc(dataset, q_function, prior=1.0): | |
"""Computes SoftOPC score.""" | |
all_qmeans = [] | |
pos_qmeans = [] | |
all_seq, pos_seq, _ = dataset | |
for seq in all_seq: | |
q_total = 0 | |
for s, a in seq: | |
q_total += q_function[s, a] | |
all_qmeans.append(q_total / float(len(seq))) | |
for seq in pos_seq: | |
q_total = 0 | |
for s, a in seq: | |
q_total += q_function[s, a] | |
pos_qmeans.append(q_total / float(len(seq))) | |
pos_mean = np.mean(pos_qmeans) | |
all_mean = np.mean(all_qmeans) | |
return prior * pos_mean - all_mean | |
# This implementation is just for undiscounted, 0/1 reward case. | |
# Bellman, disc sum adv, MCC can be implemented in other cases with other code. | |
def bellman(dataset, q_function): | |
"""Average TD Error.""" | |
err = [] | |
all_seq, _, rewards = dataset | |
for seq, rew in zip(all_seq, rewards): | |
# Q(s0,a0) Q(s1,a1) Q(s2,a2) | |
q_values = [q_function[s, a] for s, a in seq] | |
# Intermediate TD error (reward always 0). | |
errors = [(q1 - q2) ** 2 for q1, q2 in zip(q_values[:-1], q_values[1:])] | |
# Final TD error (target = 0 or 1). | |
errors.append((q_values[-1] - rew) ** 2) | |
# Average over the episode. | |
err.append(np.mean(errors)) | |
return np.mean(err) | |
def advantages(seq, q_function): | |
"""Helper function for computing advantage.""" | |
q_values = [q_function[s, a] for s, a in seq] | |
# V(s0) V(s1) V(s2) | |
values = [q_function[s].max() for s, _ in seq] | |
return np.array(q_values) - np.array(values) | |
def disc_sum_adv(dataset, q_function): | |
"""Discounted sum of advantages, discount=1.""" | |
total = [] | |
all_seq, _, _ = dataset | |
for seq in all_seq: | |
advant = advantages(seq, q_function) | |
# To get better estimate, take sum over any start time (equivalent to | |
# init in random start state). | |
for t in range(len(advant)): | |
total.append(advant[t:].sum()) | |
return np.mean(total) | |
def mcc(dataset, q_function): | |
"""Monte Carlo corrected error.""" | |
errs = [] | |
all_seq, _, rewards = dataset | |
for seq, rew in zip(all_seq, rewards): | |
advant = advantages(seq, q_function) | |
q_values = np.array([q_function[s, a] for s, a in seq]) | |
# error between Q_t and [(0/1) - sum advantages from t+1]. | |
for t in range(len(advant)): | |
target = rew - advant[t+1:].sum() | |
q_val = q_values[t] | |
errs.append((target - q_val) ** 2) | |
return np.mean(errs) | |
def correlations(returns, metrics): | |
pearson = scipy.stats.pearsonr(returns, metrics)[0] | |
spearman = scipy.stats.spearmanr(returns, metrics).correlation | |
# Report R^2 not r. | |
return pearson ** 2, spearman | |
def ranks(returns, metrics, NUM_FUNCS): | |
"""Reports fraction of pairs with incorrect rank, unused.""" | |
tot = 0 | |
errors = 0 | |
for i in range(NUM_FUNCS - 1): | |
for j in range(i + 1, NUM_FUNCS): | |
# if equal, don't care | |
if returns[i] == returns[j]: | |
continue | |
tot += 1 | |
if metrics[i] < metrics[j]: | |
errors += ((metrics[i] < metrics[j]) != (returns[i] < returns[j])) | |
return float(errors) / tot | |
if __name__ == '__main__': | |
# Ensure repeatable results. | |
np.random.seed(123) | |
LEVELS = 6 | |
# 1 | |
# / \ | |
# 2 3 | |
# / \ / \ | |
# 4 5 6 7 | |
# /\ /\ /\ /\ | |
# 89101112131415 | |
# | |
## 1 success leaf. | |
NUM_BAD = 2 ** (LEVELS - 1) - 1 | |
## 1 failure leaf. | |
# NUM_BAD = 1 | |
NUM_FUNCS = 1000 | |
tree = make_tree(levels=LEVELS, num_bad=NUM_BAD) | |
## Default plot. | |
q_functions = [sample_qfunction(LEVELS) for _ in xrange(NUM_FUNCS)] | |
## Q-functions of different magnitude. Use with 1 success leaf. | |
# q_functions = [] | |
# for i in range(1, NUM_FUNCS+1): | |
# q_functions.append(i * sample_qfunction(LEVELS)) | |
## Big Q-functions. Use with 1 success leaf. | |
# q_functions = [1000 * sample_qfunction(LEVELS) for _ in xrange(NUM_FUNCS)] | |
for attempt in [0, 0.1, 0.2, 0.3, 0.4]: | |
STOCHASTICITY = attempt | |
dataset = get_dataset(tree, NUM_EPS=1000) | |
returns = [get_exact_return(q_func, tree) for q_func in q_functions] | |
print('Return of 1st Q-func for debugging', returns[0]) | |
# Baselines | |
bellman_score = [bellman(dataset, q_func) for q_func in q_functions] | |
disc_sum_adv_score = [disc_sum_adv(dataset, q_func) for q_func in q_functions] | |
mcc_score = [mcc(dataset, q_func) for q_func in q_functions] | |
print('Flip probability', STOCHASTICITY) | |
print('Bad states', NUM_BAD) | |
print('Bellman', correlations(returns, bellman_score)) | |
print('disc_sum_adv', correlations(returns, disc_sum_adv_score)) | |
print('mcc', correlations(returns, mcc_score)) | |
# New ones | |
#priors = np.arange(1+20) / 20.0 | |
priors = [1.0] | |
soft_opc_rank_scores = [] | |
opc_rank_scores = [] | |
for prior in priors: | |
soft_opcs = [soft_opc(dataset, q_func, prior) for q_func in q_functions] | |
# min opc -> max -opc, easier for ranking | |
opcs = [-opc_score(dataset, q_func, prior) for q_func in q_functions] | |
# Pearson, Spearman | |
soft_opc_r2, soft_opc_corr = correlations(returns, soft_opcs) | |
opc_r2, opc_corr = correlations(returns, opcs) | |
print('Prior %f, SoftOPC correlations' % prior, soft_opc_r2, soft_opc_corr) | |
print('Prior %f, OPC correlations' % prior, opc_r2, opc_corr) | |
soft_opc_rank_scores.append(soft_opc_corr) | |
opc_rank_scores.append(opc_corr) | |
#plt.show() | |
""" | |
plt.plot(priors, soft_opc_rank_scores) | |
plt.plot(priors, opc_rank_scores) | |
plt.rc('text', usetex=True) | |
# If U[0,k], best possible total max is 1000 * 1001 / 2 = about 500 * 1000. | |
# If U[0,1000], best possible total max is 1000 * 1000. | |
# We expect max_{s,a} Q(s,a) > 500 so this threshold is enough to figure out | |
# correct plot title. | |
if sum(q_func.max() for q_func in q_functions) > (1000 * 1001 / 2): | |
plt.title('Binary Tree, %d Levels, %d Success State\nQ(s,a) ~ U[0,1000]' % ( | |
LEVELS, 2 ** (LEVELS - 1) - NUM_BAD)) | |
elif max(q_func.max() for q_func in q_functions) > 1: | |
plt.title('Binary Tree, %d Levels, %d Success State\nQ(s,a) ~ U[0,k]' % ( | |
LEVELS, 2 ** (LEVELS - 1) - NUM_BAD)) | |
elif NUM_BAD == 1: | |
plt.title('Binary Tree, Flip Prob %f, %d Levels, %d Fail State' % (STOCHASTICITY, LEVELS, NUM_BAD)) | |
elif NUM_BAD == 2 ** (LEVELS - 1) - 1: | |
plt.title('Binary Tree, Flip Prob %f, %d Levels, %d Success State' % (STOCHASTICITY, LEVELS, 2 ** (LEVELS - 1) - NUM_BAD)) | |
else: | |
plt.title('Binary Tree, Flip Prob %f, %d Levels, %d Fail States' % (STOCHASTICITY, LEVELS, NUM_BAD)) | |
plt.xlabel('Prior $p(y=1)$') | |
plt.ylabel('Spearman Correlation To Episode Return') | |
plt.legend(['SoftOPC', 'OPC']) | |
bs = correlations(returns, bellman_score)[1] | |
plt.axhline(y=bs, linestyle='--', color='gray') | |
plt.text(x=0.75, y=bs + 0.007, s='TD Error') | |
ds = correlations(returns, disc_sum_adv_score)[1] | |
plt.axhline(y=ds, linestyle='--', color='gray') | |
plt.text(x=0.75, y=ds + 0.007, s='Sum Advantages') | |
ms = correlations(returns, mcc_score)[1] | |
plt.axhline(y=ms, linestyle='--', color='gray') | |
plt.text(x=0.75, y=ms + 0.007, s='MCC Error') | |
plt.show() | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment