Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save evanfrisch/06098b4c2dd31bba1c2bb571600b12a1 to your computer and use it in GitHub Desktop.
Save evanfrisch/06098b4c2dd31bba1c2bb571600b12a1 to your computer and use it in GitHub Desktop.
def predict_from_standalone_rf(self,train,test,valid,x,y,prediction_field_name):
"""Produces an H2O dataframe containing a field with predictions from random forest model.
:param train: the training H2O dataframe
:param test: the testing H2O dataframe
:param valid: the validation H2O dataframe
:param x: the feature variables
:param y: the target variable
:param prediction_field_name: the name to use for field to contain predictions
:returns: the H2O dataframe with prediction field and all fields from the supplied dataframe
"""
print("Random Forest")
rf_standalone = H2ORandomForestEstimator(
model_id="rf",
stopping_rounds=2,
score_each_iteration=True,
sample_rate=.7,
col_sample_rate_per_tree=.7,
max_depth=16,
ntrees=500,
nfolds=5,
seed=1000000)
rf_standalone.train(x=x, y=y, training_frame=train, validation_frame=valid)
print("Model type:", rf_standalone.type)
if(rf_standalone.type == "classifier"):
print("train[y].levels():",train[y].levels()[0])
y_level_count = train[y].nlevels()[0]
print("y_level_count:", y_level_count)
if(y_level_count <= 2):
print("AUC (training):", rf_standalone.auc(train=True))
print("AUC (validation):", rf_standalone.auc(valid=True))
else:
print("Confusion Matrix (validation):", rf_standalone.confusion_matrix(data=valid))
print("Hit Ratio Table (validation):", rf_standalone.hit_ratio_table(valid=True))
else:
print("r2 (training):", rf_standalone.r2(train=True))
print("r2 (validation):", rf_standalone.r2(valid=True))
rf_standalone_predictions = rf_standalone.predict(test)
print("Random Forest Variable Importances")
print(rf_standalone._model_json['output']['variable_importances'].as_data_frame())
print("Best Random Forest Predictions:")
print(rf_standalone_predictions.head(rows=5))
return(self.set_prediction_field_name(rf_standalone_predictions,prediction_field_name))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment