This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class PowerSocket: | |
""" the base power socket class """ | |
def __init__(self, q): | |
self.q = q # the true reward value | |
self.initialize() # reset the socket | |
def initialize(self): | |
self.Q = 0 # the estimate of this socket's reward value | |
self.n = 0 # the number of times this socket has been tried |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# create 5 sockets in a fixed order | |
socket_order = [2,1,3,5,4] | |
# create the sockets | |
# - the mean value of each socket is derived from the socket order index, which is doubled to give | |
# distinct values and offset by 2 to keep the distribution above zero | |
sockets = [PowerSocket((q*2)+2) for q in socket_order] | |
# save the number of sockets | |
NUM_SOCKETS = len(socket_order) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create an Optimistic Power Socket class by inheriting from the standard Power Socket | |
class OptimisticPowerSocket( PowerSocket ): | |
def __init__(self, q, initial_estimate ): | |
# pass the true reward value to the base PowerSocket | |
super().__init__(q) | |
# estimate of this socket's reward value | |
# - set to supplied initial value | |
self.Q = initial_estimate |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class EpsilonGreedySocketTester( SocketTester ): | |
def __init__(self, epsilon = 0. ): | |
# create a standard socket tester | |
super().__init__() | |
# save the probability of selecting the non-greedy action | |
self.epsilon = epsilon | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class UCBSocketTester( SocketTester ): | |
def __init__(self, socket_order, confidence_level = 2.0 ): | |
super().__init__(socket_order) # create a standard socket tester | |
self.confidence_level = confidence_level # save the confidence_level | |
def ucb(self, Q, t, n): | |
if n == 0: return float('inf') | |
return Q + self.confidence_level * (np.sqrt(np.log(t) / n)) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class GaussianThompsonSocket( PowerSocket ): | |
def __init__(self, q): | |
self.τ_0 = 0.0001 # the posterior precision | |
self.μ_0 = 1 # the posterior mean | |
# pass the true reward value to the base PowerSocket | |
super().__init__(q) | |
def sample(self): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BernoulliThompsonSocket( PowerSocket ): | |
def __init__( self, q ): | |
self.α = 1 # the number of times this socket returned a charge | |
self.β = 1 # the number of times no charge was returned | |
# pass the true reward value to the base PowerSocket | |
super().__init__(q) | |
def charge(self): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class UCBSocket( PowerSocket ): | |
def __init__( self, q, **kwargs ): | |
""" initialize the UCB socket """ | |
# store the confidence level controlling exploration | |
self.confidence_level = kwargs.pop('confidence_level', 2.0) | |
# pass the true reward value to the base PowerSocket | |
super().__init__(q) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SocketTester(): | |
""" create and test a set of sockets over a single test run """ | |
def __init__(self, socket=PowerSocket, socket_order=socket_order, **kwargs ): | |
# create supplied socket type with a mean value defined by socket order | |
self.sockets = [socket((q*2)+2, **kwargs) for q in socket_order] | |
def charge_and_update(self,socket_index): | |
""" charge from the chosen socket and update its mean reward value """ | |
reward = self.sockets[socket_index].charge() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# return the index of the largest value in the supplied list | |
# - arbitrarily select between the largest values in the case of a tie | |
# (the standard np.argmax just chooses the first value in the case of a tie) | |
def random_argmax(value_list): | |
""" a random tie-breaking argmax """ | |
values = np.asarray(value_list) | |
return np.argmax(np.random.random(values.shape) * (values==values.max())) |
OlderNewer