Skip to content

Instantly share code, notes, and snippets.

@erkyrath
Last active February 23, 2020 06:36
Show Gist options
  • Save erkyrath/3499afbdd650c97fd04fbcc6b95b6e4b to your computer and use it in GitHub Desktop.
Save erkyrath/3499afbdd650c97fd04fbcc6b95b6e4b to your computer and use it in GitHub Desktop.
Weighted choice with temperature parameter
#!/usr/bin/env python3
import math
import random
# This code is in the public domain.
def weighted_choice(ls, temp=1.0, verbose=False):
"""The argument to this function should be a list of tuples, representing
a weighted list of options. The weights must be non-negative numbers.
For example:
weighted_choice([ ('A', 0.5), ('B', 2.0), ('C', 2.5) ])
This will return 'A' roughly 10% of the time, 'B' roughly 40% of the
time, and 'C' roughly 50% of the time.
The temperature parameter adjusts the probabilities. If temp < 1.0,
the weights are treated as extra-important. As temp approaches zero,
the highest-weighted option becomes inevitable.
If temp > 1.0, the weights become less important. As temp becomes high
(more than twenty), all options become about equally likely.
Set verbose to True to see the probabilities go by.
"""
count = len(ls)
if count == 0:
if verbose:
print('No options')
return None
values = [ tup[0] for tup in ls ]
origweights = [ float(tup[1]) for tup in ls ]
if count == 1:
if verbose:
print('Only one option')
return values[0]
if temp < 0.05:
# Below 0.05, we risk numeric overflow, and the chance of the highest-
# weighted option approaches 100% anyhow. So we switch to a
# deterministic choice.
if verbose:
print('Temperature is close to zero; no randomness')
bestval = values[0]
bestwgt = origweights[0]
for ix in range(1, count):
if origweights[ix] > bestwgt:
bestwgt = origweights[ix]
bestval = values[ix]
return bestval
# Normalize the weights (so that they add up to 1.0).
totalweight = sum(origweights)
adjustweights = [ val / totalweight for val in origweights ]
# Perform the softmax operation. I throw in an extra factor of "count"
# in order to keep the behavior sensible around temp 1.0.
expweights = [ math.exp(val * count / temp) for val in adjustweights ]
# Normalize the weights again. Yes, we normalize twice.
totalweight = sum(expweights)
normweights = [ val / totalweight for val in expweights ]
if verbose:
vals = [ '%.4f' % val for val in normweights ]
print('Adjusted weights:', ', '.join(vals))
# Select according to the new weights.
val = random.uniform(0, 1)
for ix in range(0, count):
if val < normweights[ix]:
return values[ix]
val -= normweights[ix]
return values[-1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment