Skip to content

Instantly share code, notes, and snippets.

@bbrelje
Last active May 6, 2020 20:32
Show Gist options
  • Save bbrelje/402910db2b7f3fc7076763aadbda3c99 to your computer and use it in GitHub Desktop.
Save bbrelje/402910db2b7f3fc7076763aadbda3c99 to your computer and use it in GitHub Desktop.
Can't declare empty rows/cols in declare_partials
import openmdao.api as om
import numpy as np
from openmdao.utils.array_utils import evenly_distrib_idxs
import mpi4py.MPI as MPI
N = 3
class DistribComp(om.ExplicitComponent):
def initialize(self):
self.options['distributed'] = True
def setup(self):
self.add_input('x', shape=1)
rank = self.comm.rank
sizes, offsets = evenly_distrib_idxs(self.comm.size, N)
self.add_output('y', shape=sizes[rank])
rows = np.arange(0, sizes[rank])
cols = np.zeros((sizes[rank],))
self.declare_partials('y', 'x', rows=rows, cols=cols)
self.declare_partials('y', 'x', rows=np.arange(0, sizes[rank]), cols=np.zeros((sizes[rank],)))
def compute(self, inputs, outputs):
rank = self.comm.rank
sizes, offsets = evenly_distrib_idxs(self.comm.size, N)
outputs['y'] = inputs['x']*np.ones((sizes[rank],))
if rank == 0:
outputs['y'][0] = 2.3
def compute_partials(self, inputs, J):
rank = self.comm.rank
sizes, offsets = evenly_distrib_idxs(self.comm.size, N)
if MPI.COMM_WORLD.rank == 0:
J['y','x'] = np.zeros((sizes[rank],))
J['y','x'][0] = 0.0
else:
J['y','x'] = np.ones((sizes[rank],))
class GroupThatDoesWork(om.Group):
def setup(self):
dvs = om.IndepVarComp()
dvs.add_output('x', val=6.0)
self.add_subsystem('dvs', dvs, promotes_outputs=['*'])
self.add_subsystem('distcomp',DistribComp(), promotes_inputs=['*'])
self.add_subsystem('execcomp',om.ExecComp('z = sum(y)', y=np.zeros((N,)), z=0.0))
self.connect('distcomp.y', 'execcomp.y')
def mainfunc():
prob = om.Problem()
prob.model = GroupThatDoesWork()
prob.model.add_design_var('x', lower=-1.0, upper=7.0)
prob.model.add_objective('execcomp.z')
prob.setup(mode='fwd')
prob['dvs.x'] = 7.0
prob.run_model()
prob.model.list_outputs(print_arrays=True)
prob.check_totals()
if __name__ == "__main__":
# works if run with 3 or fewer procs
# raises an error in _declare_partials if run with 4 or more (one proc has no output, raises ValueError)
mainfunc()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment