Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Randomly partitions a set of elements using the Dirichlet process
#! /usr/bin/env python
###############################################################################
##
## Copyright 2017 Jeet Sukumaran.
##
## This program is free software; you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation; either version 3 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
## GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License along
## with this program. If not, see <http://www.gnu.org/licenses/>.
##
###############################################################################
"""
Randomly partitions a set of elements using the Dirichlet process.
"""
import argparse
import random
def weighted_index_choice(weights, sum_of_weights, rng):
"""
(From: http://eli.thegreenplace.net/2010/01/22/weighted-random-generation-in-python/)
The following is a simple function to implement weighted random choice in
Python. Given a list of weights, it returns an index randomly, according
to these weights [1].
For example, given [2, 3, 5] it returns 0 (the index of the first element)
with probability 0.2, 1 with probability 0.3 and 2 with probability 0.5.
The weights need not sum up to anything in particular, and can actually be
arbitrary Python floating point numbers.
If we manage to sort the weights in descending order before passing them
to weighted_choice_sub, it will run even faster, since the random call
returns a uniformly distributed value and larger chunks of the total
weight will be skipped in the beginning.
"""
rnd = rng.uniform(0, 1) * sum_of_weights
for i, w in enumerate(weights):
rnd -= w
if rnd < 0:
return i
def sample_partition(
number_of_elements,
scaling_parameter,
rng,):
groups = []
# element_ids = ["t{}".format(i+1) for i in range(number_of_elements)]
element_ids = [i+1 for i in range(number_of_elements)]
# element_ids = [chr(65+i) for i in range(number_of_elements)]
# element_ids = [chr(97+i) for i in range(number_of_elements)]
rng.shuffle(element_ids)
for i, element_id in enumerate(element_ids):
probs = []
element_idx = i + 1
if i == 0:
groups.append([element_id])
continue
p_new = scaling_parameter/(scaling_parameter + element_idx - 1.0)
probs.append(p_new)
for group in groups:
p = len(group)/(scaling_parameter + element_idx - 1.0)
probs.append(p)
assert abs(sum(probs) - 1.0) <= 1e-5
selected_idx = weighted_index_choice(
weights=probs,
sum_of_weights=1.0,
rng=rng)
if selected_idx == 0:
groups.append([element_id])
else:
groups[selected_idx-1].append(element_id)
return groups
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("-K", "--number-of-elements",
type=int,
default=10,
help="Number of elements in the set. Default: %(default)s.")
parser.add_argument("-a", "--scaling-parameter", "--alpha",
type=float,
default=1.5,
help="(Anti-)Concentration or scaling parameter:"
" low values result in a more clumpier/clustered"
" partitions, while higher values result in a more"
" dispersed partitions. Default: %(default)s."
)
parser.add_argument("-n", "--num-replicates",
type=int,
default=10,
help="How many draws to run. Default: %(default)s.")
parser.add_argument("-v", "--verbosity",
type=int,
default=1,
help="How much to report about what is going on"
" with 0 being almost completely quiet and"
" higher numbers reporting more and more."
" Default: %(default)s.")
parser.add_argument("-z", "--random-seed",
type=int,
default=None,
help="Seed for random number generator.")
args = parser.parse_args()
rng = random.Random(args.random_seed)
num_subsets = []
num_elements_in_subsets = []
for rep_idx in range(args.num_replicates):
partition = sample_partition(
number_of_elements=args.number_of_elements,
scaling_parameter=args.scaling_parameter,
rng=rng)
if args.verbosity >= 1:
print(partition)
num_subsets.append(len(partition))
num_elements_in_subsets.append( sum(len(s) for s in partition)/float(len(partition)) )
print("---")
print("Mean number of subsets per partition: {}".format(
sum(num_subsets) / float(len(num_subsets))))
print(" Mean number of elements per subset: {}".format(
sum(num_elements_in_subsets) / float(len(num_elements_in_subsets))))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment