Created
May 30, 2019 18:58
-
-
Save richardliaw/f5c3002bf4a84574496937671ddcf0a6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Service API on PyTorch CNN" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import numpy as np\n", | |
"\n", | |
"from ax.plot.contour import plot_contour\n", | |
"from ax.plot.trace import optimization_trace_single_method\n", | |
"from ax.service.ax_client import AxClient\n", | |
"from ax.utils.notebook.plotting import render, init_notebook_plotting\n", | |
"from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate\n", | |
"\n", | |
"\n", | |
"init_notebook_plotting()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ax = AxClient(enforce_sequential_optimization=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ax.create_experiment(\n", | |
" name=\"mnist_experiment\",\n", | |
" parameters=[\n", | |
" {\"name\": \"lr\", \"type\": \"range\", \"bounds\": [1e-6, 0.4], \"log_scale\": True},\n", | |
" {\"name\": \"momentum\", \"type\": \"range\", \"bounds\": [0.0, 1.0]},\n", | |
" ],\n", | |
" objective_name=\"mean_accuracy\",\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from ray import tune\n", | |
"from ray.tune import track\n", | |
"from ray.tune.suggest.ax import AxSearch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def train_evaluate(parameterization):\n", | |
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |
" train_loader, valid_loader, test_loader = load_mnist()\n", | |
" net = train(train_loader=train_loader, parameters=parameterization, dtype=torch.float, device=device)\n", | |
" track.log(mean_accuracy=evaluate(\n", | |
" net=net,\n", | |
" data_loader=valid_loader,\n", | |
" dtype=torch.float,\n", | |
" device=device,\n", | |
" ))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"tune.run(train_evaluate, \n", | |
" num_samples=30, \n", | |
" search_alg=AxSearch(ax))\n", | |
" #resources_per_trial={\"gpu\": 1}) - uncomment this to use GPU" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"best_parameters, values = ax.get_best_parameters()\n", | |
"best_parameters" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"means, covariances = values\n", | |
"means" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"render(\n", | |
" plot_contour(\n", | |
" model=ax.generation_strategy.model, param_x='lr', param_y='momentum', metric_name='Accuracy'\n", | |
" )\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We can also plot the optimization trace, showing the progression of finding the point with the optimal objective:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple \n", | |
"# optimization runs, so we wrap out best objectives array in another array.\n", | |
"best_objectives = np.array([[trial.objective_mean * 100 for trial in ax.experiment.trials.values()]])\n", | |
"best_objective_plot = optimization_trace_single_method(\n", | |
" y=np.maximum.accumulate(best_objectives, axis=1),\n", | |
" title=\"Model performance vs. # of iterations\",\n", | |
" ylabel=\"Accuracy\",\n", | |
")\n", | |
"render(best_objective_plot)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment