Skip to content

Instantly share code, notes, and snippets.

@adoskk
Created October 9, 2020 19:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adoskk/5dba64243eb639c598dcbba685411a06 to your computer and use it in GitHub Desktop.
Save adoskk/5dba64243eb639c598dcbba685411a06 to your computer and use it in GitHub Desktop.
ax_bayesianopt.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "ax_bayesianopt.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMi6AbjGDXrfFl8Uvqo7Gev",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/adoskk/5dba64243eb639c598dcbba685411a06/ax_bayesianopt.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "_BIcfXKu2uSf",
"outputId": "aa716e60-3320-4ccf-b399-6de8936179ae",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 547
}
},
"source": [
"# Original Code here:\n",
"# https://github.com/pytorch/examples/blob/master/mnist/main.py\n",
"import os\n",
"import argparse\n",
"from filelock import FileLock\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torchvision import datasets, transforms\n",
"\n",
"!pip install ax-platform\n",
"from ax.service.managed_loop import optimize\n",
"from ax.utils.notebook.plotting import render\n",
"from ax.utils.tutorials.cnn_utils import train, evaluate\n"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting ax-platform\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/c3/e5/defa97540bf23447f15d142a644eed9a9d9fd1925cf1e3c4f47a49282ec0/ax_platform-0.1.9-py3-none-any.whl (499kB)\n",
"\u001b[K |████████████████████████████████| 501kB 2.8MB/s \n",
"\u001b[?25hRequirement already satisfied: jinja2 in /usr/local/lib/python3.6/dist-packages (from ax-platform) (2.11.2)\n",
"Collecting botorch==0.2.1\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/64/e4/d696b12a84d505e9592fb6f8458a968b19efc22e30cc517dd2d2817e27e4/botorch-0.2.1-py3-none-any.whl (221kB)\n",
"\u001b[K |████████████████████████████████| 225kB 12.1MB/s \n",
"\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from ax-platform) (1.1.2)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from ax-platform) (0.22.2.post1)\n",
"Requirement already satisfied: plotly in /usr/local/lib/python3.6/dist-packages (from ax-platform) (4.4.1)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from ax-platform) (1.4.1)\n",
"Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.6/dist-packages (from jinja2->ax-platform) (1.1.1)\n",
"Collecting gpytorch>=1.0.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/6f/2f/6343548d88284ebf18d241dee12d0975cd7dbdee63c0fb749b23c8f536a1/gpytorch-1.2.0.tar.gz (274kB)\n",
"\u001b[K |████████████████████████████████| 276kB 8.4MB/s \n",
"\u001b[?25hRequirement already satisfied: torch>=1.3.1 in /usr/local/lib/python3.6/dist-packages (from botorch==0.2.1->ax-platform) (1.6.0+cu101)\n",
"Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->ax-platform) (2018.9)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.6/dist-packages (from pandas->ax-platform) (2.8.1)\n",
"Requirement already satisfied: numpy>=1.15.4 in /usr/local/lib/python3.6/dist-packages (from pandas->ax-platform) (1.18.5)\n",
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->ax-platform) (0.16.0)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from plotly->ax-platform) (1.15.0)\n",
"Requirement already satisfied: retrying>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from plotly->ax-platform) (1.3.3)\n",
"Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch>=1.3.1->botorch==0.2.1->ax-platform) (0.16.0)\n",
"Building wheels for collected packages: gpytorch\n",
" Building wheel for gpytorch (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for gpytorch: filename=gpytorch-1.2.0-py2.py3-none-any.whl size=459510 sha256=29906707d58c16570f045ead6c33bacdb5d0e1773fdb971d4d28bea68bfdee77\n",
" Stored in directory: /root/.cache/pip/wheels/e8/eb/36/f415815e8a8b66c1f1d5a3534718c39c2d83501051f1ab604e\n",
"Successfully built gpytorch\n",
"Installing collected packages: gpytorch, botorch, ax-platform\n",
"Successfully installed ax-platform-0.1.9 botorch-0.2.1 gpytorch-1.2.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8Kv8ILE6Patf"
},
"source": [
"## **Create CNN Model**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "kI_Sgs9o3Lg5"
},
"source": [
"# Change these values if you want the training to run quicker or slower.\n",
"EPOCH_SIZE = 512\n",
"TEST_SIZE = 256\n",
"\n",
"# define the network with 1 convolutional layer + 2 FC layers\n",
"class ConvNet(nn.Module):\n",
" def __init__(self):\n",
" super(ConvNet, self).__init__()\n",
" self.conv1 = nn.Conv2d(1, 3, kernel_size=3)\n",
" self.fc = nn.Linear(192, 10)\n",
"\n",
" def forward(self, x):\n",
" x = F.relu(F.max_pool2d(self.conv1(x), 3))\n",
" x = x.view(-1, 192)\n",
" x = self.fc(x)\n",
" return F.log_softmax(x, dim=1)\n"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "8N-e23SzPtJG"
},
"source": [
"## **Define Training and Testing Functions**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "4JhBdRFc3oWY"
},
"source": [
"def train_fun(model, optimizer, train_loader, device=None):\n",
" device = device or torch.device(\"cpu\")\n",
" model.train()\n",
" for batch_idx, (data, target) in enumerate(train_loader):\n",
" if batch_idx * len(data) > EPOCH_SIZE:\n",
" return\n",
" data, target = data.to(device), target.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = F.nll_loss(output, target)\n",
" loss.backward()\n",
" optimizer.step()"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "di1XyDl032XX"
},
"source": [
"def test_fun(model, data_loader, device=None):\n",
" device = device or torch.device(\"cpu\")\n",
" model.eval()\n",
" correct = 0\n",
" total = 0\n",
" with torch.no_grad():\n",
" for batch_idx, (data, target) in enumerate(data_loader):\n",
" if batch_idx * len(data) > TEST_SIZE:\n",
" break\n",
" data, target = data.to(device), target.to(device)\n",
" outputs = model(data)\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += target.size(0)\n",
" correct += (predicted == target).sum().item()\n",
"\n",
" return correct / total\n"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "6aZOxN5uP2zH"
},
"source": [
"## **Data Loader**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "tYW4B6WC3_5X"
},
"source": [
"def get_data_loaders():\n",
" mnist_transforms = transforms.Compose(\n",
" [transforms.ToTensor(),\n",
" transforms.Normalize((0.1307, ), (0.3081, ))])\n",
"\n",
" # We add FileLock here because multiple workers will want to\n",
" # download data, and this may cause overwrites since\n",
" # DataLoader is not threadsafe.\n",
" with FileLock(os.path.expanduser(\"~/data.lock\")):\n",
" train_loader = torch.utils.data.DataLoader(\n",
" datasets.MNIST(\n",
" \"~/data\",\n",
" train=True,\n",
" download=True,\n",
" transform=mnist_transforms),\n",
" batch_size=64,\n",
" shuffle=True)\n",
" test_loader = torch.utils.data.DataLoader(\n",
" datasets.MNIST(\"~/data\", train=False, transform=mnist_transforms),\n",
" batch_size=1,\n",
" shuffle=True)\n",
" return train_loader, test_loader"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "9HVdUXndP6Pf"
},
"source": [
"## **Train Your Model, Collect Testing Accuracy, and Return a Value that can be Evaluated Toward the Best Trial**\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "5fvJDmyY4HQ8"
},
"source": [
"def evaluate_mnist(parameters):\n",
" use_cuda = torch.cuda.is_available()\n",
" device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
" train_loader, test_loader = get_data_loaders()\n",
" model = ConvNet().to(device)\n",
"\n",
" optimizer = optim.SGD(\n",
" model.parameters(), lr=parameters.get(\"lr\", 0.001), momentum=parameters.get(\"momentum\", 0.95))\n",
"\n",
" for epoch in range(100):\n",
" train_fun(model, optimizer, train_loader, device)\n",
" \n",
" acc = test_fun(model, test_loader, device)\n",
" return acc"
],
"execution_count": 12,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "nHXSw3KsRJMp"
},
"source": [
"\n",
"## **Configuration of the Ax**\n",
"\n",
"Notes:\n",
"1. The objective name should the function you want to optimize;\n",
"2. Unlike the Ray, the Ax seems to discourage running trials in parallel. You have to manually change the AxClient setting to run parallel trials. \n",
"3. Set minimize=False if your objective functions need to be maximized. \n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "wUGuXbPl7LB-",
"outputId": "d99be376-7fd3-43c0-a6e7-8b6d703229b5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"from ax.service.ax_client import AxClient\n",
"\n",
"ax_client = AxClient()\n",
"ax_client.create_experiment(name='my_bayesianopt',\n",
" parameters=[{\"name\": \"lr\", \"type\": \"range\", \"bounds\": [1e-4, 1e-2], \"log_scale\": True},\n",
" {\"name\": \"momentum\", \"type\": \"range\", \"bounds\": [0.1, 0.9]}],\n",
" objective_name='evaluate_mnist',\n",
" minimize=False)\n",
"\n",
"for _ in range(50):\n",
" parameters, trial_index = ax_client.get_next_trial()\n",
" ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate_mnist(parameters))\n",
"\n",
"best_parameters, metrics = ax_client.get_best_parameters()\n",
"\n",
"print(best_parameters)"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": [
"[INFO 10-09 19:00:09] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 2 decimal points.\n",
"[INFO 10-09 19:00:09] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 arms, GPEI for subsequent arms], generated 0 arm(s) so far). Iterations after 5 will take longer to generate due to model-fitting.\n",
"[INFO 10-09 19:00:09] ax.service.ax_client: Generated new trial 0 with parameters {'lr': 0.0, 'momentum': 0.18}.\n",
"[INFO 10-09 19:00:23] ax.service.ax_client: Completed trial 0 with data: {'evaluate_mnist': (0.25, None)}.\n",
"[INFO 10-09 19:00:23] ax.service.ax_client: Generated new trial 1 with parameters {'lr': 0.0, 'momentum': 0.51}.\n",
"[INFO 10-09 19:00:37] ax.service.ax_client: Completed trial 1 with data: {'evaluate_mnist': (0.81, None)}.\n",
"[INFO 10-09 19:00:38] ax.service.ax_client: Generated new trial 2 with parameters {'lr': 0.01, 'momentum': 0.57}.\n",
"[INFO 10-09 19:00:52] ax.service.ax_client: Completed trial 2 with data: {'evaluate_mnist': (0.94, None)}.\n",
"[INFO 10-09 19:00:52] ax.service.ax_client: Generated new trial 3 with parameters {'lr': 0.0, 'momentum': 0.59}.\n",
"[INFO 10-09 19:01:06] ax.service.ax_client: Completed trial 3 with data: {'evaluate_mnist': (0.59, None)}.\n",
"[INFO 10-09 19:01:06] ax.service.ax_client: Generated new trial 4 with parameters {'lr': 0.0, 'momentum': 0.8}.\n",
"[INFO 10-09 19:01:20] ax.service.ax_client: Completed trial 4 with data: {'evaluate_mnist': (0.83, None)}.\n",
"[INFO 10-09 19:01:20] ax.service.ax_client: Generated new trial 5 with parameters {'lr': 0.0, 'momentum': 0.73}.\n",
"[INFO 10-09 19:01:34] ax.service.ax_client: Completed trial 5 with data: {'evaluate_mnist': (0.89, None)}.\n",
"[INFO 10-09 19:01:35] ax.service.ax_client: Generated new trial 6 with parameters {'lr': 0.0, 'momentum': 0.56}.\n",
"[INFO 10-09 19:01:49] ax.service.ax_client: Completed trial 6 with data: {'evaluate_mnist': (0.88, None)}.\n",
"[INFO 10-09 19:01:50] ax.service.ax_client: Generated new trial 7 with parameters {'lr': 0.01, 'momentum': 0.9}.\n",
"[INFO 10-09 19:02:04] ax.service.ax_client: Completed trial 7 with data: {'evaluate_mnist': (0.93, None)}.\n",
"[INFO 10-09 19:02:04] ax.service.ax_client: Generated new trial 8 with parameters {'lr': 0.01, 'momentum': 0.71}.\n",
"[INFO 10-09 19:02:19] ax.service.ax_client: Completed trial 8 with data: {'evaluate_mnist': (0.91, None)}.\n",
"[INFO 10-09 19:02:19] ax.service.ax_client: Generated new trial 9 with parameters {'lr': 0.01, 'momentum': 0.33}.\n",
"[INFO 10-09 19:02:34] ax.service.ax_client: Completed trial 9 with data: {'evaluate_mnist': (0.92, None)}.\n",
"[INFO 10-09 19:02:34] ax.service.ax_client: Generated new trial 10 with parameters {'lr': 0.01, 'momentum': 0.48}.\n",
"[INFO 10-09 19:02:48] ax.service.ax_client: Completed trial 10 with data: {'evaluate_mnist': (0.93, None)}.\n",
"[INFO 10-09 19:02:49] ax.service.ax_client: Generated new trial 11 with parameters {'lr': 0.0, 'momentum': 0.9}.\n",
"[INFO 10-09 19:03:03] ax.service.ax_client: Completed trial 11 with data: {'evaluate_mnist': (0.89, None)}.\n",
"[INFO 10-09 19:03:04] ax.service.ax_client: Generated new trial 12 with parameters {'lr': 0.01, 'momentum': 0.9}.\n",
"[INFO 10-09 19:03:19] ax.service.ax_client: Completed trial 12 with data: {'evaluate_mnist': (0.91, None)}.\n",
"[INFO 10-09 19:03:20] ax.service.ax_client: Generated new trial 13 with parameters {'lr': 0.01, 'momentum': 0.1}.\n",
"[INFO 10-09 19:03:34] ax.service.ax_client: Completed trial 13 with data: {'evaluate_mnist': (0.88, None)}.\n",
"[INFO 10-09 19:03:35] ax.service.ax_client: Generated new trial 14 with parameters {'lr': 0.01, 'momentum': 0.41}.\n",
"[INFO 10-09 19:03:50] ax.service.ax_client: Completed trial 14 with data: {'evaluate_mnist': (0.86, None)}.\n",
"[INFO 10-09 19:03:50] ax.service.ax_client: Generated new trial 15 with parameters {'lr': 0.0, 'momentum': 0.71}.\n",
"[INFO 10-09 19:04:04] ax.service.ax_client: Completed trial 15 with data: {'evaluate_mnist': (0.91, None)}.\n",
"[INFO 10-09 19:04:05] ax.service.ax_client: Generated new trial 16 with parameters {'lr': 0.01, 'momentum': 0.58}.\n",
"[INFO 10-09 19:04:19] ax.service.ax_client: Completed trial 16 with data: {'evaluate_mnist': (0.93, None)}.\n",
"[INFO 10-09 19:04:20] ax.service.ax_client: Generated new trial 17 with parameters {'lr': 0.01, 'momentum': 0.69}.\n",
"[INFO 10-09 19:04:34] ax.service.ax_client: Completed trial 17 with data: {'evaluate_mnist': (0.91, None)}.\n",
"[INFO 10-09 19:04:35] ax.service.ax_client: Generated new trial 18 with parameters {'lr': 0.0, 'momentum': 0.9}.\n",
"[INFO 10-09 19:04:49] ax.service.ax_client: Completed trial 18 with data: {'evaluate_mnist': (0.63, None)}.\n",
"[INFO 10-09 19:04:50] ax.service.ax_client: Generated new trial 19 with parameters {'lr': 0.01, 'momentum': 0.84}.\n",
"[INFO 10-09 19:05:04] ax.service.ax_client: Completed trial 19 with data: {'evaluate_mnist': (0.96, None)}.\n",
"[INFO 10-09 19:05:05] ax.service.ax_client: Generated new trial 20 with parameters {'lr': 0.01, 'momentum': 0.8}.\n",
"[INFO 10-09 19:05:19] ax.service.ax_client: Completed trial 20 with data: {'evaluate_mnist': (0.91, None)}.\n",
"[INFO 10-09 19:05:20] ax.service.ax_client: Generated new trial 21 with parameters {'lr': 0.01, 'momentum': 0.9}.\n",
"[INFO 10-09 19:05:34] ax.service.ax_client: Completed trial 21 with data: {'evaluate_mnist': (0.93, None)}.\n",
"[INFO 10-09 19:05:35] ax.service.ax_client: Generated new trial 22 with parameters {'lr': 0.01, 'momentum': 0.8}.\n",
"[INFO 10-09 19:05:49] ax.service.ax_client: Completed trial 22 with data: {'evaluate_mnist': (0.92, None)}.\n",
"[INFO 10-09 19:05:50] ax.service.ax_client: Generated new trial 23 with parameters {'lr': 0.01, 'momentum': 0.25}.\n",
"[INFO 10-09 19:06:04] ax.service.ax_client: Completed trial 23 with data: {'evaluate_mnist': (0.89, None)}.\n",
"[INFO 10-09 19:06:06] ax.service.ax_client: Generated new trial 24 with parameters {'lr': 0.0, 'momentum': 0.9}.\n",
"[INFO 10-09 19:06:21] ax.service.ax_client: Completed trial 24 with data: {'evaluate_mnist': (0.94, None)}.\n",
"[INFO 10-09 19:06:22] ax.service.ax_client: Generated new trial 25 with parameters {'lr': 0.0, 'momentum': 0.9}.\n",
"[INFO 10-09 19:06:36] ax.service.ax_client: Completed trial 25 with data: {'evaluate_mnist': (0.95, None)}.\n",
"[INFO 10-09 19:06:39] ax.service.ax_client: Generated new trial 26 with parameters {'lr': 0.0, 'momentum': 0.82}.\n",
"[INFO 10-09 19:06:53] ax.service.ax_client: Completed trial 26 with data: {'evaluate_mnist': (0.94, None)}.\n",
"[INFO 10-09 19:06:55] ax.service.ax_client: Generated new trial 27 with parameters {'lr': 0.01, 'momentum': 0.54}.\n",
"[INFO 10-09 19:07:10] ax.service.ax_client: Completed trial 27 with data: {'evaluate_mnist': (0.93, None)}.\n",
"[INFO 10-09 19:07:13] ax.service.ax_client: Generated new trial 28 with parameters {'lr': 0.0, 'momentum': 0.84}.\n",
"[INFO 10-09 19:07:27] ax.service.ax_client: Completed trial 28 with data: {'evaluate_mnist': (0.89, None)}.\n",
"[INFO 10-09 19:07:28] ax.service.ax_client: Generated new trial 29 with parameters {'lr': 0.01, 'momentum': 0.63}.\n",
"[INFO 10-09 19:07:43] ax.service.ax_client: Completed trial 29 with data: {'evaluate_mnist': (0.96, None)}.\n",
"[INFO 10-09 19:07:45] ax.service.ax_client: Generated new trial 30 with parameters {'lr': 0.01, 'momentum': 0.9}.\n",
"[INFO 10-09 19:07:59] ax.service.ax_client: Completed trial 30 with data: {'evaluate_mnist': (0.97, None)}.\n",
"[INFO 10-09 19:08:01] ax.service.ax_client: Generated new trial 31 with parameters {'lr': 0.01, 'momentum': 0.86}.\n",
"[INFO 10-09 19:08:15] ax.service.ax_client: Completed trial 31 with data: {'evaluate_mnist': (0.88, None)}.\n",
"[INFO 10-09 19:08:17] ax.service.ax_client: Generated new trial 32 with parameters {'lr': 0.0, 'momentum': 0.77}.\n",
"[INFO 10-09 19:08:31] ax.service.ax_client: Completed trial 32 with data: {'evaluate_mnist': (0.91, None)}.\n",
"[INFO 10-09 19:08:35] ax.service.ax_client: Generated new trial 33 with parameters {'lr': 0.0, 'momentum': 0.9}.\n",
"[INFO 10-09 19:08:49] ax.service.ax_client: Completed trial 33 with data: {'evaluate_mnist': (0.92, None)}.\n",
"[INFO 10-09 19:08:52] ax.service.ax_client: Generated new trial 34 with parameters {'lr': 0.01, 'momentum': 0.64}.\n",
"[INFO 10-09 19:09:06] ax.service.ax_client: Completed trial 34 with data: {'evaluate_mnist': (0.94, None)}.\n",
"[INFO 10-09 19:09:08] ax.service.ax_client: Generated new trial 35 with parameters {'lr': 0.01, 'momentum': 0.41}.\n",
"[INFO 10-09 19:09:22] ax.service.ax_client: Completed trial 35 with data: {'evaluate_mnist': (0.91, None)}.\n",
"[INFO 10-09 19:09:27] ax.service.ax_client: Generated new trial 36 with parameters {'lr': 0.01, 'momentum': 0.67}.\n",
"[INFO 10-09 19:09:41] ax.service.ax_client: Completed trial 36 with data: {'evaluate_mnist': (0.89, None)}.\n",
"[INFO 10-09 19:09:44] ax.service.ax_client: Generated new trial 37 with parameters {'lr': 0.01, 'momentum': 0.74}.\n",
"[INFO 10-09 19:09:57] ax.service.ax_client: Completed trial 37 with data: {'evaluate_mnist': (0.93, None)}.\n",
"[INFO 10-09 19:10:00] ax.service.ax_client: Generated new trial 38 with parameters {'lr': 0.01, 'momentum': 0.62}.\n",
"[INFO 10-09 19:10:14] ax.service.ax_client: Completed trial 38 with data: {'evaluate_mnist': (0.91, None)}.\n",
"[INFO 10-09 19:10:17] ax.service.ax_client: Generated new trial 39 with parameters {'lr': 0.0, 'momentum': 0.65}.\n",
"[INFO 10-09 19:10:31] ax.service.ax_client: Completed trial 39 with data: {'evaluate_mnist': (0.85, None)}.\n",
"[INFO 10-09 19:10:34] ax.service.ax_client: Generated new trial 40 with parameters {'lr': 0.0, 'momentum': 0.9}.\n",
"[INFO 10-09 19:10:48] ax.service.ax_client: Completed trial 40 with data: {'evaluate_mnist': (0.9, None)}.\n",
"[INFO 10-09 19:10:51] ax.service.ax_client: Generated new trial 41 with parameters {'lr': 0.0, 'momentum': 0.1}.\n",
"[INFO 10-09 19:11:05] ax.service.ax_client: Completed trial 41 with data: {'evaluate_mnist': (0.89, None)}.\n",
"[INFO 10-09 19:11:08] ax.service.ax_client: Generated new trial 42 with parameters {'lr': 0.0, 'momentum': 0.9}.\n",
"[INFO 10-09 19:11:22] ax.service.ax_client: Completed trial 42 with data: {'evaluate_mnist': (0.91, None)}.\n",
"[INFO 10-09 19:11:25] ax.service.ax_client: Generated new trial 43 with parameters {'lr': 0.01, 'momentum': 0.21}.\n",
"[INFO 10-09 19:11:39] ax.service.ax_client: Completed trial 43 with data: {'evaluate_mnist': (0.9, None)}.\n",
"[INFO 10-09 19:11:43] ax.service.ax_client: Generated new trial 44 with parameters {'lr': 0.0, 'momentum': 0.78}.\n",
"[INFO 10-09 19:11:57] ax.service.ax_client: Completed trial 44 with data: {'evaluate_mnist': (0.86, None)}.\n",
"[INFO 10-09 19:12:01] ax.service.ax_client: Generated new trial 45 with parameters {'lr': 0.01, 'momentum': 0.1}.\n",
"[INFO 10-09 19:12:15] ax.service.ax_client: Completed trial 45 with data: {'evaluate_mnist': (0.86, None)}.\n",
"[INFO 10-09 19:12:18] ax.service.ax_client: Generated new trial 46 with parameters {'lr': 0.0, 'momentum': 0.31}.\n",
"[INFO 10-09 19:12:32] ax.service.ax_client: Completed trial 46 with data: {'evaluate_mnist': (0.9, None)}.\n",
"[INFO 10-09 19:12:34] ax.service.ax_client: Generated new trial 47 with parameters {'lr': 0.0, 'momentum': 0.9}.\n",
"[INFO 10-09 19:12:48] ax.service.ax_client: Completed trial 47 with data: {'evaluate_mnist': (0.89, None)}.\n",
"[INFO 10-09 19:12:51] ax.service.ax_client: Generated new trial 48 with parameters {'lr': 0.01, 'momentum': 0.77}.\n",
"[INFO 10-09 19:13:05] ax.service.ax_client: Completed trial 48 with data: {'evaluate_mnist': (0.95, None)}.\n",
"[INFO 10-09 19:13:08] ax.service.ax_client: Generated new trial 49 with parameters {'lr': 0.0, 'momentum': 0.69}.\n",
"[INFO 10-09 19:13:22] ax.service.ax_client: Completed trial 49 with data: {'evaluate_mnist': (0.87, None)}.\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"{'lr': 0.007584330389670517, 'momentum': 0.9}\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment