Skip to content

Instantly share code, notes, and snippets.

@barusan
Created February 28, 2018 13:13
Show Gist options
  • Save barusan/696d1579ba2887871a68efeeacc71ab1 to your computer and use it in GitHub Desktop.
Save barusan/696d1579ba2887871a68efeeacc71ab1 to your computer and use it in GitHub Desktop.
2-class 2-D random forest sample with iris dataset
#!/usr/bin/env python3
#-*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import export_graphviz
export_dot = True
iris = load_iris()
mask = iris.target != 2
X = iris.data[mask][:,[1,0]]
Y = iris.target[mask]
xmin, ymin = np.min(X, axis=0) - 0.1
xmax, ymax = np.max(X, axis=0) + 0.1
clf = RandomForestClassifier(n_estimators=4)
clf.fit(X, Y)
fig, axes = plt.subplots(2, 2)
plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)
for i in range(4):
ax = axes[i//2, i%2]
E = clf.estimators_[i]
if export_dot:
export_graphviz(E, out_file="tree%d.dot" % i)
# run
# $ for file in *.dot; do dot -Tpdf $file -o $file.pdf; done
# on a bash to generate tree visualization
T = E.tree_
n_nodes = T.node_count
children_left = T.children_left
children_right = T.children_right
feature = T.feature
threshold = T.threshold
values = T.value
node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
def plot_boundary(node_id, xmi, xma, ymi, yma):
if (children_left[node_id] != children_right[node_id]):
thres = threshold[node_id]
if feature[node_id] == 0:
ax.plot([thres, thres], [ymi, yma], "g-")
plot_boundary(children_left[node_id], xmi, thres, ymi, yma)
plot_boundary(children_right[node_id], thres, xma, ymi, yma)
elif feature[node_id] == 1:
ax.plot([xmi, xma], [thres, thres], "g-")
plot_boundary(children_left[node_id], xmi, xma, ymi, thres)
plot_boundary(children_right[node_id], xmi, xma, thres, yma)
else:
raise Exception
else:
color = "r" if np.argmax(values[node_id]) == 1 else "b"
ax.add_patch(
plt.Rectangle(xy=[xmi, ymi], width=xma-xmi, height=yma-ymi,
linewidth=0, alpha=0.2, facecolor=color)
)
plot_boundary(0, xmin, xmax, ymin, ymax)
ax.plot(X[Y==0,0], X[Y==0,1], "b_")
ax.plot(X[Y==1,0], X[Y==1,1], "r+")
ax.set_xlim((xmin, xmax))
ax.set_ylim((ymin, ymax))
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment