Skip to content

Instantly share code, notes, and snippets.

@jdiez17
Last active July 10, 2019 15:28
Show Gist options
  • Save jdiez17/6927c3ece844ebbe881e1294b38b6d7e to your computer and use it in GitHub Desktop.
Save jdiez17/6927c3ece844ebbe881e1294b38b6d7e to your computer and use it in GitHub Desktop.
from collections import defaultdict, OrderedDict
from scipy.integrate import solve_ivp
import networkx as nx
import numpy as np
import pdb
class Node:
def __init__(self, *args, **kwargs):
self._values = {}
self._connections = defaultdict(list)
# Initialize all values to None
for k, v in vars(self.__class__).items():
if not isinstance(v, Value):
continue
setattr(self, k, None)
# Override any values if they are passed as kwargs
for k, v in kwargs.items():
if k not in self._values:
# TODO raise somethign here
continue
setattr(self, k, v)
def solve(self, *args, **kwargs):
raise NotImplementedError("Function solve() not implemented for Node class {}".format(self.__class__.__name__))
class Value:
def __init__(self, initval=0):
self.initval = initval
def __set_name__(self, inst, name):
self.name = name
def __get__(self, inst, objtype=None):
return inst.__dict__['_values'][self.name]
def __set__(self, inst, val):
inst.__dict__['_values'][self.name] = val
connections = inst._connections[self.name]
for target, target_val in connections:
setattr(target, target_val, val)
class Input(Value):
pass
class Output(Value):
pass
class State(Output):
pass
class NodeGraph:
def __init__(self, root_node):
self.G = nx.DiGraph()
self._root_node = root_node
def get_execution_order(self):
execution_order = []
for edge in nx.bfs_edges(self.G, self._root_node):
for node in edge:
if node not in execution_order:
execution_order.append(node)
return execution_order
def connect(self, obj, prop, target):
self.G.add_edge(obj, target, label=prop)
obj._connections[prop].append((target, prop))
def get_human_readable_graph(self):
nodes = self.G.nodes()
mapping = {}
for node in nodes:
mapping[node] = node.__class__.__name__
# TODO use record-based nodes https://www.graphviz.org/doc/info/shapes.html#record
g = nx.relabel_nodes(self.G, mapping, copy=True)
p = nx.nx_pydot.to_pydot(g)
p.set_rankdir("LR")
return p
class Solver:
def __init__(self, *nodes):
self._execution_order = []
self._state_lengths = []
self.state_map = OrderedDict()
self.node_graph = NodeGraph(nodes[0]) # First node is considered to be the root node
for node in nodes:
self.add(node)
def connect(self, *args, **kwargs):
return self.node_graph.connect(*args, **kwargs)
def add(self, model):
if model not in self.state_map:
self.state_map[model] = []
for k, v in vars(model.__class__).items():
if isinstance(v, State):
self.state_map[model].append(k)
def _rhs(self, t, x):
# Unpack states, put them into the model's attributes
cnt = 0
state_idx = 0
for model in self._execution_order:
for state in self.state_map[model]:
state_length = self._state_lengths[state_idx]
state_idx += 1
state_value = x[cnt:cnt+state_length]
cnt += state_length
#print("setting", model, state, "=", state_value)
setattr(model, state, state_value)
# Run the diff eqs for each model
diffs = []
for model in self._execution_order:
results = model.solve(t)
#print("res", results)
if len(self.state_map[model]) == 0:
# Model has no states, so don't include its changes over time
continue
# TODO type checking here
diffs.extend(results)
return np.hstack(diffs)
def solve(self, start, end):
self._execution_order = self.node_graph.get_execution_order()
self._state_lengths = []
# First, gather all current states
all_states = []
for model in self._execution_order:
for state in self.state_map[model]:
model_state = getattr(model, state)
# Figure out how many entries in the `all_states` array this state will occupy.
try:
# If it's a list of some sort, just take its `len`
length = len(model_state)
except TypeError:
# Single values take 1 entry
length = 1
self._state_lengths.append(length)
all_states.append(model_state)
all_states = np.hstack(all_states)
return solve_ivp(self._rhs, (start, end), all_states, rtol=1e-9)
from lib import Node, Input, Output, State, Solver
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from astropy.constants import GM_earth
class Orbit(Node):
r = State()
v = State()
def solve(self, t):
return [
self.v,
-GM_earth.value * self.r / np.linalg.norm(self.r) ** 3
]
class MagneticField(Node):
r = Input()
def solve(self, t):
pass
#loc = coord.EarthLocation.from_geocentric(self.r[0], self.r[1], self.r[2], unit=u.m)
#print(loc.lat.value, loc.lon.value, loc.height)
class SunModel(Node):
S = Output()
F = Output()
def solve(self, t):
pass
class Eclipse(Node):
r = Input()
S = Input()
O = Output()
def solve(self, t):
pass
if __name__ == '__main__':
r = 7018136.30000
v = np.sqrt(GM_earth.value / r)
orbit = Orbit(
r=np.array([r, 0, 0]),
v=np.array([0, v, 0])
)
mf = MagneticField()
sun = SunModel()
e = Eclipse()
solver = Solver(orbit, mf, sun, e)
solver.connect(orbit, 'r', mf)
solver.connect(orbit, 'r', e)
solver.connect(sun, 'S', e)
res = solver.solve(0, 3600)
print(res)
plt.figure()
plt.plot(res.y[0, :], res.y[1, :])
plt.axis('equal')
plt.show()
p = solver.node_graph.get_human_readable_graph()
p.write_png("m1-pydot.png")
print(p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment