Skip to content

Instantly share code, notes, and snippets.

@jay-trivedi
Last active June 12, 2017 11:09
Show Gist options
  • Save jay-trivedi/2e10a1547bd5967797f8b891884115d0 to your computer and use it in GitHub Desktop.
Save jay-trivedi/2e10a1547bd5967797f8b891884115d0 to your computer and use it in GitHub Desktop.
#Random Forest
seed = 1
depth_range = range(1, 30,1)
acc_vs_depth_result_rf = {"depth": [],\
"train_acc": [],
"valid_acc": [],
"top_feature": [],
"second_feature": [],
"third_feature": []}
for depth in depth_range:
model = H2ORandomForestEstimator(model_id="model", \
sample_rate=1, \
ntrees=200, \
max_depth=depth, \
seed=seed)
model.train(x=x, y=y, training_frame=train)
predict_valid = model.predict(valid[x])
predict_train = model.predict(train[x])
acc_vs_depth_result_rf["depth"].append(depth)
acc_vs_depth_result_rf["valid_acc"].append((predict_valid["predict"] == valid["Survived"]).mean()[0])
acc_vs_depth_result_rf["train_acc"].append((predict_train["predict"] == train["Survived"]).mean()[0])
acc_vs_depth_result_rf["top_feature"].append(model.varimp()[0][0])
acc_vs_depth_result_rf["second_feature"].append(model.varimp()[1][0])
acc_vs_depth_result_rf["third_feature"].append(model.varimp()[2][0])
#Converting Results to DataFrame
acc_vs_depth_result_df_rf = pd.DataFrame(acc_vs_depth_result_rf)
cols = ["depth", "train_acc", "valid_acc", "top_feature", "second_feature", "third_feature"]
acc_vs_depth_result_df_rf = acc_vs_depth_result_df_rf[cols]
acc_vs_depth_result_df_rf
#Plotting results
fig = plt.figure(figsize=(10, 7))
plt.plot(acc_vs_depth_result_df_rf.depth, acc_vs_depth_result_df_rf.train_acc, label="train accuracy (RF)")
plt.plot(acc_vs_depth_result_df_rf.depth, acc_vs_depth_result_df_rf.valid_acc, label="validation accuracy (RF)")
plt.plot(acc_vs_depth_result_df.depth, acc_vs_depth_result_df.train_acc, label="train accuracy (DT)")
plt.plot(acc_vs_depth_result_df.depth, acc_vs_depth_result_df.valid_acc, label="validation accuracy (DT)")
plt.legend(loc='upper left', frameon=False)
plt.xlabel('Tree Depth')
plt.ylabel('Accuracy')
plt.savefig("figures/Titanic_RF")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment