Skip to content

Instantly share code, notes, and snippets.

@armgilles
Created August 31, 2016 14:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save armgilles/e2ce207283813d0f91f60d870df416be to your computer and use it in GitHub Desktop.
Save armgilles/e2ce207283813d0f91f60d870df416be to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Testing https://github.com/rushter/heamy\n",
"\n",
"import logging\n",
"import numpy as np\n",
"\n",
"from sklearn.cross_validation import train_test_split\n",
"from sklearn.datasets import load_boston\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.linear_model import LinearRegression, LogisticRegression\n",
"from sklearn.metrics import mean_absolute_error\n",
"from sklearn.neighbors import KNeighborsRegressor\n",
"from xgboost import XGBClassifier\n",
"\n",
"from heamy.dataset import Dataset\n",
"from heamy.estimator import Regressor, Classifier\n",
"from heamy.pipeline import ModelsPipeline\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"%matplotlib inline\n",
"\n",
"np.set_printoptions(precision=6)\n",
"np.set_printoptions(suppress=True)\n",
"\n",
"np.random.seed(1000)\n",
"logging.basicConfig(level=logging.DEBUG)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# load boston dataset from sklearn\n",
"data = load_boston()\n",
"# ALL the dataset\n",
"X, y = data['data'], data['target']\n",
"\n",
"# Train_split with 0.2 on Validation\n",
"X_learning, X_validation, y_learning, y_validation = train_test_split(X, y, test_size=0.2, random_state=111)\n",
"\n",
"# Split on full\n",
"X_train, X_test, y_train, y_test = train_test_split(X_learning, y_learning, test_size=0.25, random_state=111)\n",
"\n",
"# create dataset\n",
"dataset = Dataset(X_train,y_train,X_test)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Metric: mean_absolute_error\n",
"Folds accuracy: [2.7613431018613728, 2.532476431230593, 2.2803574011792391, 2.2500467215873949, 2.0561370636921401, 3.1650090591301701, 2.2334568908274335, 1.6800618137128551, 1.8328443836484318, 1.8852668227858682]\n",
"Mean accuracy: 2.26769996897\n",
"Standard Deviation: 0.4296545255\n",
"Variance: 0.184603011283\n"
]
}
],
"source": [
"# initialize RandomForest & LinearRegression \n",
"model_rf = Regressor(dataset=dataset, estimator=RandomForestRegressor, parameters={'n_estimators': 50},name='rf')\n",
"model_lr = Regressor(dataset=dataset, estimator=LinearRegression, parameters={'normalize': True},name='lr')\n",
"\n",
"# Stack two models \n",
"# Returns new dataset with out-of-fold predictions\n",
"pipeline = ModelsPipeline(model_rf,model_lr)\n",
"stack_ds = pipeline.stack(k=10,seed=111)\n",
"\n",
"# Train LinearRegression on stacked data (second stage)\n",
"stacker = Regressor(dataset=stack_ds, estimator=LinearRegression)\n",
"results = stacker.predict()\n",
"# Validate results using 10 fold cross-validation\n",
"results = stacker.validate(k=10,scorer=mean_absolute_error)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Now I want to predict the result on my validation dataset (X_validation)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"mean_absolute_error(y_validation, predict(X_validation))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment