Skip to content

Instantly share code, notes, and snippets.

@samskalicky
Last active May 1, 2020 07:08
Show Gist options
  • Save samskalicky/5f44e159e9f1b04237eed8d20e5d9f28 to your computer and use it in GitHub Desktop.
Save samskalicky/5f44e159e9f1b04237eed8d20e5d9f28 to your computer and use it in GitHub Desktop.
Example graph pass
#include <math.h>
#include <iostream>
#include <algorithm>
#include <unordered_set>
#include <functional>
#include "lib_api.h"
class Node;
struct NodeEntry {
Node* node;
int entry;
};
class Node {
public:
std::string op,name;
std::vector<NodeEntry> inputs;
std::vector<NodeEntry> outputs;
std::unordered_map<std::string, std::string> attrs;
};
class Graph {
public:
Graph() {}
static Graph fromString(const std::string& json) {
JsonParser parser;
JsonVal val = parser.parse_to_json(json);
return fromJson(val);
}
~Graph() {
for(int i=0; i<nodes.size(); i++)
delete nodes[i];
}
static Graph fromJson(JsonVal val) {
// get nodes list
JsonVal nodes = val.map[JsonVal("nodes")];
Graph g;
std::map<int, Node*> nodeMap;
// loop over nodes
for(int i=0; i<nodes.list.size(); i++) {
Node* n = new Node();
g.nodes.push_back(n);
JsonVal node = nodes.list[i];
// set the op info
n->op = node.map[JsonVal("op")].str;
n->name = node.map[JsonVal("name")].str;
// if op is null its an input to the graph
if(n->op.compare("null") == 0)
g.inputs.push_back(n);
// set attrs
JsonVal attributes = node.map[JsonVal("attrs")];
for(auto& kv : attributes.map) {
n->attrs[kv.first.str] = kv.second.str;
}
// set node inputs
JsonVal node_inputs = node.map[JsonVal("inputs")];
n->inputs.resize(node_inputs.list.size());
for(int j=0; j<node_inputs.list.size(); j++) {
JsonVal input = node_inputs.list[j];
NodeEntry& entry = n->inputs[j];
//get pointer to other node
entry.node = nodeMap[input.list[0].num];
//get the other node's output index
entry.entry = input.list[1].num;
//set other nodes output as connected to this node
entry.node->outputs.push_back({n,j});
}
nodeMap[i] = n;
}
JsonVal& heads = val.map[JsonVal("heads")];
g.outputs.resize(heads.list.size());
for(int i=0; i<heads.list.size(); i++) {
JsonVal head = heads.list[i];
g.outputs[i].node = nodeMap[head.list[0].num];
g.outputs[i].entry = head.list[1].num;
}
JsonParser parser;
for(auto& kv : val.map) {
if(kv.first.str.compare("nodes") != 0 &&
kv.first.str.compare("heads") != 0 &&
kv.first.str.compare("node_row_ptr") != 0 &&
kv.first.str.compare("arg_nodes") != 0) {
g.attrs[kv.first.str] = kv.second;
}
}
return g;
}
JsonVal toJson() {
JsonVal val(MAP);
for(auto& kv : attrs) {
val.map[JsonVal(kv.first)] = kv.second;
}
std::map<Node*, int> nodeMap;
std::vector<Node*> sorted = topological_sort();
for(int i=sorted.size()-1; i>=0; i--) {
nodeMap[sorted[i]] = sorted.size()-1-i;
}
val.map[JsonVal("node_row_ptr")] = JsonVal(LIST);
JsonVal& node_row_ptr = val.map[JsonVal("node_row_ptr")];
for(int i=0; i<nodes.size(); i++)
node_row_ptr.list.push_back(JsonVal(i));
val.map[JsonVal("arg_nodes")] = JsonVal(LIST);
JsonVal& arg_nodes = val.map[JsonVal("arg_nodes")];
for(int i=0; i<inputs.size(); i++)
arg_nodes.list.push_back(JsonVal(nodeMap[inputs[i]]));
val.map[JsonVal("heads")] = JsonVal(LIST);
JsonVal& heads = val.map[JsonVal("heads")];
for(int i=0; i<outputs.size(); i++) {
heads.list.push_back(JsonVal(LIST));
JsonVal& out = heads.list[i];
out.list.push_back(JsonVal(nodeMap[outputs[i].node]));
out.list.push_back(JsonVal(outputs[i].entry));
out.list.push_back(JsonVal(0));
}
val.map[JsonVal("nodes")] = JsonVal(LIST);
JsonVal& nodes_ = val.map[JsonVal("nodes")];
for(int i=sorted.size()-1; i>=0; i--) {
nodes_.list.push_back(JsonVal(MAP));
Node* n = sorted[i];
JsonVal& n_ = nodes_.list[nodes_.list.size()-1];
n_.map[JsonVal("op")] = JsonVal(n->op);
n_.map[JsonVal("name")] = JsonVal(n->name);
n_.map[JsonVal("inputs")] = JsonVal(LIST);
JsonVal& inputs_ = n_.map[JsonVal("inputs")];
for(int j=0; j<n->inputs.size(); j++) {
inputs_.list.push_back(JsonVal(LIST));
NodeEntry& entry = n->inputs[j];
JsonVal& in = inputs_.list[j];
in.list.push_back(JsonVal(nodeMap[entry.node]));
in.list.push_back(JsonVal(entry.entry));
in.list.push_back(JsonVal(0));
}
n_.map[JsonVal("attrs")] = JsonVal(MAP);
JsonVal& attrs_ = n_.map[JsonVal("attrs")];
for(auto& kv : n->attrs) {
attrs_.map[JsonVal(kv.first)] = JsonVal(kv.second);
}
}
return val;
}
std::string toString() {
JsonParser parser;
return parser.dump(toJson());
}
void _dfs_util(Node* n, std::unordered_set<Node*>* to_visit,
std::function<void(Node*)> handler) {
to_visit->erase(n);
for(NodeEntry& e : n->outputs) {
Node* o = e.node;
if(to_visit->count(o) != 0) {
_dfs_util(o,to_visit,handler);
}
}
handler(n);
}
void DFS(std::function<void(Node*)> handler) {
std::unordered_set<Node*> to_visit;
//put all nodes in set to visit
for(auto& n : nodes)
to_visit.insert(n);
//visit all inputs first
for(auto& i : inputs)
if(to_visit.count(i) != 0)
_dfs_util(i, &to_visit, handler);
//visit any nodes left
while(to_visit.size() > 0)
_dfs_util(*(to_visit.begin()), &to_visit, handler);
}
std::vector<Node*> topological_sort() {
std::vector<Node*> sorted;
auto handler = [&](Node* n) {
sorted.push_back(n);
};
DFS(handler);
return sorted;
}
void print() {
std::cout << "########### Graph #############" << std::endl;
std::cout << "inputs: " << inputs.size() << std::endl;
std::cout << "outputs: " << outputs.size() << std::endl;
std::cout << "nodes: " << nodes.size() << std::endl;
std::vector<Node*> sorted;
auto handler = [&](Node* n) {
sorted.push_back(n);
};
DFS(handler);
for(int i=sorted.size()-1; i>=0; i--) {
std::cout << "Node: " << sorted[i]->name << std::endl;
for(int j=0; j<sorted[i]->inputs.size(); j++) {
std::cout << "\tInput: " << sorted[i]->inputs[j].node->name << " " << sorted[i]->inputs[j].entry << std::endl;
}
for(int j=0; j<sorted[i]->outputs.size(); j++) {
std::cout << "\tOutput: " << sorted[i]->outputs[j].node->name << " " << sorted[i]->outputs[j].entry << std::endl;
}
}
std::cout << "###############################" << std::endl;
}
std::vector<Node*> nodes;
std::vector<Node*> inputs;
std::vector<NodeEntry> outputs;
std::map<std::string, JsonVal> attrs;
};
/* \brief a basic pass that parses the input string to JSON and then dumps it back */
MXReturnValue graphPass(const std::string& in_graph, const std::string** out_graph,
const std::unordered_map<std::string, std::string>& options,
const std::unordered_map<std::string, MXTensor>& args,
const std::unordered_map<std::string, MXTensor>& aux,
const PassResource& res) {
//convert graph from JSON string to Graph/Node data structure
Graph g = Graph::fromString(in_graph);
//print initial graph
//g.print();
//create a new arg param
MXTensor* arg_ = res.alloc_arg("test_arg",{3,2},MXContext::CPU(0),kFloat32);
//find node with 'elemwise_add' op type
Node* add = nullptr;
for(Node* n : g.nodes)
if(n->op.compare("elemwise_add") == 0)
add = n;
//create a new input Node
Node* n = new Node();
n->name = "test_arg";
n->op = "null";
//add a new node in graph
g.nodes.push_back(n);
g.inputs.push_back(n);
//disconnect old input from add node
add->inputs[0].node->outputs.clear();
//disconnect add node from old input and connect add node to new input
add->inputs[0].node = n;
add->inputs[0].entry = 0;
//connect new input to add node
n->outputs.push_back({add,0});
//print modified graph
//g.print();
//convert back to JSON string from Graph/Node
*out_graph = new std::string(g.toString());
return MX_SUCCESS;
}
REGISTER_PASS(graphPass)
.setBody(graphPass);
MXReturnValue initialize(int version) {
if (version >= 10700) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
return MX_FAIL;
}
}
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 14944a5..8157bd4 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -1485,17 +1485,17 @@ class Symbol(SymbolBase):
assert isinstance(backend, str)
if args is None or len(args) == 0:
- args = []
+ args_ = []
args_handle = c_array(NDArrayHandle, [])
else:
- args_handle, args = self._get_ndarray_inputs('args', args,
+ args_handle, args_ = self._get_ndarray_inputs('args', args,
self.list_arguments(), False)
if aux is None or len(aux) == 0:
- aux = []
+ aux_ = []
aux_handle = c_array(NDArrayHandle, [])
else:
- aux_handle, aux = self._get_ndarray_inputs('aux_states', aux,
+ aux_handle, aux_ = self._get_ndarray_inputs('aux_states', aux,
self.list_auxiliary_states(), False)
if ctx is None:
ctx = current_context()
@@ -1517,9 +1517,9 @@ class Symbol(SymbolBase):
c_str(backend),
ctypes.c_int(ctx.device_typeid),
ctypes.byref(out),
- mx_uint(len(args)),
+ mx_uint(len(args_)),
args_handle,
- mx_uint(len(aux)),
+ mx_uint(len(aux_)),
aux_handle,
mx_uint(len(key_list)),
c_str_array(key_list),
import os, ctypes
import mxnet as mx
from mxnet.gluon import nn
from mxnet import nd
from mxnet.base import _LIB, check_call, mx_uint, c_str, c_str_array, SymbolHandle
# load library
if (os.name=='posix'):
path = os.path.abspath('libpass_lib.so')
mx.library.load(path)
elif (os.name=='nt'):
path = os.path.abspath('libpass_lib.dll')
mx.library.load(path)
###############################################
# Test with not consuming params
###############################################
# example model, ops do not have args (use outputs from other ops as inputs)
a = mx.sym.var('a')
b = mx.sym.var('b')
c = a + b
d = mx.sym.exp(c)
sym = mx.sym.log(d)
def test_graph():
# execute in MXNet
print('-------------------------------')
print('Testing regular MXNet execution')
exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
out = exe.forward()
print(out)
# Symbol optimize_for
# with propogating shapes/types
print('-------------------------------')
print('Testing graphPass with shapes/types')
args = {'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}
aux = {}
print(sym.tojson())
for a in args:
print('%s: %s' % (a,args[a].shape))
mysym2 = sym.optimize_for('graphPass',args,aux)
print(mysym2.tojson())
exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
out2 = exe2.forward()
print(out2)
test_graph()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment