Skip to content

Instantly share code, notes, and snippets.

@miyakelp
Last active December 14, 2022 20:45
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save miyakelp/7f7678055b72e85278018b0a3ce6f12a to your computer and use it in GitHub Desktop.
Save miyakelp/7f7678055b72e85278018b0a3ce6f12a to your computer and use it in GitHub Desktop.
scikit_learn_randomforest_to_verilog.py
def create_tree_function_wire(name, tree, num_skip = 2):
result_str = ''
skip = ' ' * num_skip
skip2 = skip * 2
skip3 = skip * 3
result_str += skip + "function [9:0] %s;\n" % (name)
result_str += skip2 + "input [783:0] pixel;\n"
result_str += "\n"
for node in range(len(tree.feature)):
if tree.feature[node] >= 0:
result_str += "%sreg node%d_r;\n" % (skip2, node)
result_str += "%sreg node%d_l;\n" % (skip2, node)
else:
result_str += "%sreg node%d;\n" % (skip2, node)
for i in range(10):
result_str += "%sreg result%d;\n" % (skip2, i)
result_str += "\n"
result_str += skip2 + "begin\n"
# あるノードに入る枝を辿って前段のノードを探すための配列を用意しておく
before_index_left = [-1] * len(tree.feature)
before_index_right = [-1] * len(tree.feature)
leaves = []
for node in range(len(tree.feature)):
conditions = []
if before_index_left[node] >= 0:
conditions.append("node" + str(before_index_left[node]) + "_l")
if before_index_right[node] >= 0:
conditions.append("node" + str(before_index_right[node]) + "_r")
if tree.feature[node] >= 0:
conditions.append("pixel[" + str(tree.feature[node]) + "]")
result_str += "%snode%d_r = %s;\n" % (skip3, node, " & ".join(conditions))
conditions[-1] = "~" + conditions[-1]
result_str += "%snode%d_l = %s;\n" % (skip3, node, " & ".join(conditions))
before_index_left[tree.children_left[node]] = node
before_index_right[tree.children_right[node]] = node
else:
leaves.append(node)
result_str += "%snode%d = %s;\n" % (skip3, node, " & ".join(conditions))
labels = np.argmax(tree.value.T, axis=0)[0]
output = []
for i in range(10):
output.append([])
for l in leaves:
result = labels[l]
output[result].append("node" + str(l))
for i in range(len(output)):
result_str += "%sresult%d = %s;\n" % (skip3, i, " | ".join(output[i]))
result_str += "\n"
result_str += "%s%s = {result9, result8, result7, result6, result5, result4, result3, result2, result1, result0};\n" % (skip3, name)
result_str += skip2 + "end\n"
result_str += skip + "endfunction\n"
return result_str
def create_rf_module(name, rf):
result_str = ''
skip = ' '
result_str += "module %s (input wire [783:0] image, output wire [3:0] result);\n" % name
for i in range(len(rf.estimators_)):
result_str += create_tree_function_wire("tree_%d" % i, rf.estimators_[i].tree_)
result_str += "\n"
for i in range(len(rf.estimators_)):
result_str += skip + "wire [9:0] res_tree_%d;\n" % i
result_str += skip + "assign res_tree_%d = tree_%i(image);\n" % (i, i)
result_str += "\n"
for i in range(10):
result_str += skip + "wire [7:0] res_sum_%d;\n" % i
result_str += skip + "assign res_sum_%d = " % i
t = []
for j in range(len(rf.estimators_)):
t.append("{7'b0000000, res_tree_%d[%d]}" % (j, i))
result_str += " + ".join(t) + ";\n"
result_str += "\n"
result_str += skip + "wire [7:0] winner_cnt_0vs1;\n"
result_str += skip + "wire [7:0] winner_cnt_2vs3;\n"
result_str += skip + "wire [7:0] winner_cnt_4vs5;\n"
result_str += skip + "wire [7:0] winner_cnt_6vs7;\n"
result_str += skip + "wire [7:0] winner_cnt_8vs9;\n"
result_str += skip + "wire [7:0] winner_cnt_01vs23;\n"
result_str += skip + "wire [7:0] winner_cnt_45vs67;\n"
result_str += skip + "wire [7:0] winner_cnt_0123vs4567;\n"
result_str += "\n"
result_str += skip + "wire [3:0] winner_0vs1;\n"
result_str += skip + "wire [3:0] winner_2vs3;\n"
result_str += skip + "wire [3:0] winner_4vs5;\n"
result_str += skip + "wire [3:0] winner_6vs7;\n"
result_str += skip + "wire [3:0] winner_8vs9;\n"
result_str += skip + "wire [3:0] winner_01vs23;\n"
result_str += skip + "wire [3:0] winner_45vs67;\n"
result_str += skip + "wire [3:0] winner_0123vs4567;\n"
result_str += skip + "wire [3:0] winner;\n"
result_str += "\n"
result_str += skip + "assign winner_cnt_0vs1 = res_sum_1 > res_sum_0 ? res_sum_1 : res_sum_0;\n"
result_str += skip + "assign winner_0vs1 = res_sum_1 > res_sum_0 ? 1 : 0;\n"
result_str += skip + "assign winner_cnt_2vs3 = res_sum_3 > res_sum_2 ? res_sum_3 : res_sum_2;\n"
result_str += skip + "assign winner_2vs3 = res_sum_3 > res_sum_2 ? 3 : 2;\n"
result_str += skip + "assign winner_cnt_4vs5 = res_sum_5 > res_sum_4 ? res_sum_5 : res_sum_4;\n"
result_str += skip + "assign winner_4vs5 = res_sum_5 > res_sum_4 ? 5 : 4;\n"
result_str += skip + "assign winner_cnt_6vs7 = res_sum_7 > res_sum_6 ? res_sum_7 : res_sum_6;\n"
result_str += skip + "assign winner_6vs7 = res_sum_7 > res_sum_6 ? 7 : 6;\n"
result_str += skip + "assign winner_cnt_8vs9 = res_sum_9 > res_sum_8 ? res_sum_9 : res_sum_8;\n"
result_str += skip + "assign winner_8vs9 = res_sum_9 > res_sum_8 ? 9 : 8;\n"
result_str += "\n"
result_str += skip + "assign winner_cnt_01vs23 = winner_cnt_2vs3 > winner_cnt_0vs1 ? winner_cnt_2vs3 : winner_cnt_0vs1;\n"
result_str += skip + "assign winner_01vs23 = winner_cnt_2vs3 > winner_cnt_0vs1 ? winner_2vs3 : winner_0vs1;\n"
result_str += skip + "assign winner_cnt_45vs67 = winner_cnt_6vs7 > winner_cnt_4vs5 ? winner_cnt_6vs7 : winner_cnt_4vs5;\n"
result_str += skip + "assign winner_45vs67 = winner_cnt_6vs7 > winner_cnt_4vs5 ? winner_6vs7 : winner_4vs5;\n"
result_str += "\n"
result_str += skip + "assign winner_cnt_0123vs4567 = winner_cnt_45vs67 > winner_cnt_01vs23 ? winner_cnt_45vs67 : winner_cnt_01vs23;\n"
result_str += skip + "assign winner_0123vs4567 = winner_cnt_45vs67 > winner_cnt_01vs23 ? winner_45vs67 : winner_01vs23;\n"
result_str += "\n"
result_str += skip + "assign winner = winner_cnt_8vs9 > winner_cnt_0123vs4567 ? winner_8vs9 : winner_0123vs4567;\n"
result_str += skip + "assign result = winner;\n\n"
result_str += "endmodule\n"
return result_str
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment