Skip to content

Instantly share code, notes, and snippets.

@FlorisCalkoen
Last active October 7, 2020 15:03
Show Gist options
  • Save FlorisCalkoen/e9bc4cb054c6ad1c8f1ac43a9c21d09f to your computer and use it in GitHub Desktop.
Save FlorisCalkoen/e9bc4cb054c6ad1c8f1ac43a9c21d09f to your computer and use it in GitHub Desktop.
es_rnn_colab_nb_example.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "es-rnn-colab-nb-example.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyO3yaYBU+ZfWti+MxkoFxGY",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/florisrc/e9bc4cb054c6ad1c8f1ac43a9c21d09f/es-rnn-colab-nb-example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IBEgiTXP4u4o",
"colab_type": "text"
},
"source": [
"# ES-RNN Colab NB Example\n",
"\n",
"A GPU-enabled version of the hybrid ES-RNN model by Slawek et al that won the M4 time-series forecasting competition by a large margin, here implemented in a Google Colab environment. The details of our implementation and the results are discussed in detail on this [paper](https://arxiv.org/abs/1907.03329).\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FuudmEWFycm5",
"colab_type": "text"
},
"source": [
"## Get data and code:\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ovCK3_O-xU5E",
"colab_type": "code",
"outputId": "572ca8da-cab3-4050-bf55-50ca703f2b3b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"# get data\n",
"%cd /content\n",
"!mkdir /content/m4_data \n",
"%cd /content/m4_data\n",
"!wget https://www.m4.unic.ac.cy/wp-content/uploads/2017/12/M4DataSet.zip\n",
"!wget https://www.m4.unic.ac.cy/wp-content/uploads/2018/07/M-test-set.zip\n",
"!wget https://github.com/M4Competition/M4-methods/raw/master/Dataset/M4-info.csv\n",
"!mkdir ./Train && cd ./Train && unzip ../M4DataSet.zip && cd ..\n",
"!mkdir ./Test && cd ./Test && unzip ../M-test-set.zip && cd ..\n",
"\n",
"# clone git repo\n",
"%cd /content\n",
"!git clone https://github.com/damitkwr/ESRNN-GPU.git\n",
"\n",
"# copy data to repo\n",
"%cd /content/ESRNN-GPU/\n",
"!mkdir ./data\n",
"%cd data/\n",
"!mkdir ./Train && cp /content/m4_data/Train/* ./Train/\n",
"!mkdir ./Test && cp /content/m4_data/Test/* ./Test/\n",
"!cp /content/m4_data/M4-info.csv ./info.csv\n",
"!cd ../.."
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"/content\n",
"/content/m4_data\n",
"--2020-04-20 10:14:51-- https://www.m4.unic.ac.cy/wp-content/uploads/2017/12/M4DataSet.zip\n",
"Resolving www.m4.unic.ac.cy (www.m4.unic.ac.cy)... 35.177.142.35, 35.176.90.68\n",
"Connecting to www.m4.unic.ac.cy (www.m4.unic.ac.cy)|35.177.142.35|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 66613994 (64M) [application/zip]\n",
"Saving to: ‘M4DataSet.zip’\n",
"\n",
"M4DataSet.zip 100%[===================>] 63.53M 19.3MB/s in 3.6s \n",
"\n",
"2020-04-20 10:14:55 (17.9 MB/s) - ‘M4DataSet.zip’ saved [66613994/66613994]\n",
"\n",
"--2020-04-20 10:14:56-- https://www.m4.unic.ac.cy/wp-content/uploads/2018/07/M-test-set.zip\n",
"Resolving www.m4.unic.ac.cy (www.m4.unic.ac.cy)... 35.177.142.35, 35.176.90.68\n",
"Connecting to www.m4.unic.ac.cy (www.m4.unic.ac.cy)|35.177.142.35|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 3723045 (3.5M) [application/zip]\n",
"Saving to: ‘M-test-set.zip’\n",
"\n",
"M-test-set.zip 100%[===================>] 3.55M 4.20MB/s in 0.8s \n",
"\n",
"2020-04-20 10:14:57 (4.20 MB/s) - ‘M-test-set.zip’ saved [3723045/3723045]\n",
"\n",
"--2020-04-20 10:14:58-- https://github.com/M4Competition/M4-methods/raw/master/Dataset/M4-info.csv\n",
"Resolving github.com (github.com)... 192.30.255.112\n",
"Connecting to github.com (github.com)|192.30.255.112|:443... connected.\n",
"HTTP request sent, awaiting response... 301 Moved Permanently\n",
"Location: https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/M4-info.csv [following]\n",
"--2020-04-20 10:14:58-- https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/M4-info.csv\n",
"Reusing existing connection to github.com:443.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://raw.githubusercontent.com/Mcompetitions/M4-methods/master/Dataset/M4-info.csv [following]\n",
"--2020-04-20 10:14:59-- https://raw.githubusercontent.com/Mcompetitions/M4-methods/master/Dataset/M4-info.csv\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 4335598 (4.1M) [text/plain]\n",
"Saving to: ‘M4-info.csv’\n",
"\n",
"M4-info.csv 100%[===================>] 4.13M --.-KB/s in 0.1s \n",
"\n",
"2020-04-20 10:14:59 (31.7 MB/s) - ‘M4-info.csv’ saved [4335598/4335598]\n",
"\n",
"Archive: ../M4DataSet.zip\n",
" inflating: Daily-train.csv \n",
" inflating: Hourly-train.csv \n",
" inflating: Monthly-train.csv \n",
" inflating: Quarterly-train.csv \n",
" inflating: Weekly-train.csv \n",
" inflating: Yearly-train.csv \n",
"Archive: ../M-test-set.zip\n",
" inflating: Daily-test.csv \n",
" inflating: Hourly-test.csv \n",
" inflating: Monthly-test.csv \n",
" inflating: Quarterly-test.csv \n",
" inflating: Weekly-test.csv \n",
" inflating: Yearly-test.csv \n",
"/content\n",
"Cloning into 'ESRNN-GPU'...\n",
"remote: Enumerating objects: 34, done.\u001b[K\n",
"remote: Counting objects: 100% (34/34), done.\u001b[K\n",
"remote: Compressing objects: 100% (25/25), done.\u001b[K\n",
"remote: Total 521 (delta 18), reused 23 (delta 9), pack-reused 487\u001b[K\n",
"Receiving objects: 100% (521/521), 76.96 MiB | 21.29 MiB/s, done.\n",
"Resolving deltas: 100% (330/330), done.\n",
"/content/ESRNN-GPU\n",
"/content/ESRNN-GPU/data\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BhYOwpDjyiqP",
"colab_type": "text"
},
"source": [
"## Create colab environment with correct library versions"
]
},
{
"cell_type": "code",
"metadata": {
"id": "c9l1BWaOxUuw",
"colab_type": "code",
"outputId": "c53f5d40-2f6e-4574-db11-4ebc2dae8029",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 268
}
},
"source": [
"# uninstall torch \n",
"!pip uninstall torch\n",
"!pip uninstall torch # run twice (recommendation pytorch forums)\n",
"\n",
"# and re-install as 0.4.1\n",
"from os import path\n",
"from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag\n",
"platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())\n",
"\n",
"accelerator = 'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'\n",
"\n",
"!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision\n",
"\n",
"# tensorflow version 1 \n",
"%tensorflow_version 1.x\n",
"\n",
"import torch\n",
"import tensorflow as tf \n",
"print(f'Torch version: {torch.__version__}')\n",
"print(f'Tensorflow version: {tf.__version__}')\n",
"print(f'Torch.cuda.is_available: {torch.cuda.is_available()}')"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Uninstalling torch-0.4.1:\n",
" Would remove:\n",
" /usr/local/lib/python3.6/dist-packages/torch-0.4.1.dist-info/*\n",
" /usr/local/lib/python3.6/dist-packages/torch/*\n",
"Proceed (y/n)? y\n",
" Successfully uninstalled torch-0.4.1\n",
"\u001b[33mWARNING: Skipping torch as it is not installed.\u001b[0m\n",
"\u001b[K |████████████████████████████████| 483.0MB 1.2MB/s \n",
"\u001b[31mERROR: torchvision 0.5.0 has requirement torch==1.4.0, but you'll have torch 0.4.1 which is incompatible.\u001b[0m\n",
"\u001b[31mERROR: fastai 1.0.60 has requirement torch>=1.0.0, but you'll have torch 0.4.1 which is incompatible.\u001b[0m\n",
"\u001b[?25hTensorFlow 1.x selected.\n",
"Torch version: 0.4.1\n",
"Tensorflow version: 1.15.2\n",
"Torch.cuda.is_available: True\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9pVt5QPS0FN_",
"colab_type": "text"
},
"source": [
"## Check model configurations (optional)\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "EYXHAyL55YGx",
"colab_type": "code",
"outputId": "4cf207f2-b5da-4440-dd02-dc59679de4ef",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 627
}
},
"source": [
"# move to project working directory\n",
"%cd /content/ESRNN-GPU/\n",
"\n",
"# Check configuration\n",
"import pprint\n",
"from es_rnn.config import get_config\n",
"\n",
"config = get_config('Monthly') # can be quarterly, monthly, daily or yearly. \n",
"pprint.pprint(config)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"/content/ESRNN-GPU\n",
"{'add_nl_layer': True,\n",
" 'batch_size': 1024,\n",
" 'c_state_penalty': 0,\n",
" 'chop_val': 72,\n",
" 'device': 'cuda',\n",
" 'dilations': ((1, 3), (6, 12)),\n",
" 'gradient_clipping': 20,\n",
" 'input_size': 12,\n",
" 'input_size_i': 12,\n",
" 'learning_rate': 0.001,\n",
" 'learning_rates': (10, 0.0001),\n",
" 'level_variability_penalty': 50,\n",
" 'lr_anneal_rate': 0.5,\n",
" 'lr_anneal_step': 5,\n",
" 'lr_ratio': 3.1622776601683795,\n",
" 'lr_tolerance_multip': 1.005,\n",
" 'min_epochs_before_changing_lrate': 2,\n",
" 'min_learning_rate': 0.0001,\n",
" 'num_of_categories': 6,\n",
" 'num_of_train_epochs': 15,\n",
" 'output_size': 18,\n",
" 'output_size_i': 18,\n",
" 'percentile': 50,\n",
" 'print_output_stats': 3,\n",
" 'print_train_batch_every': 5,\n",
" 'prod': True,\n",
" 'rnn_cell_type': 'LSTM',\n",
" 'seasonality': 12,\n",
" 'state_hsize': 50,\n",
" 'tau': 0.5,\n",
" 'training_percentile': 45,\n",
" 'training_tau': 0.45,\n",
" 'variable': 'Monthly'}\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ehz5haXl9MKA",
"colab_type": "text"
},
"source": [
"## Edit model configurations (optional) "
]
},
{
"cell_type": "code",
"metadata": {
"id": "cERmd0848p61",
"colab_type": "code",
"outputId": "c1fff371-08ec-4ae7-ba1e-62ecf1b34fcc",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"# print config.py and copy code to clipboard \n",
"!cat /es_rnn/config.py"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"cat: /es_rnn/config.py: No such file or directory\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "3baLvW4A9eD9",
"colab_type": "code",
"outputId": "70d614fa-87c0-41cf-a52e-76678a5fb3ff",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"%%writefile /content/ESRNN-GPU/es_rnn/config.py\n",
"\n",
"from math import sqrt\n",
"\n",
"import torch\n",
"\n",
"\n",
"def get_config(interval):\n",
" config = {\n",
" 'prod': True,\n",
" 'device': (\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n",
" 'percentile': 50,\n",
" 'training_percentile': 45,\n",
" 'add_nl_layer': True,\n",
" 'rnn_cell_type': 'LSTM',\n",
" 'learning_rate': 1e-3,\n",
" 'learning_rates': ((10, 1e-4)),\n",
" 'num_of_train_epochs': 5,\n",
" 'num_of_categories': 6, # in data provided\n",
" 'batch_size': 1024,\n",
" 'gradient_clipping': 20,\n",
" 'c_state_penalty': 0,\n",
" 'min_learning_rate': 0.0001,\n",
" 'lr_ratio': sqrt(10),\n",
" 'lr_tolerance_multip': 1.005,\n",
" 'min_epochs_before_changing_lrate': 2,\n",
" 'print_train_batch_every': 5,\n",
" 'print_output_stats': 3,\n",
" 'lr_anneal_rate': 0.5,\n",
" 'lr_anneal_step': 5\n",
" }\n",
"\n",
" if interval == 'Quarterly':\n",
" config.update({\n",
" 'chop_val': 72,\n",
" 'variable': \"Quarterly\",\n",
" 'dilations': ((1, 2), (4, 8)),\n",
" 'state_hsize': 40,\n",
" 'seasonality': 4,\n",
" 'input_size': 4,\n",
" 'output_size': 8,\n",
" 'level_variability_penalty': 80\n",
" })\n",
" elif interval == 'Monthly':\n",
" config.update({\n",
" # RUNTIME PARAMETERS\n",
" 'chop_val': 72,\n",
" 'variable': \"Monthly\",\n",
" 'dilations': ((1, 3), (6, 12)),\n",
" 'state_hsize': 50,\n",
" 'seasonality': 12,\n",
" 'input_size': 12,\n",
" 'output_size': 18,\n",
" 'level_variability_penalty': 50\n",
" })\n",
" elif interval == 'Daily':\n",
" config.update({\n",
" # RUNTIME PARAMETERS\n",
" 'chop_val': 200,\n",
" 'variable': \"Daily\",\n",
" 'dilations': ((1, 7), (14, 28)),\n",
" 'state_hsize': 50,\n",
" 'seasonality': 7,\n",
" 'input_size': 7,\n",
" 'output_size': 14,\n",
" 'level_variability_penalty': 50\n",
" })\n",
" elif interval == 'Yearly':\n",
"\n",
" config.update({\n",
" # RUNTIME PARAMETERS\n",
" 'chop_val': 25,\n",
" 'variable': \"Yearly\",\n",
" 'dilations': ((1, 2), (2, 6)),\n",
" 'state_hsize': 30,\n",
" 'seasonality': 1,\n",
" 'input_size': 4,\n",
" 'output_size': 6,\n",
" 'level_variability_penalty': 0\n",
" })\n",
" else:\n",
" print(\"I don't have that config. :(\")\n",
"\n",
" config['input_size_i'] = config['input_size']\n",
" config['output_size_i'] = config['output_size']\n",
" config['tau'] = config['percentile'] / 100\n",
" config['training_tau'] = config['training_percentile'] / 100\n",
"\n",
" if not config['prod']:\n",
" config['batch_size'] = 10\n",
" config['num_of_train_epochs'] = 15\n",
"\n",
" return config"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Overwriting /content/ESRNN-GPU/es_rnn/config.py\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "dD8_1psvxUka",
"colab_type": "code",
"outputId": "fd852628-9951-46ef-c791-294970a8e5d0",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"# move to project working directory\n",
"%cd /content/ESRNN-GPU/\n",
"\n",
"import pandas as pd\n",
"from torch.utils.data import DataLoader\n",
"from es_rnn.data_loading import create_datasets, SeriesDataset\n",
"from es_rnn.config import get_config\n",
"from es_rnn.trainer import ESRNNTrainer\n",
"from es_rnn.model import ESRNN\n",
"import time\n",
"\n",
"print('loading config')\n",
"config = get_config('Monthly')\n",
"\n",
"print('loading data')\n",
"info = pd.read_csv('/content/ESRNN-GPU/data/info.csv')\n",
"\n",
"train_path = '/content/ESRNN-GPU/data/Train/%s-train.csv' % (config['variable'])\n",
"test_path = '/content/ESRNN-GPU/data/Test/%s-test.csv' % (config['variable'])\n",
"\n",
"train, val, test = create_datasets(train_path, test_path, config['output_size'])\n",
"\n",
"dataset = SeriesDataset(train, val, test, info, config['variable'], config['chop_val'], config['device'])\n",
"dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)\n",
"\n",
"run_id = str(int(time.time()))\n",
"model = ESRNN(num_series=len(dataset), config=config)\n",
"tr = ESRNNTrainer(model, dataloader, run_id, config, ohe_headers=dataset.dataInfoCatHeaders)\n",
"tr.train_epochs() "
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"/content/ESRNN-GPU\n",
"loading config\n",
"loading data\n",
"WARNING:tensorflow:From /content/ESRNN-GPU/utils/logger.py:9: The name tf.summary.FileWriter is deprecated. Please use tf.compat.v1.summary.FileWriter instead.\n",
"\n",
"Train_batch: 1\n",
"WARNING:tensorflow:From /content/ESRNN-GPU/utils/logger.py:20: The name tf.Summary is deprecated. Please use tf.compat.v1.Summary instead.\n",
"\n",
"Train_batch: 2\n",
"Train_batch: 3\n",
"Train_batch: 4\n",
"Train_batch: 5\n",
"Train_batch: 6\n",
"Train_batch: 7\n",
"Train_batch: 8\n",
"Train_batch: 9\n",
"Train_batch: 10\n",
"Train_batch: 11\n",
"Train_batch: 12\n",
"Train_batch: 13\n",
"Train_batch: 14\n",
"Train_batch: 15\n",
"Train_batch: 16\n",
"Train_batch: 17\n",
"Train_batch: 18\n",
"Train_batch: 19\n",
"Train_batch: 20\n",
"Train_batch: 21\n",
"Train_batch: 22\n",
"Train_batch: 23\n",
"Train_batch: 24\n",
"Train_batch: 25\n",
"Train_batch: 26\n",
"Train_batch: 27\n",
"Train_batch: 28\n",
"Train_batch: 29\n",
"Train_batch: 30\n",
"Train_batch: 31\n",
"Train_batch: 32\n",
"Train_batch: 33\n",
"Train_batch: 34\n",
"Train_batch: 35\n",
"[TRAIN] Epoch [1/15] Loss: 23.5622\n",
"WARNING:tensorflow:From /content/ESRNN-GPU/utils/logger.py:32: The name tf.HistogramProto is deprecated. Please use tf.compat.v1.HistogramProto instead.\n",
"\n",
"Loss decreased, saving model!\n",
"{'Demographic': 9.967270851135254, 'Finance': 14.129257202148438, 'Industry': 14.062990188598633, 'Macro': 14.504400253295898, 'Micro': 12.656795501708984, 'Other': 13.857166290283203, 'Overall': 13.255921363830566, 'hold_out_loss': 9.428352355957031}\n",
"Train_batch: 1\n",
"Train_batch: 2\n",
"Train_batch: 3\n",
"Train_batch: 4\n",
"Train_batch: 5\n",
"Train_batch: 6\n",
"Train_batch: 7\n",
"Train_batch: 8\n",
"Train_batch: 9\n",
"Train_batch: 10\n",
"Train_batch: 11\n",
"Train_batch: 12\n",
"Train_batch: 13\n",
"Train_batch: 14\n",
"Train_batch: 15\n",
"Train_batch: 16\n",
"Train_batch: 17\n",
"Train_batch: 18\n",
"Train_batch: 19\n",
"Train_batch: 20\n",
"Train_batch: 21\n",
"Train_batch: 22\n",
"Train_batch: 23\n",
"Train_batch: 24\n",
"Train_batch: 25\n",
"Train_batch: 26\n",
"Train_batch: 27\n",
"Train_batch: 28\n",
"Train_batch: 29\n",
"Train_batch: 30\n",
"Train_batch: 31\n",
"Train_batch: 32\n",
"Train_batch: 33\n",
"Train_batch: 34\n",
"Train_batch: 35\n",
"[TRAIN] Epoch [2/15] Loss: 6.0620\n",
"Loss decreased, saving model!\n",
"{'Demographic': 5.617069244384766, 'Finance': 10.89091968536377, 'Industry': 10.572635650634766, 'Macro': 11.038947105407715, 'Micro': 8.760212898254395, 'Other': 10.539490699768066, 'Overall': 9.613520622253418, 'hold_out_loss': 6.416837692260742}\n",
"Train_batch: 1\n",
"Train_batch: 2\n",
"Train_batch: 3\n",
"Train_batch: 4\n",
"Train_batch: 5\n",
"Train_batch: 6\n",
"Train_batch: 7\n",
"Train_batch: 8\n",
"Train_batch: 9\n",
"Train_batch: 10\n",
"Train_batch: 11\n",
"Train_batch: 12\n",
"Train_batch: 13\n",
"Train_batch: 14\n",
"Train_batch: 15\n",
"Train_batch: 16\n",
"Train_batch: 17\n",
"Train_batch: 18\n",
"Train_batch: 19\n",
"Train_batch: 20\n",
"Train_batch: 21\n",
"Train_batch: 22\n",
"Train_batch: 23\n",
"Train_batch: 24\n",
"Train_batch: 25\n",
"Train_batch: 26\n",
"Train_batch: 27\n",
"Train_batch: 28\n",
"Train_batch: 29\n",
"Train_batch: 30\n",
"Train_batch: 31\n",
"Train_batch: 32\n",
"Train_batch: 33\n",
"Train_batch: 34\n",
"Train_batch: 35\n",
"[TRAIN] Epoch [3/15] Loss: 5.0237\n",
"Loss decreased, saving model!\n",
"{'Demographic': 5.255251884460449, 'Finance': 10.675373077392578, 'Industry': 10.28479290008545, 'Macro': 10.747631072998047, 'Micro': 8.443456649780273, 'Other': 10.35895824432373, 'Overall': 9.323665618896484, 'hold_out_loss': 5.998364448547363}\n",
"Train_batch: 1\n",
"Train_batch: 2\n",
"Train_batch: 3\n",
"Train_batch: 4\n",
"Train_batch: 5\n",
"Train_batch: 6\n",
"Train_batch: 7\n",
"Train_batch: 8\n",
"Train_batch: 9\n",
"Train_batch: 10\n",
"Train_batch: 11\n",
"Train_batch: 12\n",
"Train_batch: 13\n",
"Train_batch: 14\n",
"Train_batch: 15\n",
"Train_batch: 16\n",
"Train_batch: 17\n",
"Train_batch: 18\n",
"Train_batch: 19\n",
"Train_batch: 20\n",
"Train_batch: 21\n",
"Train_batch: 22\n",
"Train_batch: 23\n",
"Train_batch: 24\n",
"Train_batch: 25\n",
"Train_batch: 26\n",
"Train_batch: 27\n",
"Train_batch: 28\n",
"Train_batch: 29\n",
"Train_batch: 30\n",
"Train_batch: 31\n",
"Train_batch: 32\n",
"Train_batch: 33\n",
"Train_batch: 34\n",
"Train_batch: 35\n",
"[TRAIN] Epoch [4/15] Loss: 4.7732\n",
"Loss decreased, saving model!\n",
"{'Demographic': 5.173051834106445, 'Finance': 10.586264610290527, 'Industry': 10.207746505737305, 'Macro': 10.641335487365723, 'Micro': 8.287897109985352, 'Other': 10.417609214782715, 'Overall': 9.224047660827637, 'hold_out_loss': 5.846914291381836}\n",
"Train_batch: 1\n",
"Train_batch: 2\n",
"Train_batch: 3\n",
"Train_batch: 4\n",
"Train_batch: 5\n",
"Train_batch: 6\n",
"Train_batch: 7\n",
"Train_batch: 8\n",
"Train_batch: 9\n",
"Train_batch: 10\n",
"Train_batch: 11\n",
"Train_batch: 12\n",
"Train_batch: 13\n",
"Train_batch: 14\n",
"Train_batch: 15\n",
"Train_batch: 16\n",
"Train_batch: 17\n",
"Train_batch: 18\n",
"Train_batch: 19\n",
"Train_batch: 20\n",
"Train_batch: 21\n",
"Train_batch: 22\n",
"Train_batch: 23\n",
"Train_batch: 24\n",
"Train_batch: 25\n",
"Train_batch: 26\n",
"Train_batch: 27\n",
"Train_batch: 28\n",
"Train_batch: 29\n",
"Train_batch: 30\n",
"Train_batch: 31\n",
"Train_batch: 32\n",
"Train_batch: 33\n",
"Train_batch: 34\n",
"Train_batch: 35\n",
"[TRAIN] Epoch [5/15] Loss: 4.6542\n",
"Loss decreased, saving model!\n",
"{'Demographic': 5.164522171020508, 'Finance': 10.578543663024902, 'Industry': 10.187837600708008, 'Macro': 10.623577117919922, 'Micro': 8.26281452178955, 'Other': 10.45866870880127, 'Overall': 9.20844554901123, 'hold_out_loss': 5.779313087463379}\n",
"Train_batch: 1\n",
"Train_batch: 2\n",
"Train_batch: 3\n",
"Train_batch: 4\n",
"Train_batch: 5\n",
"Train_batch: 6\n",
"Train_batch: 7\n",
"Train_batch: 8\n",
"Train_batch: 9\n",
"Train_batch: 10\n",
"Train_batch: 11\n",
"Train_batch: 12\n",
"Train_batch: 13\n",
"Train_batch: 14\n",
"Train_batch: 15\n",
"Train_batch: 16\n",
"Train_batch: 17\n",
"Train_batch: 18\n",
"Train_batch: 19\n",
"Train_batch: 20\n",
"Train_batch: 21\n",
"Train_batch: 22\n",
"Train_batch: 23\n",
"Train_batch: 24\n",
"Train_batch: 25\n",
"Train_batch: 26\n",
"Train_batch: 27\n",
"Train_batch: 28\n",
"Train_batch: 29\n",
"Train_batch: 30\n",
"Train_batch: 31\n",
"Train_batch: 32\n",
"Train_batch: 33\n",
"Train_batch: 34\n",
"Train_batch: 35\n",
"[TRAIN] Epoch [6/15] Loss: 4.5945\n",
"Loss decreased, saving model!\n",
"{'Demographic': 5.154304027557373, 'Finance': 10.566814422607422, 'Industry': 10.174098014831543, 'Macro': 10.608725547790527, 'Micro': 8.236871719360352, 'Other': 10.460794448852539, 'Overall': 9.193401336669922, 'hold_out_loss': 5.757056713104248}\n",
"Train_batch: 1\n",
"Train_batch: 2\n",
"Train_batch: 3\n",
"Train_batch: 4\n",
"Train_batch: 5\n",
"Train_batch: 6\n",
"Train_batch: 7\n",
"Train_batch: 8\n",
"Train_batch: 9\n",
"Train_batch: 10\n",
"Train_batch: 11\n",
"Train_batch: 12\n",
"Train_batch: 13\n",
"Train_batch: 14\n",
"Train_batch: 15\n",
"Train_batch: 16\n",
"Train_batch: 17\n",
"Train_batch: 18\n",
"Train_batch: 19\n",
"Train_batch: 20\n",
"Train_batch: 21\n",
"Train_batch: 22\n",
"Train_batch: 23\n",
"Train_batch: 24\n",
"Train_batch: 25\n",
"Train_batch: 26\n",
"Train_batch: 27\n",
"Train_batch: 28\n",
"Train_batch: 29\n",
"Train_batch: 30\n",
"Train_batch: 31\n",
"Train_batch: 32\n",
"Train_batch: 33\n",
"Train_batch: 34\n",
"Train_batch: 35\n",
"[TRAIN] Epoch [7/15] Loss: 4.5661\n",
"Loss decreased, saving model!\n",
"{'Demographic': 5.155148506164551, 'Finance': 10.550554275512695, 'Industry': 10.173151016235352, 'Macro': 10.588346481323242, 'Micro': 8.188667297363281, 'Other': 10.482839584350586, 'Overall': 9.177229881286621, 'hold_out_loss': 5.744265556335449}\n",
"Train_batch: 1\n",
"Train_batch: 2\n",
"Train_batch: 3\n",
"Train_batch: 4\n",
"Train_batch: 5\n",
"Train_batch: 6\n",
"Train_batch: 7\n",
"Train_batch: 8\n",
"Train_batch: 9\n",
"Train_batch: 10\n",
"Train_batch: 11\n",
"Train_batch: 12\n",
"Train_batch: 13\n",
"Train_batch: 14\n",
"Train_batch: 15\n",
"Train_batch: 16\n",
"Train_batch: 17\n",
"Train_batch: 18\n",
"Train_batch: 19\n",
"Train_batch: 20\n",
"Train_batch: 21\n",
"Train_batch: 22\n",
"Train_batch: 23\n",
"Train_batch: 24\n",
"Train_batch: 25\n",
"Train_batch: 26\n",
"Train_batch: 27\n",
"Train_batch: 28\n",
"Train_batch: 29\n",
"Train_batch: 30\n",
"Train_batch: 31\n",
"Train_batch: 32\n",
"Train_batch: 33\n",
"Train_batch: 34\n",
"Train_batch: 35\n",
"[TRAIN] Epoch [8/15] Loss: 4.5431\n",
"Loss decreased, saving model!\n",
"{'Demographic': 5.154897212982178, 'Finance': 10.566178321838379, 'Industry': 10.165489196777344, 'Macro': 10.598137855529785, 'Micro': 8.2199068069458, 'Other': 10.474693298339844, 'Overall': 9.186197280883789, 'hold_out_loss': 5.735326766967773}\n",
"Train_batch: 1\n",
"Train_batch: 2\n",
"Train_batch: 3\n",
"Train_batch: 4\n",
"Train_batch: 5\n",
"Train_batch: 6\n",
"Train_batch: 7\n",
"Train_batch: 8\n",
"Train_batch: 9\n",
"Train_batch: 10\n",
"Train_batch: 11\n",
"Train_batch: 12\n",
"Train_batch: 13\n",
"Train_batch: 14\n",
"Train_batch: 15\n",
"Train_batch: 16\n",
"Train_batch: 17\n",
"Train_batch: 18\n",
"Train_batch: 19\n",
"Train_batch: 20\n",
"Train_batch: 21\n",
"Train_batch: 22\n",
"Train_batch: 23\n",
"Train_batch: 24\n",
"Train_batch: 25\n",
"Train_batch: 26\n",
"Train_batch: 27\n",
"Train_batch: 28\n",
"Train_batch: 29\n",
"Train_batch: 30\n",
"Train_batch: 31\n",
"Train_batch: 32\n",
"Train_batch: 33\n",
"Train_batch: 34\n",
"Train_batch: 35\n",
"[TRAIN] Epoch [9/15] Loss: 4.5236\n",
"Loss decreased, saving model!\n",
"{'Demographic': 5.148446083068848, 'Finance': 10.555113792419434, 'Industry': 10.15269947052002, 'Macro': 10.580646514892578, 'Micro': 8.18532943725586, 'Other': 10.47098445892334, 'Overall': 9.170007705688477, 'hold_out_loss': 5.723298072814941}\n",
"Train_batch: 1\n",
"Train_batch: 2\n",
"Train_batch: 3\n",
"Train_batch: 4\n",
"Train_batch: 5\n",
"Train_batch: 6\n",
"Train_batch: 7\n",
"Train_batch: 8\n",
"Train_batch: 9\n",
"Train_batch: 10\n",
"Train_batch: 11\n",
"Train_batch: 12\n",
"Train_batch: 13\n",
"Train_batch: 14\n",
"Train_batch: 15\n",
"Train_batch: 16\n",
"Train_batch: 17\n",
"Train_batch: 18\n",
"Train_batch: 19\n",
"Train_batch: 20\n",
"Train_batch: 21\n",
"Train_batch: 22\n",
"Train_batch: 23\n",
"Train_batch: 24\n",
"Train_batch: 25\n",
"Train_batch: 26\n",
"Train_batch: 27\n",
"Train_batch: 28\n",
"Train_batch: 29\n",
"Train_batch: 30\n",
"Train_batch: 31\n",
"Train_batch: 32\n",
"Train_batch: 33\n",
"Train_batch: 34\n",
"Train_batch: 35\n",
"[TRAIN] Epoch [10/15] Loss: 4.5057\n",
"Loss decreased, saving model!\n",
"{'Demographic': 5.152772426605225, 'Finance': 10.554600715637207, 'Industry': 10.157341957092285, 'Macro': 10.57533073425293, 'Micro': 8.166055679321289, 'Other': 10.48807430267334, 'Overall': 9.167271614074707, 'hold_out_loss': 5.719578266143799}\n",
"Train_batch: 1\n",
"Train_batch: 2\n",
"Train_batch: 3\n",
"Train_batch: 4\n",
"Train_batch: 5\n",
"Train_batch: 6\n",
"Train_batch: 7\n",
"Train_batch: 8\n",
"Train_batch: 9\n",
"Train_batch: 10\n",
"Train_batch: 11\n",
"Train_batch: 12\n",
"Train_batch: 13\n",
"Train_batch: 14\n",
"Train_batch: 15\n",
"Train_batch: 16\n",
"Train_batch: 17\n",
"Train_batch: 18\n",
"Train_batch: 19\n",
"Train_batch: 20\n",
"Train_batch: 21\n",
"Train_batch: 22\n",
"Train_batch: 23\n",
"Train_batch: 24\n",
"Train_batch: 25\n",
"Train_batch: 26\n",
"Train_batch: 27\n",
"Train_batch: 28\n",
"Train_batch: 29\n",
"Train_batch: 30\n",
"Train_batch: 31\n",
"Train_batch: 32\n",
"Train_batch: 33\n",
"Train_batch: 34\n",
"Train_batch: 35\n",
"[TRAIN] Epoch [11/15] Loss: 4.4916\n",
"Loss decreased, saving model!\n",
"{'Demographic': 5.1514434814453125, 'Finance': 10.558562278747559, 'Industry': 10.150896072387695, 'Macro': 10.575738906860352, 'Micro': 8.171943664550781, 'Other': 10.482748985290527, 'Overall': 9.167464256286621, 'hold_out_loss': 5.716712951660156}\n",
"Train_batch: 1\n",
"Train_batch: 2\n",
"Train_batch: 3\n",
"Train_batch: 4\n",
"Train_batch: 5\n",
"Train_batch: 6\n",
"Train_batch: 7\n",
"Train_batch: 8\n",
"Train_batch: 9\n",
"Train_batch: 10\n",
"Train_batch: 11\n",
"Train_batch: 12\n",
"Train_batch: 13\n",
"Train_batch: 14\n",
"Train_batch: 15\n",
"Train_batch: 16\n",
"Train_batch: 17\n",
"Train_batch: 18\n",
"Train_batch: 19\n",
"Train_batch: 20\n",
"Train_batch: 21\n",
"Train_batch: 22\n",
"Train_batch: 23\n",
"Train_batch: 24\n",
"Train_batch: 25\n",
"Train_batch: 26\n",
"Train_batch: 27\n",
"Train_batch: 28\n",
"Train_batch: 29\n",
"Train_batch: 30\n",
"Train_batch: 31\n",
"Train_batch: 32\n",
"Train_batch: 33\n",
"Train_batch: 34\n",
"Train_batch: 35\n",
"[TRAIN] Epoch [12/15] Loss: 4.4839\n",
"Loss decreased, saving model!\n",
"{'Demographic': 5.161124229431152, 'Finance': 10.557811737060547, 'Industry': 10.139664649963379, 'Macro': 10.573782920837402, 'Micro': 8.167551040649414, 'Other': 10.473078727722168, 'Overall': 9.164985656738281, 'hold_out_loss': 5.712883472442627}\n",
"Train_batch: 1\n",
"Train_batch: 2\n",
"Train_batch: 3\n",
"Train_batch: 4\n",
"Train_batch: 5\n",
"Train_batch: 6\n",
"Train_batch: 7\n",
"Train_batch: 8\n",
"Train_batch: 9\n",
"Train_batch: 10\n",
"Train_batch: 11\n",
"Train_batch: 12\n",
"Train_batch: 13\n",
"Train_batch: 14\n",
"Train_batch: 15\n",
"Train_batch: 16\n",
"Train_batch: 17\n",
"Train_batch: 18\n",
"Train_batch: 19\n",
"Train_batch: 20\n",
"Train_batch: 21\n",
"Train_batch: 22\n",
"Train_batch: 23\n",
"Train_batch: 24\n",
"Train_batch: 25\n",
"Train_batch: 26\n",
"Train_batch: 27\n",
"Train_batch: 28\n",
"Train_batch: 29\n",
"Train_batch: 30\n",
"Train_batch: 31\n",
"Train_batch: 32\n",
"Train_batch: 33\n",
"Train_batch: 34\n",
"Train_batch: 35\n",
"[TRAIN] Epoch [13/15] Loss: 4.4770\n",
"Loss decreased, saving model!\n",
"{'Demographic': 5.154560089111328, 'Finance': 10.553889274597168, 'Industry': 10.143238067626953, 'Macro': 10.56783676147461, 'Micro': 8.161402702331543, 'Other': 10.475221633911133, 'Overall': 9.161617279052734, 'hold_out_loss': 5.711756229400635}\n",
"Train_batch: 1\n",
"Train_batch: 2\n",
"Train_batch: 3\n",
"Train_batch: 4\n",
"Train_batch: 5\n",
"Train_batch: 6\n",
"Train_batch: 7\n",
"Train_batch: 8\n",
"Train_batch: 9\n",
"Train_batch: 10\n",
"Train_batch: 11\n",
"Train_batch: 12\n",
"Train_batch: 13\n",
"Train_batch: 14\n",
"Train_batch: 15\n",
"Train_batch: 16\n",
"Train_batch: 17\n",
"Train_batch: 18\n",
"Train_batch: 19\n",
"Train_batch: 20\n",
"Train_batch: 21\n",
"Train_batch: 22\n",
"Train_batch: 23\n",
"Train_batch: 24\n",
"Train_batch: 25\n",
"Train_batch: 26\n",
"Train_batch: 27\n",
"Train_batch: 28\n",
"Train_batch: 29\n",
"Train_batch: 30\n",
"Train_batch: 31\n",
"Train_batch: 32\n",
"Train_batch: 33\n",
"Train_batch: 34\n",
"Train_batch: 35\n",
"[TRAIN] Epoch [14/15] Loss: 4.4707\n",
"Loss decreased, saving model!\n",
"{'Demographic': 5.163309097290039, 'Finance': 10.553582191467285, 'Industry': 10.137224197387695, 'Macro': 10.565842628479004, 'Micro': 8.156747817993164, 'Other': 10.47332763671875, 'Overall': 9.16030216217041, 'hold_out_loss': 5.708845138549805}\n",
"Train_batch: 1\n",
"Train_batch: 2\n",
"Train_batch: 3\n",
"Train_batch: 4\n",
"Train_batch: 5\n",
"Train_batch: 6\n",
"Train_batch: 7\n",
"Train_batch: 8\n",
"Train_batch: 9\n",
"Train_batch: 10\n",
"Train_batch: 11\n",
"Train_batch: 12\n",
"Train_batch: 13\n",
"Train_batch: 14\n",
"Train_batch: 15\n",
"Train_batch: 16\n",
"Train_batch: 17\n",
"Train_batch: 18\n",
"Train_batch: 19\n",
"Train_batch: 20\n",
"Train_batch: 21\n",
"Train_batch: 22\n",
"Train_batch: 23\n",
"Train_batch: 24\n",
"Train_batch: 25\n",
"Train_batch: 26\n",
"Train_batch: 27\n",
"Train_batch: 28\n",
"Train_batch: 29\n",
"Train_batch: 30\n",
"Train_batch: 31\n",
"Train_batch: 32\n",
"Train_batch: 33\n",
"Train_batch: 34\n",
"Train_batch: 35\n",
"[TRAIN] Epoch [15/15] Loss: 4.4641\n",
"Loss decreased, saving model!\n",
"{'Demographic': 5.177516460418701, 'Finance': 10.559904098510742, 'Industry': 10.120597839355469, 'Macro': 10.571897506713867, 'Micro': 8.172297477722168, 'Other': 10.448819160461426, 'Overall': 9.163888931274414, 'hold_out_loss': 5.705763339996338}\n",
"Total Training Mins: 64.15\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Nh2_90Fm-nPX",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment