Last active
March 11, 2021 14:58
-
-
Save howardjp/6f2d0620c3fb91afaf6855e87305fe28 to your computer and use it in GitHub Desktop.
TorchCDE Demo - Time Series Classification.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "TorchCDE Demo - Time Series Classification.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"toc_visible": true, | |
"authorship_tag": "ABX9TyN6Q0sKkPBrwIoTafSrKS4a", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/howardjp/6f2d0620c3fb91afaf6855e87305fe28/torchcde-demo-time-series-classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "OOcqpVGgq9Me" | |
}, | |
"source": [ | |
"Some configuration settings" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2ohPTiJ5q7QJ" | |
}, | |
"source": [ | |
"num_epochs = 30" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "vTne4xrBblx_", | |
"outputId": "7a84bb27-eafc-4bcc-8ba8-fc2525889051" | |
}, | |
"source": [ | |
"!pip install git+https://github.com/patrick-kidger/torchcde.git" | |
], | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting git+https://github.com/patrick-kidger/torchcde.git\n", | |
" Cloning https://github.com/patrick-kidger/torchcde.git to /tmp/pip-req-build-gfq8hhl6\n", | |
" Running command git clone -q https://github.com/patrick-kidger/torchcde.git /tmp/pip-req-build-gfq8hhl6\n", | |
"Requirement already satisfied (use --upgrade to upgrade): torchcde==0.2.0 from git+https://github.com/patrick-kidger/torchcde.git in /usr/local/lib/python3.7/dist-packages\n", | |
"Requirement already satisfied: torch>=1.7.0 in /usr/local/lib/python3.7/dist-packages (from torchcde==0.2.0) (1.8.0+cu101)\n", | |
"Requirement already satisfied: torchdiffeq>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from torchcde==0.2.0) (0.2.1)\n", | |
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.7.0->torchcde==0.2.0) (3.7.4.3)\n", | |
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torch>=1.7.0->torchcde==0.2.0) (1.19.5)\n", | |
"Building wheels for collected packages: torchcde\n", | |
" Building wheel for torchcde (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for torchcde: filename=torchcde-0.2.0-cp37-none-any.whl size=27210 sha256=1dc0c7c49cffc7ae26b7c79e7f0007d1d15d93e8817e7ec0a69c9ecf2e3893a0\n", | |
" Stored in directory: /tmp/pip-ephem-wheel-cache-ufingodg/wheels/27/70/62/fcc2954fe81b4263df3751dbd62599080933abee0a3f4736b4\n", | |
"Successfully built torchcde\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "9mycm2eGo9Ab" | |
}, | |
"source": [ | |
"So you want to train a Neural CDE model?\n", | |
"Let's get started!" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dpvLuvYll8pR" | |
}, | |
"source": [ | |
"import math\n", | |
"import torch\n", | |
"import torchcde\n", | |
"import matplotlib.pyplot as plt " | |
], | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ErY5qykwpNix" | |
}, | |
"source": [ | |
"A CDE model looks like:\n", | |
"\n", | |
"$z_t = z_0 + \\int_0^t f_\\theta(z_s) dX_s$\n", | |
"\n", | |
"Where $X$ is your data and $f_\\theta$ is a neural network. So the first thing we need to do is define such an $f_\\theta$. That's what this `CDEFunc` class does. Here we've built a small single-hidden-layer neural network, whose hidden layer is of width 128." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "QwRxbNTRpbTW" | |
}, | |
"source": [ | |
"class CDEFunc(torch.nn.Module):\n", | |
" def __init__(self, input_channels, hidden_channels):\n", | |
" ######################\n", | |
" # input_channels is the number of input channels in the data X. (Determined by the data.)\n", | |
" # hidden_channels is the number of channels for z_t. (Determined by you!)\n", | |
" ######################\n", | |
" super(CDEFunc, self).__init__()\n", | |
" self.input_channels = input_channels\n", | |
" self.hidden_channels = hidden_channels\n", | |
"\n", | |
" self.linear1 = torch.nn.Linear(hidden_channels, 128)\n", | |
" self.linear2 = torch.nn.Linear(128, input_channels * hidden_channels)\n", | |
"\n", | |
" ######################\n", | |
" # For most purposes the t argument can probably be ignored; unless you want your CDE to behave differently at\n", | |
" # different times, which would be unusual. But it's there if you need it!\n", | |
" ######################\n", | |
" def forward(self, t, z):\n", | |
" # z has shape (batch, hidden_channels)\n", | |
" z = self.linear1(z)\n", | |
" z = z.relu()\n", | |
" z = self.linear2(z)\n", | |
" ######################\n", | |
" # Easy-to-forget gotcha: Best results tend to be obtained by adding a final tanh nonlinearity.\n", | |
" ######################\n", | |
" z = z.tanh()\n", | |
" ######################\n", | |
" # Ignoring the batch dimension, the shape of the output tensor must be a matrix,\n", | |
" # because we need it to represent a linear map from R^input_channels to R^hidden_channels.\n", | |
" ######################\n", | |
" z = z.view(z.size(0), self.hidden_channels, self.input_channels)\n", | |
" return z" | |
], | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "g-Xgy5-vpzo1" | |
}, | |
"source": [ | |
"Next, we need to package `CDEFunc` up into a model that computes the integral.\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Ovb6I1dPpwZY" | |
}, | |
"source": [ | |
"class NeuralCDE(torch.nn.Module):\n", | |
" def __init__(self, input_channels, hidden_channels, output_channels):\n", | |
" super(NeuralCDE, self).__init__()\n", | |
"\n", | |
" self.func = CDEFunc(input_channels, hidden_channels)\n", | |
" self.initial = torch.nn.Linear(input_channels, hidden_channels)\n", | |
" self.readout = torch.nn.Linear(hidden_channels, output_channels)\n", | |
"\n", | |
" def forward(self, coeffs):\n", | |
" X = torchcde.NaturalCubicSpline(coeffs)\n", | |
"\n", | |
" ######################\n", | |
" # Easy to forget gotcha: Initial hidden state should be a function of the first observation.\n", | |
" ######################\n", | |
" X0 = X.evaluate(X.interval[0])\n", | |
" z0 = self.initial(X0)\n", | |
"\n", | |
" ######################\n", | |
" # Actually solve the CDE.\n", | |
" ######################\n", | |
" z_T = torchcde.cdeint(X=X,\n", | |
" z0=z0,\n", | |
" func=self.func,\n", | |
" t=X.interval)\n", | |
"\n", | |
" ######################\n", | |
" # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,\n", | |
" # and then apply a linear map.\n", | |
" ######################\n", | |
" z_T = z_T[:, 1]\n", | |
" pred_y = self.readout(z_T)\n", | |
" return pred_y" | |
], | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "QsNwgcfQqBqY" | |
}, | |
"source": [ | |
"Now we need some data. Here we have a simple example which generates some spirals, some going clockwise, some going anticlockwise." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "CzLt1LPyp7B4" | |
}, | |
"source": [ | |
"def get_data():\n", | |
" t = torch.linspace(0., 4 * math.pi, 100)\n", | |
"\n", | |
" start = torch.rand(128) * 2 * math.pi\n", | |
" x_pos = torch.cos(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)\n", | |
" x_pos[:64] *= -1\n", | |
" y_pos = torch.sin(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)\n", | |
" x_pos += 0.01 * torch.randn_like(x_pos)\n", | |
" y_pos += 0.01 * torch.randn_like(y_pos)\n", | |
" ######################\n", | |
" # Easy to forget gotcha: time should be included as a channel; Neural CDEs need to be explicitly told the\n", | |
" # rate at which time passes. Here, we have a regularly sampled dataset, so appending time is pretty simple.\n", | |
" ######################\n", | |
" X = torch.stack([t.unsqueeze(0).repeat(128, 1), x_pos, y_pos], dim=2)\n", | |
" y = torch.zeros(128)\n", | |
" y[:64] = 1\n", | |
"\n", | |
" perm = torch.randperm(128)\n", | |
" X = X[perm]\n", | |
" y = y[perm]\n", | |
"\n", | |
" ######################\n", | |
" # X is a tensor of observations, of shape (batch=128, sequence=100, channels=3)\n", | |
" # y is a tensor of labels, of shape (batch=128,), either 0 or 1 corresponding to anticlockwise or clockwise respectively.\n", | |
" ######################\n", | |
" return X, y" | |
], | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "mk5cJGuOqdQo" | |
}, | |
"source": [ | |
"train_X, train_y = get_data()" | |
], | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "BcETtCKOp-fX" | |
}, | |
"source": [ | |
"input_channels=3 because we have both the horizontal and vertical position of a point in the spiral, and time.\n", | |
"\n", | |
"hidden_channels=8 is the number of hidden channels for the evolving z_t, which we get to choose.\n", | |
"\n", | |
"output_channels=1 because we're doing binary classification.\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "PovXT-zCqeu3" | |
}, | |
"source": [ | |
"model = NeuralCDE(input_channels=3, hidden_channels=8, output_channels=1)\n", | |
"optimizer = torch.optim.Adam(model.parameters())" | |
], | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "v1DeWaLrqgBd" | |
}, | |
"source": [ | |
"Now we turn our dataset into a continuous path. We do this here via natural cubic spline interpolation. The resulting `train_coeffs` is a tensor describing the path. For most problems, it's probably easiest to save this tensor and treat it as the dataset.\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "uXXfl3W1qP6k" | |
}, | |
"source": [ | |
"train_coeffs = torchcde.natural_cubic_coeffs(train_X)\n", | |
"\n", | |
"train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)\n", | |
"train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)" | |
], | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "tMrsjxVhqqu1", | |
"outputId": "63bf52c7-1490-47a3-a7dc-5d4633e16d04" | |
}, | |
"source": [ | |
"for epoch in range(num_epochs):\n", | |
" for batch in train_dataloader:\n", | |
" batch_coeffs, batch_y = batch\n", | |
" pred_y = model(batch_coeffs).squeeze(-1)\n", | |
" loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y, batch_y)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" optimizer.zero_grad()\n", | |
" print('Epoch: {} Training loss: {}'.format(epoch, loss.item()))" | |
], | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 0 Training loss: 0.6910864114761353\n", | |
"Epoch: 1 Training loss: 0.6720166206359863\n", | |
"Epoch: 2 Training loss: 0.614913821220398\n", | |
"Epoch: 3 Training loss: 0.5624614953994751\n", | |
"Epoch: 4 Training loss: 0.5064358711242676\n", | |
"Epoch: 5 Training loss: 0.44263210892677307\n", | |
"Epoch: 6 Training loss: 0.3414338231086731\n", | |
"Epoch: 7 Training loss: 0.235631063580513\n", | |
"Epoch: 8 Training loss: 0.12428899109363556\n", | |
"Epoch: 9 Training loss: 0.04819280654191971\n", | |
"Epoch: 10 Training loss: 0.017104361206293106\n", | |
"Epoch: 11 Training loss: 0.006561676971614361\n", | |
"Epoch: 12 Training loss: 0.0032162293791770935\n", | |
"Epoch: 13 Training loss: 0.0019024275243282318\n", | |
"Epoch: 14 Training loss: 0.0012913704849779606\n", | |
"Epoch: 15 Training loss: 0.000980702112428844\n", | |
"Epoch: 16 Training loss: 0.0007996470667421818\n", | |
"Epoch: 17 Training loss: 0.0006868061609566212\n", | |
"Epoch: 18 Training loss: 0.0006104775820858777\n", | |
"Epoch: 19 Training loss: 0.0005563427694141865\n", | |
"Epoch: 20 Training loss: 0.0005130225908942521\n", | |
"Epoch: 21 Training loss: 0.0004782585892826319\n", | |
"Epoch: 22 Training loss: 0.0004496659094002098\n", | |
"Epoch: 23 Training loss: 0.0004244595183990896\n", | |
"Epoch: 24 Training loss: 0.0004041045904159546\n", | |
"Epoch: 25 Training loss: 0.00038389809196814895\n", | |
"Epoch: 26 Training loss: 0.0003656374174170196\n", | |
"Epoch: 27 Training loss: 0.00034697819501161575\n", | |
"Epoch: 28 Training loss: 0.0003301869728602469\n", | |
"Epoch: 29 Training loss: 0.0003143766080029309\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "xsy8DpwoqsuR", | |
"outputId": "ae913dc4-18a6-4384-9a57-7fd10ec7d379" | |
}, | |
"source": [ | |
"test_X, test_y = get_data()\n", | |
"test_coeffs = torchcde.natural_cubic_coeffs(test_X)\n", | |
"pred_y = model(test_coeffs).squeeze(-1)\n", | |
"binary_prediction = (torch.sigmoid(pred_y) > 0.5).to(test_y.dtype)\n", | |
"prediction_matches = (binary_prediction == test_y).to(test_y.dtype)\n", | |
"proportion_correct = prediction_matches.sum() / test_y.size(0)\n", | |
"print('Test Accuracy: {}'.format(proportion_correct))" | |
], | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Test Accuracy: 1.0\n" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment