Skip to content

Instantly share code, notes, and snippets.

@bbrelje
Created May 8, 2020 13:26
Show Gist options
  • Save bbrelje/bf0d003cb96310ab0682007625395541 to your computer and use it in GitHub Desktop.
Save bbrelje/bf0d003cb96310ab0682007625395541 to your computer and use it in GitHub Desktop.
Parallel coloring duplicate calls to matrix-free API
import openmdao.api as om
import numpy as np
import time
from mpi4py import MPI
class SumComp(om.ExplicitComponent):
def __init__(self, size):
super(SumComp, self).__init__()
self.size = size
def setup(self):
self.add_input('x', val=np.zeros(self.size))
self.add_output('y', val=0.0)
self.declare_partials(of='*', wrt='*')
def compute(self, inputs, outputs):
outputs['y'] = np.sum(inputs['x'])
def compute_partials(self, inputs, partials):
partials['y', 'x'] = np.ones(inputs['x'].size)
class SlowCompComputePartials(om.ExplicitComponent):
"""
Component with a delay that multiplies the input by a multiplier.
"""
def __init__(self, delay=1.0, size=3, mult=2.0):
super(SlowCompComputePartials, self).__init__()
self.delay = delay
self.size = size
self.mult = mult
def setup(self):
self.add_input('x', val=0.0)
self.add_output('y', val=np.zeros(self.size))
self.declare_partials(of='*', wrt='*')
def compute(self, inputs, outputs):
outputs['y'] = inputs['x'] * self.mult
def compute_partials(self, inputs, partials):
partials['y', 'x'] = self.mult
def _apply_linear(self, jac, vec_names, rel_systems, mode, scope_out=None, scope_in=None):
print('applying_linear on rank ' + str(MPI.COMM_WORLD.rank))
time.sleep(self.delay)
super(SlowCompComputePartials, self)._apply_linear(jac, vec_names, rel_systems, mode, scope_out, scope_in)
class SlowCompComputeJacvec(om.ExplicitComponent):
"""
Component with a delay that multiplies the input by a multiplier.
"""
def __init__(self, delay=1.0, size=3, mult=2.0):
super(SlowCompComputeJacvec, self).__init__()
self.delay = delay
self.size = size
self.mult = mult
def setup(self):
self.add_input('x', val=0.0)
self.add_output('y', val=np.zeros(self.size))
self.declare_partials(of='*', wrt='*')
def compute(self, inputs, outputs):
outputs['y'] = inputs['x'] * self.mult
def compute_jacvec_product(self, inputs, d_inputs, d_outputs, mode):
print('Jacvec on rank '+str(MPI.COMM_WORLD.rank))
if mode == 'fwd':
if 'x' in d_inputs:
d_outputs['y'] += self.mult * np.eye(self.size)*d_inputs['x']
elif mode == 'rev':
if 'x' in d_inputs:
d_inputs['x'] += self.mult * np.sum(np.eye(self.size)*d_outputs['y'])
def _apply_linear(self, jac, vec_names, rel_systems, mode, scope_out=None, scope_in=None):
print('applying_linear on rank ' + str(MPI.COMM_WORLD.rank))
time.sleep(self.delay)
super(SlowCompComputeJacvec, self)._apply_linear(jac, vec_names, rel_systems, mode, scope_out, scope_in)
class PartialDependGroup(om.Group):
def setup(self):
size = 4
SlowComp = SlowCompComputeJacvec
Indep1 = self.add_subsystem('Indep1', om.IndepVarComp('x', np.arange(size, dtype=float)+1.0))
Comp1 = self.add_subsystem('Comp1', SumComp(size))
pargroup = self.add_subsystem('ParallelGroup1', om.ParallelGroup())
self.linear_solver = om.LinearRunOnce()
self.linear_solver.options['iprint'] = -1
pargroup.linear_solver = om.LinearRunOnce()
pargroup.linear_solver.options['iprint'] = -1
delay = .1
Con1 = pargroup.add_subsystem('Con1', SlowComp(delay=delay, size=2, mult=2.0))
Con2 = pargroup.add_subsystem('Con2', SlowComp(delay=delay, size=2, mult=-3.0))
Con3 = pargroup.add_subsystem('Con3', SlowComp(delay=delay, size=2, mult=-3.0))
self.connect('Indep1.x', 'Comp1.x')
self.connect('Comp1.y', 'ParallelGroup1.Con1.x')
self.connect('Comp1.y', 'ParallelGroup1.Con2.x')
self.connect('Comp1.y', 'ParallelGroup1.Con3.x')
color = None
# color = 'parcon'
self.add_design_var('Indep1.x')
self.add_constraint('ParallelGroup1.Con1.y', lower=0.0, parallel_deriv_color=color)
self.add_constraint('ParallelGroup1.Con3.y', lower=0.0, parallel_deriv_color=color)
self.add_objective('ParallelGroup1.Con2.y', parallel_deriv_color=color)
if __name__ == "__main__":
size = 4
of = ['ParallelGroup1.Con1.y', 'ParallelGroup1.Con2.y', 'ParallelGroup1.Con3.y']
wrt = ['Indep1.x']
# run first in fwd mode
p = om.Problem(model=PartialDependGroup())
p.setup(mode='rev')
p.run_model()
timein = time.time()
J = p.compute_totals(of, wrt, return_format='dict')
timout = time.time() - timein
print(timout)
print(J['ParallelGroup1.Con1.y']['Indep1.x'][0])
print(J['ParallelGroup1.Con2.y']['Indep1.x'][0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment