Skip to content

Instantly share code, notes, and snippets.

@MaxHalford
Created June 7, 2020 14:50
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MaxHalford/e23c4fe26c035b818bc40cbdde9c3a8f to your computer and use it in GitHub Desktop.
Save MaxHalford/e23c4fe26c035b818bc40cbdde9c3a8f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Predicting taxi trip durations with creme and chantilly"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example we'll build a model to predict the duration of taxi trips in the city of New-York (dataset [here](https://www.kaggle.com/c/nyc-taxi-trip-duration))."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's first install the necessary dependencies."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install creme chantilly dill"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's now take a look at the data."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Taxis dataset\n",
"\n",
" Task Regression \n",
" Number of samples 1,458,644 \n",
"Number of features 8 \n",
" Sparse False \n",
" Path /Users/mhalford/creme_data/Taxis/train.csv \n",
" URL https://maxhalford.github.io/files/datasets/nyc_taxis.zip\n",
" Size 186.23 MB \n",
" Downloaded True "
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from creme import datasets\n",
"\n",
"trips = datasets.Taxis()\n",
"trips"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'vendor_id': '2',\n",
" 'pickup_datetime': datetime.datetime(2016, 1, 1, 0, 0, 17),\n",
" 'passenger_count': 5,\n",
" 'pickup_longitude': -73.98174285888672,\n",
" 'pickup_latitude': 40.71915817260742,\n",
" 'dropoff_longitude': -73.93882751464845,\n",
" 'dropoff_latitude': 40.82918167114258,\n",
" 'store_and_fwd_flag': 'N'}"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x, y = next(iter(trips))\n",
"x"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"849"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It seems reasonable to use the distance in order to predict the duration.\n",
"\n",
"With `creme`, we're working with dictionaries. Therefore, a simple way to go about extracting features is to write a function."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"\n",
"def distances(trip):\n",
" lat_dist = trip['dropoff_latitude'] - trip['pickup_latitude']\n",
" lon_dist = trip['dropoff_longitude'] - trip['pickup_longitude']\n",
" return {\n",
" 'manhattan_distance': abs(lat_dist) + abs(lon_dist),\n",
" 'euclidean_distance': math.sqrt(lat_dist ** 2 + lon_dist ** 2)\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can verify that this function works on the first sample."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'manhattan_distance': 0.1529388427734233,\n",
" 'euclidean_distance': 0.11809698133739274}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"distances(trip=x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Additionally, it should worthwhile to extract temporal information."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'hour': 0, 'day': 'Friday'}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import calendar\n",
"\n",
"def datetime_info(trip):\n",
" day_no = trip['pickup_datetime'].weekday()\n",
" return {\n",
" 'hour': trip['pickup_datetime'].hour,\n",
" 'day': calendar.day_name[day_no]\n",
" }\n",
"\n",
"datetime_info(trip=x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now assemble these steps into a `TransformerUnion`."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TransformerUnion (\n",
" FuncTransformer (\n",
" func=\"distances\"\n",
" ),\n",
" FuncTransformer (\n",
" func=\"datetime_info\"\n",
" )\n",
")"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from creme import compose\n",
"\n",
"extract_features = compose.TransformerUnion(distances, datetime_info)\n",
"extract_features"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`TransformerUnion` is a `Transformer`, which means that it has a `transform_one` method."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'hour': 0,\n",
" 'day': 'Friday',\n",
" 'manhattan_distance': 0.1529388427734233,\n",
" 'euclidean_distance': 0.11809698133739274}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"features = extract_features.transform_one(x)\n",
"features"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also call `fit_one`, but in this case it is unnecessary because our feature extractors are stateless.\n",
"\n",
"We would now like to train a linear regression. The problem is that the `day` feature is categorical, whilst a linear regression only accepts numeric data. A simple way circumvent this issue is to use one-hot encoding, which involves replacing the `day` feature with a binary feature per day of the week."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'hour': 0.0,\n",
" 'manhattan_distance': 0.1529388427734233,\n",
" 'euclidean_distance': 0.11809698133739274,\n",
" 'day_Friday': 1}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numbers \n",
"from creme import preprocessing\n",
"\n",
"cat = compose.SelectType(str) | preprocessing.OneHotEncoder()\n",
"num = compose.SelectType(numbers.Number) | preprocessing.StandardScaler()\n",
"\n",
"preprocess = compose.TransformerUnion(cat, num)\n",
"preprocess.transform_one(features)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now assemble these steps into a pipeline."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'hour': 0.0,\n",
" 'manhattan_distance': 0.1529388427734233,\n",
" 'euclidean_distance': 0.11809698133739274,\n",
" 'day_Friday': 1}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pipeline = compose.Pipeline(\n",
" extract_features,\n",
" preprocess\n",
")\n",
"\n",
"pipeline.transform_one(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We're now ready to append a linear regression to our pipeline."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from creme import linear_model\n",
"\n",
"pipeline |= linear_model.LinearRegression()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's take a look at what our pipeline looks like."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.42.3 (20191010.1750)\n",
" -->\n",
"<!-- Title: %3 Pages: 1 -->\n",
"<svg width=\"288pt\" height=\"404pt\"\n",
" viewBox=\"0.00 0.00 288.21 404.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 400)\">\n",
"<title>%3</title>\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-400 284.21,-400 284.21,4 -4,4\"/>\n",
"<!-- x -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>x</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"141.57\" cy=\"-378\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"141.57\" y=\"-373.8\" font-family=\"Times,serif\" font-size=\"14.00\">x</text>\n",
"</g>\n",
"<!-- distances -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>distances</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"81.57\" cy=\"-306\" rx=\"42.55\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"81.57\" y=\"-301.8\" font-family=\"Times,serif\" font-size=\"14.00\">distances</text>\n",
"</g>\n",
"<!-- x&#45;&gt;distances -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>x&#45;&gt;distances</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M128.55,-361.81C120.83,-352.8 110.86,-341.18 102.1,-330.95\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"104.56,-328.45 95.4,-323.13 99.25,-333 104.56,-328.45\"/>\n",
"</g>\n",
"<!-- datetime_info -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>datetime_info</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"201.57\" cy=\"-306\" rx=\"59.44\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"201.57\" y=\"-301.8\" font-family=\"Times,serif\" font-size=\"14.00\">datetime_info</text>\n",
"</g>\n",
"<!-- x&#45;&gt;datetime_info -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>x&#45;&gt;datetime_info</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M154.59,-361.81C162.2,-352.93 171.99,-341.5 180.66,-331.39\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"183.46,-333.5 187.32,-323.63 178.15,-328.94 183.46,-333.5\"/>\n",
"</g>\n",
"<!-- Select(str) -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>Select(str)</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"76.57\" cy=\"-234\" rx=\"46.4\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"76.57\" y=\"-229.8\" font-family=\"Times,serif\" font-size=\"14.00\">Select(str)</text>\n",
"</g>\n",
"<!-- distances&#45;&gt;Select(str) -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>distances&#45;&gt;Select(str)</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M80.33,-287.7C79.78,-279.98 79.12,-270.71 78.5,-262.11\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"81.99,-261.83 77.79,-252.1 75.01,-262.33 81.99,-261.83\"/>\n",
"</g>\n",
"<!-- Select(Number) -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>Select(Number)</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"207.57\" cy=\"-234\" rx=\"66.67\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"207.57\" y=\"-229.8\" font-family=\"Times,serif\" font-size=\"14.00\">Select(Number)</text>\n",
"</g>\n",
"<!-- distances&#45;&gt;Select(Number) -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>distances&#45;&gt;Select(Number)</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M106.46,-291.17C124.83,-280.97 150.12,-266.92 170.81,-255.42\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"172.61,-258.42 179.66,-250.51 169.21,-252.3 172.61,-258.42\"/>\n",
"</g>\n",
"<!-- datetime_info&#45;&gt;Select(str) -->\n",
"<g id=\"edge10\" class=\"edge\">\n",
"<title>datetime_info&#45;&gt;Select(str)</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M174.44,-289.81C155.76,-279.35 130.8,-265.37 110.76,-254.15\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"112.37,-251.04 101.93,-249.21 108.95,-257.14 112.37,-251.04\"/>\n",
"</g>\n",
"<!-- datetime_info&#45;&gt;Select(Number) -->\n",
"<g id=\"edge11\" class=\"edge\">\n",
"<title>datetime_info&#45;&gt;Select(Number)</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M203.05,-287.7C203.71,-279.98 204.51,-270.71 205.24,-262.11\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"208.73,-262.37 206.1,-252.1 201.76,-261.77 208.73,-262.37\"/>\n",
"</g>\n",
"<!-- OneHotEncoder -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>OneHotEncoder</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"67.57\" cy=\"-162\" rx=\"67.64\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"67.57\" y=\"-157.8\" font-family=\"Times,serif\" font-size=\"14.00\">OneHotEncoder</text>\n",
"</g>\n",
"<!-- Select(str)&#45;&gt;OneHotEncoder -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>Select(str)&#45;&gt;OneHotEncoder</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M74.34,-215.7C73.35,-207.98 72.16,-198.71 71.05,-190.11\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"74.51,-189.58 69.77,-180.1 67.57,-190.47 74.51,-189.58\"/>\n",
"</g>\n",
"<!-- LinearRegression -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>LinearRegression</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"141.57\" cy=\"-90\" rx=\"72.46\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"141.57\" y=\"-85.8\" font-family=\"Times,serif\" font-size=\"14.00\">LinearRegression</text>\n",
"</g>\n",
"<!-- OneHotEncoder&#45;&gt;LinearRegression -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>OneHotEncoder&#45;&gt;LinearRegression</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M85.1,-144.41C94.44,-135.58 106.08,-124.57 116.36,-114.84\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"119.04,-117.13 123.9,-107.71 114.23,-112.04 119.04,-117.13\"/>\n",
"</g>\n",
"<!-- StandardScaler -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>StandardScaler</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"216.57\" cy=\"-162\" rx=\"63.78\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"216.57\" y=\"-157.8\" font-family=\"Times,serif\" font-size=\"14.00\">StandardScaler</text>\n",
"</g>\n",
"<!-- Select(Number)&#45;&gt;StandardScaler -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>Select(Number)&#45;&gt;StandardScaler</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M209.79,-215.7C210.78,-207.98 211.98,-198.71 213.08,-190.11\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"216.56,-190.47 214.37,-180.1 209.62,-189.58 216.56,-190.47\"/>\n",
"</g>\n",
"<!-- StandardScaler&#45;&gt;LinearRegression -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>StandardScaler&#45;&gt;LinearRegression</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M198.79,-144.41C189.17,-135.43 177.13,-124.19 166.58,-114.34\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"168.91,-111.73 159.21,-107.47 164.14,-116.85 168.91,-111.73\"/>\n",
"</g>\n",
"<!-- y -->\n",
"<g id=\"node9\" class=\"node\">\n",
"<title>y</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"141.57\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"141.57\" y=\"-13.8\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n",
"</g>\n",
"<!-- LinearRegression&#45;&gt;y -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>LinearRegression&#45;&gt;y</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M141.57,-71.7C141.57,-63.98 141.57,-54.71 141.57,-46.11\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"145.07,-46.1 141.57,-36.1 138.07,-46.1 145.07,-46.1\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.dot.Digraph at 0x1a30d1b910>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pipeline.draw()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us now use progressive validation to evaluate the performance of our model. This will loop through the data and make a prediction for each sample before learning from it. This is the canonical way of evaluating online machine learning models."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[10,000] MAE: 534.984054\n",
"[20,000] MAE: 537.327384\n",
"[30,000] MAE: 865.921832\n",
"[40,000] MAE: 759.319743\n",
"[50,000] MAE: 903.466296\n"
]
},
{
"data": {
"text/plain": [
"MAE: 903.466296"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from creme import metrics\n",
"from creme import model_selection\n",
"\n",
"model_selection.progressive_val_score(\n",
" X_y=trips.take(50_000),\n",
" model=pipeline,\n",
" metric=metrics.MAE(),\n",
" print_every=10_000\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we might want to look at tuning some hyperparameters. This is quite to batch learning, because in our case we want the best parameters on-the-fly. To start off, we can enumerate a list of hyperparameters combinations we want to try out. The `expand_param_grid` function is really practical for doing so."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from creme import optim\n",
"\n",
"param_grid = model_selection.expand_param_grid({\n",
" 'LinearRegression': {\n",
" 'optimizer': [\n",
" (optim.SGD, {'lr': [.1, .01, .005]}),\n",
" (optim.Adam, {'beta_1': [.01, .001], 'lr': [.1, .01, .001]}),\n",
" (optim.Adam, {'beta_1': [.1], 'lr': [.001]}),\n",
" ]\n",
" }\n",
"})\n",
"\n",
"models = [\n",
" pipeline._set_params(params)\n",
" for params in param_grid\n",
"]\n",
"\n",
"len(models)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"At of writing this document, the only available model selection tool is [successive halving](https://arxiv.org/pdf/1502.07943.pdf). In our case we're doing regression, so we'll use `SuccessiveHalvingRegressor`. You can treat it like any other model, as it implements `fit_one` and `predict_one`."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[10,000] MAE: 550.542469\n",
"[20,000] MAE: 556.9258\n",
"[30,000] MAE: 901.923797\n",
"[40,000] MAE: 799.411683\n",
"[50,000] MAE: 873.054589\n"
]
},
{
"data": {
"text/plain": [
"MAE: 873.054589"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sh = model_selection.SuccessiveHalvingRegressor(\n",
" models=models,\n",
" metric=metrics.MAE(),\n",
" budget=10000\n",
")\n",
"\n",
"model_selection.progressive_val_score(\n",
" X_y=trips.take(50_000),\n",
" model=sh,\n",
" metric=metrics.MAE(),\n",
" print_every=10_000\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Pipeline (\n",
" TransformerUnion (\n",
" FuncTransformer (\n",
" func=\"distances\"\n",
" ),\n",
" FuncTransformer (\n",
" func=\"datetime_info\"\n",
" )\n",
" ),\n",
" TransformerUnion (\n",
" Pipeline (\n",
" Select (\n",
" <class 'str'>\n",
" ),\n",
" OneHotEncoder (\n",
" sparse=False\n",
" )\n",
" ),\n",
" Pipeline (\n",
" Select (\n",
" <class 'numbers.Number'>\n",
" ),\n",
" StandardScaler (\n",
" with_mean=True\n",
" with_std=True\n",
" )\n",
" )\n",
" ),\n",
" LinearRegression (\n",
" optimizer=SGD (\n",
" lr=Constant (\n",
" learning_rate=0.005\n",
" )\n",
" )\n",
" loss=Squared ()\n",
" l2=0.\n",
" intercept=750.856353\n",
" intercept_lr=Constant (\n",
" learning_rate=0.01\n",
" )\n",
" clip_gradient=1e+12\n",
" initializer=Zeros ()\n",
" )\n",
")"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sh.best_model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now how about deploying our model? Well the `creme` team has developped a little tool called [`chantilly`](https://github.com/creme-ml/chantilly) to simplify the process. It is essentially a [Flask](https://flask.palletsprojects.com/en/1.1.x/) app, and so is very simple to install. The source code is also very easy to delve into."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/bin/sh: chantilly: command not found\n"
]
}
],
"source": [
"!chantilly run # run this in a terminal session"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We first have to tell Chantilly what \"flavor\" we want to use. In this case we're doing regression so we'll use the \"regression\" flavor."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Response [201]>"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import requests\n",
"\n",
"host = 'http://localhost:5000'\n",
"\n",
"requests.post(host + '/api/init', json={'flavor': 'regression'})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us now upload the model. We need to make a couple of changes first:\n",
"\n",
"- At the moment, using external needs in user-defined functions need to be done within each function. This might change in a future release.\n",
"- The `creme` dataset already takes care of parsing the datetimes. However, `chantilly` will assume that JSON data is provided, which thus has to be accounted for."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"def distances(trip):\n",
" import math\n",
" \n",
" lat_dist = trip['dropoff_latitude'] - trip['pickup_latitude']\n",
" lon_dist = trip['dropoff_longitude'] - trip['pickup_longitude']\n",
" \n",
" return {\n",
" 'manhattan_distance': abs(lat_dist) + abs(lon_dist),\n",
" 'euclidean_distance': math.sqrt(lat_dist ** 2 + lon_dist ** 2)\n",
" }\n",
"\n",
"def datetime_info(trip):\n",
" import calendar\n",
" import datetime as dt\n",
" \n",
" day = dt.datetime.fromisoformat(trip['pickup_datetime'])\n",
" \n",
" return {\n",
" 'hour': day.hour,\n",
" 'day': calendar.day_name[day.weekday()]\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The model can now be uploaded."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Response [201]>"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import dill\n",
"\n",
"extract_features = compose.TransformerUnion(distances, datetime_info)\n",
"\n",
"pipeline = compose.Pipeline(\n",
" extract_features,\n",
" preprocess,\n",
" linear_model.LinearRegression()\n",
")\n",
"\n",
"requests.post(host + '/api/model', data=dill.dumps(pipeline))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To make things realistic, we'll run a simulation where the taxis leave and arrive in the order as given in the dataset. Indeed, we can reproduce a live workload from a historical dataset, therefore producing an environment which is very close to what happens in a production setting."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"#0000000 departs at 2016-01-01 00:00:17\n",
"#0000001 departs at 2016-01-01 00:00:53\n",
"#0000002 departs at 2016-01-01 00:01:01\n",
"#0000003 departs at 2016-01-01 00:01:14\n",
"#0000004 departs at 2016-01-01 00:01:20\n",
"#0000005 departs at 2016-01-01 00:01:33\n",
"#0000006 departs at 2016-01-01 00:01:37\n",
"#0000007 departs at 2016-01-01 00:01:47\n",
"#0000008 departs at 2016-01-01 00:02:06\n",
"#0000009 departs at 2016-01-01 00:02:45\n",
"#0000010 departs at 2016-01-01 00:03:02\n",
"#0000006 arrives at 2016-01-01 00:03:31 - average error: 0:01:54\n",
"#0000011 departs at 2016-01-01 00:03:31\n",
"#0000012 departs at 2016-01-01 00:03:35\n",
"#0000013 departs at 2016-01-01 00:04:42\n",
"#0000014 departs at 2016-01-01 00:04:57\n",
"#0000015 departs at 2016-01-01 00:05:07\n",
"#0000016 departs at 2016-01-01 00:05:08\n",
"#0000017 departs at 2016-01-01 00:05:18\n",
"#0000018 departs at 2016-01-01 00:05:35\n",
"#0000019 departs at 2016-01-01 00:05:39\n",
"#0000003 arrives at 2016-01-01 00:05:54 - average error: 0:03:17\n",
"#0000020 departs at 2016-01-01 00:06:04\n",
"#0000021 departs at 2016-01-01 00:06:12\n",
"#0000022 departs at 2016-01-01 00:06:22\n",
"#0000023 departs at 2016-01-01 00:06:24\n",
"#0000024 departs at 2016-01-01 00:06:47\n",
"#0000025 departs at 2016-01-01 00:06:56\n",
"#0000026 departs at 2016-01-01 00:06:59\n",
"#0000027 departs at 2016-01-01 00:07:04\n",
"#0000028 departs at 2016-01-01 00:07:06\n",
"#0000029 departs at 2016-01-01 00:07:07\n",
"#0000021 arrives at 2016-01-01 00:07:13 - average error: 0:02:16.660295\n",
"#0000030 departs at 2016-01-01 00:07:22\n",
"#0000010 arrives at 2016-01-01 00:07:25 - average error: 0:02:48.245221\n",
"#0000031 departs at 2016-01-01 00:07:27\n",
"#0000032 departs at 2016-01-01 00:07:29\n",
"#0000033 departs at 2016-01-01 00:07:34\n",
"#0000034 departs at 2016-01-01 00:07:46\n",
"#0000035 departs at 2016-01-01 00:07:47\n",
"#0000002 arrives at 2016-01-01 00:07:49 - average error: 0:03:36.196177\n",
"#0000036 departs at 2016-01-01 00:07:52\n",
"#0000037 departs at 2016-01-01 00:08:07\n",
"#0000038 departs at 2016-01-01 00:08:09\n",
"#0000039 departs at 2016-01-01 00:08:11\n",
"#0000040 departs at 2016-01-01 00:08:15\n",
"#0000041 departs at 2016-01-01 00:08:29\n",
"#0000014 arrives at 2016-01-01 00:08:37 - average error: 0:03:34.342094\n",
"#0000042 departs at 2016-01-01 00:08:37\n",
"#0000043 departs at 2016-01-01 00:08:38\n",
"#0000044 departs at 2016-01-01 00:08:40\n",
"#0000045 departs at 2016-01-01 00:08:46\n",
"#0000046 departs at 2016-01-01 00:08:47\n",
"#0000047 departs at 2016-01-01 00:08:49\n",
"#0000048 departs at 2016-01-01 00:08:52\n",
"#0000049 departs at 2016-01-01 00:08:53\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-28-15f9fc6bbe36>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;31m# Wait\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mnap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrip\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'pickup_datetime'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mnow\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0mnow\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrip\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'pickup_datetime'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-28-15f9fc6bbe36>\u001b[0m in \u001b[0;36mnap\u001b[0;34m(td)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mnap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtd\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mdt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimedelta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msleep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mseconds\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"import datetime as dt\n",
"import time\n",
"from creme import datasets\n",
"from creme import stream\n",
"import requests\n",
"\n",
"\n",
"# Use the first trip's departure time as a reference time\n",
"taxis = datasets.Taxis()\n",
"now = next(iter(taxis))[0]['pickup_datetime']\n",
"mae = metrics.MAE() \n",
"predictions = {}\n",
"\n",
"\n",
"def nap(td: dt.timedelta):\n",
" time.sleep(td.seconds / 10)\n",
"\n",
"\n",
"for trip_no, trip, duration in stream.simulate_qa(\n",
" taxis,\n",
" moment='pickup_datetime',\n",
" delay=lambda _, duration: dt.timedelta(seconds=duration)\n",
"):\n",
"\n",
" trip_no = str(trip_no).zfill(len(str(taxis.n_samples)))\n",
"\n",
" # Taxi trip starts\n",
"\n",
" if duration is None:\n",
"\n",
" # Wait\n",
" nap(trip['pickup_datetime'] - now)\n",
" now = trip['pickup_datetime']\n",
"\n",
" # Ask chantilly to make a prediction\n",
" r = requests.post(host + '/api/predict', json={\n",
" 'id': trip_no,\n",
" 'features': {**trip, 'pickup_datetime': trip['pickup_datetime'].isoformat()}\n",
" })\n",
"\n",
" # Store the prediction\n",
" predictions[trip_no] = r.json()['prediction']\n",
"\n",
" print(f'#{trip_no} departs at {now}')\n",
" continue\n",
"\n",
" # Taxi trip ends\n",
"\n",
" # Wait\n",
" arrival_time = trip['pickup_datetime'] + dt.timedelta(seconds=duration)\n",
" nap(arrival_time - now)\n",
" now = arrival_time\n",
"\n",
" # Ask chantilly to update the model\n",
" requests.post(host + '/api/learn', json={'id': trip_no, 'ground_truth': duration})\n",
"\n",
" # Update the metric\n",
" mae.update(y_true=duration, y_pred=predictions.pop(trip_no))\n",
"\n",
" msg = f'#{trip_no} arrives at {now} - average error: {dt.timedelta(seconds=mae.get())}'\n",
" print(msg)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'MAE': 214.34209431927968,\n",
" 'RMSE': 248.1057514493265,\n",
" 'SMAPE': 167.4549197133317}"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"requests.get(host + '/api/metrics').json()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'learn': {'ewm_duration': 4426537,\n",
" 'ewm_duration_human': '4ms426μs537ns',\n",
" 'mean_duration': 4422086,\n",
" 'mean_duration_human': '4ms422μs86ns',\n",
" 'n_calls': 6},\n",
" 'predict': {'ewm_duration': 3974240,\n",
" 'ewm_duration_human': '3ms974μs240ns',\n",
" 'mean_duration': 4081125,\n",
" 'mean_duration_human': '4ms81μs125ns',\n",
" 'n_calls': 62}}"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"requests.get(host + '/api/stats').json()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment