Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save elianlaura/7ad257ff550ee0335523c0b447ad3710 to your computer and use it in GitHub Desktop.
Save elianlaura/7ad257ff550ee0335523c0b447ad3710 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Timeseries classification with a Transformer model\n",
"\n",
"**Author:** [Theodoros Ntakouris](https://github.com/ntakouris)<br>\n",
"**Date created:** 2021/06/25<br>\n",
"**Last modified:** 2021/08/05<br>\n",
"**Description:** This notebook demonstrates how to do timeseries classification using a Transformer model."
],
"metadata": {
"id": "toxQaKPncw1a"
}
},
{
"cell_type": "markdown",
"source": [
"## Introduction\n",
"\n",
"This is the Transformer architecture from\n",
"[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n",
"applied to timeseries instead of natural language.\n",
"\n",
"This example requires TensorFlow 2.4 or higher.\n",
"\n",
"## Load the dataset\n",
"\n",
"We are going to use the same dataset and preprocessing as the\n",
"[TimeSeries Classification from Scratch](https://keras.io/examples/timeseries/timeseries_classification_from_scratch)\n",
"example."
],
"metadata": {
"id": "E-bMQOtAcw1g"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"import numpy as np\n",
"\n",
"\n",
"def readucr(filename):\n",
" data = np.loadtxt(filename, delimiter=\"\\t\")\n",
" y = data[:, 0]\n",
" x = data[:, 1:]\n",
" return x, y.astype(int)\n",
"\n",
"\n",
"root_url = \"https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/\"\n",
"\n",
"x_train, y_train = readucr(root_url + \"FordA_TRAIN.tsv\")\n",
"x_test, y_test = readucr(root_url + \"FordA_TEST.tsv\")\n",
"\n",
"x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], 1))\n",
"x_test = x_test.reshape((x_test.shape[0], x_test.shape[1], 1))\n",
"\n",
"n_classes = len(np.unique(y_train))\n",
"\n",
"idx = np.random.permutation(len(x_train))\n",
"x_train = x_train[idx]\n",
"y_train = y_train[idx]\n",
"\n",
"y_train[y_train == -1] = 0\n",
"y_test[y_test == -1] = 0"
],
"outputs": [],
"metadata": {
"id": "Yii3Ire7cw1i"
}
},
{
"cell_type": "markdown",
"source": [
"### Load dataset"
],
"metadata": {
"id": "DVP3QdoMttpo"
}
},
{
"cell_type": "code",
"metadata": {
"id": "eRP_SJBcyTfi",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "1ac250d0-673f-41ba-98b7-9516a831de90"
},
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7c4810e2-4b99-4cfd-e40c-86815469b0b1",
"id": "z7Wdj7wjsksr"
},
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import os\n",
"import numpy as np\n",
"import glob\n",
"import pandas as pd"
],
"metadata": {
"id": "XY7vJj9jJ0GB"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"GOOGLE_DRIVE_PATH_AFTER_MYDRIVE = 'Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel' \n",
"GOOGLE_DRIVE_PATH = os.path.join('drive', 'My Drive', GOOGLE_DRIVE_PATH_AFTER_MYDRIVE)\n",
"print(os.listdir(GOOGLE_DRIVE_PATH))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Zz7CPI5ktyQn",
"outputId": "08b2af5e-e6aa-4be6-b9d3-05fb67f51abe"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['data_1603_accel_watch.arff', 'data_1601_accel_watch.arff', '.DS_Store', 'data_1602_accel_watch.arff', 'data_1600_accel_watch.arff', 'data_1606_accel_watch.arff', 'data_1620_accel_watch.arff', 'data_1608_accel_watch.arff', 'data_1629_accel_watch.arff', 'data_1623_accel_watch.arff', 'data_1618_accel_watch.arff', 'data_1604_accel_watch.arff', 'data_1622_accel_watch.arff', 'data_1610_accel_watch.arff', 'data_1625_accel_watch.arff', 'data_1626_accel_watch.arff', 'data_1617_accel_watch.arff', 'data_1619_accel_watch.arff', 'data_1630_accel_watch.arff', 'data_1605_accel_watch.arff', 'data_1613_accel_watch.arff', 'data_1612_accel_watch.arff', 'data_1609_accel_watch.arff', 'data_1621_accel_watch.arff', 'data_1628_accel_watch.arff', 'data_1607_accel_watch.arff', 'data_1616_accel_watch.arff', 'data_1611_accel_watch.arff', 'data_1627_accel_watch.arff', 'data_1615_accel_watch.arff', 'data_1624_accel_watch.arff', 'data_1636_accel_watch.arff', 'data_1634_accel_watch.arff', 'data_1635_accel_watch.arff', 'data_1633_accel_watch.arff', 'data_1631_accel_watch.arff', 'data_1637_accel_watch.arff', 'data_1638_accel_watch.arff', 'data_1639_accel_watch.arff', 'data_1632_accel_watch.arff', 'data_1643_accel_watch.arff', 'data_1649_accel_watch.arff', 'data_1642_accel_watch.arff', 'data_1641_accel_watch.arff', 'data_1645_accel_watch.arff', 'data_1650_accel_watch.arff', 'data_1644_accel_watch.arff', 'data_1640_accel_watch.arff', 'data_1647_accel_watch.arff', 'data_1646_accel_watch.arff', 'data_1648_accel_watch.arff']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"features = ['ACTIVITY',\n",
" 'X0', # 1st bin fraction of x axis acceleration distribution\n",
" 'X1', # 2nd bin fraction ...\n",
" 'X2',\n",
" 'X3',\n",
" 'X4',\n",
" 'X5',\n",
" 'X6',\n",
" 'X7',\n",
" 'X8',\n",
" 'X9',\n",
" 'Y0', # 1st bin fraction of y axis acceleration distribution\n",
" 'Y1', # 2nd bin fraction ...\n",
" 'Y2',\n",
" 'Y3',\n",
" 'Y4',\n",
" 'Y5',\n",
" 'Y6',\n",
" 'Y7',\n",
" 'Y8',\n",
" 'Y9',\n",
" 'Z0', # 1st bin fraction of z axis acceleration distribution\n",
" 'Z1', # 2nd bin fraction ...\n",
" 'Z2',\n",
" 'Z3',\n",
" 'Z4',\n",
" 'Z5',\n",
" 'Z6',\n",
" 'Z7',\n",
" 'Z8',\n",
" 'Z9',\n",
" 'XAVG', # average sensor value over the window (per axis)\n",
" 'YAVG',\n",
" 'ZAVG',\n",
" 'XPEAK', # Time in milliseconds between the peaks in the wave associated with most activities. heuristically determined (per axis)\n",
" 'YPEAK',\n",
" 'ZPEAK',\n",
" 'XABSOLDEV', # Average absolute difference between the each of the 200 readings and the mean of those values (per axis)\n",
" 'YABSOLDEV',\n",
" 'ZABSOLDEV',\n",
" 'XSTANDDEV', # Standard deviation of the 200 window's values (per axis) ***BUG!***\n",
" 'YSTANDDEV',\n",
" 'ZSTANDDEV',\n",
" 'XVAR', # Variance of the 200 window's values (per axis) ***BUG!***\n",
" 'YVAR',\n",
" 'ZVAR',\n",
" 'XMFCC0', # short-term power spectrum of a wave, based on a linear cosine transform of a log power spectrum on a non-linear mel scale of frequency (13 values per axis)\n",
" 'XMFCC1',\n",
" 'XMFCC2',\n",
" 'XMFCC3',\n",
" 'XMFCC4',\n",
" 'XMFCC5',\n",
" 'XMFCC6',\n",
" 'XMFCC7',\n",
" 'XMFCC8',\n",
" 'XMFCC9',\n",
" 'XMFCC10',\n",
" 'XMFCC11',\n",
" 'XMFCC12',\n",
" 'YMFCC0', # short-term power spectrum of a wave, based on a linear cosine transform of a log power spectrum on a non-linear mel scale of frequency (13 values per axis)\n",
" 'YMFCC1',\n",
" 'YMFCC2',\n",
" 'YMFCC3',\n",
" 'YMFCC4',\n",
" 'YMFCC5',\n",
" 'YMFCC6',\n",
" 'YMFCC7',\n",
" 'YMFCC8',\n",
" 'YMFCC9',\n",
" 'YMFCC10',\n",
" 'YMFCC11',\n",
" 'YMFCC12',\n",
" 'ZMFCC0', # short-term power spectrum of a wave, based on a linear cosine transform of a log power spectrum on a non-linear mel scale of frequency (13 values per axis)\n",
" 'ZMFCC1',\n",
" 'ZMFCC2',\n",
" 'ZMFCC3',\n",
" 'ZMFCC4',\n",
" 'ZMFCC5',\n",
" 'ZMFCC6',\n",
" 'ZMFCC7',\n",
" 'ZMFCC8',\n",
" 'ZMFCC9',\n",
" 'ZMFCC10',\n",
" 'ZMFCC11',\n",
" 'ZMFCC12',\n",
" 'XYCOS', # The cosine distances between sensor values for pairs of axes (three pairs of axes)\n",
" 'XZCOS',\n",
" 'YZCOS',\n",
" 'XYCOR', # The correlation between sensor values for pairs of axes (three pairs of axes)\n",
" 'XZCOR',\n",
" 'YZCOR',\n",
" 'RESULTANT', # Average resultant value, computed by squaring each matching x, y, and z value, summing them, taking the square root, and then averaging these values over the 200 readings\n",
" 'PARTICIPANT'] # Categirical: 1600 -1650\n",
"\n",
"len(features)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HBulO6aqMvsY",
"outputId": "18e3bd46-2ac3-4be2-b39e-23eeafca673c"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"93"
]
},
"metadata": {},
"execution_count": 41
}
]
},
{
"cell_type": "code",
"source": [
"activity_codes_mapping = {'A': 'walking',\n",
" 'B': 'jogging',\n",
" 'C': 'stairs',\n",
" 'D': 'sitting',\n",
" 'E': 'standing',\n",
" 'F': 'typing',\n",
" 'G': 'brushing teeth',\n",
" 'H': 'eating soup',\n",
" 'I': 'eating chips',\n",
" 'J': 'eating pasta',\n",
" 'K': 'drinking from cup',\n",
" 'L': 'eating sandwich',\n",
" 'M': 'kicking soccer ball',\n",
" 'O': 'playing catch tennis ball',\n",
" 'P': 'dribbling basket ball',\n",
" 'Q': 'writing',\n",
" 'R': 'clapping',\n",
" 'S': 'folding clothes'}"
],
"metadata": {
"id": "2FxBH7bnNvsH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Read watch accelerometer values"
],
"metadata": {
"id": "yMUWqR7Oovt5"
}
},
{
"cell_type": "code",
"source": [
"GOOGLE_DRIVE_PATH"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 36
},
"id": "eyIrxc_sZsQ7",
"outputId": "906e9e60-837d-4492-a955-63be08786046"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel'"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
}
},
"metadata": {},
"execution_count": 43
}
]
},
{
"cell_type": "code",
"source": [
"all_files = glob.glob(GOOGLE_DRIVE_PATH + \"/*.arff\")\n",
"\n",
"list_dfs_phone_accel = []\n",
"\n",
"for filename in all_files:\n",
" print(filename)\n",
" df = pd.read_csv(filename, names = features, skiprows = 96, index_col=None, header=0)\n",
" list_dfs_phone_accel.append(df)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2JUPucHRlocb",
"outputId": "4ef99ffe-89f3-4e70-8996-9f9fddcb9096"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1603_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1601_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1602_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1600_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1606_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1620_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1608_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1629_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1623_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1618_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1604_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1622_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1610_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1625_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1626_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1617_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1619_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1630_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1605_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1613_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1612_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1609_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1621_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1628_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1607_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1616_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1611_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1627_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1615_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1624_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1636_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1634_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1635_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1633_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1631_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1637_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1638_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1639_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1632_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1643_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1649_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1642_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1641_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1645_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1650_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1644_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1640_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1647_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1646_accel_watch.arff\n",
"drive/My Drive/Colab Notebooks/dataset/wisdm-dataset/arff_files/watch/accel/data_1648_accel_watch.arff\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"all_phone_accel = pd.concat(list_dfs_phone_accel, axis=0, ignore_index=True, sort=False)\n",
"all_phone_accel"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 467
},
"id": "lEcbQ_ZNMTOx",
"outputId": "5bd18864-751f-48fe-ae61-2ee38795ed86"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" ACTIVITY X0 X1 X2 X3 X4 X5 X6 X7 X8 \\\n",
"0 A 0.000 0.000 0.000 0.000 0.020 0.195 0.275 0.240 0.205 \n",
"1 A 0.000 0.000 0.000 0.000 0.000 0.190 0.465 0.310 0.035 \n",
"2 A 0.000 0.000 0.000 0.000 0.000 0.100 0.610 0.270 0.020 \n",
"3 A 0.000 0.000 0.000 0.000 0.000 0.135 0.585 0.240 0.035 \n",
"4 A 0.000 0.000 0.000 0.000 0.000 0.115 0.555 0.265 0.065 \n",
"... ... ... ... ... ... ... ... ... ... ... \n",
"18206 S 0.035 0.065 0.340 0.260 0.145 0.155 0.000 0.000 0.000 \n",
"18207 S 0.005 0.380 0.445 0.140 0.015 0.015 0.000 0.000 0.000 \n",
"18208 S 0.025 0.135 0.365 0.390 0.065 0.020 0.000 0.000 0.000 \n",
"18209 S 0.030 0.235 0.495 0.200 0.040 0.000 0.000 0.000 0.000 \n",
"18210 S 0.000 0.045 0.220 0.605 0.130 0.000 0.000 0.000 0.000 \n",
"\n",
" ... ZMFCC11 ZMFCC12 XYCOS XZCOS YZCOS XYCOR \\\n",
"0 ... 0.482083 0.475888 -0.424800 -0.540664 0.742245 0.180966 \n",
"1 ... 0.408916 0.403662 -0.543209 -0.657680 0.661040 0.223324 \n",
"2 ... 0.362887 0.358224 -0.526723 -0.542331 0.768224 0.268513 \n",
"3 ... 0.358655 0.354046 -0.562874 -0.573174 0.750614 0.270042 \n",
"4 ... 0.372644 0.367856 -0.455862 -0.561776 0.766964 0.382146 \n",
"... ... ... ... ... ... ... ... \n",
"18206 ... 0.539420 0.532489 -0.449785 0.458159 -0.528255 0.407894 \n",
"18207 ... 0.526219 0.519457 -0.130254 0.400250 -0.535456 0.263006 \n",
"18208 ... 0.595985 0.588327 -0.586402 0.478979 -0.440846 -0.142285 \n",
"18209 ... 0.572834 0.565473 -0.429104 0.237741 -0.568138 0.030976 \n",
"18210 ... 0.546479 0.539457 -0.629039 0.794984 -0.583910 0.052829 \n",
"\n",
" XZCOR YZCOR RESULTANT PARTICIPANT \n",
"0 0.028993 0.653510 13.69870 1603 \n",
"1 -0.043878 0.455660 12.37600 1603 \n",
"2 0.066037 0.664313 12.43600 1603 \n",
"3 -0.030761 0.619433 12.28300 1603 \n",
"4 0.091794 0.672474 12.50750 1603 \n",
"... ... ... ... ... \n",
"18206 -0.235094 0.436199 9.95470 1648 \n",
"18207 -0.133095 0.586823 9.90372 1648 \n",
"18208 -0.374878 0.629600 10.06100 1648 \n",
"18209 -0.420649 0.571948 9.89928 1648 \n",
"18210 -0.280061 0.593019 9.95396 1648 \n",
"\n",
"[18211 rows x 93 columns]"
],
"text/html": [
"\n",
" <div id=\"df-93505de1-0f0d-4407-83f8-2611183449b5\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>ACTIVITY</th>\n",
" <th>X0</th>\n",
" <th>X1</th>\n",
" <th>X2</th>\n",
" <th>X3</th>\n",
" <th>X4</th>\n",
" <th>X5</th>\n",
" <th>X6</th>\n",
" <th>X7</th>\n",
" <th>X8</th>\n",
" <th>...</th>\n",
" <th>ZMFCC11</th>\n",
" <th>ZMFCC12</th>\n",
" <th>XYCOS</th>\n",
" <th>XZCOS</th>\n",
" <th>YZCOS</th>\n",
" <th>XYCOR</th>\n",
" <th>XZCOR</th>\n",
" <th>YZCOR</th>\n",
" <th>RESULTANT</th>\n",
" <th>PARTICIPANT</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>A</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.020</td>\n",
" <td>0.195</td>\n",
" <td>0.275</td>\n",
" <td>0.240</td>\n",
" <td>0.205</td>\n",
" <td>...</td>\n",
" <td>0.482083</td>\n",
" <td>0.475888</td>\n",
" <td>-0.424800</td>\n",
" <td>-0.540664</td>\n",
" <td>0.742245</td>\n",
" <td>0.180966</td>\n",
" <td>0.028993</td>\n",
" <td>0.653510</td>\n",
" <td>13.69870</td>\n",
" <td>1603</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>A</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.190</td>\n",
" <td>0.465</td>\n",
" <td>0.310</td>\n",
" <td>0.035</td>\n",
" <td>...</td>\n",
" <td>0.408916</td>\n",
" <td>0.403662</td>\n",
" <td>-0.543209</td>\n",
" <td>-0.657680</td>\n",
" <td>0.661040</td>\n",
" <td>0.223324</td>\n",
" <td>-0.043878</td>\n",
" <td>0.455660</td>\n",
" <td>12.37600</td>\n",
" <td>1603</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>A</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.100</td>\n",
" <td>0.610</td>\n",
" <td>0.270</td>\n",
" <td>0.020</td>\n",
" <td>...</td>\n",
" <td>0.362887</td>\n",
" <td>0.358224</td>\n",
" <td>-0.526723</td>\n",
" <td>-0.542331</td>\n",
" <td>0.768224</td>\n",
" <td>0.268513</td>\n",
" <td>0.066037</td>\n",
" <td>0.664313</td>\n",
" <td>12.43600</td>\n",
" <td>1603</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>A</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.135</td>\n",
" <td>0.585</td>\n",
" <td>0.240</td>\n",
" <td>0.035</td>\n",
" <td>...</td>\n",
" <td>0.358655</td>\n",
" <td>0.354046</td>\n",
" <td>-0.562874</td>\n",
" <td>-0.573174</td>\n",
" <td>0.750614</td>\n",
" <td>0.270042</td>\n",
" <td>-0.030761</td>\n",
" <td>0.619433</td>\n",
" <td>12.28300</td>\n",
" <td>1603</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>A</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.115</td>\n",
" <td>0.555</td>\n",
" <td>0.265</td>\n",
" <td>0.065</td>\n",
" <td>...</td>\n",
" <td>0.372644</td>\n",
" <td>0.367856</td>\n",
" <td>-0.455862</td>\n",
" <td>-0.561776</td>\n",
" <td>0.766964</td>\n",
" <td>0.382146</td>\n",
" <td>0.091794</td>\n",
" <td>0.672474</td>\n",
" <td>12.50750</td>\n",
" <td>1603</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18206</th>\n",
" <td>S</td>\n",
" <td>0.035</td>\n",
" <td>0.065</td>\n",
" <td>0.340</td>\n",
" <td>0.260</td>\n",
" <td>0.145</td>\n",
" <td>0.155</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>...</td>\n",
" <td>0.539420</td>\n",
" <td>0.532489</td>\n",
" <td>-0.449785</td>\n",
" <td>0.458159</td>\n",
" <td>-0.528255</td>\n",
" <td>0.407894</td>\n",
" <td>-0.235094</td>\n",
" <td>0.436199</td>\n",
" <td>9.95470</td>\n",
" <td>1648</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18207</th>\n",
" <td>S</td>\n",
" <td>0.005</td>\n",
" <td>0.380</td>\n",
" <td>0.445</td>\n",
" <td>0.140</td>\n",
" <td>0.015</td>\n",
" <td>0.015</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>...</td>\n",
" <td>0.526219</td>\n",
" <td>0.519457</td>\n",
" <td>-0.130254</td>\n",
" <td>0.400250</td>\n",
" <td>-0.535456</td>\n",
" <td>0.263006</td>\n",
" <td>-0.133095</td>\n",
" <td>0.586823</td>\n",
" <td>9.90372</td>\n",
" <td>1648</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18208</th>\n",
" <td>S</td>\n",
" <td>0.025</td>\n",
" <td>0.135</td>\n",
" <td>0.365</td>\n",
" <td>0.390</td>\n",
" <td>0.065</td>\n",
" <td>0.020</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>...</td>\n",
" <td>0.595985</td>\n",
" <td>0.588327</td>\n",
" <td>-0.586402</td>\n",
" <td>0.478979</td>\n",
" <td>-0.440846</td>\n",
" <td>-0.142285</td>\n",
" <td>-0.374878</td>\n",
" <td>0.629600</td>\n",
" <td>10.06100</td>\n",
" <td>1648</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18209</th>\n",
" <td>S</td>\n",
" <td>0.030</td>\n",
" <td>0.235</td>\n",
" <td>0.495</td>\n",
" <td>0.200</td>\n",
" <td>0.040</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>...</td>\n",
" <td>0.572834</td>\n",
" <td>0.565473</td>\n",
" <td>-0.429104</td>\n",
" <td>0.237741</td>\n",
" <td>-0.568138</td>\n",
" <td>0.030976</td>\n",
" <td>-0.420649</td>\n",
" <td>0.571948</td>\n",
" <td>9.89928</td>\n",
" <td>1648</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18210</th>\n",
" <td>S</td>\n",
" <td>0.000</td>\n",
" <td>0.045</td>\n",
" <td>0.220</td>\n",
" <td>0.605</td>\n",
" <td>0.130</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>0.000</td>\n",
" <td>...</td>\n",
" <td>0.546479</td>\n",
" <td>0.539457</td>\n",
" <td>-0.629039</td>\n",
" <td>0.794984</td>\n",
" <td>-0.583910</td>\n",
" <td>0.052829</td>\n",
" <td>-0.280061</td>\n",
" <td>0.593019</td>\n",
" <td>9.95396</td>\n",
" <td>1648</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>18211 rows × 93 columns</p>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-93505de1-0f0d-4407-83f8-2611183449b5')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-93505de1-0f0d-4407-83f8-2611183449b5 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-93505de1-0f0d-4407-83f8-2611183449b5');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 45
}
]
},
{
"cell_type": "code",
"source": [
"all_phone_accel.drop(['XSTANDDEV','YSTANDDEV','ZSTANDDEV','XVAR','YVAR','ZVAR'], axis = 1, inplace = True)"
],
"metadata": {
"id": "mhYU_3cDlmNT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"all_phone_accel['ACTIVITY'].map(activity_codes_mapping).value_counts()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nAmXTjN9NV0A",
"outputId": "a6d2637e-dbdb-425a-e375-26c21c1ce9a6"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"standing 1046\n",
"drinking from cup 1044\n",
"writing 1038\n",
"sitting 1028\n",
"dribbling basket ball 1027\n",
"folding clothes 1019\n",
"playing catch tennis ball 1015\n",
"eating soup 1012\n",
"walking 1011\n",
"eating chips 1011\n",
"kicking soccer ball 1009\n",
"clapping 1009\n",
"brushing teeth 1006\n",
"stairs 997\n",
"jogging 993\n",
"typing 988\n",
"eating sandwich 980\n",
"eating pasta 978\n",
"Name: ACTIVITY, dtype: int64"
]
},
"metadata": {},
"execution_count": 47
}
]
},
{
"cell_type": "code",
"source": [
"all_phone_accel.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 235
},
"id": "XsP2NVvIVUNC",
"outputId": "bfc5487d-6df3-4953-fee3-8247cff2f99a"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" ACTIVITY X0 X1 X2 X3 X4 X5 X6 X7 X8 ... \\\n",
"0 A 0.0 0.0 0.0 0.0 0.02 0.195 0.275 0.240 0.205 ... \n",
"1 A 0.0 0.0 0.0 0.0 0.00 0.190 0.465 0.310 0.035 ... \n",
"2 A 0.0 0.0 0.0 0.0 0.00 0.100 0.610 0.270 0.020 ... \n",
"3 A 0.0 0.0 0.0 0.0 0.00 0.135 0.585 0.240 0.035 ... \n",
"4 A 0.0 0.0 0.0 0.0 0.00 0.115 0.555 0.265 0.065 ... \n",
"\n",
" ZMFCC11 ZMFCC12 XYCOS XZCOS YZCOS XYCOR XZCOR \\\n",
"0 0.482083 0.475888 -0.424800 -0.540664 0.742245 0.180966 0.028993 \n",
"1 0.408916 0.403662 -0.543209 -0.657680 0.661040 0.223324 -0.043878 \n",
"2 0.362887 0.358224 -0.526723 -0.542331 0.768224 0.268513 0.066037 \n",
"3 0.358655 0.354046 -0.562874 -0.573174 0.750614 0.270042 -0.030761 \n",
"4 0.372644 0.367856 -0.455862 -0.561776 0.766964 0.382146 0.091794 \n",
"\n",
" YZCOR RESULTANT PARTICIPANT \n",
"0 0.653510 13.6987 1603 \n",
"1 0.455660 12.3760 1603 \n",
"2 0.664313 12.4360 1603 \n",
"3 0.619433 12.2830 1603 \n",
"4 0.672474 12.5075 1603 \n",
"\n",
"[5 rows x 87 columns]"
],
"text/html": [
"\n",
" <div id=\"df-b94eb78d-8afd-4535-bb43-983e1ae13e5a\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>ACTIVITY</th>\n",
" <th>X0</th>\n",
" <th>X1</th>\n",
" <th>X2</th>\n",
" <th>X3</th>\n",
" <th>X4</th>\n",
" <th>X5</th>\n",
" <th>X6</th>\n",
" <th>X7</th>\n",
" <th>X8</th>\n",
" <th>...</th>\n",
" <th>ZMFCC11</th>\n",
" <th>ZMFCC12</th>\n",
" <th>XYCOS</th>\n",
" <th>XZCOS</th>\n",
" <th>YZCOS</th>\n",
" <th>XYCOR</th>\n",
" <th>XZCOR</th>\n",
" <th>YZCOR</th>\n",
" <th>RESULTANT</th>\n",
" <th>PARTICIPANT</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>A</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.02</td>\n",
" <td>0.195</td>\n",
" <td>0.275</td>\n",
" <td>0.240</td>\n",
" <td>0.205</td>\n",
" <td>...</td>\n",
" <td>0.482083</td>\n",
" <td>0.475888</td>\n",
" <td>-0.424800</td>\n",
" <td>-0.540664</td>\n",
" <td>0.742245</td>\n",
" <td>0.180966</td>\n",
" <td>0.028993</td>\n",
" <td>0.653510</td>\n",
" <td>13.6987</td>\n",
" <td>1603</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>A</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.190</td>\n",
" <td>0.465</td>\n",
" <td>0.310</td>\n",
" <td>0.035</td>\n",
" <td>...</td>\n",
" <td>0.408916</td>\n",
" <td>0.403662</td>\n",
" <td>-0.543209</td>\n",
" <td>-0.657680</td>\n",
" <td>0.661040</td>\n",
" <td>0.223324</td>\n",
" <td>-0.043878</td>\n",
" <td>0.455660</td>\n",
" <td>12.3760</td>\n",
" <td>1603</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>A</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.100</td>\n",
" <td>0.610</td>\n",
" <td>0.270</td>\n",
" <td>0.020</td>\n",
" <td>...</td>\n",
" <td>0.362887</td>\n",
" <td>0.358224</td>\n",
" <td>-0.526723</td>\n",
" <td>-0.542331</td>\n",
" <td>0.768224</td>\n",
" <td>0.268513</td>\n",
" <td>0.066037</td>\n",
" <td>0.664313</td>\n",
" <td>12.4360</td>\n",
" <td>1603</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>A</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.135</td>\n",
" <td>0.585</td>\n",
" <td>0.240</td>\n",
" <td>0.035</td>\n",
" <td>...</td>\n",
" <td>0.358655</td>\n",
" <td>0.354046</td>\n",
" <td>-0.562874</td>\n",
" <td>-0.573174</td>\n",
" <td>0.750614</td>\n",
" <td>0.270042</td>\n",
" <td>-0.030761</td>\n",
" <td>0.619433</td>\n",
" <td>12.2830</td>\n",
" <td>1603</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>A</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0.115</td>\n",
" <td>0.555</td>\n",
" <td>0.265</td>\n",
" <td>0.065</td>\n",
" <td>...</td>\n",
" <td>0.372644</td>\n",
" <td>0.367856</td>\n",
" <td>-0.455862</td>\n",
" <td>-0.561776</td>\n",
" <td>0.766964</td>\n",
" <td>0.382146</td>\n",
" <td>0.091794</td>\n",
" <td>0.672474</td>\n",
" <td>12.5075</td>\n",
" <td>1603</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 87 columns</p>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-b94eb78d-8afd-4535-bb43-983e1ae13e5a')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-b94eb78d-8afd-4535-bb43-983e1ae13e5a button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-b94eb78d-8afd-4535-bb43-983e1ae13e5a');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 48
}
]
},
{
"cell_type": "code",
"source": [
"all_phone_accel.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "jlXKEWoDwUtK",
"outputId": "e31c845d-26b9-483d-de77-8e7294af0497"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(18211, 87)"
]
},
"metadata": {},
"execution_count": 49
}
]
},
{
"cell_type": "code",
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"y = all_phone_accel.ACTIVITY\n",
"x = all_phone_accel.drop('ACTIVITY', axis=1)\n",
"\n",
"x_train, x_test, y_train, y_test = train_test_split(x, y, \n",
" train_size = 0.75, \n",
" test_size = 0.25,\n",
" shuffle = True, \n",
" stratify = all_phone_accel.ACTIVITY)"
],
"metadata": {
"id": "oGskFZs6czlY"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"y_train.reset_index(drop = True, inplace = True)\n",
"y_test.reset_index(drop = True, inplace = True)"
],
"metadata": {
"id": "b2Sloq97362c"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"x_train = np.array(x_train)\n",
"x_test = np.array(x_test)\n",
"y_train = np.array(y_train)\n",
"y_test = np.array(y_test)"
],
"metadata": {
"id": "bx9x2ymBtvYc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], 1))\n",
"x_test = x_test.reshape((x_test.shape[0], x_test.shape[1], 1))\n",
"\n",
"n_classes = len(np.unique(y_train))\n",
"\n",
"idx = np.random.permutation(len(x_train))\n",
"x_train = x_train[idx]\n",
"y_train = y_train[idx]\n",
"\n",
"y_train[y_train == -1] = 0\n",
"y_test[y_test == -1] = 0"
],
"metadata": {
"id": "wKyIAUe6tm0h"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"type(y_test)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gHKxmOjzzrHD",
"outputId": "7e8ff035-d8d2-458a-90f3-12dfe2ff8cc0"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"numpy.ndarray"
]
},
"metadata": {},
"execution_count": 54
}
]
},
{
"cell_type": "code",
"source": [
"n_classes = len(np.unique(y_train))\n",
"n_classes"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DRBTCMFkzSyx",
"outputId": "affdfbc1-fac2-4857-b58a-8e2909a14347"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"18"
]
},
"metadata": {},
"execution_count": 55
}
]
},
{
"cell_type": "markdown",
"source": [
"### Create train and test data"
],
"metadata": {
"id": "heaaCWQ1bNeb"
}
},
{
"cell_type": "code",
"source": [
""
],
"metadata": {
"id": "oMZ1hNnXbQkX"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Build the model\n",
"\n",
"Our model processes a tensor of shape `(batch size, sequence length, features)`,\n",
"where `sequence length` is the number of time steps and `features` is each input\n",
"timeseries.\n",
"\n",
"You can replace your classification RNN layers with this one: the\n",
"inputs are fully compatible!"
],
"metadata": {
"id": "kxFVsu0Jcw1l"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"from tensorflow import keras\n",
"from tensorflow.keras import layers"
],
"outputs": [],
"metadata": {
"id": "xtpyEgJRcw1m"
}
},
{
"cell_type": "markdown",
"source": [
"We include residual connections, layer normalization, and dropout.\n",
"The resulting layer can be stacked multiple times.\n",
"\n",
"The projection layers are implemented through `keras.layers.Conv1D`."
],
"metadata": {
"id": "T3t2d-ZKcw1n"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"\n",
"def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):\n",
" # Normalization and Attention\n",
" x = layers.LayerNormalization(epsilon=1e-6)(inputs)\n",
" x = layers.MultiHeadAttention(\n",
" key_dim=head_size, num_heads=num_heads, dropout=dropout\n",
" )(x, x)\n",
" x = layers.Dropout(dropout)(x)\n",
" res = x + inputs\n",
"\n",
" # Feed Forward Part\n",
" x = layers.LayerNormalization(epsilon=1e-6)(res)\n",
" x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation=\"relu\")(x)\n",
" x = layers.Dropout(dropout)(x)\n",
" x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)\n",
" return x + res\n"
],
"outputs": [],
"metadata": {
"id": "rYBMYuXBcw1o"
}
},
{
"cell_type": "markdown",
"source": [
"The main part of our model is now complete. We can stack multiple of those\n",
"`transformer_encoder` blocks and we can also proceed to add the final\n",
"Multi-Layer Perceptron classification head. Apart from a stack of `Dense`\n",
"layers, we need to reduce the output tensor of the `TransformerEncoder` part of\n",
"our model down to a vector of features for each data point in the current\n",
"batch. A common way to achieve this is to use a pooling layer. For\n",
"this example, a `GlobalAveragePooling1D` layer is sufficient."
],
"metadata": {
"id": "gzNm2_bdcw1r"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"def build_model(\n",
" input_shape,\n",
" head_size,\n",
" num_heads,\n",
" ff_dim,\n",
" num_transformer_blocks,\n",
" mlp_units,\n",
" dropout=0,\n",
" mlp_dropout=0,\n",
"):\n",
" inputs = keras.Input(shape=input_shape)\n",
" x = inputs\n",
" for _ in range(num_transformer_blocks):\n",
" x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)\n",
"\n",
" x = layers.GlobalAveragePooling1D(data_format=\"channels_first\")(x)\n",
" for dim in mlp_units:\n",
" x = layers.Dense(dim, activation=\"relu\")(x)\n",
" x = layers.Dropout(mlp_dropout)(x)\n",
" outputs = layers.Dense(n_classes, activation=\"softmax\")(x)\n",
" return keras.Model(inputs, outputs)"
],
"outputs": [],
"metadata": {
"id": "gZFXtZzIcw1s"
}
},
{
"cell_type": "markdown",
"source": [
"## Train and evaluate"
],
"metadata": {
"id": "-2n5kjV4cw1t"
}
},
{
"cell_type": "code",
"source": [
"input_shape = x_train.shape[1:]\n",
"input_shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ojPEdQxObJ3A",
"outputId": "2dd74b0a-5eb5-45f9-94e1-4e301e4f09bf"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(86, 1)"
]
},
"metadata": {},
"execution_count": 59
}
]
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"model = build_model(\n",
" input_shape,\n",
" head_size=256,\n",
" num_heads=4,\n",
" ff_dim=4,\n",
" num_transformer_blocks=4,\n",
" mlp_units=[128],\n",
" mlp_dropout=0.4,\n",
" dropout=0.25,\n",
")\n",
"\n",
"model.compile(\n",
" loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=keras.optimizers.Adam(learning_rate=1e-4),\n",
" metrics=[\"sparse_categorical_accuracy\"],\n",
")\n",
"model.summary()\n"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model: \"model_1\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" input_2 (InputLayer) [(None, 86, 1)] 0 [] \n",
" \n",
" layer_normalization_8 (LayerNo (None, 86, 1) 2 ['input_2[0][0]'] \n",
" rmalization) \n",
" \n",
" multi_head_attention_4 (MultiH (None, 86, 1) 7169 ['layer_normalization_8[0][0]', \n",
" eadAttention) 'layer_normalization_8[0][0]'] \n",
" \n",
" dropout_9 (Dropout) (None, 86, 1) 0 ['multi_head_attention_4[0][0]'] \n",
" \n",
" tf.__operators__.add_8 (TFOpLa (None, 86, 1) 0 ['dropout_9[0][0]', \n",
" mbda) 'input_2[0][0]'] \n",
" \n",
" layer_normalization_9 (LayerNo (None, 86, 1) 2 ['tf.__operators__.add_8[0][0]'] \n",
" rmalization) \n",
" \n",
" conv1d_8 (Conv1D) (None, 86, 4) 8 ['layer_normalization_9[0][0]'] \n",
" \n",
" dropout_10 (Dropout) (None, 86, 4) 0 ['conv1d_8[0][0]'] \n",
" \n",
" conv1d_9 (Conv1D) (None, 86, 1) 5 ['dropout_10[0][0]'] \n",
" \n",
" tf.__operators__.add_9 (TFOpLa (None, 86, 1) 0 ['conv1d_9[0][0]', \n",
" mbda) 'tf.__operators__.add_8[0][0]'] \n",
" \n",
" layer_normalization_10 (LayerN (None, 86, 1) 2 ['tf.__operators__.add_9[0][0]'] \n",
" ormalization) \n",
" \n",
" multi_head_attention_5 (MultiH (None, 86, 1) 7169 ['layer_normalization_10[0][0]', \n",
" eadAttention) 'layer_normalization_10[0][0]'] \n",
" \n",
" dropout_11 (Dropout) (None, 86, 1) 0 ['multi_head_attention_5[0][0]'] \n",
" \n",
" tf.__operators__.add_10 (TFOpL (None, 86, 1) 0 ['dropout_11[0][0]', \n",
" ambda) 'tf.__operators__.add_9[0][0]'] \n",
" \n",
" layer_normalization_11 (LayerN (None, 86, 1) 2 ['tf.__operators__.add_10[0][0]']\n",
" ormalization) \n",
" \n",
" conv1d_10 (Conv1D) (None, 86, 4) 8 ['layer_normalization_11[0][0]'] \n",
" \n",
" dropout_12 (Dropout) (None, 86, 4) 0 ['conv1d_10[0][0]'] \n",
" \n",
" conv1d_11 (Conv1D) (None, 86, 1) 5 ['dropout_12[0][0]'] \n",
" \n",
" tf.__operators__.add_11 (TFOpL (None, 86, 1) 0 ['conv1d_11[0][0]', \n",
" ambda) 'tf.__operators__.add_10[0][0]']\n",
" \n",
" layer_normalization_12 (LayerN (None, 86, 1) 2 ['tf.__operators__.add_11[0][0]']\n",
" ormalization) \n",
" \n",
" multi_head_attention_6 (MultiH (None, 86, 1) 7169 ['layer_normalization_12[0][0]', \n",
" eadAttention) 'layer_normalization_12[0][0]'] \n",
" \n",
" dropout_13 (Dropout) (None, 86, 1) 0 ['multi_head_attention_6[0][0]'] \n",
" \n",
" tf.__operators__.add_12 (TFOpL (None, 86, 1) 0 ['dropout_13[0][0]', \n",
" ambda) 'tf.__operators__.add_11[0][0]']\n",
" \n",
" layer_normalization_13 (LayerN (None, 86, 1) 2 ['tf.__operators__.add_12[0][0]']\n",
" ormalization) \n",
" \n",
" conv1d_12 (Conv1D) (None, 86, 4) 8 ['layer_normalization_13[0][0]'] \n",
" \n",
" dropout_14 (Dropout) (None, 86, 4) 0 ['conv1d_12[0][0]'] \n",
" \n",
" conv1d_13 (Conv1D) (None, 86, 1) 5 ['dropout_14[0][0]'] \n",
" \n",
" tf.__operators__.add_13 (TFOpL (None, 86, 1) 0 ['conv1d_13[0][0]', \n",
" ambda) 'tf.__operators__.add_12[0][0]']\n",
" \n",
" layer_normalization_14 (LayerN (None, 86, 1) 2 ['tf.__operators__.add_13[0][0]']\n",
" ormalization) \n",
" \n",
" multi_head_attention_7 (MultiH (None, 86, 1) 7169 ['layer_normalization_14[0][0]', \n",
" eadAttention) 'layer_normalization_14[0][0]'] \n",
" \n",
" dropout_15 (Dropout) (None, 86, 1) 0 ['multi_head_attention_7[0][0]'] \n",
" \n",
" tf.__operators__.add_14 (TFOpL (None, 86, 1) 0 ['dropout_15[0][0]', \n",
" ambda) 'tf.__operators__.add_13[0][0]']\n",
" \n",
" layer_normalization_15 (LayerN (None, 86, 1) 2 ['tf.__operators__.add_14[0][0]']\n",
" ormalization) \n",
" \n",
" conv1d_14 (Conv1D) (None, 86, 4) 8 ['layer_normalization_15[0][0]'] \n",
" \n",
" dropout_16 (Dropout) (None, 86, 4) 0 ['conv1d_14[0][0]'] \n",
" \n",
" conv1d_15 (Conv1D) (None, 86, 1) 5 ['dropout_16[0][0]'] \n",
" \n",
" tf.__operators__.add_15 (TFOpL (None, 86, 1) 0 ['conv1d_15[0][0]', \n",
" ambda) 'tf.__operators__.add_14[0][0]']\n",
" \n",
" global_average_pooling1d_1 (Gl (None, 86) 0 ['tf.__operators__.add_15[0][0]']\n",
" obalAveragePooling1D) \n",
" \n",
" dense_2 (Dense) (None, 128) 11136 ['global_average_pooling1d_1[0][0\n",
" ]'] \n",
" \n",
" dropout_17 (Dropout) (None, 128) 0 ['dense_2[0][0]'] \n",
" \n",
" dense_3 (Dense) (None, 18) 2322 ['dropout_17[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 42,202\n",
"Trainable params: 42,202\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
]
}
],
"metadata": {
"id": "_bzYSsnjcw1x",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "526552a6-6a46-4606-b2ca-5f4e146cec1b"
}
},
{
"cell_type": "markdown",
"source": [
"### Change y_train from categorical to numeric"
],
"metadata": {
"id": "LYY2EwsrxME0"
}
},
{
"cell_type": "code",
"source": [
"_, y_train = np.unique(y_train, return_inverse = True)"
],
"metadata": {
"id": "2cmpnBSsvf0P"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"_, y_test = np.unique(y_test, return_inverse = True)"
],
"metadata": {
"id": "AuWmYwlDyFs1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Init Learning"
],
"metadata": {
"id": "z6L9MlpLyfos"
}
},
{
"cell_type": "code",
"source": [
"callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)]\n",
"\n",
"model.fit(\n",
" x_train,\n",
" y_train,\n",
" validation_split=0.2,\n",
" epochs=50,\n",
" batch_size=64,\n",
" callbacks=callbacks,\n",
")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5u3VAUbMuPjc",
"outputId": "501b3379-fd3d-47a0-abe0-dafc4cd878d9"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/50\n",
"171/171 [==============================] - 42s 167ms/step - loss: 245.1680 - sparse_categorical_accuracy: 0.0542 - val_loss: 61.0063 - val_sparse_categorical_accuracy: 0.0571\n",
"Epoch 2/50\n",
"171/171 [==============================] - 27s 158ms/step - loss: 152.2004 - sparse_categorical_accuracy: 0.0533 - val_loss: 12.2472 - val_sparse_categorical_accuracy: 0.0439\n",
"Epoch 3/50\n",
"171/171 [==============================] - 27s 158ms/step - loss: 42.6060 - sparse_categorical_accuracy: 0.0546 - val_loss: 2.8920 - val_sparse_categorical_accuracy: 0.0622\n",
"Epoch 4/50\n",
"171/171 [==============================] - 27s 157ms/step - loss: 3.0482 - sparse_categorical_accuracy: 0.0520 - val_loss: 2.8910 - val_sparse_categorical_accuracy: 0.0622\n",
"Epoch 5/50\n",
"171/171 [==============================] - 27s 157ms/step - loss: 2.9342 - sparse_categorical_accuracy: 0.0556 - val_loss: 2.8907 - val_sparse_categorical_accuracy: 0.0501\n",
"Epoch 6/50\n",
"171/171 [==============================] - 27s 157ms/step - loss: 2.9182 - sparse_categorical_accuracy: 0.0580 - val_loss: 2.8907 - val_sparse_categorical_accuracy: 0.0501\n",
"Epoch 7/50\n",
"171/171 [==============================] - 27s 157ms/step - loss: 2.9030 - sparse_categorical_accuracy: 0.0597 - val_loss: 2.8908 - val_sparse_categorical_accuracy: 0.0501\n",
"Epoch 8/50\n",
"171/171 [==============================] - 27s 157ms/step - loss: 2.9007 - sparse_categorical_accuracy: 0.0591 - val_loss: 2.8908 - val_sparse_categorical_accuracy: 0.0501\n",
"Epoch 9/50\n",
"171/171 [==============================] - 27s 157ms/step - loss: 2.9003 - sparse_categorical_accuracy: 0.0590 - val_loss: 2.8908 - val_sparse_categorical_accuracy: 0.0501\n",
"Epoch 10/50\n",
"171/171 [==============================] - 27s 156ms/step - loss: 2.9011 - sparse_categorical_accuracy: 0.0589 - val_loss: 2.8909 - val_sparse_categorical_accuracy: 0.0501\n",
"Epoch 11/50\n",
"171/171 [==============================] - 27s 159ms/step - loss: 2.8951 - sparse_categorical_accuracy: 0.0595 - val_loss: 2.8909 - val_sparse_categorical_accuracy: 0.0501\n",
"Epoch 12/50\n",
"171/171 [==============================] - 27s 157ms/step - loss: 2.8978 - sparse_categorical_accuracy: 0.0593 - val_loss: 2.8909 - val_sparse_categorical_accuracy: 0.0501\n",
"Epoch 13/50\n",
"171/171 [==============================] - 27s 156ms/step - loss: 2.8964 - sparse_categorical_accuracy: 0.0589 - val_loss: 2.8910 - val_sparse_categorical_accuracy: 0.0501\n",
"Epoch 14/50\n",
"171/171 [==============================] - 27s 156ms/step - loss: 2.8942 - sparse_categorical_accuracy: 0.0589 - val_loss: 2.8910 - val_sparse_categorical_accuracy: 0.0501\n",
"Epoch 15/50\n",
"171/171 [==============================] - 28s 162ms/step - loss: 2.8949 - sparse_categorical_accuracy: 0.0590 - val_loss: 2.8910 - val_sparse_categorical_accuracy: 0.0501\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f386b9b2c90>"
]
},
"metadata": {},
"execution_count": 65
}
]
},
{
"cell_type": "code",
"source": [
"model.evaluate(x_test, y_test, verbose=1)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "A25yPDK5uTED",
"outputId": "e430de1e-164e-4c7d-8efa-2bc532e6a9b9"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"143/143 [==============================] - 5s 33ms/step - loss: 2.8912 - sparse_categorical_accuracy: 0.0578\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[2.891247510910034, 0.05776411294937134]"
]
},
"metadata": {},
"execution_count": 66
}
]
},
{
"cell_type": "code",
"source": [
"model2 = build_model(\n",
" input_shape,\n",
" head_size=256,\n",
" num_heads=4,\n",
" ff_dim=4,\n",
" num_transformer_blocks=4,\n",
" mlp_units=[128],\n",
" mlp_dropout=0.4,\n",
" dropout=0.25,\n",
")\n",
"\n",
"model2.compile(\n",
" loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=keras.optimizers.Adam(learning_rate=1e-4),\n",
" metrics=[\"sparse_categorical_accuracy\"],\n",
")\n",
"model2.fit(\n",
" x_train,\n",
" y_train,\n",
" validation_split=0.2,\n",
" epochs=200,\n",
" batch_size=64\n",
" # not callback\n",
")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 245
},
"id": "wa3o4VD9z6Le",
"outputId": "9d3fd020-87a3-4ce4-f18e-d5188f28ac4d"
},
"execution_count": null,
"outputs": [
{
"output_type": "error",
"ename": "NameError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-1-2a229574af7e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m model2 = build_model(\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mhead_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mnum_heads\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mff_dim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'build_model' is not defined"
]
}
]
},
{
"cell_type": "code",
"source": [
"model2.evaluate(x_test, y_test, verbose=1)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dap1Osuv0DGE",
"outputId": "67a5c158-8fcb-4cb2-db1c-a6d7ddb69400"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"143/143 [==============================] - 5s 33ms/step - loss: 2.8909 - sparse_categorical_accuracy: 0.0573\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[2.890895128250122, 0.0573248416185379]"
]
},
"metadata": {},
"execution_count": 69
}
]
},
{
"cell_type": "code",
"source": [
"model2.save('drive/My Drive/Colab Notebooks/dataset')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2O-h-ug91CEW",
"outputId": "111afd84-471f-4eba-ec0c-974d29f3ec6a"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:absl:Found untraced functions such as query_layer_call_fn, query_layer_call_and_return_conditional_losses, key_layer_call_fn, key_layer_call_and_return_conditional_losses, value_layer_call_fn while saving (showing 5 of 48). These functions will not be directly callable after loading.\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"INFO:tensorflow:Assets written to: drive/My Drive/Colab Notebooks/dataset/assets\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO:tensorflow:Assets written to: drive/My Drive/Colab Notebooks/dataset/assets\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# How to draw the validation-training loss"
],
"metadata": {
"id": "bOoqP02F3TPh"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Conclusions\n",
"\n",
"In about 110-120 epochs (25s each on Colab), the model reaches a training\n",
"accuracy of ~0.95, validation accuracy of ~84 and a testing\n",
"accuracy of ~85, without hyperparameter tuning. And that is for a model\n",
"with less than 100k parameters. Of course, parameter count and accuracy could be\n",
"improved by a hyperparameter search and a more sophisticated learning rate\n",
"schedule, or a different optimizer.\n",
"\n",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/timeseries_transformer_classification) and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/timeseries_transformer_classification)."
],
"metadata": {
"id": "s2oxF3IOcw1z"
}
},
{
"cell_type": "code",
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers"
],
"metadata": {
"id": "ich2rfZ7nqGM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"vocab_size = 20000 # Only consider the top 20k words\n",
"maxlen = 200 # Only consider the first 200 words of each movie review\n",
"(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "B6PcGBo0nqJU",
"outputId": "31dd6cdf-59d9-4074-89f5-e1dc21bd34fd"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz\n",
"17465344/17464789 [==============================] - 0s 0us/step\n",
"17473536/17464789 [==============================] - 0s 0us/step\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"dicti = tf.keras.datasets.imdb.get_word_index(path=\"imdb_word_index.json\")\n",
"type(dicti)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "iK-iFI90oU8b",
"outputId": "4234da03-f08e-4e2c-d98a-e37f2f5ae3af"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"dict"
]
},
"metadata": {},
"execution_count": 14
}
]
},
{
"cell_type": "code",
"source": [
"y_val"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ri5jtG_NqMtn",
"outputId": "e54d5c07-f724-4140-e197-b2c393f32443"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0, 1, 1, ..., 0, 0, 0])"
]
},
"metadata": {},
"execution_count": 19
}
]
},
{
"cell_type": "code",
"source": [
""
],
"metadata": {
"id": "pWR_4rufvUa1"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Copia de timeseries_transformer_classification",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment