Last active
October 31, 2020 13:40
-
-
Save BSVogler/fcd685f80bb8723fb6711b5b7453a35f to your computer and use it in GitHub Desktop.
dump and load
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import nest | |
from typing import Dict, List | |
def dump(selections: List = ("nodes", "synapses")) -> Dict: | |
"""Returns a dictionary containing the current net in memory for serialization. | |
Only a subset of the data can be serialized. | |
Parameters | |
---------- | |
selections : List[strings], optional | |
Only obtain a subset of the net parameters. Options are "nodes" and/or "synapses" | |
Returns | |
------- | |
Dict: | |
Object containing the specified serialization sorted by dict key likes "nodes" or "synapses" | |
See Also | |
--------- | |
load : Function to construct the net from the data. | |
""" | |
nest.ll_api.sli_func("({}) M_WARNING message".format("Dumping is currently only supported for a single process.")) | |
dumpdata = dict() | |
numnetwork = nest.GetKernelStatus("network_size") | |
# nothing added | |
if numnetwork < 2: | |
return dumpdata | |
nodes = nest.NodeCollection(range(1, numnetwork + 1)) | |
if "nodes" in selections: | |
dumpdata["nodes"] = nest.GetStatus(nodes) | |
syn_ids = nest.GetConnections(source=nodes) | |
if "synapses" in selections: | |
dumpdata["synapses"] = nest.GetStatus(syn_ids) | |
return dumpdata | |
def load(data: Dict) -> Dict: | |
""" | |
Loads a dictionary obtained by the dump method. | |
Repeated loading will add to the network size. To overwrite clear with a call to resetnetwork. | |
To directly overwrite the binary memory state a lower level access needs to be developed in the future. | |
Parameters | |
---------- | |
data: Dictionary | |
the data to be loaded. | |
Returns | |
------- | |
Dict: | |
the created nest obejcts and all synapses | |
See Also | |
--------- | |
dump : Function to obtain a structured memory dump. | |
""" | |
created = dict() | |
nest.ll_api.sli_func("({}) M_WARNING message".format("Loading is currently only supported for a single process.")) | |
if "nodes" in data: | |
dictmissbefore = nest.GetKernelStatus({"dict_miss_is_error"})[0] | |
nest.SetKernelStatus({"dict_miss_is_error": False}) | |
verbose = nest.get_verbosity() | |
nest.set_verbosity("M_ERROR") | |
newnodes = nest.NodeCollection() | |
try: | |
for d in data["nodes"]: | |
newnodes += nest.Create(d["model"], d) | |
finally: | |
nest.SetKernelStatus({"dict_miss_is_error": dictmissbefore}) | |
# restore verbosity level | |
nest.set_verbosity(verbose) | |
created["nodes"] = newnodes | |
if "synapses" in data: | |
try: | |
for conn in data["synapses"]: | |
source, target = conn["source"], conn["target"] | |
# remove unused, copy to prevent side effects | |
specs = copy.copy(conn) | |
specs.pop("port") | |
specs.pop("receptor") | |
specs.pop("synapse_id") | |
specs.pop("target_thread") | |
specs.pop("source") | |
specs.pop("target") | |
nest.Connect([source], [target], syn_spec=specs) | |
except Exception as e: | |
print("Error during synapse loading:" + str(e)) | |
#we only want the newly created connections, but nest.Connect does not return a SynapseCollection object | |
created["synapses"] = nest.GetConnections() | |
return created |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import nest | |
from dumpload import * | |
class TestSerializationMethods(unittest.TestCase): | |
def setUp(self) -> None: | |
nest.set_verbosity("M_ERROR") | |
nest.ResetKernel() | |
def test_dump(self): | |
nest.ResetKernel() | |
a = nest.Create("aeif_psc_alpha", 5) | |
b = nest.Create("aeif_psc_alpha", 5) | |
nest.Connect(a, b) | |
thedump = dump() | |
onlysynapses = dump(selections=["synapses"]) | |
onlynodes = dump(selections=["nodes"]) | |
self.assertIsNotNone(thedump) | |
self.assertEqual(len(onlysynapses), 1) | |
self.assertEqual(len(onlynodes), 1) | |
c = nest.Create("aeif_psc_alpha", 5) | |
dump2 = dump() | |
self.assertNotEqual(thedump, dump2) | |
def test_dump_emtpy(self): | |
nest.ResetKernel() | |
wholedump = dump() | |
self.assertEqual(wholedump, {}) | |
nodes = dump(selections=["nodes"]) | |
self.assertEqual(nodes, {}) | |
def test_load_repeated(self): | |
#checks for growing network size | |
nest.ResetKernel() | |
a = nest.Create("aeif_psc_alpha", 5) | |
b = nest.Create("aeif_psc_alpha", 5) | |
nest.Connect(a, b) | |
thedump = dump() | |
numload0 = nest.GetKernelStatus("network_size") | |
#no reset loadand load twice | |
result = load(thedump) | |
numload1 = nest.GetKernelStatus("network_size") | |
self.assertGreater(numload1, numload0) | |
result = load(thedump) | |
numload2 = nest.GetKernelStatus("network_size") | |
self.assertGreater(numload2, numload1) | |
def test_load_empty(self): | |
nest.ResetKernel() | |
numbefore = nest.GetKernelStatus("network_size") | |
syns = load({}) | |
numafter = nest.GetKernelStatus("network_size") | |
self.assertEqual(numbefore, numafter) | |
def test_integration(self): | |
nest.ResetKernel() | |
a = nest.Create("aeif_psc_alpha", 5) | |
b = nest.Create("aeif_psc_alpha", 5) | |
nest.Connect(a, b) | |
thedump = dump() | |
nest.ResetKernel() | |
result = load(thedump) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment