Skip to content

Instantly share code, notes, and snippets.

@hvasbath
Created May 20, 2018 04:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hvasbath/ba6bcc5aee6f0c620d9f51053f243793 to your computer and use it in GitHub Desktop.
Save hvasbath/ba6bcc5aee6f0c620d9f51053f243793 to your computer and use it in GitHub Desktop.
MPI skeleton for PT algorithm
#!/usr/bin/env python
"""
Parallel Tempering algorithm with mpi4py
"""
from __future__ import print_function
from mpi4py import MPI
from numpy import random
nsamples = 100
def enum(*sequential, **named):
"""Handy way to fake an enumerated type in Python
http://stackoverflow.com/questions/36932/how-can-i-represent-an-enum-in-python
"""
enums = dict(zip(sequential, range(len(sequential))), **named)
return type('Enum', (), enums)
def metrop_select(m1, m2):
u = random.rand()
if u < 0.5:
print('Rejected swap')
return m1, m2
else:
print('Accepted swap')
return m2, m1
# Define MPI message tags
tags = enum('READY', 'DONE', 'EXIT', 'START')
# Initializations and preliminaries
comm = MPI.COMM_WORLD # get MPI communicator object
size = comm.size # total number of processes
rank = comm.rank # rank of this process
status = MPI.Status() # get MPI status object
if rank == 0:
# Master process executes code below
task_index = 0
num_workers = size - 1
tasks = range(num_workers)
chain = []
active_workers = 0
# start sampling of chains with given seed
print("Master starting with %d workers" % num_workers)
for i in range(num_workers):
comm.recv(source=MPI.ANY_SOURCE, tag=tags.READY, status=status)
source = status.Get_source()
comm.send(tasks[i], dest=source, tag=tags.START)
print("Sent task to worker %i" % source)
active_workers += 1
print("Parallel tempering ...")
print("----------------------")
while True:
m1 = comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status)
source1 = status.Get_source()
print("Got sample 1 from worker %i" % source1)
m2 = comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status)
source2 = status.Get_source()
print("Got sample 2 from worker %i" % source1)
m1, m2 = metrop_select(m1, m2)
print('samples 1, 2 %i %i' % (m1, m2))
chain.extend([m1, m2])
if len(chain) < nsamples:
print("Sending states back to workers ...")
comm.send(m1, dest=source1, tag=tags.START)
comm.send(m2, dest=source2, tag=tags.START)
else:
print('Requested number of samples reached!')
break
print("Master finishing, recorded chain:")
print(chain)
print("Closing ...")
for i in range(1, size):
print('sending signal to close to %i' % i)
comm.send(None, dest=i, tag=tags.EXIT)
print("CLosed worker %i" % i)
active_workers -= 1
else:
# Worker processes execute code below
name = MPI.Get_processor_name()
print("I am a worker with rank %d on %s." % (rank, name))
comm.send(None, dest=0, tag=tags.READY)
while True:
print('Recieving ...')
task = comm.recv(source=0, tag=MPI.ANY_TAG, status=status)
print('received!')
tag = status.Get_tag()
if tag == tags.START:
# Do the work here
result = task + 1
print('attempting to send ...')
comm.send(result, dest=0, tag=tags.DONE)
print('sending worked ...')
elif tag == tags.EXIT:
print('went through exit')
break
print('Closed all workers!')
print('running?')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment