Skip to content

Instantly share code, notes, and snippets.

@mlaves
Last active November 1, 2020 13:08
Show Gist options
  • Save mlaves/77efca31e4fc76aed725cae3eb67c4cf to your computer and use it in GitHub Desktop.
Save mlaves/77efca31e4fc76aed725cae3eb67c4cf to your computer and use it in GitHub Desktop.
Unsupervised toy experiment on the "two moons" data set.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "1LwIMQaG7JtM"
},
"source": [
"# Self-Supervised Consistency Learning with Pseudo-Labeling from Bayesian Uncertainty\n",
"## A *Two Moons* Example"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "8rHBnKPzatqn"
},
"outputs": [],
"source": [
"import torch\n",
"from sklearn.datasets import make_moons\n",
"import matplotlib\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
"import sys\n",
"from copy import deepcopy\n",
"import seaborn as sns\n",
"sns.set()\n",
"sns.set_context('paper')\n",
"\n",
"matplotlib.rcParams['text.usetex'] = True\n",
"matplotlib.rcParams['text.latex.preamble'] = [\n",
" r'\\usepackage{bm}']\n",
"\n",
"torch.manual_seed(0)\n",
"np.random.seed(0)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "ueq7DjJAoozH"
},
"outputs": [],
"source": [
"from torch.nn.functional import one_hot\n",
"\n",
"def mutual_information(z, zt, eps=sys.float_info.epsilon):\n",
" _, k = z.size()\n",
"\n",
" P = (z.unsqueeze(2) * zt.unsqueeze(1)).sum(dim=0)\n",
" P = ((P + P.t()) / 2) / P.sum()\n",
" P[(P < eps).data] = eps\n",
" assert (P.size() == (k, k))\n",
"\n",
" Pi = P.sum(dim=1).view(k, 1).expand(k, k)\n",
" Pj = P.sum(dim=0).view(1, k).expand(k, k)\n",
"\n",
" return (-P * (P.log() - Pi.log() - Pj.log())).sum()\n",
"\n",
"\n",
"def entropy(p):\n",
" return -(p*p.log()).sum(dim=1)\n",
"\n",
"\n",
"def batch_pl(p_B_C, eps=sys.float_info.epsilon):\n",
" p_B_C.add_(eps)\n",
" e = entropy(p_B_C)\n",
" marginal = p_B_C.mean(dim=0) # compute marginal\n",
" kld = (marginal*(marginal.log()-p_B_C.log())).sum(dim=1) # replace with nansum later\n",
" kld[kld != kld] = 0\n",
" return e - kld\n",
"\n",
"\n",
"def mixup(x1, x2, y1, y2, lamb=0.5, num_classes=None):\n",
" if not num_classes:\n",
" num_classes = y1.max()\n",
" x = lamb*x1 + (1-lamb)*x2\n",
" y = lamb*one_hot(y1, num_classes) + (1-lamb)*one_hot(y2, num_classes)\n",
"\n",
" return x, y\n",
"\n",
"\n",
"def extract_pseudo_labels(net, x, T):\n",
" y_p_a_mean = []\n",
" y_p_b_mean = []\n",
" for i in range(15):\n",
" y_p_a, y_p_b = net(x)\n",
" y_p_a_mean.append((T*y_p_a).softmax(1).detach().unsqueeze(0))\n",
" y_p_b_mean.append((y_p_b).softmax(1).detach().unsqueeze(0))\n",
" y_p_a_mean = torch.cat(y_p_a_mean, dim=0).mean(dim=0)\n",
" y_p_b_mean = torch.cat(y_p_b_mean, dim=0).mean(dim=0)\n",
" pseudo_targets = y_p_a_mean.argmax(dim=1)\n",
"\n",
" # get most uncertain\n",
" batchpl = batch_pl(y_p_a_mean)\n",
" uncert = x[torch.where(batchpl > 0.0)].detach()\n",
" conf = x[torch.where(batchpl < 0.0)].detach()\n",
" conf_target = y_p_a_mean[\n",
" torch.where(batchpl < 0.0)\n",
" ].detach().argmax(dim=1)\n",
"\n",
" return uncert, conf, pseudo_targets, conf_target\n",
"\n",
"\n",
"def extract_pseudo_labels_conf(net, x, T):\n",
" y_p_a_mean = []\n",
" y_p_b_mean = []\n",
" for i in range(15):\n",
" y_p_a, y_p_b = net(x)\n",
" y_p_a_mean.append((T*y_p_a).softmax(1).detach().unsqueeze(0))\n",
" y_p_b_mean.append((y_p_b).softmax(1).detach().unsqueeze(0))\n",
" y_p_a_mean = torch.cat(y_p_a_mean, dim=0).mean(dim=0)\n",
" y_p_b_mean = torch.cat(y_p_b_mean, dim=0).mean(dim=0)\n",
" pseudo_targets = y_p_a_mean.argmax(dim=1)\n",
"\n",
" # get most uncertain\n",
" uncert = x[torch.where((y_p_a_mean[:,0] > 0.1) & (y_p_a_mean[:,0] < 0.9))].detach()\n",
" conf = x[torch.where((y_p_a_mean[:,0] > 0.9) | (y_p_a_mean[:,0] < 0.1))].detach()\n",
" conf_target = y_p_a_mean[\n",
" torch.where((y_p_a_mean[:,0] > 0.9) | (y_p_a_mean[:,0] < 0.1))\n",
" ].detach().argmax(dim=1)\n",
"\n",
" return uncert, conf, pseudo_targets, conf_target\n",
"\n",
"\n",
"def create_plot_data(net, T):\n",
" xx = np.linspace(-2.2, 2.2, 100)\n",
" yy = np.linspace(-2.3, 2.3, 100)\n",
" XX, YY = np.meshgrid(xx, yy)\n",
"\n",
" inp = np.concatenate([XX.reshape(-1, 1), YY.reshape(-1, 1)], axis=1)\n",
"\n",
" with torch.no_grad():\n",
" y_p_test_n = []\n",
" for n in range(15):\n",
" y_p_test_n.append((T*net(torch.FloatTensor(inp))[0].detach()).softmax(1)[:,0].unsqueeze(0))\n",
" y_p_test = torch.cat(y_p_test_n, dim=0).mean(dim=0).cpu().numpy()\n",
"\n",
" return xx, yy, y_p_test.reshape(100, 100)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"x_moons, y_moons = make_moons(n_samples=1000, shuffle=True, noise=0.1, random_state=0)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "DsXigrdha3Os"
},
"outputs": [],
"source": [
"x_moons_t = torch.FloatTensor(x_moons[np.where(y_moons == 0)].tolist() + x_moons[np.where(y_moons == 1)].tolist()[:400])\n",
"y_moons_t = torch.LongTensor(y_moons[np.where(y_moons == 0)].tolist() + y_moons[np.where(y_moons == 1)].tolist()[:400])\n",
"\n",
"x_moons_t -= x_moons_t.mean(axis=0)\n",
"x_moons_t /= x_moons_t.std(axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 268
},
"id": "gcNYdc_tbEye",
"outputId": "69bc70f9-3d95-480f-97c6-b8eaab457a7c"
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure()\n",
"plt.scatter(x_moons_t[:,0], x_moons_t[:,1])\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "zk9ISmgzbHEW"
},
"outputs": [],
"source": [
"class NetBayes(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" self.base = net = torch.nn.Sequential(\n",
" torch.nn.Linear(2, 32),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Linear(32, 32),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Dropout(0.3)\n",
" )\n",
" \n",
" self.head_b = torch.nn.Linear(32,10)\n",
" self.head_a = torch.nn.Linear(32,2)\n",
" \n",
" def forward(self, x):\n",
" x = self.base(x)\n",
" return self.head_a(x), self.head_b(x)\n",
"\n",
"net_bayes = NetBayes()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 102
},
"id": "BvZDnWq-l9vi",
"outputId": "d6154bde-9b6b-44fd-b7f1-51693f856fda"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-3.281165845692158e-07\n",
"-0.0023017022758722305\n",
"-0.14624759554862976\n",
"-0.3479679822921753\n",
"-0.43934932351112366\n"
]
}
],
"source": [
"opt = torch.optim.Adam(net_bayes.parameters(), lr=1e-4, weight_decay=1e-6)\n",
"net_bayes.train()\n",
"\n",
"# bootstrap with IIC\n",
"for e in range(500):\n",
" opt.zero_grad()\n",
" _, y_p_b_1 = net_bayes(x_moons_t + torch.randn_like(x_moons_t)*0.05)\n",
" _, y_p_b_2 = net_bayes(x_moons_t + torch.randn_like(x_moons_t)*0.15)\n",
" \n",
" loss = mutual_information(y_p_b_1.softmax(1), y_p_b_2.softmax(1))\n",
" loss.backward()\n",
" opt.step()\n",
"\n",
" opt.zero_grad()\n",
" y_p_a_1, _ = net_bayes(x_moons_t + torch.randn_like(x_moons_t)*0.05)\n",
" y_p_a_2, _ = net_bayes(x_moons_t + torch.randn_like(x_moons_t)*0.15)\n",
" \n",
" loss = mutual_information(y_p_a_1.softmax(1), y_p_a_2.softmax(1))\n",
" loss.backward()\n",
" opt.step()\n",
" \n",
" if e % 100 == 0:\n",
" print(loss.item())"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "-yWaUywOdAxl"
},
"outputs": [],
"source": [
"T = 0.35\n",
"net_cl = deepcopy(net_bayes)\n",
"net_conf = deepcopy(net_bayes)\n",
"net_ours = deepcopy(net_bayes)\n",
"\n",
"epochs = 4000"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 289
},
"id": "rkcILFrbmEYA",
"outputId": "4a45b096-76ca-4953-a927-bcdbbc6078ff"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 0.13111111521720886\n",
"-0.4840652346611023\n",
"500 0.13555555045604706\n",
"-0.5790530443191528\n"
]
}
],
"source": [
"opt_cl = torch.optim.Adam(net_cl.parameters(), lr=1e-2, weight_decay=1e-6)\n",
"\n",
"for e in range(epochs//4): # we need way less iterations here\n",
" # shuffle batch\n",
" N = x_moons_t.size(0)\n",
" batch = torch.randperm(N)\n",
"\n",
" opt_cl.zero_grad()\n",
" y_p_a_1, _ = net_cl(x_moons_t[batch] + torch.randn_like(x_moons_t)*0.05)\n",
" y_p_a_2, _ = net_cl(x_moons_t[batch] + torch.randn_like(x_moons_t)*0.15)\n",
" \n",
" loss = mutual_information(y_p_a_1.softmax(1), y_p_a_2.softmax(1))\n",
"\n",
" loss.backward()\n",
" opt_cl.step()\n",
"\n",
" if e % 500 == 0:\n",
" with torch.no_grad():\n",
" y = net_cl(x_moons_t)[0]\n",
" print(e, ((y.argmax(1) == y_moons_t).sum().float() / y_moons_t.size(0)).item())\n",
"\n",
" print(loss.item())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "cv1odwr3gPsv"
},
"outputs": [],
"source": [
"_, _, y_test_cl = create_plot_data(net_cl, T)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 289
},
"id": "iQy0Hzo9hr_J",
"outputId": "8460368e-d62f-4508-d7e4-a01c74a468c5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 0.12999999523162842\n",
"298 323 0.03344118222594261\n",
"500 0.09777777642011642\n",
"344 391 0.0033677611500024796\n",
"1000 0.08666666597127914\n",
"340 428 -0.01942259818315506\n",
"1500 0.07666666805744171\n",
"340 435 0.013390600681304932\n",
"2000 0.07999999821186066\n",
"341 429 -0.022545505315065384\n",
"2500 0.06111111119389534\n",
"344 435 -0.028293190523982048\n",
"3000 0.07666666805744171\n",
"343 435 -0.028790870681405067\n",
"3500 0.052222222089767456\n",
"347 472 -0.008332509547472\n"
]
}
],
"source": [
"opt_conf = torch.optim.Adam(net_conf.parameters(), lr=1e-2, weight_decay=1e-6)\n",
"\n",
"for e in range(epochs):\n",
" # shuffle batch\n",
" N = x_moons_t.size(0)\n",
" batch = torch.randperm(N)\n",
"\n",
" uncert, conf, pseudo_targets, conf_target = extract_pseudo_labels(net_conf, x_moons_t, T)\n",
"\n",
" opt_conf.zero_grad()\n",
" y_p_a_1, _ = net_conf(x_moons_t[batch] + torch.randn_like(x_moons_t)*0.05)\n",
" y_p_a_2, _ = net_conf(x_moons_t[batch] + torch.randn_like(x_moons_t)*0.15)\n",
" \n",
" loss_mi2 = mutual_information(y_p_a_1.softmax(1), y_p_a_2.softmax(1))\n",
" \n",
" # mixup training\n",
" lamb = torch.rand(1)\n",
" x_mixup, y_mixup = mixup(x_moons_t[batch][:N//2], x_moons_t[batch][N//2:(N//2)*2],\n",
" pseudo_targets[batch][:N//2], pseudo_targets[batch][N//2:(N//2)*2],\n",
" lamb=lamb,\n",
" num_classes=2)\n",
" y_p_1 = net_conf(x_moons_t[batch][:N//2])[0].softmax(1)\n",
" y_p_2 = net_conf(x_moons_t[batch][N//2:(N//2)*2])[0].softmax(1)\n",
" y_p_mixup = net_conf(x_mixup)[0].softmax(1)\n",
"\n",
" y_p_mixup2 = lamb*y_p_1 + (1-lamb)*y_p_2\n",
" loss_mixup = torch.nn.functional.mse_loss(y_p_mixup, y_p_mixup2)\n",
"\n",
" # pseudo-label training\n",
" y_pseudo = net_conf(conf + torch.randn_like(conf)*0.15)[0]\n",
" loss_pseudo = torch.nn.functional.cross_entropy(y_pseudo, conf_target)\n",
"\n",
" loss = 0.07*loss_mi2 + np.cos(e/epochs)*loss_mixup + 1.0*loss_pseudo\n",
" loss.backward()\n",
" opt_conf.step()\n",
"\n",
" if e % 500 == 0:\n",
" with torch.no_grad():\n",
" y = net_conf(x_moons_t)[0]\n",
" print(e, ((y.argmax(1) == y_moons_t).sum().float() / y_moons_t.size(0)).item())\n",
"\n",
" N_0 = conf[torch.where(conf_target == 0)].size(0)\n",
" N_1 = conf[torch.where(conf_target == 1)].size(0)\n",
" print(N_0, N_1, loss.item())"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "iDerX-4-h36p"
},
"outputs": [],
"source": [
"_, _, y_test_conf = create_plot_data(net_conf, T)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 289
},
"id": "8Y8KKkuPgQAO",
"outputId": "2a738beb-ec78-4691-847f-d5345da252ff"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 0.13111111521720886\n",
"297 326 0.012329605408012867\n",
"500 0.09222222119569778\n",
"345 396 -0.005128343589603901\n",
"1000 0.09444444626569748\n",
"346 398 0.007566150743514299\n",
"1500 0.08444444090127945\n",
"340 428 -0.0097690774127841\n",
"2000 0.08222222328186035\n",
"341 429 -0.020236210897564888\n",
"2500 0.07000000029802322\n",
"344 437 0.006941571831703186\n",
"3000 0.06888888776302338\n",
"344 436 -0.030865471810102463\n",
"3500 0.06333333253860474\n",
"343 473 -0.026111770421266556\n"
]
}
],
"source": [
"opt_ours = torch.optim.Adam(net_ours.parameters(), lr=1e-2, weight_decay=1e-6)\n",
"\n",
"for e in range(epochs):\n",
" # shuffle batch\n",
" N = x_moons_t.size(0)\n",
" batch = torch.randperm(N)\n",
"\n",
" uncert, conf, pseudo_targets, conf_target = extract_pseudo_labels(net_ours, x_moons_t, T)\n",
"\n",
" opt_ours.zero_grad()\n",
" y_p_a_1, _ = net_ours(x_moons_t[batch] + torch.randn_like(x_moons_t)*0.05)\n",
" y_p_a_2, _ = net_ours(x_moons_t[batch] + torch.randn_like(x_moons_t)*0.15)\n",
" \n",
" loss_mi2 = mutual_information(y_p_a_1.softmax(1), y_p_a_2.softmax(1))\n",
" \n",
" # mixup training\n",
" lamb = torch.rand(1)\n",
" x_mixup, y_mixup = mixup(x_moons_t[batch][:N//2], x_moons_t[batch][N//2:(N//2)*2],\n",
" pseudo_targets[batch][:N//2], pseudo_targets[batch][N//2:(N//2)*2],\n",
" lamb=lamb,\n",
" num_classes=2)\n",
" y_p_1 = net_ours(x_moons_t[batch][:N//2])[0].softmax(1)\n",
" y_p_2 = net_ours(x_moons_t[batch][N//2:(N//2)*2])[0].softmax(1)\n",
" y_p_mixup = net_ours(x_mixup)[0].softmax(1)\n",
"\n",
" y_p_mixup2 = lamb*y_p_1 + (1-lamb)*y_p_2\n",
" loss_mixup = torch.nn.functional.mse_loss(y_p_mixup, y_p_mixup2)\n",
"\n",
" # pseudo-label training\n",
" y_pseudo = net_ours(conf + torch.randn_like(conf)*0.15)[0]\n",
" loss_pseudo = torch.nn.functional.cross_entropy(y_pseudo, conf_target)\n",
"\n",
" loss = 0.07*loss_mi2 + np.cos(e/epochs)*loss_mixup + 1.0*loss_pseudo\n",
" loss.backward()\n",
" opt_ours.step()\n",
"\n",
" if e % 500 == 0:\n",
" with torch.no_grad():\n",
" y = net_ours(x_moons_t)[0]\n",
" print(e, ((y.argmax(1) == y_moons_t).sum().float() / y_moons_t.size(0)).item())\n",
"\n",
" N_0 = conf[torch.where(conf_target == 0)].size(0)\n",
" N_1 = conf[torch.where(conf_target == 1)].size(0)\n",
" print(N_0, N_1, loss.item())"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "5EyVK4KdgO2z"
},
"outputs": [],
"source": [
"xx, yy, y_test_ours = create_plot_data(net_ours, T)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 199
},
"id": "-RCwH7pCi3a0",
"outputId": "37962a2c-550a-4a69-ce4d-eba3b1e139ac"
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 504x162 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1, 3, figsize=(7,2.25), sharey=True)\n",
"\n",
"cs0 = ax[0].contourf(xx, yy, y_test_cl, levels=np.linspace(0, 1, 11))\n",
"cs02 = ax[0].contour(cs0, levels=[0, 0.5, 1], colors='tab:green', linewidths=1.5)\n",
"ax[0].plot(np.nan, color='tab:green', label='decision boundary')\n",
"ax[0].scatter(x_moons_t[::2,0], x_moons_t[::2,1], s=8)\n",
"ax[0].set_title('Consistency Learning')\n",
"ax[0].set_xlabel('x')\n",
"\n",
"cs1 = ax[1].contourf(xx, yy, y_test_conf, levels=np.linspace(0, 1, 11))\n",
"cs12 = ax[1].contour(cs1, levels=[0, 0.5, 1], colors='tab:green', linewidths=1.5)\n",
"ax[1].plot(np.nan, color='tab:green', label='decision boundary')\n",
"ax[1].scatter(x_moons_t[::2,0], x_moons_t[::2,1], s=8)\n",
"ax[1].set_title('CL+ConfPL')\n",
"ax[1].set_xlabel('x')\n",
"\n",
"cs2 = ax[2].contourf(xx, yy, y_test_ours, levels=np.linspace(0, 1, 11))\n",
"cs22 = ax[2].contour(cs2, levels=[0, 0.5, 1], colors='tab:green', linewidths=1.5)\n",
"ax[2].plot(np.nan, color='tab:green', label='decision boundary')\n",
"ax[2].scatter(x_moons_t[::2,0], x_moons_t[::2,1], s=8)\n",
"ax[2].set_title('CL+BatchPL (ours)')\n",
"ax[2].set_xlabel('x')\n",
"\n",
"fig.subplots_adjust(wspace=0.1, hspace=0.1,right=0.8)\n",
"cbar_ax = fig.add_axes([0.82, 0.13, 0.02, 0.67])\n",
"cbar = fig.colorbar(cs2, cax=cbar_ax)\n",
"cbar_ax.set_title(r'$p(c=0 \\vert \\ldots)$', loc='left', fontsize=8)\n",
"cbar.add_lines(cs22)\n",
"\n",
"ax[0].set_ylabel('y')\n",
"ax[0].legend(prop={'size': 8})\n",
"\n",
"fig.savefig(\"self_supervised_opener.pdf\", bbox_inches='tight', pad_inches=0.01)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rIHEWqRWkn6g"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "two_moons_create_opener.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment