Created
May 13, 2012 14:14
-
-
Save anonymous/2688645 to your computer and use it in GitHub Desktop.
matrix_shuffle.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Computes the probability distribution of states created by | |
the shuffle2 algorithm. This should return results similar to | |
enumerate_shuffle.py. | |
A state is a 2-tuple: (swapped, deck) | |
The probability of transitioning from one state to another can be represented by | |
a Markov matrix M. | |
The probability distribution of states after N iterations of the shuffle2 algorithm | |
is given by M**N. | |
""" | |
import numpy as np | |
import itertools | |
def test_uniform(M, start, maxiter = 20): | |
"""Compute the probability distribution of states | |
as the inital state `start` transitions through `maxiter` applications of | |
the Markov matrix M. | |
Raise AssertionError if the distribution is provably nonuniform. | |
""" | |
final_states = set([i for i, (swapped, deck) in enumerate(states) if all(swapped)]) | |
# final_states = [42, 43, 44, 45, 46, 47] | |
uniform = 1./len(final_states) | |
distribution = np.dot(M, start) | |
missing, lo, hi = 1.0, 1.0, 0.0 | |
for iteration in range(1, maxiter+1): | |
distribution = np.dot(M, distribution) | |
assert np.allclose(sum(distribution), 1.0) | |
missing = sum([distribution[i] for i in set(range(len(states)))-final_states]) | |
lo = min([distribution[i] for i in final_states]) | |
hi = max([distribution[i] for i in final_states]) | |
ml = missing + lo | |
results = dict([(deck_state(states[i]), distribution[i]) for i in final_states]) | |
print(report.format(**locals())) | |
assert hi <= uniform, too_hi.format(**locals()) | |
assert missing+lo >= uniform, too_lo.format(**locals()) | |
return distribution | |
def deck_state(state): | |
swapped, deck = state | |
return deck | |
def initial_state(): | |
"""Return the initial probability distribution corresponding to state[0], i.e. | |
state[0] = swapped, deck = (0, 0, 0), (0, 1, 2) | |
""" | |
x = np.zeros(len(states)) | |
x[0] = 1.0 | |
return x | |
def test(): | |
distribution = start[:] | |
for i in range(10): | |
distribution = np.dot(M, distribution) | |
assert np.allclose(sum(distribution), 1.0) | |
print('passed: {i}'.format(**locals())) | |
return 'tests passed!' | |
def makeM(): | |
""" | |
Return the Markov transition matrix, M, where | |
M[row, col] is the probability of moving from states[col] to states[row] | |
""" | |
L = len(states) | |
M = np.zeros((L, L)) | |
for col, state in enumerate(states): | |
swapped, deck = state | |
N = len(deck) | |
if not all(swapped): | |
for i in range(N): | |
for j in range(N): | |
newdeck, newswapped = list(deck), list(swapped) | |
newswapped[i] = True | |
newdeck[i], newdeck[j] = deck[j], deck[i] | |
newswapped = tuple(newswapped) | |
newdeck = tuple(newdeck) | |
row = states.index((newswapped, newdeck)) | |
M[row, col] += 1./(N**2) | |
else: | |
row = states.index((swapped, deck)) | |
M[row, col] = 1.0 | |
return M | |
# states enumerates all possible (swapped, deck) pairs | |
states = list(itertools.product(itertools.product([0, 1], repeat = 3), | |
itertools.permutations([0, 1, 2]))) | |
M = makeM() | |
start = initial_state() | |
report = '''\ | |
iteration: {iteration} | |
missing: {missing} | |
lo = {lo} | |
hi = {hi} | |
results = {results} | |
''' | |
too_lo = 'distribution is not uniform since missing+lo = {ml} < {uniform}' | |
too_hi = 'distribution is not uniform since hi = {hi} > {uniform}' | |
if __name__ == '__main__': | |
print(test()) | |
test_uniform(M , start) | |
''' | |
iteration: 15 | |
missing: 0.00456724682932 | |
lo = 0.16092710389 | |
hi = 0.16691555987 | |
results = {(2, 1, 0): 0.1669155598702729, (0, 1, 2): 0.16092710389008436, (1, 0, 2): 0.1669155598702729, (2, 0, 1): 0.16687948483488679, (0, 2, 1): 0.1669155598702729, (1, 2, 0): 0.16687948483488679} | |
AssertionError: distribution is not uniform since hi = 0.16691555987 > 0.166666666667 | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment