Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save Kenneth-T-Moore/fc240ba09def05d38a3a8c0bbb4de79e to your computer and use it in GitHub Desktop.
Save Kenneth-T-Moore/fc240ba09def05d38a3a8c0bbb4de79e to your computer and use it in GitHub Desktop.
Distributed example.
import numpy as np
import openmdao.api as om
from openmdao.utils.array_utils import evenly_distrib_idxs
from openmdao.utils.mpi import MPI
class Distrib_Derivs(om.ExplicitComponent):
"""Simplest example that combines distributed and serial inputs and outputs."""
def setup(self):
# Distributed Input
self.add_input('in_dist', shape_by_conn=True, distributed=True)
# Serial Input
self.add_input('in_serial', shape_by_conn=True)
# Distributed Output
self.add_output('out_dist', copy_shape='in_dist', distributed=True)
# Serial Output
self.add_output('out_serial_wrong', copy_shape='in_serial')
self.add_output('out_serial', copy_shape='in_serial')
def setup_partials(self):
meta = self.get_io_metadata(metadata_keys=['shape'])
local_size = meta['in_dist']['shape'][0]
serial_size = meta['in_serial']['shape'][0]
# Try this stuff
comm = self.comm
rank = comm.rank
sizes, offsets = evenly_distrib_idxs(comm.size, size)
row_col_d = np.arange(local_size)
row_col_s = np.arange(serial_size)
self.declare_partials('out_dist', 'in_dist', rows=row_col_d, cols=row_col_d)
self.declare_partials('out_serial', 'in_serial', rows=row_col_s, cols=row_col_s)
self.declare_partials('out_dist', 'in_serial')
self.declare_partials('out_serial', 'in_dist')
def compute(self, inputs, outputs):
comm = self.comm
x = inputs['in_dist']
y = inputs['in_serial']
# "Computationally Intensive" operation that we wish to parallelize.
f_x = x**2 - 2.0*x + 4.0
f_y = y ** 0.5
g_x = x ** 0.5
g_y = y**2 + 3.0*y - 5.0
# Our local distributed output is a function of local distributed input computed above.
# It also is a function of the serial input.
outputs['out_dist'] = f_x + np.sum(f_y)
# Our serial distributed output is a function of serial input and the distributed input.
# However, this is the wrong way because we are only using the local part of the distributed
# input instead of gathering all of it.
# As a result, we end up with different values for this output on each proc.
outputs['out_serial_wrong'] = g_y + np.sum(g_x)
if MPI and comm.size > 1:
# We need to gather the summed values to compute the total sum over all procs.
local_sum = np.sum(g_x)
all_local_sums = np.zeros(comm.size)
self.comm.Allgather(local_sum, all_local_sums)
outputs['out_serial'] = g_y + np.sum(all_local_sums)
else:
outputs['out_serial'] = g_y + np.sum(g_x)
def compute_partials(self, inputs, partials):
x = inputs['in_dist']
y = inputs['in_serial']
size = len(y)
local_size = len(x)
partials['out_dist', 'in_dist'] = 2.0 * x - 2.0
partials['out_serial', 'in_serial'] = 2.0 * y + 3.0
df_dy = 0.5 / y ** 0.5
partials['out_dist', 'in_serial'] = np.tile(df_dy, local_size).reshape((local_size, size))
dg_dx = 0.5 / x ** 0.5
partials['out_serial', 'in_dist'] = np.tile(dg_dx, size).reshape((size, local_size))
size = 5
if MPI:
comm = MPI.COMM_WORLD
rank = comm.rank
sizes, offsets = evenly_distrib_idxs(comm.size, size)
else:
rank = 0
sizes = {rank : size}
offsets = {rank : 0}
model = om.Group()
ivc = om.IndepVarComp()
ivc.add_output('x_dist', np.zeros(sizes[rank]), distributed=True)
ivc.add_output('x_serial', np.zeros(size))
model.add_subsystem("indep", ivc)
model.add_subsystem("D1", Distrib_Derivs())
model.connect('indep.x_dist', 'D1.in_dist')
model.connect('indep.x_serial', 'D1.in_serial')
prob = om.Problem(model)
prob.setup()
x_dist_init = 3.0 + np.arange(size)[offsets[rank]:offsets[rank] + sizes[rank]]
x_serial_init = 1.0 + 2.0*np.arange(size)
# This set operates on the entire vector.
prob.set_val('indep.x_dist', x_dist_init)
prob.set_val('indep.x_serial', x_serial_init)
prob.run_model()
for var in ['indep.x_dist', 'D1.out_dist', 'D1.out_serial_wrong', 'D1.out_serial']:
print(var, prob.get_val(var))
print('remote', var, prob.get_val(var, get_remote=True))
print('check derivs', flush=True)
# Uncomment for Check Partials story (hang: exception on a rank)
#prob.check_partials()
#prob.check_totals(of=['D1.out_serial'], wrt=['indep.x_serial'])
# Uncomment for Check Totals story (hang: exception on a rank)
prob.check_totals(of=['D1.out_dist'], wrt=['indep.x_dist'])
# Uncomment for Compute Totals story (wrong values, remote values not gathered)
#J = prob.compute_totals(of=['D1.out_serial'], wrt=['indep.x_dist'])
#print(J)
#J = prob.compute_totals(of=['D1.out_serial', 'D1.out_dist'], wrt=['indep.x_serial', 'indep.x_dist'])
#print(J)
print('done')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment