Skip to content

Instantly share code, notes, and snippets.

@pmineiro
Last active February 17, 2022 17:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pmineiro/171edfa6963b7d14e6f3d10dc38af9a4 to your computer and use it in GitHub Desktop.
Save pmineiro/171edfa6963b7d14e6f3d10dc38af9a4 to your computer and use it in GitHub Desktop.
IGL with action dependent feedback, mnist demo
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "287318ac",
"metadata": {},
"source": [
"# Supervised"
]
},
{
"cell_type": "code",
"execution_count": 70,
"id": "7fb42e91",
"metadata": {
"code_folding": [
0,
6,
35,
36,
58,
63,
70,
84,
114,
121
],
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"n \tmean \tsince \tacc \tsince \n",
"1 \t2.30257 \t2.30257 \t0.15625 \t0.15625 \n",
"2 \t2.40082 \t2.49907 \t0.21875 \t0.28125 \n",
"3 \t2.21948 \t1.85680 \t0.29167 \t0.43750 \n",
"5 \t1.95383 \t1.55536 \t0.40000 \t0.56250 \n",
"9 \t1.56551 \t1.08011 \t0.50694 \t0.64062 \n",
"17 \t1.18324 \t0.75318 \t0.63419 \t0.77734 \n",
"33 \t0.87001 \t0.53720 \t0.73059 \t0.83301 \n",
"65 \t0.65744 \t0.43823 \t0.79615 \t0.86377 \n",
"129 \t0.49248 \t0.32494 \t0.84726 \t0.89917 \n",
"257 \t0.39149 \t0.28972 \t0.87840 \t0.90979 \n",
"513 \t0.31768 \t0.24357 \t0.90244 \t0.92657 \n",
"938 \t0.27401 \t0.22131 \t0.91613 \t0.93265 \n",
"testacc 0.9558735489845276 testloss 0.14295095205307007\n"
]
}
],
"source": [
"def supervisedLearn():\n",
" import itertools\n",
" import numpy\n",
" import torch\n",
" import torchvision\n",
" \n",
" class EasyAcc:\n",
" def __init__(self):\n",
" self.n = 0\n",
" self.sum = 0\n",
" self.sumsq = 0\n",
"\n",
" def __iadd__(self, other):\n",
" self.n += 1\n",
" self.sum += other\n",
" self.sumsq += other*other\n",
" return self\n",
"\n",
" def __isub__(self, other):\n",
" self.n += 1\n",
" self.sum -= other\n",
" self.sumsq += other*other\n",
" return self\n",
"\n",
" def mean(self):\n",
" return self.sum / max(self.n, 1)\n",
"\n",
" def var(self):\n",
" from math import sqrt\n",
" return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)\n",
"\n",
" def semean(self):\n",
" from math import sqrt\n",
" return self.var() / sqrt(max(self.n, 1))\n",
"\n",
" class RFFSoftmax(torch.nn.Module):\n",
" def __init__(self, hilo, naction, numrff, sigma, seed):\n",
" from math import pi\n",
" import numpy as np\n",
"\n",
" super(RFFSoftmax, self).__init__()\n",
"\n",
" torch.manual_seed(seed)\n",
" nobs = hilo.shape[1]\n",
" high = hilo[1, :]\n",
" low = hilo[0, :]\n",
" \n",
" self.rff = torch.nn.Linear(nobs, numrff)\n",
" self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n",
" torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n",
" self.rff.weight.requires_grad = False\n",
" self.rff.bias.data = 2 * pi * torch.rand(numrff)\n",
" self.rff.bias.requires_grad = False\n",
" self.sqrtrff = np.sqrt(numrff)\n",
" self.final = torch.nn.Linear(numrff, naction)\n",
" self.final.weight.data *= 0.01\n",
" self.final.bias.data *= 0.01\n",
"\n",
" def logits(self, x):\n",
" with torch.no_grad():\n",
" rff = self.rff(x).cos() / self.sqrtrff\n",
" return self.final(rff)\n",
"\n",
" transform = torchvision.transforms.Compose([\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize((0.1307,), (0.3081,))\n",
" ])\n",
" mnist_train = torchvision.datasets.MNIST('/tmp/mnist', train=True, download=True, transform=transform)\n",
" \n",
" quantile_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10000, shuffle=True)\n",
" for bno, (images, labels) in enumerate(quantile_loader):\n",
" flat = images.reshape(images.shape[0], -1)\n",
" hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n",
" pi = RFFSoftmax(hilo, 10, 2000, 0.01, 45)\n",
" break\n",
" \n",
" train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n",
" mnist_test = torchvision.datasets.MNIST('/tmp/mnist', train=False, download=True, transform=transform)\n",
" test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=1000, shuffle=True)\n",
" \n",
" opt = torch.optim.Adam(( p for p in pi.parameters() if p.requires_grad ), lr=0.1)\n",
" loss = torch.nn.CrossEntropyLoss()\n",
" acc, accsincelast, avloss, avlosssincelast = [ EasyAcc() for _ in range(4) ]\n",
" \n",
" print('{:<5s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}'.format(\n",
" 'n', 'mean', 'since',\n",
" 'acc', 'since',\n",
" ),\n",
" flush=True)\n",
" \n",
" for bno, (images, labels) in enumerate(itertools.chain(*[ train_loader for _ in range(1) ])):\n",
" flat = images.reshape(images.shape[0], -1)\n",
" \n",
" opt.zero_grad()\n",
" ld = pi.logits(flat)\n",
" output = loss(ld, labels)\n",
" output.backward()\n",
" opt.step()\n",
" \n",
" with torch.no_grad():\n",
" pred = ld.argmax(dim=1)\n",
" acc += torch.mean((labels == pred).float())\n",
" accsincelast += torch.mean((labels == pred).float())\n",
" avloss += output\n",
" avlosssincelast += output\n",
"\n",
" if (bno & (bno - 1) == 0):\n",
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n",
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n",
" acc.mean(), accsincelast.mean(), \n",
" ),\n",
" flush=True)\n",
" accsincelast, avlosssincelast = EasyAcc(), EasyAcc()\n",
" \n",
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n",
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n",
" acc.mean(), accsincelast.mean(), \n",
" ),\n",
" flush=True)\n",
" accsincelast, avlosssincelast = EasyAcc(), EasyAcc()\n",
" testacc, testloss = EasyAcc(), EasyAcc()\n",
" with torch.no_grad():\n",
" for ti, tl in train_loader:\n",
" flat = ti.reshape(ti.shape[0], -1)\n",
" ld = pi.logits(flat)\n",
" output = loss(ld, tl)\n",
" testloss += output\n",
" testpred = ld.argmax(dim=1)\n",
" testacc += torch.mean((tl == testpred).float())\n",
"\n",
" print(f'testacc {testacc.mean()} testloss {testloss.mean()}')\n",
" \n",
"supervisedLearn()"
]
},
{
"cell_type": "markdown",
"id": "55cda05e",
"metadata": {},
"source": [
"# CB"
]
},
{
"cell_type": "code",
"execution_count": 80,
"id": "de5fa0ff",
"metadata": {
"code_folding": [
0,
6,
7,
10,
20,
49,
50,
73,
78
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"n \tloss \tsince \tacc \tsince \treward \tsince \n",
"1 \t0.69313 \t0.69313 \t0.15625 \t0.15625 \t0.06250 \t0.06250 \n",
"2 \t0.55508 \t0.41703 \t0.10156 \t0.04688 \t0.05469 \t0.04688 \n",
"3 \t0.46854 \t0.29545 \t0.09896 \t0.09375 \t0.05729 \t0.06250 \n",
"5 \t0.46672 \t0.46399 \t0.11563 \t0.14062 \t0.10000 \t0.16406 \n",
"9 \t0.42025 \t0.36216 \t0.18229 \t0.26562 \t0.16493 \t0.24609 \n",
"17 \t0.37235 \t0.31846 \t0.22151 \t0.26562 \t0.20956 \t0.25977 \n",
"33 \t0.34175 \t0.30924 \t0.27415 \t0.33008 \t0.26089 \t0.31543 \n",
"65 \t0.34546 \t0.34929 \t0.38726 \t0.50391 \t0.36250 \t0.46729 \n",
"129 \t0.31504 \t0.28415 \t0.61216 \t0.84058 \t0.56468 \t0.77002 \n",
"257 \t0.26828 \t0.22116 \t0.75298 \t0.89490 \t0.69127 \t0.81885 \n",
"513 \t0.22252 \t0.17659 \t0.83458 \t0.91650 \t0.76459 \t0.83820 \n",
"938 \t0.19493 \t0.16162 \t0.87785 \t0.93007 \t0.80360 \t0.85070 \n",
"testacc 0.9445962309837341\n"
]
}
],
"source": [
"def cbLearn():\n",
" import itertools\n",
" import numpy\n",
" import torch\n",
" import torchvision\n",
" \n",
" class FastCB:\n",
" def __init__(self, gamma):\n",
" self.gamma = gamma\n",
"\n",
" def sample(self, fhat):\n",
" N, K = fhat.shape\n",
" rando = torch.randint(high=K, size=(N, 1), device=fhat.device)\n",
" fhatstar, ahatstar = torch.max(fhat, dim=1, keepdim=True)\n",
" fhatrando = torch.gather(input=fhat, dim=1, index=rando)\n",
" probs = K / (K + self.gamma * (1 - fhatrando / fhatstar))\n",
" unif = torch.rand(size=(N, 1), device=fhat.device)\n",
" shouldexplore = (unif <= probs).long()\n",
" return (ahatstar + shouldexplore * (rando - ahatstar)).squeeze(1)\n",
"\n",
" class EasyAcc:\n",
" def __init__(self):\n",
" self.n = 0\n",
" self.sum = 0\n",
" self.sumsq = 0\n",
"\n",
" def __iadd__(self, other):\n",
" self.n += 1\n",
" self.sum += other\n",
" self.sumsq += other*other\n",
" return self\n",
"\n",
" def __isub__(self, other):\n",
" self.n += 1\n",
" self.sum -= other\n",
" self.sumsq += other*other\n",
" return self\n",
"\n",
" def mean(self):\n",
" return self.sum / max(self.n, 1)\n",
"\n",
" def var(self):\n",
" from math import sqrt\n",
" return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)\n",
"\n",
" def semean(self):\n",
" from math import sqrt\n",
" return self.var() / sqrt(max(self.n, 1))\n",
"\n",
" class RFFSoftmax(torch.nn.Module):\n",
" def __init__(self, hilo, naction, numrff, sigma, seed):\n",
" from math import pi\n",
" import numpy as np\n",
"\n",
" super(RFFSoftmax, self).__init__()\n",
"\n",
" torch.manual_seed(seed)\n",
" nobs = hilo.shape[1]\n",
" high = hilo[1, :]\n",
" low = hilo[0, :]\n",
" \n",
" self.rff = torch.nn.Linear(nobs, numrff)\n",
" self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n",
" torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n",
" self.rff.weight.requires_grad = False\n",
" self.rff.bias.data = 2 * pi * torch.rand(numrff)\n",
" self.rff.bias.requires_grad = False\n",
" self.sqrtrff = np.sqrt(numrff)\n",
" self.final = torch.nn.Linear(numrff, naction)\n",
" self.final.weight.data *= 0.01\n",
" self.final.bias.data *= 0.01\n",
" self.sigmoid = torch.nn.Sigmoid()\n",
"\n",
" def forward(self, x):\n",
" with torch.no_grad():\n",
" rff = self.rff(x).cos() / self.sqrtrff\n",
" return self.final(rff)\n",
" \n",
" def density(self, logits):\n",
" return self.sigmoid(logits)\n",
"\n",
" transform = torchvision.transforms.Compose([\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize((0.1307,), (0.3081,))\n",
" ])\n",
" mnist_train = torchvision.datasets.MNIST('/tmp/mnist', train=True, download=True, transform=transform)\n",
" \n",
" quantile_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10000, shuffle=True)\n",
" for bno, (images, labels) in enumerate(quantile_loader):\n",
" flat = images.reshape(images.shape[0], -1)\n",
" hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n",
" pi = RFFSoftmax(hilo, 10, 2000, 0.01, 45)\n",
" break\n",
" \n",
" train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n",
" mnist_test = torchvision.datasets.MNIST('/tmp/mnist', train=False, download=True, transform=transform)\n",
" test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=1000, shuffle=True)\n",
" sampler = FastCB(gamma=100)\n",
" \n",
" opt = torch.optim.Adam(( p for p in pi.parameters() if p.requires_grad ), lr=1e-1)\n",
" log_loss = torch.nn.BCEWithLogitsLoss()\n",
" acc, accsincelast, avloss, avlosssincelast, avreward, avrewardsincelast = [ EasyAcc() for _ in range(6) ]\n",
" \n",
" print('{:<5s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}'.format(\n",
" 'n', 'loss', 'since', \n",
" 'acc', 'since',\n",
" 'reward', 'since',\n",
" ),\n",
" flush=True)\n",
" \n",
" for bno, (images, labels) in enumerate(itertools.chain(*[ train_loader for _ in range(1) ])):\n",
" flat = images.reshape(images.shape[0], -1)\n",
" \n",
" opt.zero_grad()\n",
" logit = pi(flat)\n",
" with torch.no_grad():\n",
" fhat = pi.density(logit)\n",
" sample = sampler.sample(fhat)\n",
" reward = (sample == labels).unsqueeze(1).float()\n",
" \n",
" samplelogit = torch.gather(input=logit, index=sample.unsqueeze(1), dim=1)\n",
" loss = log_loss(samplelogit, reward)\n",
" loss.backward()\n",
" opt.step()\n",
" \n",
" with torch.no_grad():\n",
" pred = logit.argmax(dim=1)\n",
" acc += torch.mean((labels == pred).float())\n",
" accsincelast += torch.mean((labels == pred).float())\n",
" avloss += loss\n",
" avlosssincelast += loss\n",
" avreward += torch.mean(reward)\n",
" avrewardsincelast += torch.mean(reward)\n",
" \n",
" if (bno & (bno - 1) == 0):\n",
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n",
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n",
" acc.mean(), accsincelast.mean(), \n",
" avreward.mean(), avrewardsincelast.mean(),\n",
" ),\n",
" flush=True)\n",
" accsincelast, avlosssincelast, avrewardsincelast = EasyAcc(), EasyAcc(), EasyAcc()\n",
" \n",
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n",
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n",
" acc.mean(), accsincelast.mean(), \n",
" avreward.mean(), avrewardsincelast.mean(),\n",
" ),\n",
" flush=True)\n",
" accsincelast, avlosssincelast, avrewardsincelast = EasyAcc(), EasyAcc(), EasyAcc()\n",
" testacc = EasyAcc()\n",
" with torch.no_grad():\n",
" for ti, tl in train_loader:\n",
" flat = ti.reshape(ti.shape[0], -1)\n",
" logit = pi(flat)\n",
" testpred = logit.argmax(dim=1)\n",
" testacc += torch.mean((tl == testpred).float())\n",
"\n",
" print(f'testacc {testacc.mean()}')\n",
"\n",
"cbLearn()"
]
},
{
"cell_type": "markdown",
"id": "09186826",
"metadata": {},
"source": [
"# IGL ($y_a \\perp x, a|r_a$)\n",
"$y_a$ is a (randomly selected) \"zero\" image or a (randomly selected) \"one\" image depending only upon $r_a$."
]
},
{
"cell_type": "code",
"execution_count": 113,
"id": "67459e03",
"metadata": {
"code_folding": [
0,
6,
22,
51,
83,
90,
100,
113,
139,
188,
197
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"n \tloss \tsince \tacc \tsince \treward \tsince \tfake \tsince \n",
"1 \t1.38627 \t1.38627 \t0.32812 \t0.32812 \t0.09375 \t0.09375 \t0.49997 \t0.49997 \n",
"2 \t1.37295 \t1.35963 \t0.18750 \t0.04688 \t0.07812 \t0.06250 \t0.48236 \t0.46476 \n",
"3 \t1.35572 \t1.32125 \t0.17188 \t0.14062 \t0.08333 \t0.09375 \t0.46612 \t0.43364 \n",
"5 \t1.33542 \t1.30497 \t0.14687 \t0.10938 \t0.09375 \t0.10938 \t0.43557 \t0.38975 \n",
"9 \t1.28511 \t1.22223 \t0.14583 \t0.14453 \t0.11111 \t0.13281 \t0.38302 \t0.31733 \n",
"17 \t1.19374 \t1.09095 \t0.15533 \t0.16602 \t0.11949 \t0.12891 \t0.31242 \t0.23300 \n",
"33 \t1.19390 \t1.19407 \t0.16714 \t0.17969 \t0.12689 \t0.13477 \t0.26854 \t0.22192 \n",
"65 \t1.25434 \t1.31667 \t0.25024 \t0.33594 \t0.18438 \t0.24365 \t0.43038 \t0.59728 \n",
"129 \t1.12546 \t0.99457 \t0.45094 \t0.65479 \t0.33285 \t0.48364 \t0.54036 \t0.65205 \n",
"257 \t0.88065 \t0.63392 \t0.65394 \t0.85852 \t0.51283 \t0.69421 \t0.65974 \t0.78006 \n",
"513 \t0.68728 \t0.49315 \t0.77656 \t0.89966 \t0.63411 \t0.75586 \t0.74226 \t0.82510 \n",
"938 \t0.56352 \t0.41413 \t0.84097 \t0.91871 \t0.70281 \t0.78574 \t0.78930 \t0.84607 \n",
"testacc 0.928521454334259\n"
]
}
],
"source": [
"def iglLearn():\n",
" import itertools\n",
" import numpy\n",
" import torch\n",
" import torchvision\n",
" \n",
" class SquareCB(object):\n",
" def __init__(self, gamma):\n",
" super(SquareCB, self).__init__()\n",
"\n",
" self.gamma = gamma\n",
"\n",
" def sample(self, fhat):\n",
" N, K = fhat.shape\n",
" rando = torch.randint(high=K, size=(N, 1), device=fhat.device)\n",
" fhatstar, ahatstar = torch.max(fhat, dim=1, keepdim=True)\n",
" fhatrando = torch.gather(input=fhat, dim=1, index=rando)\n",
" probs = K / (K + self.gamma * (fhatstar - fhatrando))\n",
" unif = torch.rand(size=(N, 1), device=fhat.device)\n",
" shouldexplore = (unif <= probs).long()\n",
" return (ahatstar + shouldexplore * (rando - ahatstar)).squeeze(1)\n",
" \n",
" class EasyAcc:\n",
" def __init__(self):\n",
" self.n = 0\n",
" self.sum = 0\n",
" self.sumsq = 0\n",
"\n",
" def __iadd__(self, other):\n",
" self.n += 1\n",
" self.sum += other\n",
" self.sumsq += other*other\n",
" return self\n",
"\n",
" def __isub__(self, other):\n",
" self.n += 1\n",
" self.sum -= other\n",
" self.sumsq += other*other\n",
" return self\n",
"\n",
" def mean(self):\n",
" return self.sum / max(self.n, 1)\n",
"\n",
" def var(self):\n",
" from math import sqrt\n",
" return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)\n",
"\n",
" def semean(self):\n",
" from math import sqrt\n",
" return self.var() / sqrt(max(self.n, 1))\n",
"\n",
" class RFFSoftmax(torch.nn.Module):\n",
" def __init__(self, hilo, naction, numrff, sigma, seed):\n",
" from math import pi\n",
" import numpy as np\n",
"\n",
" super(RFFSoftmax, self).__init__()\n",
"\n",
" torch.manual_seed(seed)\n",
" nobs = hilo.shape[1]\n",
" high = hilo[1, :]\n",
" low = hilo[0, :]\n",
" \n",
" self.rff = torch.nn.Linear(nobs, numrff)\n",
" self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n",
" torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n",
" self.rff.weight.requires_grad = False\n",
" self.rff.bias.data = 2 * pi * torch.rand(numrff)\n",
" self.rff.bias.requires_grad = False\n",
" self.sqrtrff = np.sqrt(numrff)\n",
" self.final = torch.nn.Linear(numrff, naction)\n",
" self.final.weight.data *= 0.01\n",
" self.final.bias.data *= 0.01\n",
" self.sigmoid = torch.nn.Sigmoid()\n",
"\n",
" def forward(self, x):\n",
" with torch.no_grad():\n",
" rff = self.rff(x).cos() / self.sqrtrff\n",
" return self.final(rff)\n",
" \n",
" def density(self, logits):\n",
" return self.sigmoid(logits)\n",
"\n",
" transform = torchvision.transforms.Compose([\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize((0.1307,), (0.3081,))\n",
" ])\n",
" mnist_train = torchvision.datasets.MNIST('/tmp/mnist', train=True, download=True, transform=transform)\n",
" \n",
" quantile_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10000, shuffle=True)\n",
" for bno, (images, labels) in enumerate(quantile_loader):\n",
" flat = images.reshape(images.shape[0], -1)\n",
" hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n",
" pi = RFFSoftmax(hilo, 10, 2000, 0.01, 45)\n",
" decoder = RFFSoftmax(hilo, 1, 2000, 0.01, 2112)\n",
" break\n",
" \n",
" zero_one_loader = torch.utils.data.DataLoader(mnist_train, batch_size=1, shuffle=True)\n",
" zeros = []\n",
" ones = []\n",
" for bno, (images, labels) in enumerate(zero_one_loader):\n",
" flat = images.reshape(images.shape[0], -1)\n",
" if labels[0] == 0:\n",
" zeros.append(flat)\n",
" elif labels[0] == 1:\n",
" ones.append(flat)\n",
" \n",
" if len(zeros) > 100 and len(ones) > 100:\n",
" break \n",
" zeros = torch.cat(zeros, dim=0)\n",
" ones = torch.cat(ones, dim=0)\n",
" \n",
" # pre-train to get policy \"better than random\"\n",
" if True:\n",
" preopt = torch.optim.Adam(( p for p in pi.parameters() if p.requires_grad ), lr=1e-3) # 0.1\n",
" preloss = torch.nn.CrossEntropyLoss()\n",
" pretrain_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n",
" for bno, (images, labels) in enumerate(itertools.chain(*[ pretrain_loader for _ in range(1) ])):\n",
" flat = images.reshape(images.shape[0], -1)\n",
"\n",
" preopt.zero_grad()\n",
" ld = pi.forward(flat)\n",
" output = preloss(ld, labels)\n",
" output.backward()\n",
" preopt.step()\n",
"\n",
" if bno > 0:\n",
" break\n",
" \n",
" train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n",
" mnist_test = torchvision.datasets.MNIST('/tmp/mnist', train=False, download=True, transform=transform)\n",
" test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=1000, shuffle=True)\n",
" \n",
" opt = torch.optim.Adam(( p for p in itertools.chain(pi.parameters(), decoder.parameters()) if p.requires_grad ), lr=1e-2)\n",
" log_loss = torch.nn.BCEWithLogitsLoss(reduce='none')\n",
" sampler = SquareCB(gamma=100)\n",
" acc, accsincelast, avloss, avlosssincelast = [ EasyAcc() for _ in range(4) ]\n",
" avreward, avrewardsincelast, avfake, avfakesincelast = [ EasyAcc() for _ in range(4) ]\n",
" \n",
" print('{:<5s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}'.format(\n",
" 'n', 'loss', 'since', \n",
" 'acc', 'since',\n",
" 'reward', 'since',\n",
" 'fake', 'since',\n",
" ),\n",
" flush=True)\n",
" \n",
" for bno, (images, labels) in enumerate(itertools.chain(*[ train_loader for _ in range(1) ])):\n",
" flat = images.reshape(images.shape[0], -1)\n",
" \n",
" opt.zero_grad()\n",
" logit = pi(flat)\n",
" with torch.no_grad():\n",
" fhat = pi.density(logit)\n",
" sample = sampler.sample(fhat)\n",
" reward = (sample == labels).unsqueeze(1).float()\n",
" pred = logit.argmax(dim=1)\n",
" ispred = (sample == pred).unsqueeze(1).float()\n",
" antipred = logit.argmin(dim=1)\n",
" isantipred = (sample == antipred).unsqueeze(1).float()\n",
" zerossample = torch.randint(low=0, high=zeros.shape[0], size=(fhat.shape[0], 1))\n",
" zerofeedback = torch.gather(input=zeros, index=zerossample.expand(-1, zeros.shape[1]), dim=0)\n",
" onessample = torch.randint(low=0, high=ones.shape[0], size=(fhat.shape[0], 1))\n",
" onefeedback = torch.gather(input=ones, index=onessample.expand(-1, ones.shape[1]), dim=0)\n",
" feedback = zerofeedback + reward * (onefeedback - zerofeedback) \n",
" \n",
" samplelogit = torch.gather(input=logit, index=sample.unsqueeze(1), dim=1)\n",
" fakelogit = decoder(feedback)\n",
" fakereward = decoder.density(fakelogit)\n",
" predloss = torch.mean(log_loss(fakelogit, ispred) + log_loss(samplelogit, fakereward.detach()))\n",
" antipredloss = torch.mean(log_loss(1 - fakelogit, isantipred) + log_loss(1 - samplelogit, fakereward.detach()))\n",
" loss = torch.min(predloss, antipredloss)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" \n",
" with torch.no_grad():\n",
" pred = logit.argmax(dim=1)\n",
" acc += torch.mean((labels == pred).float())\n",
" accsincelast += torch.mean((labels == pred).float())\n",
" avloss += loss\n",
" avlosssincelast += loss\n",
" avreward += torch.mean(reward)\n",
" avrewardsincelast += torch.mean(reward)\n",
" avfake += torch.mean(fakereward)\n",
" avfakesincelast += torch.mean(fakereward)\n",
" \n",
" if (bno & (bno - 1) == 0):\n",
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n",
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n",
" acc.mean(), accsincelast.mean(), \n",
" avreward.mean(), avrewardsincelast.mean(),\n",
" avfake.mean(), avfakesincelast.mean(),\n",
" ),\n",
" flush=True)\n",
" accsincelast, avlosssincelast, avrewardsincelast, avfakesincelast = [ EasyAcc() for _ in range(4) ]\n",
" \n",
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n",
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n",
" acc.mean(), accsincelast.mean(), \n",
" avreward.mean(), avrewardsincelast.mean(),\n",
" avfake.mean(), avfakesincelast.mean(),\n",
" ),\n",
" flush=True)\n",
" accsincelast, avlosssincelast, avrewardsincelast = EasyAcc(), EasyAcc(), EasyAcc()\n",
" testacc = EasyAcc()\n",
" with torch.no_grad():\n",
" for ti, tl in train_loader:\n",
" flat = ti.reshape(ti.shape[0], -1)\n",
" logit = pi(flat)\n",
" testpred = logit.argmax(dim=1)\n",
" testacc += torch.mean((tl == testpred).float())\n",
"\n",
" print(f'testacc {testacc.mean()}')\n",
"\n",
"iglLearn()"
]
},
{
"cell_type": "markdown",
"id": "90994504",
"metadata": {},
"source": [
"# IGL ($y_a \\perp x|r_a$)\n",
"$y_a$ is an image of the action taken if $r_a = 1$, e.g., if $a=3$, a \"three\" image; otherwise if $r_a = 0$, an image of $(9-a)$, e.g., if $a=3$, a \"six\" image."
]
},
{
"cell_type": "code",
"execution_count": 288,
"id": "4d4e6631",
"metadata": {
"code_folding": [
6,
16,
26,
31,
55,
84,
85,
108,
113,
116,
117,
140,
145,
148,
155,
164,
172,
199,
213,
233,
271,
305,
314
],
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"n \tloss \tsince \tacc \tsince \treward \tsince \tfake \tsince \n",
"1 \t1.38891 \t1.38891 \t0.28125 \t0.28125 \t0.23438 \t0.23438 \t0.49940 \t0.49940 \n",
"2 \t1.38223 \t1.37554 \t0.29688 \t0.31250 \t0.19531 \t0.15625 \t0.49962 \t0.49983 \n",
"3 \t1.38101 \t1.37858 \t0.31771 \t0.35938 \t0.18750 \t0.17188 \t0.49995 \t0.50062 \n",
"5 \t1.38767 \t1.39765 \t0.29063 \t0.25000 \t0.14687 \t0.08594 \t0.49908 \t0.49777 \n",
"9 \t1.38487 \t1.38138 \t0.19965 \t0.08594 \t0.13368 \t0.11719 \t0.50189 \t0.50540 \n",
"17 \t1.35851 \t1.32885 \t0.18199 \t0.16211 \t0.13787 \t0.14258 \t0.50388 \t0.50613 \n",
"33 \t1.34694 \t1.33465 \t0.20028 \t0.21973 \t0.14725 \t0.15723 \t0.50259 \t0.50121 \n",
"65 \t1.33522 \t1.32314 \t0.27428 \t0.35059 \t0.19447 \t0.24316 \t0.50141 \t0.50020 \n",
"129 \t1.26931 \t1.20236 \t0.35913 \t0.44531 \t0.26211 \t0.33081 \t0.50427 \t0.50716 \n",
"257 \t1.14736 \t1.02446 \t0.52420 \t0.69055 \t0.40096 \t0.54089 \t0.53432 \t0.56461 \n",
"513 \t0.99472 \t0.84148 \t0.65068 \t0.77765 \t0.52537 \t0.65027 \t0.57622 \t0.61829 \n",
"938 \t0.86970 \t0.71879 \t0.73391 \t0.83438 \t0.61042 \t0.71309 \t0.61330 \t0.65806 \n",
"testacc 0.8389525413513184\n"
]
}
],
"source": [
"def iglADepLearn():\n",
" import itertools\n",
" import numpy\n",
" import torch\n",
" import torchvision\n",
" \n",
" class WeightedReservoir(object):\n",
" def __init__(self, n, seed):\n",
" import random\n",
" \n",
" super().__init__()\n",
" self.n = n\n",
" self.items = []\n",
" self.wsum = 0\n",
" self.gen = random.Random(seed) \n",
" \n",
" def insert(self, item, weight):\n",
" if weight > 0:\n",
" self.wsum += weight\n",
" if self.wsum * self.gen.random() < weight:\n",
" if len(self.items) < self.n:\n",
" self.items.append(item)\n",
" else:\n",
" index = self.gen.randrange(0, self.n) \n",
" self.items[index] = item\n",
" \n",
" def sample(self):\n",
" assert len(self.items) > 0\n",
" index = self.gen.randrange(0, len(self.items))\n",
" return self.items[index]\n",
" \n",
" class SquareCB(object):\n",
" def __init__(self, gamma):\n",
" super().__init__()\n",
"\n",
" self.gamma = gamma\n",
"\n",
" def sample(self, fhat, *, keepdim=False):\n",
" N, K = fhat.shape\n",
" fhatstar, ahatstar = torch.max(fhat, dim=1, keepdim=True)\n",
" probs = 1 / (K + self.gamma * (fhatstar - fhat))\n",
" psum = torch.sum(probs, dim=1, keepdim=True)\n",
" phatstar = psum + torch.gather(input=probs, dim=1, index=ahatstar)\n",
"\n",
" rando = torch.randint(high=K, size=(N, 1), device=fhat.device)\n",
" prando = torch.gather(input=probs, dim=1, index=rando)\n",
" unif = torch.rand(size=(N, 1), device=fhat.device)\n",
" shouldexplore = (unif <= K * prando).long()\n",
" actions = ahatstar + shouldexplore * (rando - ahatstar)\n",
" pactions = phatstar + shouldexplore * (prando - phatstar)\n",
" if not keepdim:\n",
" actions = actions.squeeze(1)\n",
" pactions = pactions.squeeze(1)\n",
" return actions, pactions\n",
" \n",
" class EasyAcc:\n",
" def __init__(self):\n",
" self.n = 0\n",
" self.sum = 0\n",
" self.sumsq = 0\n",
"\n",
" def __iadd__(self, other):\n",
" self.n += 1\n",
" self.sum += other\n",
" self.sumsq += other*other\n",
" return self\n",
"\n",
" def __isub__(self, other):\n",
" self.n += 1\n",
" self.sum -= other\n",
" self.sumsq += other*other\n",
" return self\n",
"\n",
" def mean(self):\n",
" return self.sum / max(self.n, 1)\n",
"\n",
" def var(self):\n",
" from math import sqrt\n",
" return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)\n",
"\n",
" def semean(self):\n",
" from math import sqrt\n",
" return self.var() / sqrt(max(self.n, 1))\n",
"\n",
" class RFFBilinearSoftmax(torch.nn.Module):\n",
" def __init__(self, hilo, naction, numrff, sigma, seed):\n",
" from math import pi\n",
" import numpy as np\n",
"\n",
" super().__init__()\n",
"\n",
" torch.manual_seed(seed)\n",
" nobs = hilo.shape[1]\n",
" high = hilo[1, :]\n",
" low = hilo[0, :]\n",
" \n",
" self.rff = torch.nn.Linear(nobs, numrff)\n",
" self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n",
" torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n",
" self.rff.weight.requires_grad = False\n",
" self.rff.bias.data = 2 * pi * torch.rand(numrff)\n",
" self.rff.bias.requires_grad = False\n",
" self.sqrtrff = np.sqrt(numrff)\n",
" self.final = torch.nn.Bilinear(naction, numrff, 1)\n",
" self.final.weight.data *= 0.01\n",
" self.final.bias.data *= 0.01\n",
" self.sigmoid = torch.nn.Sigmoid()\n",
"\n",
" def forward(self, a, y):\n",
" with torch.no_grad():\n",
" rff = self.rff(y).cos() / self.sqrtrff\n",
" return self.final(a, rff)\n",
" \n",
" def density(self, logits):\n",
" return self.sigmoid(logits)\n",
"\n",
" class RFFSoftmax(torch.nn.Module):\n",
" def __init__(self, hilo, naction, numrff, sigma, seed):\n",
" from math import pi\n",
" import numpy as np\n",
"\n",
" super().__init__()\n",
"\n",
" torch.manual_seed(seed)\n",
" nobs = hilo.shape[1]\n",
" high = hilo[1, :]\n",
" low = hilo[0, :]\n",
" \n",
" self.rff = torch.nn.Linear(nobs, numrff)\n",
" self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n",
" torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n",
" self.rff.weight.requires_grad = False\n",
" self.rff.bias.data = 2 * pi * torch.rand(numrff)\n",
" self.rff.bias.requires_grad = False\n",
" self.sqrtrff = np.sqrt(numrff)\n",
" self.final = torch.nn.Linear(numrff, naction)\n",
" self.final.weight.data *= 0.01\n",
" self.final.bias.data *= 0.01\n",
" self.sigmoid = torch.nn.Sigmoid()\n",
"\n",
" def forward(self, x):\n",
" with torch.no_grad():\n",
" rff = self.rff(x).cos() / self.sqrtrff\n",
" return self.final(rff)\n",
" \n",
" def preq1(self, logits):\n",
" return self.sigmoid(logits)\n",
"\n",
" transform = torchvision.transforms.Compose([\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize((0.1307,), (0.3081,))\n",
" ])\n",
" mnist_train = torchvision.datasets.MNIST('/tmp/mnist', train=True, download=True, transform=transform)\n",
" \n",
" quantile_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10000, shuffle=True)\n",
" for bno, (images, labels) in enumerate(quantile_loader):\n",
" flat = images.reshape(images.shape[0], -1)\n",
" hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n",
" pi = RFFSoftmax(hilo, 10, 2000, 0.01, 45)\n",
" decoder = RFFBilinearSoftmax(hilo, 10, 2000, 0.01, 2112)\n",
" break\n",
" \n",
" feedback_loader = torch.utils.data.DataLoader(mnist_train, batch_size=1, shuffle=True)\n",
" feedbacks = [ [] for _ in range(10) ]\n",
" for bno, (images, labels) in enumerate(feedback_loader):\n",
" flat = images.reshape(images.shape[0], -1)\n",
" feedbacks[labels[0]].append(flat)\n",
" if all(len(x) > 100 for x in feedbacks):\n",
" break \n",
" feedbacks = torch.cat([ torch.cat(x[:100], dim=0).unsqueeze(0) for x in feedbacks ], dim=0)\n",
" \n",
" # pre-train to get policy \"better than random\"\n",
" if True:\n",
" preopt = torch.optim.Adam(( p for p in pi.parameters() if p.requires_grad ), lr=1e-2) # 0.1\n",
" preloss = torch.nn.CrossEntropyLoss()\n",
" pretrain_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n",
" for bno, (images, labels) in enumerate(itertools.chain(*[ pretrain_loader for _ in range(1) ])):\n",
" flat = images.reshape(images.shape[0], -1)\n",
"\n",
" preopt.zero_grad()\n",
" ld = pi.forward(flat)\n",
" output = preloss(ld, labels)\n",
" output.backward()\n",
" preopt.step()\n",
"\n",
" if bno > 0:\n",
" break\n",
" \n",
" train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n",
" mnist_test = torchvision.datasets.MNIST('/tmp/mnist', train=False, download=True, transform=transform)\n",
" test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=1000, shuffle=True)\n",
" \n",
" opt = torch.optim.Adam(( p for p in itertools.chain(pi.parameters(), decoder.parameters()) if p.requires_grad ), lr=1e-2)\n",
" log_loss = torch.nn.BCEWithLogitsLoss(reduce='none')\n",
" sampler = SquareCB(gamma=100)\n",
" acc, accsincelast, avloss, avlosssincelast = [ EasyAcc() for _ in range(4) ]\n",
" avreward, avrewardsincelast, avfake, avfakesincelast = [ EasyAcc() for _ in range(4) ]\n",
" reservoirs = [ WeightedReservoir(20, 1973+a) for a in range(10) ]\n",
" \n",
" print('{:<5s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}'.format(\n",
" 'n', 'loss', 'since', \n",
" 'acc', 'since',\n",
" 'reward', 'since',\n",
" 'fake', 'since',\n",
" ),\n",
" flush=True)\n",
" \n",
" for bno, (images, labels) in enumerate(itertools.chain(*[ train_loader for _ in range(1) ])):\n",
" flatimage = images.reshape(images.shape[0], -1)\n",
" \n",
" opt.zero_grad()\n",
" logit = pi(flatimage)\n",
" \n",
" with torch.no_grad():\n",
" fhat = pi.preq1(logit)\n",
" sample, probs = sampler.sample(fhat, keepdim=True)\n",
" \n",
" reward = (sample == labels.unsqueeze(1)).float()\n",
" pred = logit.argmax(dim=1, keepdim=True)\n",
" ispred = (sample == pred).float()\n",
" antipred = logit.argmin(dim=1, keepdim=True)\n",
" isantipred = (sample == antipred).float()\n",
" \n",
" # this assumes a particular majorization (Torch tensors are row-major)\n",
" bigfeedbacks = feedbacks.unsqueeze(0).expand(fhat.shape[0], -1, -1, -1).reshape(fhat.shape[0], -1, flatimage.shape[1]) # Batch x (A x Rep) x Pixels\n",
" nreps = feedbacks.shape[1]\n",
" goodwhich = feedbacks.shape[1] * sample.squeeze(1) + torch.randint(low=0, high=feedbacks.shape[1], size=(fhat.shape[0],))\n",
" goodwhich = goodwhich.unsqueeze(1).unsqueeze(2).expand(-1, -1, flatimage.shape[1])\n",
" goodfeedbacks = torch.gather(input=bigfeedbacks, index=goodwhich, dim=1).squeeze(1)\n",
" badwhich = feedbacks.shape[1] * (9-sample).squeeze(1) + torch.randint(low=0, high=feedbacks.shape[1], size=(fhat.shape[0],))\n",
" badwhich = badwhich.unsqueeze(1).unsqueeze(2).expand(-1, -1, flatimage.shape[1])\n",
" badfeedbacks = torch.gather(input=bigfeedbacks, index=badwhich, dim=1).squeeze(1)\n",
" \n",
" if False:\n",
" import matplotlib.pyplot as plt\n",
"\n",
" fig, axs = plt.subplots(1, 10)\n",
" for n, (s, f) in enumerate(zip(sample, goodfeedbacks)):\n",
" if n > 9:\n",
" break\n",
" axs[n].imshow(f.reshape(28, 28))\n",
" axs[n].set_title(f'{s.item()}')\n",
" \n",
" plt.show()\n",
" \n",
" fig, axs = plt.subplots(1, 10)\n",
" for n, (s, f) in enumerate(zip(sample, badfeedbacks)):\n",
" if n > 9:\n",
" break\n",
" axs[n].imshow(f.reshape(28, 28))\n",
" axs[n].set_title(f'{s.item()}')\n",
" \n",
" plt.show()\n",
" assert False\n",
" \n",
" feedback = badfeedbacks + reward * (goodfeedbacks - badfeedbacks)\n",
" onehotsample = torch.nn.functional.one_hot(sample.squeeze(1), num_classes=fhat.shape[1]).float()\n",
" \n",
" # insert then sample ... means the first time we play an action there will be no update, that's ok\n",
" for s, p, r, f in zip(sample, probs, reward, feedback):\n",
" reservoirs[s.item()].insert((f, r), 1/p)\n",
" \n",
" compfeedback = []\n",
" compreward = []\n",
" for s in sample:\n",
" f, r = reservoirs[s.item()].sample()\n",
" compfeedback.append(f.unsqueeze(0))\n",
" compreward.append(r.unsqueeze(0))\n",
" compfeedback = torch.cat(compfeedback, dim=0)\n",
" compreward = torch.cat(compreward, dim=0)\n",
" \n",
" if False:\n",
" import matplotlib.pyplot as plt\n",
"\n",
" fig, axs = plt.subplots(1, 10)\n",
" for n, (s, f, r) in enumerate(zip(sample, compfeedback, compreward)):\n",
" if n > 9:\n",
" break\n",
" axs[n].imshow(f.reshape(28, 28))\n",
" axs[n].set_title(f'{s.item()} {r.long().item()}')\n",
" \n",
" plt.show()\n",
" assert False\n",
"\n",
" samplelogit = torch.gather(input=logit, index=sample, dim=1)\n",
" fakelogit = decoder(onehotsample, feedback)\n",
" fakereward = decoder.density(fakelogit)\n",
" fakecomplogit = decoder(onehotsample, compfeedback)\n",
" predloss = torch.mean(log_loss(fakelogit - fakecomplogit, ispred) + log_loss(samplelogit, fakereward.detach()))\n",
" antipredloss = torch.mean(log_loss(fakecomplogit - fakelogit, isantipred) + log_loss(1 - samplelogit, fakereward.detach()))\n",
" loss = torch.min(predloss, antipredloss)\n",
" loss.backward()\n",
" opt.step()\n",
" \n",
" with torch.no_grad():\n",
" acc += torch.mean((labels.unsqueeze(1) == pred).float())\n",
" accsincelast += torch.mean((labels.unsqueeze(1) == pred).float())\n",
" avloss += loss\n",
" avlosssincelast += loss\n",
" avreward += torch.mean(reward)\n",
" avrewardsincelast += torch.mean(reward)\n",
" avfake += torch.mean(fakereward)\n",
" avfakesincelast += torch.mean(fakereward)\n",
" \n",
" if (bno & (bno - 1) == 0):\n",
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n",
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n",
" acc.mean(), accsincelast.mean(), \n",
" avreward.mean(), avrewardsincelast.mean(),\n",
" avfake.mean(), avfakesincelast.mean(),\n",
" ),\n",
" flush=True)\n",
" accsincelast, avlosssincelast, avrewardsincelast, avfakesincelast = [ EasyAcc() for _ in range(4) ]\n",
" \n",
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n",
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n",
" acc.mean(), accsincelast.mean(), \n",
" avreward.mean(), avrewardsincelast.mean(),\n",
" avfake.mean(), avfakesincelast.mean(),\n",
" ),\n",
" flush=True)\n",
" accsincelast, avlosssincelast, avrewardsincelast = EasyAcc(), EasyAcc(), EasyAcc()\n",
" testacc = EasyAcc()\n",
" with torch.no_grad():\n",
" for ti, tl in train_loader:\n",
" flat = ti.reshape(ti.shape[0], -1)\n",
" logit = pi(flat)\n",
" testpred = logit.argmax(dim=1)\n",
" testacc += torch.mean((tl == testpred).float())\n",
"\n",
" print(f'testacc {testacc.mean()}')\n",
"\n",
"iglADepLearn()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef8c19c9",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment