Skip to content

Instantly share code, notes, and snippets.

@dhpollack
Forked from awjuliani/ContextualPolicy.ipynb
Created February 9, 2017 14:06
Show Gist options
  • Save dhpollack/0bf9ba76f99261b534486d0777fb2ec5 to your computer and use it in GitHub Desktop.
Save dhpollack/0bf9ba76f99261b534486d0777fb2ec5 to your computer and use it in GitHub Desktop.
A Policy-Gradient algorithm that solves Contextual Bandit problems.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@dhpollack
Copy link
Author

Here's a python3 version of the tutorial.

import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np

class contextual_bandit(object):
    def __init__(self):
        self.state=0
        self.bandits=np.array([[0.2,0.,-0.0,-5.],[0.1,-5.,1.,0.25],[-5.,5.,5.,5.]])
        self.num_bandits=self.bandits.shape[0]
        self.num_actions=self.bandits.shape[1]
        
    def getBandit(self):
        self.state = np.random.randint(len(self.bandits))
        return(self.state)
    
    def pullArm(self, action):
        bandit = self.bandits[self.state,action]
        result = np.random.randn()
        if result > bandit:
            return 1.
        else:
            return -1.

class agent(object):
    def __init__(self, lr, s_size, a_size):
        #These lines established the feed-forward part of the network. The agent takes a state and produces an action.
        self.state_in=tf.placeholder(shape=[1], dtype=tf.int32)
        state_in_OH=slim.one_hot_encoding(self.state_in, s_size)
        output=slim.fully_connected(state_in_OH, a_size, 
                                    biases_initializer=None, 
                                    activation_fn=tf.nn.sigmoid, 
                                    weights_initializer=tf.ones_initializer())
        self.output=tf.reshape(output, [-1])
        self.chosen_action=tf.argmax(self.output, 0)
        
        #The next six lines establish the training proceedure. We feed the reward and chosen action into the network
        #to compute the loss, and use it to update the network.
    
        self.reward_holder = tf.placeholder(shape=[1],dtype=tf.float32)
        self.action_holder = tf.placeholder(shape=[1],dtype=tf.int32)
        self.responsible_weight = tf.slice(self.output,self.action_holder,[1])
        self.loss = -(tf.log(self.responsible_weight)*self.reward_holder)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=lr)
        self.update = optimizer.minimize(self.loss)
        

tf.reset_default_graph()

cBandit = contextual_bandit()
myAgent = agent(lr=0.001, s_size=cBandit.num_bandits, a_size=cBandit.num_actions)
weights = tf.trainable_variables()[0]

total_rounds = 10000
#total_reward = tf.Variable(np.zeros((cBandit.num_bandits, cBandit.num_actions)))
#update_reward = tf.scatter_add(total_reward,[action_holder],[reward_holder])
total_reward = np.zeros((cBandit.num_bandits, cBandit.num_actions))
e = 0.1

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    i = 0
    while i < total_rounds:
        s = cBandit.getBandit() # get state
        if np.random.rand() < e:
            action = np.random.randint(cBandit.num_bandits)
        else:
            action = sess.run(myAgent.chosen_action, feed_dict={myAgent.state_in:[s]})
        
        reward = cBandit.pullArm(action)
        
        #Update the network.
        fd = {myAgent.reward_holder:[reward], myAgent.action_holder:[action], myAgent.state_in:[s]}
        _, ww = sess.run([myAgent.update, weights], feed_dict = fd)
        
        #Update our running tally of scores.
        #sess.run(update_reward, feed_dict = {reward_holder: [reward], action_holder: [action]}) # need to feed variables into scoreboard update
        total_reward[s,action] += reward
        
        if i%(total_rounds//10) == 0:
            #print("Running reward for the " + str(cBandit.num_bandits) + " bandits: " + str(sess.run(total_reward))) # using sess.run to print variable
            print("Running reward for the " + str(cBandit.num_bandits) + " bandits: " + str(total_reward)) # why a mean?
        i += 1

for a in range(cBandit.num_bandits):
    print("The agent thinks action " + str(np.argmax(ww[a])+1) + " for bandit " + str(a+1) + " is the most promising....")
    if np.argmax(ww[a]) == np.argmin(cBandit.bandits[a]):
        print("...and it was right!")
    else:
        print("...and it was wrong!")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment