Skip to content

Instantly share code, notes, and snippets.

@maxentile
Created August 29, 2023 19:29
Show Gist options
  • Save maxentile/9a90704b6df4f8cd2e255a033ee25e10 to your computer and use it in GitHub Desktop.
Save maxentile/9a90704b6df4f8cd2e255a033ee25e10 to your computer and use it in GitHub Desktop.
side exploration -- pre-enumerated neighborhoods for permutation MCMC
def enumerate_restricted_permutations(n_states=50, num_neighbor_swaps=1):
"""enumerate all permutations accessible within num_neighbor_swaps applications of a move that swaps (i,i+1)"""
identity_permutation = tuple(range(n_states))
def neighbor_swap(perm, i):
"""return copy where (perm[i], perm[i+1]) are swapped"""
assert (i + 1) < len(perm)
new_perm = list(perm)
new_perm[i] = perm[i + 1]
new_perm[i+1] = perm[i]
return tuple(new_perm)
permutations = [identity_permutation]
for _ in range(num_neighbor_swaps):
for perm in set(permutations):
for i in range(n_states - 1):
permutations.append(neighbor_swap(perm, i))
return list(set(permutations))
def test_restricted_permutations():
for n_states in [10,20,50]:
perms = enumerate_restricted_permutations(n_states, 1)
assert len(perms) == n_states
for p in perms:
assert len(set(p)) == n_states
n_perms = 0
n_states = 10
for n_swaps in [1,2,3,4]:
perms = enumerate_restricted_permutations(n_states, n_swaps)
assert len(perms) > n_perms
n_perms = len(perms)
for p in perms:
assert len(set(p)) == n_states
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment