Created
February 26, 2015 22:27
-
-
Save thearn/81fd678beed92705239b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def doctor(self): | |
print "-- OpenMDAO Doctor --" | |
print 30*"-" | |
uncon_inputs = self.get_unconnected_inputs() | |
n_uncon = len(uncon_inputs) | |
if n_uncon > 0: | |
print "- Assembly contains %i unconnected inputs:" % n_uncon | |
for name in uncon_inputs: | |
print " " + name | |
n_caserecorders = len(self.recorders) | |
if n_caserecorders == 0: | |
print "- No case recorders have been set" | |
name = self.__class__.__name__ | |
with open("%s_depgraph.html" % name, "wb") as f: | |
f.write(self._repr_svg_()) | |
from openmdao.util.dotgraph import plot_system_tree | |
#plot_system_tree(self._system) | |
def get_unconnected_inputs(self): | |
unconnected_inputs = [] | |
connected_inputs = [i[1] for i in self.list_connections()] | |
defaults = ['itername', 'force_execute', 'directory', 'exec_count', | |
'derivative_exec_count', 'fixed_external_vars', | |
'missing_deriv_policy', 'force_fd'] | |
for compname in self.list_components() + ['self']: | |
if compname in ["self", "driver"]: | |
continue | |
comp = self.get(compname) | |
for var in comp.list_inputs(): | |
if var in defaults: | |
continue | |
fullname = '.'.join([compname, var]) | |
if fullname not in connected_inputs: | |
unconnected_inputs.append(fullname) | |
return unconnected_inputs | |
def create(self, inst): | |
name = inst.__class__.__name__ | |
self.add(name, inst) | |
self.driver.workflow.add(name) | |
def auto_connect(self, print_only=False): | |
""" | |
Collects the names of all input and output variables for all | |
components within the assembly (drivers excluded). | |
Then establishes connections between | |
any output variable and input variable that has the same name so | |
long as the variable name does not exist as an output to more than | |
a single component (so excludes default outputs). | |
""" | |
inputs, outputs = {}, {} | |
# Gather all inputs and output from the components. Ignore all | |
# framework vars. | |
for compname in self.list_components(): | |
comp = self.get(compname) | |
comp_inputs = [inp for inp in comp.list_inputs() if \ | |
comp._trait_metadata[inp].get('framework_var') != True] | |
for input_name in comp_inputs: | |
if input_name not in inputs: | |
inputs[input_name] = [compname] | |
else: | |
inputs[input_name].append(compname) | |
comp_outputs = [inp for inp in comp.list_outputs() if \ | |
comp._trait_metadata[inp].get('framework_var') != True] | |
for output_name in comp_outputs: | |
if output_name not in outputs: | |
outputs[output_name] = [compname] | |
else: | |
outputs[output_name].append(compname) | |
# Automatically connect assembly boundary inputs too. | |
assym_level = self.list_inputs() | |
assym_level.remove('directory') | |
for var in assym_level: | |
if var in outputs: | |
outputs[var].append('') | |
else: | |
outputs[var] = [''] | |
assym_level = self.list_outputs() | |
assym_level.remove('derivative_exec_count') | |
assym_level.remove('exec_count') | |
assym_level.remove('itername') | |
for var in assym_level: | |
if var in inputs: | |
inputs[var].append('') | |
else: | |
inputs[var] = [''] | |
# Do the connections | |
if print_only: | |
print(30*"-" + "\nConnections:\n" + 30*"-") | |
connections = [] | |
for varname in outputs.keys(): | |
comps = outputs[varname] | |
#if len(comps) > 1: | |
# continue | |
if comps[0]: | |
frompath = '.'.join([comps[0], varname]) | |
else: | |
frompath = varname | |
if varname in inputs: | |
for compname in inputs[varname]: | |
if compname == "": | |
topath = varname | |
else: | |
topath = '.'.join([compname, varname]) | |
if print_only: | |
connections.append("connect(%s, %s)" % (frompath, topath)) | |
else: | |
self.connect(frompath, topath) | |
if print_only: | |
connections.sort() | |
comp = connections[0].split(",")[0] | |
for con in connections: | |
this_comp = con.split(",")[0] | |
if this_comp != comp: | |
print "self." + con | |
comp = this_comp | |
print(30*"-") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment