Created
September 14, 2017 16:54
-
-
Save TomAugspurger/94ee62127bbc8e20223f97ebd7d29191 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"import dask.dataframe as dd\n", | |
"\n", | |
"import matplotlib.pyplot as plt\n", | |
"import seaborn as sns" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Script to download the data: `download.py`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# %load download.py\n", | |
"\"\"\"\n", | |
"Download taxi data from S3 to local\n", | |
"\"\"\"\n", | |
"from pathlib import Path\n", | |
"import sys\n", | |
"import argparse\n", | |
"import s3fs\n", | |
"from distributed import Client, wait\n", | |
"\n", | |
"\n", | |
"def parse_args(args=None):\n", | |
" parser = argparse.ArgumentParser(description=__doc__)\n", | |
" parser.add_argument('-s', '--scheduler', default=None,\n", | |
" help='Scheduler address')\n", | |
" return parser.parse_args(args)\n", | |
"\n", | |
"\n", | |
"def fetch(key):\n", | |
" fs = s3fs.S3FileSystem(anon=True)\n", | |
" dest = Path('data').joinpath(Path(key).name)\n", | |
" dest.parent.mkdir(exist_ok=True)\n", | |
" fs.get(key, str(dest))\n", | |
" return key\n", | |
"\n", | |
"\n", | |
"def main(args=None):\n", | |
" args = parse_args(args)\n", | |
" client = Client(args.scheduler)\n", | |
" keys = [\n", | |
" f'nyc-tlc/trip data/yellow_tripdata_2009-{m:0>2}.csv'\n", | |
" for m in range(1, 13)\n", | |
" ]\n", | |
" results = client.map(fetch, keys)\n", | |
" wait(results)\n", | |
"\n", | |
"\n", | |
"if __name__ == '__main__':\n", | |
" sys.exit(main())\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"pd.options.display.max_rows = 10" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"-rw-r--r-- 1 taugspurger staff 2.4G Sep 9 06:14 data/yellow_tripdata_2009-01.csv\r\n", | |
"-rw-r--r-- 1 taugspurger staff 2.2G Sep 9 10:56 data/yellow_tripdata_2009-02.csv\r\n", | |
"-rw-r--r-- 1 taugspurger staff 2.4G Sep 9 11:03 data/yellow_tripdata_2009-03.csv\r\n", | |
"-rw-r--r-- 1 taugspurger staff 2.4G Sep 9 11:10 data/yellow_tripdata_2009-04.csv\r\n", | |
"-rw-r--r-- 1 taugspurger staff 2.5G Sep 9 11:17 data/yellow_tripdata_2009-05.csv\r\n", | |
"-rw-r--r-- 1 taugspurger staff 2.4G Sep 9 11:23 data/yellow_tripdata_2009-06.csv\r\n", | |
"-rw-r--r-- 1 taugspurger staff 2.3G Sep 9 11:30 data/yellow_tripdata_2009-07.csv\r\n", | |
"-rw-r--r-- 1 taugspurger staff 2.3G Sep 9 11:36 data/yellow_tripdata_2009-08.csv\r\n", | |
"-rw-r--r-- 1 taugspurger staff 2.4G Sep 9 11:44 data/yellow_tripdata_2009-09.csv\r\n", | |
"-rw-r--r-- 1 taugspurger staff 2.6G Sep 9 11:52 data/yellow_tripdata_2009-10.csv\r\n", | |
"-rw-r--r-- 1 taugspurger staff 2.4G Sep 9 11:59 data/yellow_tripdata_2009-11.csv\r\n", | |
"-rw-r--r-- 1 taugspurger staff 2.5G Sep 9 12:07 data/yellow_tripdata_2009-12.csv\r\n" | |
] | |
} | |
], | |
"source": [ | |
"ls -lh data/*.csv" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Load the first `DataFrame` into memory." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1min 6s, sys: 9.81 s, total: 1min 16s\n", | |
"Wall time: 1min 17s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"dtype = {\n", | |
" 'vendor_name': 'category',\n", | |
" 'Payment_Type': 'category',\n", | |
"}\n", | |
"\n", | |
"df = pd.read_csv(\"data/yellow_tripdata_2009-01.csv\", dtype=dtype,\n", | |
" parse_dates=['Trip_Pickup_DateTime', 'Trip_Dropoff_DateTime'],)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>vendor_name</th>\n", | |
" <th>Trip_Pickup_DateTime</th>\n", | |
" <th>Trip_Dropoff_DateTime</th>\n", | |
" <th>Passenger_Count</th>\n", | |
" <th>Trip_Distance</th>\n", | |
" <th>Start_Lon</th>\n", | |
" <th>Start_Lat</th>\n", | |
" <th>Rate_Code</th>\n", | |
" <th>store_and_forward</th>\n", | |
" <th>End_Lon</th>\n", | |
" <th>End_Lat</th>\n", | |
" <th>Payment_Type</th>\n", | |
" <th>Fare_Amt</th>\n", | |
" <th>surcharge</th>\n", | |
" <th>mta_tax</th>\n", | |
" <th>Tip_Amt</th>\n", | |
" <th>Tolls_Amt</th>\n", | |
" <th>Total_Amt</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>VTS</td>\n", | |
" <td>2009-01-04 02:52:00</td>\n", | |
" <td>2009-01-04 03:02:00</td>\n", | |
" <td>1</td>\n", | |
" <td>2.63</td>\n", | |
" <td>-73.991957</td>\n", | |
" <td>40.721567</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>-73.993803</td>\n", | |
" <td>40.695922</td>\n", | |
" <td>CASH</td>\n", | |
" <td>8.9</td>\n", | |
" <td>0.5</td>\n", | |
" <td>NaN</td>\n", | |
" <td>0.00</td>\n", | |
" <td>0.0</td>\n", | |
" <td>9.40</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>VTS</td>\n", | |
" <td>2009-01-04 03:31:00</td>\n", | |
" <td>2009-01-04 03:38:00</td>\n", | |
" <td>3</td>\n", | |
" <td>4.55</td>\n", | |
" <td>-73.982102</td>\n", | |
" <td>40.736290</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>-73.955850</td>\n", | |
" <td>40.768030</td>\n", | |
" <td>Credit</td>\n", | |
" <td>12.1</td>\n", | |
" <td>0.5</td>\n", | |
" <td>NaN</td>\n", | |
" <td>2.00</td>\n", | |
" <td>0.0</td>\n", | |
" <td>14.60</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>VTS</td>\n", | |
" <td>2009-01-03 15:43:00</td>\n", | |
" <td>2009-01-03 15:57:00</td>\n", | |
" <td>5</td>\n", | |
" <td>10.35</td>\n", | |
" <td>-74.002587</td>\n", | |
" <td>40.739748</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>-73.869983</td>\n", | |
" <td>40.770225</td>\n", | |
" <td>Credit</td>\n", | |
" <td>23.7</td>\n", | |
" <td>0.0</td>\n", | |
" <td>NaN</td>\n", | |
" <td>4.74</td>\n", | |
" <td>0.0</td>\n", | |
" <td>28.44</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>DDS</td>\n", | |
" <td>2009-01-01 20:52:58</td>\n", | |
" <td>2009-01-01 21:14:00</td>\n", | |
" <td>1</td>\n", | |
" <td>5.00</td>\n", | |
" <td>-73.974267</td>\n", | |
" <td>40.790955</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>-73.996558</td>\n", | |
" <td>40.731849</td>\n", | |
" <td>CREDIT</td>\n", | |
" <td>14.9</td>\n", | |
" <td>0.5</td>\n", | |
" <td>NaN</td>\n", | |
" <td>3.05</td>\n", | |
" <td>0.0</td>\n", | |
" <td>18.45</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>DDS</td>\n", | |
" <td>2009-01-24 16:18:23</td>\n", | |
" <td>2009-01-24 16:24:56</td>\n", | |
" <td>1</td>\n", | |
" <td>0.40</td>\n", | |
" <td>-74.001580</td>\n", | |
" <td>40.719382</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>-74.008378</td>\n", | |
" <td>40.720350</td>\n", | |
" <td>CASH</td>\n", | |
" <td>3.7</td>\n", | |
" <td>0.0</td>\n", | |
" <td>NaN</td>\n", | |
" <td>0.00</td>\n", | |
" <td>0.0</td>\n", | |
" <td>3.70</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" vendor_name Trip_Pickup_DateTime Trip_Dropoff_DateTime Passenger_Count \\\n", | |
"0 VTS 2009-01-04 02:52:00 2009-01-04 03:02:00 1 \n", | |
"1 VTS 2009-01-04 03:31:00 2009-01-04 03:38:00 3 \n", | |
"2 VTS 2009-01-03 15:43:00 2009-01-03 15:57:00 5 \n", | |
"3 DDS 2009-01-01 20:52:58 2009-01-01 21:14:00 1 \n", | |
"4 DDS 2009-01-24 16:18:23 2009-01-24 16:24:56 1 \n", | |
"\n", | |
" Trip_Distance Start_Lon Start_Lat Rate_Code store_and_forward \\\n", | |
"0 2.63 -73.991957 40.721567 NaN NaN \n", | |
"1 4.55 -73.982102 40.736290 NaN NaN \n", | |
"2 10.35 -74.002587 40.739748 NaN NaN \n", | |
"3 5.00 -73.974267 40.790955 NaN NaN \n", | |
"4 0.40 -74.001580 40.719382 NaN NaN \n", | |
"\n", | |
" End_Lon End_Lat Payment_Type Fare_Amt surcharge mta_tax Tip_Amt \\\n", | |
"0 -73.993803 40.695922 CASH 8.9 0.5 NaN 0.00 \n", | |
"1 -73.955850 40.768030 Credit 12.1 0.5 NaN 2.00 \n", | |
"2 -73.869983 40.770225 Credit 23.7 0.0 NaN 4.74 \n", | |
"3 -73.996558 40.731849 CREDIT 14.9 0.5 NaN 3.05 \n", | |
"4 -74.008378 40.720350 CASH 3.7 0.0 NaN 0.00 \n", | |
"\n", | |
" Tolls_Amt Total_Amt \n", | |
"0 0.0 9.40 \n", | |
"1 0.0 14.60 \n", | |
"2 0.0 28.44 \n", | |
"3 0.0 18.45 \n", | |
"4 0.0 3.70 " | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df.head()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's predict whether or not the person tips." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"X = df.drop(\"Tip_Amt\", axis=1)\n", | |
"y = df['Tip_Amt'] > 0" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We're in-memory, so all this is normal." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.model_selection import train_test_split" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"X_train, X_test, y_train, y_test = train_test_split(X, y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"10569309" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(X_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"3523104" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(X_test)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"I notice that there are some minor differences in the spelling on \"Payment Type\":" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Index(['CASH', 'CREDIT', 'Cash', 'Credit', 'Dispute', 'No Charge'], dtype='object')" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df.Payment_Type.cat.categories" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We'll consolidate those by just lower-casing them:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0 cash\n", | |
"1 credit\n", | |
"2 credit\n", | |
"3 credit\n", | |
"4 cash\n", | |
" ... \n", | |
"14092408 cash\n", | |
"14092409 credit\n", | |
"14092410 cash\n", | |
"14092411 cash\n", | |
"14092412 credit\n", | |
"Name: Payment_Type, Length: 14092413, dtype: object" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df.Payment_Type.str.lower()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"And since we're good sci-kittens, we'll package all this up in a pipeline." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.base import BaseEstimator, TransformerMixin\n", | |
"from sklearn.pipeline import make_pipeline\n", | |
"from sklearn.preprocessing import FunctionTransformer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class ColumnSelector(TransformerMixin):\n", | |
" \"Select `columns` from `X`\"\n", | |
" def __init__(self, columns):\n", | |
" self.columns = columns\n", | |
"\n", | |
" def fit(self, X, y=None):\n", | |
" return self\n", | |
"\n", | |
" def transform(self, X, y=None):\n", | |
" return X[self.columns]\n", | |
" \n", | |
"\n", | |
"class HourExtractor(TransformerMixin):\n", | |
" \"Transform each datetime64 column in `columns` to integer hours\"\n", | |
" def __init__(self, columns):\n", | |
" self.columns = columns\n", | |
"\n", | |
" def fit(self, X, y=None):\n", | |
" return self\n", | |
"\n", | |
" def transform(self, X, y=None):\n", | |
" return X.assign(**{col: lambda x: x[col].dt.hour for col in self.columns})\n", | |
"\n", | |
"\n", | |
"def payment_lowerer(X):\n", | |
" \"\"\"Lowercase all the Payment_Type values\"\"\"\n", | |
" return X.assign(Payment_Type=X.Payment_Type.str.lower())\n", | |
"\n", | |
"\n", | |
"class CategoricalEncoder(TransformerMixin):\n", | |
" \"\"\"Convert to Categorical with specific `categories`\"\"\"\n", | |
" def __init__(self, categories):\n", | |
" self.categories = categories\n", | |
" \n", | |
" def fit(self, X, y=None):\n", | |
" return self\n", | |
" \n", | |
" def transform(self, X, y=None):\n", | |
" for col, categories in self.categories.items():\n", | |
" X[col] = X[col].astype('category').cat.set_categories(categories)\n", | |
" return X\n", | |
" \n", | |
"class StandardScaler(TransformerMixin):\n", | |
" \"Scale a subset of the columns in a DataFrame\"\n", | |
" def __init__(self, columns):\n", | |
" self.columns = columns\n", | |
" \n", | |
" def fit(self, X, y=None):\n", | |
" self.μs = X[self.columns].mean()\n", | |
" self.σs = X[self.columns].std()\n", | |
" return self\n", | |
"\n", | |
" def transform(self, X, y=None):\n", | |
" X = X.copy()\n", | |
" X[self.columns] = X[self.columns].sub(self.μs).div(self.σs)\n", | |
" return X" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Pipeline(memory=None,\n", | |
" steps=[('columnselector', <__main__.ColumnSelector object at 0x1561b2dd8>), ('hourextractor', <__main__.HourExtractor object at 0x1561b2278>), ('functiontransformer-1', FunctionTransformer(accept_sparse=False,\n", | |
" func=<function payment_lowerer at 0x1f3508e18>, inv_kw_args=None,\n", | |
" inve...ty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", | |
" verbose=0, warm_start=False))])" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# The columns at the start of the pipeline\n", | |
"columns = ['vendor_name', 'Trip_Pickup_DateTime',\n", | |
" 'Passenger_Count', 'Trip_Distance',\n", | |
" 'Payment_Type', 'Fare_Amt', 'surcharge']\n", | |
"\n", | |
"# The mapping of {column: set of categories}\n", | |
"categories = {\n", | |
" 'vendor_name': ['CMT', 'DDS', 'VTS'],\n", | |
" 'Payment_Type': ['cash', 'credit', 'dispute', 'no charge'],\n", | |
"}\n", | |
"\n", | |
"scale = ['Trip_Distance', 'Fare_Amt', 'surcharge']\n", | |
"\n", | |
"pipe = make_pipeline(\n", | |
" ColumnSelector(columns),\n", | |
" HourExtractor(['Trip_Pickup_DateTime']),\n", | |
" FunctionTransformer(payment_lowerer, validate=False),\n", | |
" CategoricalEncoder(categories),\n", | |
" FunctionTransformer(pd.get_dummies, validate=False),\n", | |
" StandardScaler(scale),\n", | |
" LogisticRegression(),\n", | |
")\n", | |
"pipe" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[('columnselector', <__main__.ColumnSelector at 0x1561b2dd8>),\n", | |
" ('hourextractor', <__main__.HourExtractor at 0x1561b2278>),\n", | |
" ('functiontransformer-1', FunctionTransformer(accept_sparse=False,\n", | |
" func=<function payment_lowerer at 0x1f3508e18>, inv_kw_args=None,\n", | |
" inverse_func=None, kw_args=None, pass_y='deprecated',\n", | |
" validate=False)),\n", | |
" ('categoricalencoder', <__main__.CategoricalEncoder at 0x1561b2668>),\n", | |
" ('functiontransformer-2', FunctionTransformer(accept_sparse=False,\n", | |
" func=<function get_dummies at 0x111a19e18>, inv_kw_args=None,\n", | |
" inverse_func=None, kw_args=None, pass_y='deprecated',\n", | |
" validate=False)),\n", | |
" ('standardscaler', <__main__.StandardScaler at 0x1561b2198>),\n", | |
" ('logisticregression',\n", | |
" LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", | |
" intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,\n", | |
" penalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", | |
" verbose=0, warm_start=False))]" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pipe.steps" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 59.6 s, sys: 5.81 s, total: 1min 5s\n", | |
"Wall time: 1min 7s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"Pipeline(memory=None,\n", | |
" steps=[('columnselector', <__main__.ColumnSelector object at 0x1561b2dd8>), ('hourextractor', <__main__.HourExtractor object at 0x1561b2278>), ('functiontransformer-1', FunctionTransformer(accept_sparse=False,\n", | |
" func=<function payment_lowerer at 0x1f3508e18>, inv_kw_args=None,\n", | |
" inve...ty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", | |
" verbose=0, warm_start=False))])" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%time pipe.fit(X_train, y_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.99314373342666018" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pipe.score(X_train, y_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.99316284730737436" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pipe.score(X_test, y_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def mkpipe():\n", | |
" pipe = make_pipeline(\n", | |
" ColumnSelector(columns),\n", | |
" HourExtractor(['Trip_Pickup_DateTime']),\n", | |
" FunctionTransformer(payment_lowerer, validate=False),\n", | |
" CategoricalEncoder(categories),\n", | |
" FunctionTransformer(pd.get_dummies, validate=False),\n", | |
" StandardScaler(scale),\n", | |
" LogisticRegression(),\n", | |
" )\n", | |
" return pipe" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Scaling it Out" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import dask.dataframe as dd" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"%%time\n", | |
"df = dd.read_csv(\"data/*.csv\", dtype=dtype,\n", | |
" parse_dates=['Trip_Pickup_DateTime', 'Trip_Dropoff_DateTime'],)\n", | |
"\n", | |
"X = df.drop(\"Tip_Amt\", axis=1)\n", | |
"y = df['Tip_Amt'] > 0" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Since the scikit-learn world isn't really \"dask-aware\" at the moment, we'll use the `map_partitions` method. This is a good escape hatch for dealing with non-daskified code." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"yhat = X.map_partitions(lambda x: pd.Series(pipe.predict_proba(x)[:, 1], name='yhat'),\n", | |
" meta=('yhat', 'f8'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 17min 52s, sys: 2min 35s, total: 20min 27s\n", | |
"Wall time: 8min 49s\n" | |
] | |
} | |
], | |
"source": [ | |
"%time yhat.to_frame().to_parquet(\"data/predictions.parq\")" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment