Skip to content

Instantly share code, notes, and snippets.

@yaronvel
Created October 17, 2023 10:27
Show Gist options
  • Save yaronvel/1202bc94935afdd267e026a468adfe0b to your computer and use it in GitHub Desktop.
Save yaronvel/1202bc94935afdd267e026a468adfe0b to your computer and use it in GitHub Desktop.
from __future__ import annotations
import sys
import numpy as np
import networkx as nx
import rich_click as click
import woke.ir as ir
import woke.ir.types as types
from rich import print
from woke.ir.statements.expression_statement import ExpressionStatement
from woke.ir.statements.abc import StatementAbc
from woke.ir.declarations.contract_definition import ContractDefinition, FunctionDefinition
from woke.ir.statements.block import Block
from woke.printers import Printer, printer
from woke.ir.expressions.function_call import FunctionCall
# importing matplotlib.pyplot
import matplotlib.pyplot as plt
class CfgStoragePrinter(Printer):
def print(self) -> None:
pass
def is_relevant_node(self, node):
if isinstance(node, StatementAbc):
return True
if isinstance(node, ContractDefinition):
return True
if isinstance(node, ContractDefinition):
return True
if isinstance(node, FunctionDefinition):
return True
if isinstance(node, FunctionCall):
if(node.kind == "functionCall"):
return isinstance(node.function_called, FunctionDefinition)
return False
def find_parents(self, node, relevant_nodes, graph):
while(node.parent):
if self.node_id(node.parent) in relevant_nodes:
if isinstance(node.parent, Block):
neighbors = [nn for nn in graph.neighbors(self.node_id(node.parent))]
if len(neighbors) == 0:
return [self.node_id(node.parent)]
# else
last_neighbor = neighbors[-1]
return self.find_terminal_nodes(last_neighbor, graph)
else:
return [self.node_id(node.parent)]
else:
return self.find_parents(node.parent, relevant_nodes, graph)
return ["null"]
def node_id(self, node):
return self.node_label(node)
if node:
return str(node.cu_hash) + ";" + node.source
else:
return "null"
def node_label(self, node):
if node:
return node.source
else:
return "null"
def find_terminal_nodes(self, node_name, graph):
# traverse all descendants and find nodes with out degree 0
descendants = nx.descendants(graph, node_name)
if len(descendants) == 0:
return [node_name]
terminals = [x for x in descendants if graph.out_degree(x)==0]
if len(terminals) == 0:
return [node_name]
return terminals
def process_node(self, node):
graph = nx.DiGraph()
relevant_nodes = []
additional_edges = []
labels = {}
storage_write_nodes = []
for new_node in node:
if not self.is_relevant_node(new_node):
continue
v1 = self.node_id(new_node)
V2 = self.find_parents(new_node, relevant_nodes, graph)
#graph.add_node(v1)
labels[v1] = self.node_label(new_node)
if isinstance(new_node, ExpressionStatement) or isinstance(new_node, FunctionCall):
if new_node.modifies_state:
storage_write_nodes.append(v1)
# todo check if change blockchain state
if isinstance(new_node, FunctionCall):
target = new_node.function_called
v3 = self.node_id(target)
additional_edges.append({"src": v1, "dst": v3})
for v2 in V2:
if not v1 == v2:
# not a self loop
graph.add_edge(v2, v1)
relevant_nodes.append(v1)
for new_edges in additional_edges:
graph.add_edge(new_edges["src"], new_edges["dst"])
pos = nx.spring_layout(graph, seed=3113794652) # positions for all nodes
options = {
"font_size": 36,
"node_size": 3000,
"node_color": "white",
"edgecolors": "black",
"linewidths": 5,
"width": 5,
}
H = graph
G = H.copy()
standard_complexity = self.complexity(H, "null", "end")
H = self.compress_graph(graph, storage_write_nodes, "null")
smart_complexity = self.complexity(H, "null", "end")
name = node.source_unit_name.split("/")[1]
print("complexity results:", name, "standard", standard_complexity, "smart", smart_complexity)
color_map = []
for node in G:
color = 'black'
if node in storage_write_nodes:
color = 'red'
elif node == "null":
color = 'green'
elif node == "end":
color = "black"
else:
color = 'blue'
color_map.append(color)
pos = nx.spring_layout(G)
nx.draw(G, pos = pos, node_color=color_map,)
plt.savefig(name + "_non_compressed.png")
plt.close()
color_map = []
for node in H:
color = 'black'
if node in storage_write_nodes:
color = 'red'
elif node == "null":
color = 'green'
elif node == "end":
color = "black"
else:
color = 'blue'
color_map.append(color)
#color_map = ['red' if node in storage_write_nodes elif node == "null" 'green' else 'blue' for node in H]
#nx.draw_networkx_labels(graph, pos)
#nx.draw_networkx(graph, node_color=color_map)
nx.draw(H, pos = pos, node_color=color_map)
pos = nx.spring_layout(G)
plt.savefig(name + "_compressed.png")
plt.close()
#plt.show()
def compress_graph(self, G, important_nodes, root):
compressed_once = False
for node in G:
should_keep_node = False
if G.in_degree(node) == 0:
should_keep_node = True
if G.out_degree(node) > 1:
# branching
should_keep_node = True
if node in important_nodes:
# storage writing
should_keep_node = True
if not should_keep_node:
for u,v in G.in_edges(node):
for n in G.neighbors(node):
G.add_edge(u, n)
G.remove_node(node)
return self.compress_graph(G, important_nodes, root)
return G
def add_terminal(self, G, root):
return G
terminal_node = "end"
#G.add_node(terminal_node)
nodes_list = [n for n in G]
for node in nodes_list:
if G.in_degree(node) == 0 and node != root and node != terminal_node:
G.remove_node(node)
#if G.out_degree(node) == 0:
# G.add_edge(node, terminal_node)
#G.add_edge(terminal_node, root)
return G
def complexity(self, G, root, terminal_node):
p = nx.number_strongly_connected_components(G)
e = G.number_of_edges()
v = G.number_of_nodes()
#print("P", p, "E", e, "V", v)
result = e - v + 2* p
#print("complexity", result)
return result
def visit_source_unit(self, node):
self.cntr = 0
return self.process_node(node)
@printer.command(name="cfg-storage")
def cli(self) -> None:
pass
'''
extra parsing:
for statement - body (and condition?)
while loop - body (and condition?)
if statement - true and false (and condition?)
assembly - ignore
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment