Created
October 31, 2017 01:21
-
-
Save evanfrisch/06098b4c2dd31bba1c2bb571600b12a1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def predict_from_standalone_rf(self,train,test,valid,x,y,prediction_field_name): | |
"""Produces an H2O dataframe containing a field with predictions from random forest model. | |
:param train: the training H2O dataframe | |
:param test: the testing H2O dataframe | |
:param valid: the validation H2O dataframe | |
:param x: the feature variables | |
:param y: the target variable | |
:param prediction_field_name: the name to use for field to contain predictions | |
:returns: the H2O dataframe with prediction field and all fields from the supplied dataframe | |
""" | |
print("Random Forest") | |
rf_standalone = H2ORandomForestEstimator( | |
model_id="rf", | |
stopping_rounds=2, | |
score_each_iteration=True, | |
sample_rate=.7, | |
col_sample_rate_per_tree=.7, | |
max_depth=16, | |
ntrees=500, | |
nfolds=5, | |
seed=1000000) | |
rf_standalone.train(x=x, y=y, training_frame=train, validation_frame=valid) | |
print("Model type:", rf_standalone.type) | |
if(rf_standalone.type == "classifier"): | |
print("train[y].levels():",train[y].levels()[0]) | |
y_level_count = train[y].nlevels()[0] | |
print("y_level_count:", y_level_count) | |
if(y_level_count <= 2): | |
print("AUC (training):", rf_standalone.auc(train=True)) | |
print("AUC (validation):", rf_standalone.auc(valid=True)) | |
else: | |
print("Confusion Matrix (validation):", rf_standalone.confusion_matrix(data=valid)) | |
print("Hit Ratio Table (validation):", rf_standalone.hit_ratio_table(valid=True)) | |
else: | |
print("r2 (training):", rf_standalone.r2(train=True)) | |
print("r2 (validation):", rf_standalone.r2(valid=True)) | |
rf_standalone_predictions = rf_standalone.predict(test) | |
print("Random Forest Variable Importances") | |
print(rf_standalone._model_json['output']['variable_importances'].as_data_frame()) | |
print("Best Random Forest Predictions:") | |
print(rf_standalone_predictions.head(rows=5)) | |
return(self.set_prediction_field_name(rf_standalone_predictions,prediction_field_name)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment