Skip to content

Instantly share code, notes, and snippets.

@wd15
Last active May 23, 2022 20:10
Show Gist options
  • Save wd15/3cfcde2aaf878ba74184a9dee42efeec to your computer and use it in GitHub Desktop.
Save wd15/3cfcde2aaf878ba74184a9dee42efeec to your computer and use it in GitHub Desktop.
misc
dask-worker-space
.local

Active Learning Hacking

Repository for active learning hacks

Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import dask.array as da\n",
"from dask_ml.model_selection import train_test_split\n",
"from sklearn.pipeline import Pipeline\n",
"from dask_ml.decomposition import IncrementalPCA\n",
"from dask_ml.preprocessing import PolynomialFeatures\n",
"from sklearn.linear_model import LinearRegression\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.gaussian_process import GaussianProcessRegressor\n",
"from sklearn.gaussian_process.kernels import RBF\n",
"from sklearn.metrics import mean_squared_error as mse\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"from pymks import (\n",
" generate_multiphase,\n",
" plot_microstructures,\n",
" PrimitiveTransformer,\n",
" TwoPointCorrelation,\n",
" GenericTransformer,\n",
" solve_fe\n",
")\n",
"\n",
"from toolz.curried import curry, pipe, valmap, itemmap, iterate\n",
"from modAL.models import ActiveLearner, CommitteeRegressor\n",
"from modAL.disagreement import max_std_sampling\n",
"from modAL.models import BayesianOptimizer\n",
"from modAL.acquisition import max_EI\n",
"import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"@curry\n",
"def iterate_times(func, times, value):\n",
" iter_ = iterate(func, value)\n",
" for _ in tqdm.tqdm(range(times)):\n",
" next(iter_)\n",
" return next(iter_)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"@curry\n",
"def update_learner(oracle_func, x_pool, evaluate_func, accuracy_learner):\n",
" accuracies, learner = accuracy_learner\n",
" query_idx, query_instance = learner.query(x_pool)\n",
" query_y = oracle_func(x_pool[query_idx].reshape(1, -1))\n",
" learner.teach(x_pool[query_idx].reshape(1, -1), query_y)\n",
" return accuracies + [evaluate_func(learner)], learner\n",
"\n",
"@curry\n",
"def evaluate_learner(oracle_func, x_pool, evaluate_func, n_query, learner):\n",
" return iterate_times(\n",
" update_learner(oracle_func, x_pool, evaluate_func),\n",
" n_query,\n",
" ([evaluate_func(learner)], learner)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def steps():\n",
" return (\n",
" (\"reshape\", GenericTransformer(\n",
" lambda x: x.reshape(x.shape[0], x_data.shape[1], x_data.shape[2])\n",
" )),\n",
" (\"discritize\",PrimitiveTransformer(n_state=2, min_=0.0, max_=1.0)),\n",
" (\"correlations\",TwoPointCorrelation(periodic_boundary=True, cutoff=31, correlations=[(0, 1), (1, 1)])),\n",
" ('flatten', GenericTransformer(lambda x: x.reshape(x.shape[0], -1))),\n",
" ('pca', IncrementalPCA(n_components=3, svd_solver='full')),\n",
" ('poly', PolynomialFeatures(degree=3))\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def make_gp_model():\n",
" kernel = 1 * RBF(length_scale=1.0, length_scale_bounds=(1e-1, 1e2))\n",
" regressor = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9)\n",
" return Pipeline(steps=steps() + (\n",
" ('regressor', regressor),\n",
" ))\n",
"\n",
"def make_linear_model():\n",
" return Pipeline(steps=steps() + (\n",
" ('regressor', LinearRegression()),\n",
" ))\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"@curry\n",
"def oracle_func(shape, x_data):\n",
" y_stress = solve_fe(x_data.reshape((-1,) + shape),\n",
" elastic_modulus=(1.3, 2.5),\n",
" poissons_ratio=(0.42, 0.35),\n",
" macro_strain=0.001)['stress'][..., 0]\n",
"\n",
" return np.array(da.average(y_stress.reshape(y_stress.shape[0], -1), axis=1))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def plot_parity(y_test, y_predict):\n",
" pred_data = np.array([y_test, y_predict])\n",
" line = np.min(pred_data), np.max(pred_data)\n",
" plt.plot(pred_data[0], pred_data[1], 'o', color='#f46d43', label='Testing Data')\n",
" plt.plot(line, line, '-', linewidth=3, color='#000000')\n",
" plt.title('Goodness of Fit', fontsize=20)\n",
" plt.xlabel('Actual', fontsize=18)\n",
" plt.ylabel('Predicted', fontsize=18)\n",
" plt.legend(loc=2, fontsize=15)\n",
" return plt"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def three_way_split(x_data, props, random_state):\n",
" x_0, x_ = train_test_split(x_data, train_size=props[0], random_state=random_state)\n",
" x_1, x_2 = train_test_split(x_, train_size=props[1] / (1 - props[0]), random_state=random_state)\n",
" return flatten(x_0), flatten(x_1), flatten(x_2)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def flatten(x_data):\n",
" return x_data.reshape(x_data.shape[0], -1)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"@curry\n",
"def evaluate_func(x_test, y_test, learner):\n",
" return mse(y_test, learner.predict(x_test))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Active Learning!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Data"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"x_data = da.from_zarr(\"data/x_data.zarr\", chunks=(100, -1))\n",
"y_data = np.array(da.from_array(\"data/y_data.npy\", chunks=(100, -1)))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"()"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_data.shape"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"x_pool, x_test, x_train = three_way_split(x_data, (0.8, 0.16), 99)\n",
"oracle = oracle_func(grid_shape)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 2880x288 with 10 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_microstructures(*x_data[:10], cmap='gray', colorbar=False);"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(96, 1681)\n",
"(24, 1681)\n",
"(480, 1681)\n",
"(600, 41, 41)\n"
]
}
],
"source": [
"print(x_test.shape)\n",
"print(x_train.shape)\n",
"print(x_pool.shape)\n",
"print(x_data.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generate the necessary oracle data"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 17min 28s, sys: 14.5 s, total: 17min 43s\n",
"Wall time: 1min 41s\n"
]
}
],
"source": [
"%%time\n",
"y_test = oracle(x_test)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 6min 31s, sys: 4.13 s, total: 6min 36s\n",
"Wall time: 38.5 s\n"
]
}
],
"source": [
"%%time\n",
"y_train = oracle(x_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set up the active learners\n",
"\n",
"One is a GPR using the maximum std and the other is random"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"std_learner = ActiveLearner(\n",
" estimator=make_gp_model(),\n",
" query_strategy=lambda model, x_: pipe(\n",
" model.predict(x_, return_std=True)[1],\n",
" np.argmax,\n",
" lambda i: (i, x_[i])\n",
" ),\n",
" X_training=x_train,\n",
" y_training=y_train\n",
")\n",
"\n",
"random_learner = ActiveLearner(\n",
" estimator=make_gp_model(),\n",
" query_strategy=lambda model, x_: pipe(\n",
" np.random.randint(0, len(x_)),\n",
" lambda i: (i, x_[i])\n",
" ),\n",
" X_training=x_train,\n",
" y_training=y_train\n",
")\n",
"\n",
"ensemble_learner = CommitteeRegressor(\n",
" learner_list=[\n",
" ActiveLearner(\n",
" estimator=make_gp_model(),\n",
" X_training=x_train_,\n",
" y_training=y_train_\n",
" )\n",
" for x_train_, y_train_ in zip(np.array_split(x_train, 5), np.array_split(y_train, 5))\n",
" ],\n",
" query_strategy=max_std_sampling\n",
")\n",
"\n",
"bayes_learner = BayesianOptimizer(\n",
" estimator=make_gp_model(),\n",
" X_training=x_train,\n",
" y_training=y_train,\n",
" query_strategy=max_EI\n",
")\n",
"\n",
"learners = dict(\n",
"# std=std_learner,\n",
"# random=random_learner,\n",
" ensemble=ensemble_learner,\n",
"# bayes=bayes_learner\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run the learners"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"evaluating ensemble\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 30/30 [02:00<00:00, 4.03s/it]\n"
]
}
],
"source": [
"def evaluate_item(item):\n",
" name, learner = item\n",
" print('evaluating', name)\n",
" return name, evaluate_learner(oracle, x_pool, evaluate_func(x_test, y_test), 30, learner)\n",
"\n",
"learner_accuracy = itemmap(evaluate_item, learners)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f049c5d61c0>"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"for k, v in learner_accuracy.items():\n",
" plt.semilogy(v[0], label=k)\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The results"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f049c14a700>"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"for k, v in learner_accuracy.items():\n",
" plt.semilogy(v[0], label=k)\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Check what the accuracy actually looks like"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"y_pred = learner_accuracy['std'][1].predict(x_test)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<module 'matplotlib.pyplot' from '/nix/store/c8sgkmibi2vyfw75w9vai2917j5smvq7-python3.8-matplotlib-3.3.1/lib/python3.8/site-packages/matplotlib/pyplot.py'>"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_parity(y_test, y_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sandbox"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
#
# $ nix-shell --pure --arg withBoost false --argstr tag 20.09
#
{
tag ? "20.09",
pymksVersion ? "cf653e004848c9c68ca31a85add0d1ac8611a93f"
}:
let
pkgs = import (builtins.fetchTarball "https://github.com/NixOS/nixpkgs/archive/${tag}.tar.gz") {};
pymkssrc = builtins.fetchTarball "https://github.com/materialsinnovation/pymks/archive/${pymksVersion}.tar.gz";
pymks = pypkgs.callPackage "${pymkssrc}/default.nix" { graspi = null; };
pypkgs = pkgs.python3Packages;
extra = with pypkgs; [ black pylint flake8 ipywidgets zarr pymks h5py ];
in
(pymks.overridePythonAttrs (old: rec {
propagatedBuildInputs = old.propagatedBuildInputs;
nativeBuildInputs = propagatedBuildInputs ++ extra;
postShellHook = ''
export OMPI_MCA_plm_rsh_agent=${pkgs.openssh}/bin/ssh
SOURCE_DATE_EPOCH=$(date +%s)
export PYTHONUSERBASE=$PWD/.local
export USER_SITE=`python -c "import site; print(site.USER_SITE)"`
export PYTHONPATH=$PYTHONPATH:$USER_SITE
export PATH=$PATH:$PYTHONUSERBASE/bin
jupyter nbextension install --py widgetsnbextension --user > /dev/null 2>&1
jupyter nbextension enable widgetsnbextension --user --py > /dev/null 2>&1
pip install jupyter_contrib_nbextensions --user > /dev/null 2>&1
jupyter contrib nbextension install --user > /dev/null 2>&1
jupyter nbextension enable spellchecker/main > /dev/null 2>&1
pip install --user nbqa
pip install --user tqdm
pip install --user modAL
'';
}))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment