Skip to content

Instantly share code, notes, and snippets.

@TomAugspurger
Created February 27, 2018 19:00
Show Gist options
  • Save TomAugspurger/0648ac556eccb216e29eb14d07741ec0 to your computer and use it in GitHub Desktop.
Save TomAugspurger/0648ac556eccb216e29eb14d07741ec0 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import dask.dataframe as dd\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"\n",
"pd.options.display.max_rows = 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Script to download the data: `download.py`"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load download.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load the first `DataFrame` into memory."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1min 7s, sys: 8.77 s, total: 1min 16s\n",
"Wall time: 1min 17s\n"
]
}
],
"source": [
"%%time\n",
"dtype = {\n",
" 'vendor_name': 'category',\n",
" 'Payment_Type': 'category',\n",
"}\n",
"\n",
"df = pd.read_csv(\"data/yellow_tripdata_2009-01.csv\", dtype=dtype,\n",
" parse_dates=['Trip_Pickup_DateTime', 'Trip_Dropoff_DateTime'],)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>vendor_name</th>\n",
" <th>Trip_Pickup_DateTime</th>\n",
" <th>Trip_Dropoff_DateTime</th>\n",
" <th>Passenger_Count</th>\n",
" <th>Trip_Distance</th>\n",
" <th>Start_Lon</th>\n",
" <th>Start_Lat</th>\n",
" <th>Rate_Code</th>\n",
" <th>store_and_forward</th>\n",
" <th>End_Lon</th>\n",
" <th>End_Lat</th>\n",
" <th>Payment_Type</th>\n",
" <th>Fare_Amt</th>\n",
" <th>surcharge</th>\n",
" <th>mta_tax</th>\n",
" <th>Tip_Amt</th>\n",
" <th>Tolls_Amt</th>\n",
" <th>Total_Amt</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>VTS</td>\n",
" <td>2009-01-04 02:52:00</td>\n",
" <td>2009-01-04 03:02:00</td>\n",
" <td>1</td>\n",
" <td>2.63</td>\n",
" <td>-73.991957</td>\n",
" <td>40.721567</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>-73.993803</td>\n",
" <td>40.695922</td>\n",
" <td>CASH</td>\n",
" <td>8.9</td>\n",
" <td>0.5</td>\n",
" <td>NaN</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>9.40</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>VTS</td>\n",
" <td>2009-01-04 03:31:00</td>\n",
" <td>2009-01-04 03:38:00</td>\n",
" <td>3</td>\n",
" <td>4.55</td>\n",
" <td>-73.982102</td>\n",
" <td>40.736290</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>-73.955850</td>\n",
" <td>40.768030</td>\n",
" <td>Credit</td>\n",
" <td>12.1</td>\n",
" <td>0.5</td>\n",
" <td>NaN</td>\n",
" <td>2.00</td>\n",
" <td>0.0</td>\n",
" <td>14.60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>VTS</td>\n",
" <td>2009-01-03 15:43:00</td>\n",
" <td>2009-01-03 15:57:00</td>\n",
" <td>5</td>\n",
" <td>10.35</td>\n",
" <td>-74.002587</td>\n",
" <td>40.739748</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>-73.869983</td>\n",
" <td>40.770225</td>\n",
" <td>Credit</td>\n",
" <td>23.7</td>\n",
" <td>0.0</td>\n",
" <td>NaN</td>\n",
" <td>4.74</td>\n",
" <td>0.0</td>\n",
" <td>28.44</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>DDS</td>\n",
" <td>2009-01-01 20:52:58</td>\n",
" <td>2009-01-01 21:14:00</td>\n",
" <td>1</td>\n",
" <td>5.00</td>\n",
" <td>-73.974267</td>\n",
" <td>40.790955</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>-73.996558</td>\n",
" <td>40.731849</td>\n",
" <td>CREDIT</td>\n",
" <td>14.9</td>\n",
" <td>0.5</td>\n",
" <td>NaN</td>\n",
" <td>3.05</td>\n",
" <td>0.0</td>\n",
" <td>18.45</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>DDS</td>\n",
" <td>2009-01-24 16:18:23</td>\n",
" <td>2009-01-24 16:24:56</td>\n",
" <td>1</td>\n",
" <td>0.40</td>\n",
" <td>-74.001580</td>\n",
" <td>40.719382</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>-74.008378</td>\n",
" <td>40.720350</td>\n",
" <td>CASH</td>\n",
" <td>3.7</td>\n",
" <td>0.0</td>\n",
" <td>NaN</td>\n",
" <td>0.00</td>\n",
" <td>0.0</td>\n",
" <td>3.70</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" vendor_name Trip_Pickup_DateTime Trip_Dropoff_DateTime Passenger_Count \\\n",
"0 VTS 2009-01-04 02:52:00 2009-01-04 03:02:00 1 \n",
"1 VTS 2009-01-04 03:31:00 2009-01-04 03:38:00 3 \n",
"2 VTS 2009-01-03 15:43:00 2009-01-03 15:57:00 5 \n",
"3 DDS 2009-01-01 20:52:58 2009-01-01 21:14:00 1 \n",
"4 DDS 2009-01-24 16:18:23 2009-01-24 16:24:56 1 \n",
"\n",
" Trip_Distance Start_Lon Start_Lat Rate_Code store_and_forward \\\n",
"0 2.63 -73.991957 40.721567 NaN NaN \n",
"1 4.55 -73.982102 40.736290 NaN NaN \n",
"2 10.35 -74.002587 40.739748 NaN NaN \n",
"3 5.00 -73.974267 40.790955 NaN NaN \n",
"4 0.40 -74.001580 40.719382 NaN NaN \n",
"\n",
" End_Lon End_Lat Payment_Type Fare_Amt surcharge mta_tax Tip_Amt \\\n",
"0 -73.993803 40.695922 CASH 8.9 0.5 NaN 0.00 \n",
"1 -73.955850 40.768030 Credit 12.1 0.5 NaN 2.00 \n",
"2 -73.869983 40.770225 Credit 23.7 0.0 NaN 4.74 \n",
"3 -73.996558 40.731849 CREDIT 14.9 0.5 NaN 3.05 \n",
"4 -74.008378 40.720350 CASH 3.7 0.0 NaN 0.00 \n",
"\n",
" Tolls_Amt Total_Amt \n",
"0 0.0 9.40 \n",
"1 0.0 14.60 \n",
"2 0.0 28.44 \n",
"3 0.0 18.45 \n",
"4 0.0 3.70 "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's predict whether or not the person tips."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"X = df.drop(\"Tip_Amt\", axis=1)\n",
"y = df['Tip_Amt'] > 0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We're in-memory, so all this is normal."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10569309"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(X_train)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3523104"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I notice that there are some minor differences in the spelling on \"Payment Type\":"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['CASH', 'CREDIT', 'Cash', 'Credit', 'Dispute', 'No Charge'], dtype='object')"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.Payment_Type.cat.categories"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll consolidate those by just lower-casing them:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 cash\n",
"1 credit\n",
"2 credit\n",
"3 credit\n",
"4 cash\n",
" ... \n",
"14092408 cash\n",
"14092409 credit\n",
"14092410 cash\n",
"14092411 cash\n",
"14092412 credit\n",
"Name: Payment_Type, Length: 14092413, dtype: object"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.Payment_Type.str.lower()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And since we're good sci-kittens, we'll package all this up in a pipeline."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import FunctionTransformer\n",
"\n",
"from pandas.api.types import CategoricalDtype\n",
"\n",
"from dask_ml.linear_model import LogisticRegression\n",
"from dask_ml.preprocessing import Categorizer, StandardScaler, DummyEncoder"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class ColumnSelector(TransformerMixin, BaseEstimator):\n",
" \"Select `columns` from `X`\"\n",
" def __init__(self, columns=None):\n",
" self.columns = columns\n",
"\n",
" def fit(self, X, y=None):\n",
" return self\n",
"\n",
" def transform(self, X, y=None):\n",
" if self.columns:\n",
" return X[self.columns]\n",
" else:\n",
" return X\n",
" \n",
"\n",
"class HourExtractor(TransformerMixin, BaseEstimator):\n",
" \"Transform each datetime64 column in `columns` to integer hours\"\n",
" def __init__(self, columns):\n",
" self.columns = columns\n",
"\n",
" def fit(self, X, y=None):\n",
" return self\n",
"\n",
" def transform(self, X, y=None):\n",
" return X.assign(**{col: lambda x: x[col].dt.hour for col in self.columns})\n",
"\n",
"\n",
"def payment_lowerer(X):\n",
" \"\"\"Lowercase all the Payment_Type values\"\"\"\n",
" return X.assign(Payment_Type=X.Payment_Type.str.lower())"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(memory=None,\n",
" steps=[('columnselector', ColumnSelector(columns=['vendor_name', 'Trip_Pickup_DateTime', 'Passenger_Count', 'Trip_Distance', 'Payment_Type', 'Fare_Amt', 'surcharge'])), ('hourextractor', HourExtractor(columns=['Trip_Pickup_DateTime'])), ('functiontransformer-1', FunctionTransformer(accept_sparse=Fal..._state=None, solver='admm',\n",
" solver_kwargs=None, tol=0.0001, verbose=0, warm_start=False))])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# The columns at the start of the pipeline\n",
"columns = ['vendor_name', 'Trip_Pickup_DateTime',\n",
" 'Passenger_Count', 'Trip_Distance',\n",
" 'Payment_Type', 'Fare_Amt', 'surcharge']\n",
"\n",
"# The mapping of {column: set of categories}\n",
"categories = {\n",
" 'vendor_name': CategoricalDtype(['CMT', 'DDS', 'VTS']),\n",
" 'Payment_Type': CategoricalDtype(['cash', 'credit', 'dispute', 'no charge']),\n",
"}\n",
"\n",
"scale = ['Trip_Distance', 'Fare_Amt', 'surcharge']\n",
"\n",
"pipe = make_pipeline(\n",
" ColumnSelector(columns=columns),\n",
" HourExtractor(['Trip_Pickup_DateTime']),\n",
" FunctionTransformer(payment_lowerer, validate=False),\n",
" Categorizer(categories=categories),\n",
" DummyEncoder(),\n",
" StandardScaler(scale),\n",
" FunctionTransformer(lambda x: x.values, validate=False),\n",
" LogisticRegression(),\n",
")\n",
"pipe"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 12min 7s, sys: 1min 11s, total: 13min 18s\n",
"Wall time: 8min 1s\n"
]
},
{
"data": {
"text/plain": [
"Pipeline(memory=None,\n",
" steps=[('columnselector', ColumnSelector(columns=['vendor_name', 'Trip_Pickup_DateTime', 'Passenger_Count', 'Trip_Distance', 'Payment_Type', 'Fare_Amt', 'surcharge'])), ('hourextractor', HourExtractor(columns=['Trip_Pickup_DateTime'])), ('functiontransformer-1', FunctionTransformer(accept_sparse=Fal..._state=None, solver='admm',\n",
" solver_kwargs=None, tol=0.0001, verbose=0, warm_start=False))])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time pipe.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.99314373342666018"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pipe.score(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.99316284730737436"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pipe.score(X_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Scaling it Out"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import dask.dataframe as dd"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%%time\n",
"df = dd.read_csv(\"data/*.csv\", dtype=dtype,\n",
" parse_dates=['Trip_Pickup_DateTime', 'Trip_Dropoff_DateTime'],)\n",
"\n",
"X = df.drop(\"Tip_Amt\", axis=1)\n",
"y = df['Tip_Amt'] > 0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since the scikit-learn world isn't really \"dask-aware\" at the moment, we'll use the `map_partitions` method. This is a good escape hatch for dealing with non-daskified code."
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"yhat = X.map_partitions(lambda x: pd.Series(pipe.predict_proba(x)[:, 1], name='yhat'),\n",
" meta=('yhat', 'f8'))"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 17min 52s, sys: 2min 35s, total: 20min 27s\n",
"Wall time: 8min 49s\n"
]
}
],
"source": [
"%time yhat.to_frame().to_parquet(\"data/predictions.parq\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment