Skip to content
{{ message }}

Instantly share code, notes, and snippets.

# alexirpan/tree.py Secret

Last active Apr 1, 2021
Code for binary tree environment
 r"""Define a tree env like so. 1 / \ 2 3 / \ / \ 4 5 6 7 where leaves are terminal and the binary tree is full. Starting from 1 lets us do 2*x, 2*x+1. From 0 is annoying. A tree is represented by a numpy array, where the index is the state ID, and the value in the array is the reward (which is always 0 or 1). Initial state distribution = uniform over any non-terminal state. """ # pylint: disable=redefined-outer-name # pylint: disable=invalid-name from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import matplotlib.pyplot as plt import numpy as np import scipy.stats NUM_ACTIONS = 2 # Controls how stochastic the tree is. # With this probability, the action taken is flipped. # 0.5 for full randomness. # NOTE: in the paper, epsilon = 2 * STOCHASTICITY, because the paper defines # it as probability of random action. Here, it's the probability of flipping # the action. Defining it this way makes the coding easier. STOCHASTICITY = 0.2 def next_state(s, a, tree): next_is_terminal = terminal(2 * s, tree) if np.random.rand() < STOCHASTICITY: # Flip action. a = 1 - a next_ind = 2*s + a return next_ind def terminal(s, tree): next_ind = 2 * s # return iff out of index return next_ind >= len(tree) def eval_policy(q_function, s, tree): while not terminal(s, tree): a = np.argmax(q_function[s]) s = next_state(s, a, tree) return tree[s] def make_tree(levels, num_bad=1): r"""Makes a tree, num_bad leaves have reward 0 and the rest have reward 1.""" # 3 levels = (0) 1 2 3 4 5 6 7 = length 8 array, corresponding to tree # 1 # / \ # 2 3 # / \ / \ # 4 5 6 7 # # Don't pile all bad into one subtree, that's not interesting. Distribute # randomly. tree = np.zeros(2 ** levels) leaves = range(2 ** (levels - 1), 2 ** levels) bad_leaves = set(np.random.choice(leaves, size=num_bad, replace=False)) for i in leaves: if i not in bad_leaves: tree[i] = 1.0 else: tree[i] = 0.0 return tree def get_return(q_function, tree): """Gets return of the argmax policy of the Q-functions.""" # Computes expected return of argmax policy for Q-function. # Only does 1 rollout, not good for stochastic envs. Use get_exact_return. total = 0 count = 0 for s in range(1, len(tree)): if terminal(s, tree): break total += 1 count += eval_policy(q_function, s, tree) return float(count) / total def get_exact_return(q_function, tree): # Compute true return using DP. # initialized uniformly at random from any non-terminal state returns =  * len(tree) # return if initialized from this state. # init DP for s in range(1, len(tree)): if terminal(s, tree): returns[s] = tree[s] # iterate backwards for s in range(len(tree), 0, -1): if terminal(s, tree): continue a = np.argmax(q_function[s]) p_a = 1 - STOCHASTICITY p_nota = STOCHASTICITY returns[s] = p_a * returns[2*s + a] + p_nota * returns[2*s + (1-a)] # average over all non-terminals values = [] for s in range(1, len(tree)): if terminal(s, tree): break values.append(returns[s]) return sum(values) / len(values) def test(q_function, tree): exact = get_exact_return(q_function, tree) returns = [get_return(q_function, tree) for _ in range(1000)] print(exact, sum(returns[:10]) / 10.0, sum(returns[:100]) / 100.0, sum(returns) / 1000.0) def get_dataset(tree, NUM_EPS=1000): """Generate dataset of NUM_EPS episodes from uniform behavior policy.""" sequences = [] positive_sequences = [] rewards = [] # Loop over initial state exactly to reduce variance in generated epsodes. nonterminal = [s for s in range(1, len(tree)) if not terminal(s, tree)] for i in xrange(NUM_EPS): s = nonterminal[i % len(nonterminal)] seq = [] while not terminal(s, tree): a = np.random.randint(NUM_ACTIONS) # 0 or 1 seq.append((s, a)) s = next_state(s, a, tree) sequences.append(seq) rewards.append(tree[s]) if tree[s] == 1.0: positive_sequences.append(seq) return sequences, positive_sequences, rewards def opc(positives, all_values, prior=1.0): """Computes OPC score given list of positive Q(s,a) and all Q(s,a).""" # OPC score: # E_all[Q > b] - prior * E_pos[Q > b] # Find the minimum. # # |------b-------| # # score: for points ahead of b, positives give +1/all - prior/pos # for points ahead of b, all gives +1/all # # initial value: 1 - prior. # # This is not normalized to be between 0 and 1. num_pos = len(positives) num_all = len(all_values) # score = contribution if x > b. # 1st term, 2nd term. # positive values will be inside all_values. pos_ = [(x, -prior / num_pos) for x in positives] all_ = [(x, 1.0 / num_all) for x in all_values] # Make sure each x is unique. # Collapse (x,1), (x,2), (x,3) into (x,6) to make sure total score is right. # Important for tree env since there are finitely many (s,a). unique_points = collections.defaultdict(list) for k, v in pos_ + all_: unique_points[k].append(v) # Aggregate + sort by x. reduced_scores = sorted([(k, sum(v)) for k, v in unique_points.items()], key=lambda pair: pair) # curr_score = if b smaller than all points. # subtract contribution of next value. curr_score = 1.0 - prior best = curr_score for qval, score in reduced_scores: # pylint: disable=unused-variable curr_score -= score best = min(curr_score, best) return best def opc_score(dataset, q_function, prior=1.0): """Computes OPC score of Q-function over dataset.""" all_seq, pos_seq, _ = dataset all_transitions = [] pos_transitions = [] for seq in all_seq: for s, a in seq: all_transitions.append(q_function[s, a]) for seq in pos_seq: for s, a in seq: pos_transitions.append(q_function[s, a]) return opc(pos_transitions, all_transitions, prior) def sample_qfunction(levels): """Generates random Q-function from U[0,1].""" len_ = 2 ** levels return np.random.uniform(size=(len_, 2)) def soft_opc(dataset, q_function, prior=1.0): """Computes SoftOPC score.""" all_qmeans = [] pos_qmeans = [] all_seq, pos_seq, _ = dataset for seq in all_seq: q_total = 0 for s, a in seq: q_total += q_function[s, a] all_qmeans.append(q_total / float(len(seq))) for seq in pos_seq: q_total = 0 for s, a in seq: q_total += q_function[s, a] pos_qmeans.append(q_total / float(len(seq))) pos_mean = np.mean(pos_qmeans) all_mean = np.mean(all_qmeans) return prior * pos_mean - all_mean # This implementation is just for undiscounted, 0/1 reward case. # Bellman, disc sum adv, MCC can be implemented in other cases with other code. def bellman(dataset, q_function): """Average TD Error.""" err = [] all_seq, _, rewards = dataset for seq, rew in zip(all_seq, rewards): # Q(s0,a0) Q(s1,a1) Q(s2,a2) q_values = [q_function[s, a] for s, a in seq] # Intermediate TD error (reward always 0). errors = [(q1 - q2) ** 2 for q1, q2 in zip(q_values[:-1], q_values[1:])] # Final TD error (target = 0 or 1). errors.append((q_values[-1] - rew) ** 2) # Average over the episode. err.append(np.mean(errors)) return np.mean(err) def advantages(seq, q_function): """Helper function for computing advantage.""" q_values = [q_function[s, a] for s, a in seq] # V(s0) V(s1) V(s2) values = [q_function[s].max() for s, _ in seq] return np.array(q_values) - np.array(values) def disc_sum_adv(dataset, q_function): """Discounted sum of advantages, discount=1.""" total = [] all_seq, _, _ = dataset for seq in all_seq: advant = advantages(seq, q_function) # To get better estimate, take sum over any start time (equivalent to # init in random start state). for t in range(len(advant)): total.append(advant[t:].sum()) return np.mean(total) def mcc(dataset, q_function): """Monte Carlo corrected error.""" errs = [] all_seq, _, rewards = dataset for seq, rew in zip(all_seq, rewards): advant = advantages(seq, q_function) q_values = np.array([q_function[s, a] for s, a in seq]) # error between Q_t and [(0/1) - sum advantages from t+1]. for t in range(len(advant)): target = rew - advant[t+1:].sum() q_val = q_values[t] errs.append((target - q_val) ** 2) return np.mean(errs) def correlations(returns, metrics): pearson = scipy.stats.pearsonr(returns, metrics) spearman = scipy.stats.spearmanr(returns, metrics).correlation # Report R^2 not r. return pearson ** 2, spearman def ranks(returns, metrics, NUM_FUNCS): """Reports fraction of pairs with incorrect rank, unused.""" tot = 0 errors = 0 for i in range(NUM_FUNCS - 1): for j in range(i + 1, NUM_FUNCS): # if equal, don't care if returns[i] == returns[j]: continue tot += 1 if metrics[i] < metrics[j]: errors += ((metrics[i] < metrics[j]) != (returns[i] < returns[j])) return float(errors) / tot if __name__ == '__main__': # Ensure repeatable results. np.random.seed(123) LEVELS = 6 # 1 # / \ # 2 3 # / \ / \ # 4 5 6 7 # /\ /\ /\ /\ # 89101112131415 # ## 1 success leaf. NUM_BAD = 2 ** (LEVELS - 1) - 1 ## 1 failure leaf. # NUM_BAD = 1 NUM_FUNCS = 1000 tree = make_tree(levels=LEVELS, num_bad=NUM_BAD) ## Default plot. q_functions = [sample_qfunction(LEVELS) for _ in xrange(NUM_FUNCS)] ## Q-functions of different magnitude. Use with 1 success leaf. # q_functions = [] # for i in range(1, NUM_FUNCS+1): # q_functions.append(i * sample_qfunction(LEVELS)) ## Big Q-functions. Use with 1 success leaf. # q_functions = [1000 * sample_qfunction(LEVELS) for _ in xrange(NUM_FUNCS)] for attempt in [0, 0.1, 0.2, 0.3, 0.4]: STOCHASTICITY = attempt dataset = get_dataset(tree, NUM_EPS=1000) returns = [get_exact_return(q_func, tree) for q_func in q_functions] print('Return of 1st Q-func for debugging', returns) # Baselines bellman_score = [bellman(dataset, q_func) for q_func in q_functions] disc_sum_adv_score = [disc_sum_adv(dataset, q_func) for q_func in q_functions] mcc_score = [mcc(dataset, q_func) for q_func in q_functions] print('Flip probability', STOCHASTICITY) print('Bad states', NUM_BAD) print('Bellman', correlations(returns, bellman_score)) print('disc_sum_adv', correlations(returns, disc_sum_adv_score)) print('mcc', correlations(returns, mcc_score)) # New ones #priors = np.arange(1+20) / 20.0 priors = [1.0] soft_opc_rank_scores = [] opc_rank_scores = [] for prior in priors: soft_opcs = [soft_opc(dataset, q_func, prior) for q_func in q_functions] # min opc -> max -opc, easier for ranking opcs = [-opc_score(dataset, q_func, prior) for q_func in q_functions] # Pearson, Spearman soft_opc_r2, soft_opc_corr = correlations(returns, soft_opcs) opc_r2, opc_corr = correlations(returns, opcs) print('Prior %f, SoftOPC correlations' % prior, soft_opc_r2, soft_opc_corr) print('Prior %f, OPC correlations' % prior, opc_r2, opc_corr) soft_opc_rank_scores.append(soft_opc_corr) opc_rank_scores.append(opc_corr) #plt.show() """ plt.plot(priors, soft_opc_rank_scores) plt.plot(priors, opc_rank_scores) plt.rc('text', usetex=True) # If U[0,k], best possible total max is 1000 * 1001 / 2 = about 500 * 1000. # If U[0,1000], best possible total max is 1000 * 1000. # We expect max_{s,a} Q(s,a) > 500 so this threshold is enough to figure out # correct plot title. if sum(q_func.max() for q_func in q_functions) > (1000 * 1001 / 2): plt.title('Binary Tree, %d Levels, %d Success State\nQ(s,a) ~ U[0,1000]' % ( LEVELS, 2 ** (LEVELS - 1) - NUM_BAD)) elif max(q_func.max() for q_func in q_functions) > 1: plt.title('Binary Tree, %d Levels, %d Success State\nQ(s,a) ~ U[0,k]' % ( LEVELS, 2 ** (LEVELS - 1) - NUM_BAD)) elif NUM_BAD == 1: plt.title('Binary Tree, Flip Prob %f, %d Levels, %d Fail State' % (STOCHASTICITY, LEVELS, NUM_BAD)) elif NUM_BAD == 2 ** (LEVELS - 1) - 1: plt.title('Binary Tree, Flip Prob %f, %d Levels, %d Success State' % (STOCHASTICITY, LEVELS, 2 ** (LEVELS - 1) - NUM_BAD)) else: plt.title('Binary Tree, Flip Prob %f, %d Levels, %d Fail States' % (STOCHASTICITY, LEVELS, NUM_BAD)) plt.xlabel('Prior \$p(y=1)\$') plt.ylabel('Spearman Correlation To Episode Return') plt.legend(['SoftOPC', 'OPC']) bs = correlations(returns, bellman_score) plt.axhline(y=bs, linestyle='--', color='gray') plt.text(x=0.75, y=bs + 0.007, s='TD Error') ds = correlations(returns, disc_sum_adv_score) plt.axhline(y=ds, linestyle='--', color='gray') plt.text(x=0.75, y=ds + 0.007, s='Sum Advantages') ms = correlations(returns, mcc_score) plt.axhline(y=ms, linestyle='--', color='gray') plt.text(x=0.75, y=ms + 0.007, s='MCC Error') plt.show() """
to join this conversation on GitHub. Already have an account? Sign in to comment