Skip to content

Instantly share code, notes, and snippets.

@SharanSMenon
Last active December 5, 2021 22:40
Show Gist options
  • Save SharanSMenon/c29a13172d95a1f4fcea30a16e7a7b47 to your computer and use it in GitHub Desktop.
Save SharanSMenon/c29a13172d95a1f4fcea30a16e7a7b47 to your computer and use it in GitHub Desktop.
A simple implementation of evolution strategies to train a classification network. Based on https://openai.com/blog/evolution-strategies/
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "35194e0e-1d70-43d1-a4d2-3cf5aeaf922e",
"metadata": {},
"source": [
"# Implementing Evolution Strategies for classification\n",
"\n",
"This notebook implements the algorithm discussed in the following blog post\n",
"\n",
"[https://openai.com/blog/evolution-strategies/](https://openai.com/blog/evolution-strategies/)\n",
"\n",
"We will use this algorithm to train a neural network to classify on scikit learn's `make_moons` dataset"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9c47fc44-973f-4518-abb9-a4d19d1cee4a",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn import functional as F"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3a36b483-f921-4724-a120-88420823f3fe",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7293a500-4194-4958-8367-7777979b02f0",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import make_moons"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e743c05e-6173-4812-bc04-db15c5ce44f0",
"metadata": {},
"outputs": [],
"source": [
"X, Y = make_moons(noise=0.15, random_state=1, n_samples=500)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "144f1c56-03a1-4c87-80ce-ca412a49ec73",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 5,
"id": "eab9d1d7-8ac0-4da3-8aeb-0107dff77dc4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x12f8b0d60>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(X[:,0], X[:,1], c=Y)"
]
},
{
"cell_type": "markdown",
"id": "f9c16543-2be0-4c60-be33-b5692acac74c",
"metadata": {},
"source": [
"Create training and testing sets. Then convert them into PyTorch Tensors"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3f231db9-abf4-4da4-b2b1-6c2fe3dd3c1f",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ce5ae456-361c-4a8a-8a18-9f9c7cde70b5",
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.36, random_state=42)\n",
"X_train, X_test, y_train, y_test = torch.FloatTensor(X_train), torch.FloatTensor(X_test), torch.LongTensor(y_train), torch.LongTensor(y_test)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ded34ad7-00dd-45cf-80a0-09897177f17d",
"metadata": {},
"outputs": [],
"source": [
"n_input_dim = X_train.shape[1]"
]
},
{
"cell_type": "markdown",
"id": "d7cfb9c8-b043-4bf7-a1f8-bba1eea2b040",
"metadata": {},
"source": [
"## Model\n",
"\n",
"Create the model with and initialize its weights. We use `model.float()` to ensure that it accepts `float` datatypes and not `double`."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "1867d3c2-5851-4c76-9444-47f24adc0877",
"metadata": {},
"outputs": [],
"source": [
"def weights_init(m):\n",
" classname = m.__class__.__name__\n",
" if classname.find('Linear') != -1:\n",
" m.weight.data.normal_(0.0, 0.02)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f1e31148-6d05-4582-b184-2999219a783b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Sequential(\n",
" (0): Linear(in_features=2, out_features=40, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=40, out_features=20, bias=True)\n",
" (3): ReLU()\n",
" (4): Linear(in_features=20, out_features=2, bias=True)\n",
" (5): Sigmoid()\n",
")"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = nn.Sequential(\n",
" nn.Linear(n_input_dim, 40),\n",
" nn.ReLU(),\n",
" nn.Linear(40, 20),\n",
" nn.ReLU(),\n",
" nn.Linear(20, 2), \n",
" nn.Sigmoid()\n",
")\n",
"model = model.float()\n",
"model.apply(weights_init)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "59274e62-85d7-4144-b51e-7cb8b43bb4c8",
"metadata": {},
"outputs": [],
"source": [
"## TESTING. SOMETIMES THE MODEL HAS FLOAT/DOUBLE ISSUES\n",
"_ = model(X_train)\n",
"## IF THIS CELL RUNS FINE, THE REST OF THE PROGRAM WILL BE FINE"
]
},
{
"cell_type": "markdown",
"id": "88fd9410-fd3c-4a37-b17f-3c12e44051fa",
"metadata": {},
"source": [
"OpenAI's blog mentions that we have an initial set of parameters, and the ES algorithm optimizes this set. Our model has a total of 982 parameters, a small model."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cfbff143-4b04-426e-9126-eb8d1d24a3f0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([982])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mother_params = model.parameters()\n",
"mother_vector = nn.utils.parameters_to_vector(mother_params)\n",
"mother_vector.shape # torch.Size([982])"
]
},
{
"cell_type": "markdown",
"id": "20052687-208c-4304-b975-22fbdd760e1c",
"metadata": {},
"source": [
"As with other evolution strategies, we are trying to maximize the loss, so we take the reciprocal of `CrossEntropyLoss`, so that a higher loss means that the model is improving"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "4eb38437-def5-45eb-8fd7-7a056bd91011",
"metadata": {},
"outputs": [],
"source": [
"loss_func = nn.CrossEntropyLoss()\n",
"def loss(y_pred, y_true):\n",
" return 1/loss_func(y_pred, y_true)"
]
},
{
"cell_type": "markdown",
"id": "2251cb75-1dfa-4a9c-bf76-8226cc67bcf5",
"metadata": {},
"source": [
"We implement our fitness function, which calculates the fitness of a given set of parameter"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "85cd162f-a805-48ff-8d5c-fb226eb95581",
"metadata": {},
"outputs": [],
"source": [
"def fitness_func(solution):\n",
" nn.utils.vector_to_parameters(solution, model.parameters())\n",
" return loss(model(X_train), y_train) + 0.00000001"
]
},
{
"cell_type": "markdown",
"id": "458355a6-2842-4c21-b9c6-53f36505a758",
"metadata": {},
"source": [
"In ES, our \"population\" is just slightly different versions of the `mother_vector`. That's the purpose of the jitter function. It slightly alters the mother vector."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "4735c389-9b23-4b8a-9c2a-ea840c932173",
"metadata": {},
"outputs": [],
"source": [
"def jitter(mother_params, state_dict):\n",
" params_try = mother_params + SIGMA*state_dict\n",
" return params_try"
]
},
{
"cell_type": "markdown",
"id": "38ac303e-9818-4b46-8009-341c6fd97061",
"metadata": {},
"source": [
"This takes a population of randomly initialized parameters, and jitters them slightly. It then calculates the fitness of the jittered population and returns that."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "ce1e7fa0-435f-4298-bf63-03bbe0733715",
"metadata": {},
"outputs": [],
"source": [
"def calc_population_fitness(pop, mother_vector):\n",
" fitness = torch.zeros(pop.shape[0])\n",
" for i, params in enumerate(pop):\n",
" p_try = jitter(mother_vector, params)\n",
" fitness[i] = fitness_func(p_try)\n",
" return fitness"
]
},
{
"cell_type": "markdown",
"id": "740649f8-4c5d-4438-aecb-75aae3494142",
"metadata": {},
"source": [
"A utility function to return the test accuracy of our model"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "005c29ec-9c87-4f9e-b202-5970164eb741",
"metadata": {},
"outputs": [],
"source": [
"def test(mother_params):\n",
" nn.utils.vector_to_parameters(mother_params, model.parameters())\n",
" return (((torch.max(model(X_test), 1)[1] == y_test).sum())/len(y_test)).item()"
]
},
{
"cell_type": "markdown",
"id": "629d3c9d-b3f3-4baf-b378-99e86390e6ee",
"metadata": {},
"source": [
"### Hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "5ca96f06-7658-4b3a-95e3-ad512b53ef83",
"metadata": {},
"outputs": [],
"source": [
"SIGMA = 0.01\n",
"LR = 0.001\n",
"POPULATION_SIZE=50\n",
"ITERATIONS = 500"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "cb7ae4b8-5ab7-4edb-831e-fe36b134c15b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of parameters: 982\n"
]
}
],
"source": [
"n_params = nn.utils.parameters_to_vector(model.parameters()).shape[0]\n",
"print(f\"Number of parameters: {n_params}\") # Number of parameters: 982"
]
},
{
"cell_type": "markdown",
"id": "2f849c42-d134-41cc-b7dc-fb070d25b3bc",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "886ba6b5-5b79-4cc5-b9f6-c91bc3102e4a",
"metadata": {},
"outputs": [],
"source": [
"from tqdm.notebook import tqdm # You can get rid of this code."
]
},
{
"cell_type": "markdown",
"id": "0cdff4ee",
"metadata": {},
"source": [
"Training Code is very simple. We disable autograd, for faster training. \n",
"We run for the specified number of iteration. In each iteration, we generate a random population, calculate the fitness. Then, we normalize the fitness so that we can update the mother vector based on the fitness of each population. We update the mother vector with the entire population, with the highest fitness having the greatest preference."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "17e06b20-9c21-4a65-940e-0fca4e9b8ed3",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "664f58978695472e9b3fd582155fa47f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/500 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration: 0, Reward: 1.4428613185882568, Accuracy: 48.88888895511627\n",
"Iteration: 50, Reward: 2.35758376121521, Accuracy: 84.44444537162781\n",
"Iteration: 100, Reward: 2.4251275062561035, Accuracy: 87.77777552604675\n",
"Iteration: 150, Reward: 2.4666550159454346, Accuracy: 88.33333253860474\n",
"Iteration: 200, Reward: 2.5329749584198, Accuracy: 90.55555462837219\n",
"Iteration: 250, Reward: 2.6997175216674805, Accuracy: 93.33333373069763\n",
"Iteration: 300, Reward: 2.810124397277832, Accuracy: 96.66666388511658\n",
"Iteration: 350, Reward: 2.8788623809814453, Accuracy: 97.22222089767456\n",
"Iteration: 400, Reward: 2.9336156845092773, Accuracy: 98.33333492279053\n",
"Iteration: 450, Reward: 2.9613733291625977, Accuracy: 97.22222089767456\n"
]
}
],
"source": [
"with torch.no_grad(): # We do not need autograd, it makes the program faster and fixes a few autograd related errors\n",
" for iteration in tqdm(range(ITERATIONS)):\n",
" pop = torch.from_numpy(np.random.randn(POPULATION_SIZE, n_params)).float()\n",
" fitness = calc_population_fitness(pop, mother_vector)\n",
" normalized_fitness = (fitness - torch.mean(fitness)) / torch.std(fitness)\n",
" mother_vector = mother_vector + (LR / (POPULATION_SIZE * SIGMA)) * torch.from_numpy(np.dot(pop.t().numpy(), normalized_fitness.numpy()))\n",
" if iteration % 50 == 0:\n",
" reward = fitness_func(mother_vector)\n",
" acc = test(mother_vector)\n",
" print(f\"Iteration: {iteration}, Reward: {reward}, Accuracy: {acc * 100}\")"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "cb962be7-da75-4c97-ad56-1f4beda63d30",
"metadata": {},
"outputs": [],
"source": [
"preds = model(X_test)\n",
"_, preds_classes = torch.max(preds, 1)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "8a47f56b-965c-4249-837f-97a8f1da8293",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x12faee490>"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# This graph should look very similar to the original dataset. If it does, then our model has learned the data.\n",
"plt.scatter(X_test[:,0], X_test[:,1], c=preds_classes.detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "05fe8e4d-e6c8-4c47-9e92-576794827ad5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X_train Accuracy: 97.81249761581421%\n",
"X_test Accuracy: 97.22222089767456%\n"
]
}
],
"source": [
"print(f\"X_train Accuracy: {(((torch.max(model(X_train), 1)[1] == y_train).sum())/len(y_train)).item() * 100}%\")\n",
"print(f\"X_test Accuracy: {(((preds_classes == y_test).sum())/len(y_test)).item() * 100}%\")"
]
},
{
"cell_type": "markdown",
"id": "0fee437d",
"metadata": {},
"source": [
"The evolution strategy clearly works, our model got 97% test accuracy."
]
},
{
"cell_type": "markdown",
"id": "efde5b13",
"metadata": {},
"source": [
"## Decision Boundary Function\n",
"\n",
"Utility function to plot the decision boundary of the model. The model has a very accurate decision boundary, consistent with the shape of the dataset"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "c532f921-8555-4b3b-a043-fcb97b8965cc",
"metadata": {},
"outputs": [],
"source": [
"def boundary(X, ax): # determine boundary between different colored dots\n",
" x_min, x_max = X[:, 0].min()-0.1, X[:, 0].max()+0.1\n",
" y_min, y_max = X[:, 1].min()-0.1, X[:, 1].max()+0.1\n",
" spacing = min(x_max - x_min, y_max - y_min) / 100\n",
" XX, YY = np.meshgrid(np.arange(x_min, x_max, spacing),np.arange(y_min, y_max, spacing))\n",
" data = np.hstack((XX.ravel().reshape(-1,1),YY.ravel().reshape(-1,1)))\n",
" data_t = torch.FloatTensor(data)\n",
" db_prob = model(data_t)\n",
" _, clf = torch.max(db_prob, 1)\n",
" Z = clf.reshape(XX.shape)\n",
" return(ax.contourf(XX, YY, Z, cmap=plt.cm.Accent, alpha=0.6))"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "058d63ca-f9fc-40b6-bdd1-e69837142de6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x12fb48fa0>"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1008x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(figsize=(14, 6))\n",
"_ = boundary(X_test, ax)\n",
"plt.scatter(X_test[:,0], X_test[:,1], c=y_test,cmap='viridis')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment