Last active
February 23, 2020 06:36
-
-
Save erkyrath/3499afbdd650c97fd04fbcc6b95b6e4b to your computer and use it in GitHub Desktop.
Weighted choice with temperature parameter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import math | |
import random | |
# This code is in the public domain. | |
def weighted_choice(ls, temp=1.0, verbose=False): | |
"""The argument to this function should be a list of tuples, representing | |
a weighted list of options. The weights must be non-negative numbers. | |
For example: | |
weighted_choice([ ('A', 0.5), ('B', 2.0), ('C', 2.5) ]) | |
This will return 'A' roughly 10% of the time, 'B' roughly 40% of the | |
time, and 'C' roughly 50% of the time. | |
The temperature parameter adjusts the probabilities. If temp < 1.0, | |
the weights are treated as extra-important. As temp approaches zero, | |
the highest-weighted option becomes inevitable. | |
If temp > 1.0, the weights become less important. As temp becomes high | |
(more than twenty), all options become about equally likely. | |
Set verbose to True to see the probabilities go by. | |
""" | |
count = len(ls) | |
if count == 0: | |
if verbose: | |
print('No options') | |
return None | |
values = [ tup[0] for tup in ls ] | |
origweights = [ float(tup[1]) for tup in ls ] | |
if count == 1: | |
if verbose: | |
print('Only one option') | |
return values[0] | |
if temp < 0.05: | |
# Below 0.05, we risk numeric overflow, and the chance of the highest- | |
# weighted option approaches 100% anyhow. So we switch to a | |
# deterministic choice. | |
if verbose: | |
print('Temperature is close to zero; no randomness') | |
bestval = values[0] | |
bestwgt = origweights[0] | |
for ix in range(1, count): | |
if origweights[ix] > bestwgt: | |
bestwgt = origweights[ix] | |
bestval = values[ix] | |
return bestval | |
# Normalize the weights (so that they add up to 1.0). | |
totalweight = sum(origweights) | |
adjustweights = [ val / totalweight for val in origweights ] | |
# Perform the softmax operation. I throw in an extra factor of "count" | |
# in order to keep the behavior sensible around temp 1.0. | |
expweights = [ math.exp(val * count / temp) for val in adjustweights ] | |
# Normalize the weights again. Yes, we normalize twice. | |
totalweight = sum(expweights) | |
normweights = [ val / totalweight for val in expweights ] | |
if verbose: | |
vals = [ '%.4f' % val for val in normweights ] | |
print('Adjusted weights:', ', '.join(vals)) | |
# Select according to the new weights. | |
val = random.uniform(0, 1) | |
for ix in range(0, count): | |
if val < normweights[ix]: | |
return values[ix] | |
val -= normweights[ix] | |
return values[-1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment