Skip to content

Instantly share code, notes, and snippets.

@pmineiro
Last active March 5, 2022 21:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pmineiro/902b40b3054a77a1e85af6d5ffd469fe to your computer and use it in GitHub Desktop.
Save pmineiro/902b40b3054a77a1e85af6d5ffd469fe to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "f90e9f40",
"metadata": {},
"source": [
"# Util"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "393278c2",
"metadata": {
"code_folding": [
2,
18,
26,
47,
48,
54,
69,
82,
91,
100,
104,
117,
156,
157,
163,
174,
187,
195,
196,
208,
228,
261,
267,
268,
274,
285,
298,
306,
307,
319,
340,
373,
379,
395,
401,
411,
412,
444,
447,
465,
468,
475,
485,
512,
515,
518,
539,
549,
559,
591,
594,
602,
603,
609,
627,
630,
631,
638,
662,
665
]
},
"outputs": [],
"source": [
"from abc import ABC, abstractmethod\n",
"\n",
"class Batch(object):\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" @abstractmethod\n",
" def getContext(self):\n",
" pass\n",
"\n",
" @abstractmethod\n",
" def getFeedback(self, action):\n",
" pass\n",
" \n",
" @abstractmethod\n",
" def getReward(self, action):\n",
" pass\n",
"\n",
"class Simulator(object):\n",
" def __init__(self):\n",
" super().__init__()\n",
" \n",
" @abstractmethod\n",
" def trainIterator(self):\n",
" pass\n",
" \n",
"class MnistSimulator(Simulator):\n",
" def __init__(self, batch_size):\n",
" import torchvision\n",
"\n",
" super().__init__()\n",
" self.batch_size = batch_size\n",
" \n",
" transform = torchvision.transforms.Compose([\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize((0.1307,), (0.3081,))\n",
" ])\n",
" self.mnist_train = torchvision.datasets.MNIST('/tmp/mnist', train=True, download=True, transform=transform)\n",
" \n",
" @abstractmethod\n",
" def computeHilo(self):\n",
" pass\n",
" \n",
" @abstractmethod\n",
" def trainIterator(self):\n",
" pass\n",
" \n",
"class MnistFullCI(MnistSimulator):\n",
" def __init__(self, *, batch_size, decodability):\n",
" super().__init__(batch_size)\n",
" self.decodability = decodability\n",
" \n",
" self._makeFeedbacks()\n",
" \n",
" def _makeFeedbacks(self):\n",
" import torch\n",
" zero_one_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=1, shuffle=True)\n",
" zeros, ones = [], []\n",
" for bno, (images, labels) in enumerate(zero_one_loader):\n",
" flat = images.reshape(images.shape[0], -1)\n",
" if labels[0] == 0:\n",
" zeros.append(flat)\n",
" elif labels[0] == 1:\n",
" ones.append(flat)\n",
"\n",
" if len(zeros) > 100 and len(ones) > 100:\n",
" break \n",
" self.zeros, self.ones = torch.cat(zeros[:100], dim=0), torch.cat(ones[:100], dim=0)\n",
"\n",
" def computeHilo(self):\n",
" import numpy\n",
" import torch\n",
" \n",
" with torch.no_grad():\n",
" quantile_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=10000, shuffle=True)\n",
" for bno, (images, labels) in enumerate(quantile_loader):\n",
" flat = images.view(images.shape[0], -1)\n",
" hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n",
" break\n",
" \n",
" return hilo, hilo\n",
" \n",
" def trainIterator(self):\n",
" import torch\n",
" \n",
" train_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)\n",
"\n",
" for images, labels in train_loader:\n",
" yield self.MyBatch(images, labels, self.zeros, self.ones, self.decodability)\n",
" \n",
" class MyBatch(Batch):\n",
" def __init__(self, images, labels, zeros, ones, decodability):\n",
" super().__init__()\n",
"\n",
" self.images = images\n",
" self.labels = labels\n",
" self.zeros = zeros\n",
" self.ones = ones\n",
" self.decodability = decodability\n",
" \n",
" def getContext(self):\n",
" return self.images.view(self.images.shape[0], -1)\n",
" \n",
" # feedback is an image of 1 if correct else 0\n",
" def getFeedback(self, action):\n",
" import torch\n",
" with torch.no_grad():\n",
" reward = (action == self.labels.unsqueeze(1)).float()\n",
" zerossample = torch.randint(low=0, high=self.zeros.shape[0], size=(action.shape[0], 1))\n",
" goodfeedbacks = torch.gather(input=self.zeros, index=zerossample.expand(-1, self.zeros.shape[1]), dim=0)\n",
" onessample = torch.randint(low=0, high=self.ones.shape[0], size=(action.shape[0], 1))\n",
" badfeedbacks = torch.gather(input=self.ones, index=onessample.expand(-1, self.ones.shape[1]), dim=0)\n",
" noise = torch.rand(size=(action.shape[0], 1), device=action.device)\n",
" shouldflip = (noise <= ((1 + self.decodability)/2)).long()\n",
" noisyreward = reward + shouldflip * (1 - 2 * reward)\n",
" feedback = badfeedbacks + noisyreward * (goodfeedbacks - badfeedbacks)\n",
" \n",
" if False:\n",
" import matplotlib.pyplot as plt\n",
"\n",
" fig, axs = plt.subplots(1, 10)\n",
" for n, (s, f) in enumerate(zip(self.labels, goodfeedbacks)):\n",
" if n > 9:\n",
" break\n",
" axs[n].imshow(f.reshape(28, 28))\n",
" axs[n].set_title(f'{s.item()}')\n",
"\n",
" plt.show()\n",
"\n",
" fig, axs = plt.subplots(1, 10)\n",
" for n, (s, f) in enumerate(zip(self.labels, badfeedbacks)):\n",
" if n > 9:\n",
" break\n",
" axs[n].imshow(f.reshape(28, 28))\n",
" axs[n].set_title(f'{s.item()}')\n",
"\n",
" plt.show()\n",
" \n",
" fig, axs = plt.subplots(1, 10)\n",
" for n, (a, r, nr, f) in enumerate(zip(action, reward, noisyreward, feedback)):\n",
" if n > 9:\n",
" break\n",
" axs[n].imshow(f.reshape(28, 28))\n",
" axs[n].set_title(f'{a.item()} {r.long().item()} {nr.long().item()}')\n",
"\n",
" plt.show()\n",
" assert False\n",
" \n",
" return feedback\n",
" \n",
" def getReward(self, action):\n",
" import torch\n",
" with torch.no_grad():\n",
" reward = (action == self.labels.unsqueeze(1)).float()\n",
" return torch.mean(reward)\n",
"\n",
"class MnistActionCI(MnistSimulator):\n",
" def __init__(self, *, batch_size, decodability):\n",
" super().__init__(batch_size)\n",
" self.decodability = decodability\n",
" \n",
" self._makeFeedbacks()\n",
" \n",
" def _makeFeedbacks(self):\n",
" import torch\n",
" feedback_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=1, shuffle=True)\n",
" feedbacks = [ [] for _ in range(10) ]\n",
" for bno, (images, labels) in enumerate(feedback_loader):\n",
" flat = images.reshape(images.shape[0], -1)\n",
" feedbacks[labels[0]].append(flat)\n",
" if all(len(x) > 100 for x in feedbacks):\n",
" break \n",
" self.feedbacks = torch.cat([ torch.cat(x[:100], dim=0).unsqueeze(0) for x in feedbacks ], dim=0)\n",
"\n",
" def computeHilo(self):\n",
" import numpy\n",
" import torch\n",
" \n",
" with torch.no_grad():\n",
" quantile_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=10000, shuffle=True)\n",
" for bno, (images, labels) in enumerate(quantile_loader):\n",
" flat = images.view(images.shape[0], -1)\n",
" hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n",
" break\n",
" \n",
" return hilo, hilo\n",
" \n",
" def trainIterator(self):\n",
" import torch\n",
" \n",
" train_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)\n",
"\n",
" for images, labels in train_loader:\n",
" yield self.MyBatch(images, labels, self.feedbacks, self.decodability)\n",
" \n",
" class MyBatch(Batch):\n",
" def __init__(self, images, labels, feedbacks, decodability):\n",
" super().__init__()\n",
"\n",
" self.images = images\n",
" self.labels = labels\n",
" self.feedbacks = feedbacks\n",
" self.decodability = decodability\n",
" \n",
" def getContext(self):\n",
" return self.images.view(self.images.shape[0], -1)\n",
" \n",
" # feedback is an image of (x + 1) % 10 if correct else (x - 1) % 10\n",
" def getFeedback(self, action):\n",
" import torch\n",
" with torch.no_grad():\n",
" reward = (action == self.labels.unsqueeze(1)).float()\n",
" pixels = self.getContext().shape[1]\n",
" \n",
" # this assumes a particular majorization (Torch tensors are row-major)\n",
" bigfeedbacks = self.feedbacks.unsqueeze(0).expand(action.shape[0], -1, -1, -1).reshape(action.shape[0], -1, pixels) # Batch x (A x Rep) x Pixels\n",
" goodwhich = self.feedbacks.shape[1] * torch.remainder(self.labels + 1, 10) + torch.randint(low=0, high=self.feedbacks.shape[1], size=(action.shape[0],))\n",
" goodwhich = goodwhich.unsqueeze(1).unsqueeze(2).expand(-1, -1, pixels)\n",
" goodfeedbacks = torch.gather(input=bigfeedbacks, index=goodwhich, dim=1).squeeze(1)\n",
" badwhich = self.feedbacks.shape[1] * torch.remainder(self.labels - 1, 10) + torch.randint(low=0, high=self.feedbacks.shape[1], size=(action.shape[0],))\n",
" badwhich = badwhich.unsqueeze(1).unsqueeze(2).expand(-1, -1, pixels)\n",
" badfeedbacks = torch.gather(input=bigfeedbacks, index=badwhich, dim=1).squeeze(1)\n",
" \n",
" noise = torch.rand(size=(action.shape[0], 1), device=action.device)\n",
" shouldflip = (noise <= ((1 + self.decodability)/2)).long()\n",
" noisyreward = reward + shouldflip * (1 - 2 * reward)\n",
" feedback = badfeedbacks + noisyreward * (goodfeedbacks - badfeedbacks)\n",
" \n",
" if False:\n",
" import matplotlib.pyplot as plt\n",
" \n",
" fig, axs = plt.subplots(1, 10)\n",
" for n, (s, f) in enumerate(zip(self.labels, goodfeedbacks)):\n",
" if n > 9:\n",
" break\n",
" axs[n].imshow(f.reshape(28, 28))\n",
" axs[n].set_title(f'{s.item()}')\n",
"\n",
" plt.show()\n",
"\n",
" fig, axs = plt.subplots(1, 10)\n",
" for n, (s, f) in enumerate(zip(self.labels, badfeedbacks)):\n",
" if n > 9:\n",
" break\n",
" axs[n].imshow(f.reshape(28, 28))\n",
" axs[n].set_title(f'{s.item()}')\n",
"\n",
" plt.show()\n",
" \n",
" fig, axs = plt.subplots(1, 10)\n",
" for n, (a, r, nr, f) in enumerate(zip(action, reward, noisyreward, feedback)):\n",
" if n > 9:\n",
" break\n",
" axs[n].imshow(f.reshape(28, 28))\n",
" axs[n].set_title(f'{a.item()} {r.long().item()} {nr.long().item()}')\n",
"\n",
" plt.show()\n",
" assert False\n",
" \n",
" return feedback\n",
" \n",
" def getReward(self, action):\n",
" import torch\n",
" with torch.no_grad():\n",
" reward = (action == self.labels.unsqueeze(1)).float()\n",
" return torch.mean(reward)\n",
"\n",
"class MnistContextCI(MnistSimulator):\n",
" def __init__(self, *, batch_size, decodability):\n",
" super().__init__(batch_size)\n",
" self.decodability = decodability\n",
" \n",
" self._makeFeedbacks()\n",
" \n",
" def _makeFeedbacks(self):\n",
" import torch\n",
" feedback_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=1, shuffle=True)\n",
" feedbacks = [ [] for _ in range(10) ]\n",
" for bno, (images, labels) in enumerate(feedback_loader):\n",
" flat = images.reshape(images.shape[0], -1)\n",
" feedbacks[labels[0]].append(flat)\n",
" if all(len(x) > 100 for x in feedbacks):\n",
" break \n",
" self.feedbacks = torch.cat([ torch.cat(x[:100], dim=0).unsqueeze(0) for x in feedbacks ], dim=0)\n",
"\n",
" def computeHilo(self):\n",
" import numpy\n",
" import torch\n",
" \n",
" with torch.no_grad():\n",
" quantile_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=10000, shuffle=True)\n",
" for bno, (images, labels) in enumerate(quantile_loader):\n",
" flat = images.view(images.shape[0], -1)\n",
" hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n",
" break\n",
" \n",
" return hilo, hilo\n",
" \n",
" def trainIterator(self):\n",
" import torch\n",
" \n",
" train_loader = torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)\n",
"\n",
" for images, labels in train_loader:\n",
" yield self.MyBatch(images, labels, self.feedbacks, self.decodability)\n",
" \n",
" class MyBatch(Batch):\n",
" def __init__(self, images, labels, feedbacks, decodability):\n",
" super().__init__()\n",
"\n",
" self.images = images\n",
" self.labels = labels\n",
" self.feedbacks = feedbacks\n",
" self.decodability = decodability\n",
" \n",
" def getContext(self):\n",
" return self.images.view(self.images.shape[0], -1)\n",
" \n",
" # feedback is an image of (a + 1) % 10 if correct else (a - 1) % 10\n",
" def getFeedback(self, action):\n",
" import torch\n",
" with torch.no_grad():\n",
" reward = (action == self.labels.unsqueeze(1)).float()\n",
" pixels = self.getContext().shape[1]\n",
" \n",
" # this assumes a particular majorization (Torch tensors are row-major)\n",
" shortaction = action.squeeze(1)\n",
" bigfeedbacks = self.feedbacks.unsqueeze(0).expand(action.shape[0], -1, -1, -1).reshape(action.shape[0], -1, pixels) # Batch x (A x Rep) x Pixels\n",
" goodwhich = self.feedbacks.shape[1] * torch.remainder(shortaction + 1, 10) + torch.randint(low=0, high=self.feedbacks.shape[1], size=(action.shape[0],))\n",
" goodwhich = goodwhich.unsqueeze(1).unsqueeze(2).expand(-1, -1, pixels)\n",
" goodfeedbacks = torch.gather(input=bigfeedbacks, index=goodwhich, dim=1).squeeze(1)\n",
" badwhich = self.feedbacks.shape[1] * torch.remainder(shortaction - 1, 10) + torch.randint(low=0, high=self.feedbacks.shape[1], size=(action.shape[0],))\n",
" badwhich = badwhich.unsqueeze(1).unsqueeze(2).expand(-1, -1, pixels)\n",
" badfeedbacks = torch.gather(input=bigfeedbacks, index=badwhich, dim=1).squeeze(1)\n",
" \n",
" noise = torch.rand(size=(action.shape[0], 1), device=action.device)\n",
" shouldflip = (noise <= ((1 + self.decodability)/2)).long()\n",
" noisyreward = reward + shouldflip * (1 - 2 * reward)\n",
" feedback = badfeedbacks + noisyreward * (goodfeedbacks - badfeedbacks)\n",
" \n",
" if False:\n",
" import matplotlib.pyplot as plt\n",
" \n",
" fig, axs = plt.subplots(1, 10)\n",
" for n, (s, f) in enumerate(zip(shortaction, goodfeedbacks)):\n",
" if n > 9:\n",
" break\n",
" axs[n].imshow(f.reshape(28, 28))\n",
" axs[n].set_title(f'{s.item()}')\n",
"\n",
" plt.show()\n",
"\n",
" fig, axs = plt.subplots(1, 10)\n",
" for n, (s, f) in enumerate(zip(shortaction, badfeedbacks)):\n",
" if n > 9:\n",
" break\n",
" axs[n].imshow(f.reshape(28, 28))\n",
" axs[n].set_title(f'{s.item()}')\n",
"\n",
" plt.show()\n",
" \n",
" fig, axs = plt.subplots(1, 10)\n",
" for n, (a, r, nr, f) in enumerate(zip(action, reward, noisyreward, feedback)):\n",
" if n > 9:\n",
" break\n",
" axs[n].imshow(f.reshape(28, 28))\n",
" axs[n].set_title(f'{a.item()} {r.long().item()} {nr.long().item()}')\n",
"\n",
" plt.show()\n",
" assert False\n",
" \n",
" return feedback\n",
" \n",
" def getReward(self, action):\n",
" import torch\n",
" with torch.no_grad():\n",
" reward = (action == self.labels.unsqueeze(1)).float()\n",
" return torch.mean(reward)\n",
"\n",
"class Algorithm(object):\n",
" def __init__(self):\n",
" super().__init__()\n",
" \n",
" @abstractmethod\n",
" def sample(self, x):\n",
" pass\n",
" \n",
" @abstractmethod\n",
" def greedy(self, x):\n",
" pass\n",
" \n",
" @abstractmethod\n",
" def update(self, sample, feedback):\n",
" pass\n",
"\n",
"class Util(object):\n",
" import torch\n",
" \n",
" def __init__(self):\n",
" super().__init__()\n",
" \n",
" class Sample(object):\n",
" def __init__(self, x, action, probs):\n",
" super().__init__()\n",
" self.x = x\n",
" self.action = action\n",
" self.probs = probs\n",
"\n",
" def getAction(self):\n",
" return self.action\n",
" \n",
" class RFFSoftmax(torch.nn.Module):\n",
" def __init__(self, hilo, naction, numrff, sigma):\n",
" from math import pi\n",
" import numpy as np\n",
" import torch\n",
"\n",
" super().__init__()\n",
"\n",
" nobs = hilo.shape[1]\n",
" high = hilo[1, :]\n",
" low = hilo[0, :]\n",
"\n",
" self.rff = torch.nn.Linear(nobs, numrff)\n",
" self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n",
" torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n",
" self.rff.weight.requires_grad = False\n",
" self.rff.bias.data = 2 * pi * torch.rand(numrff)\n",
" self.rff.bias.requires_grad = False\n",
" self.sqrtrff = np.sqrt(numrff)\n",
" self.final = torch.nn.Linear(numrff, naction)\n",
" self.final.weight.data *= 0\n",
" self.final.bias.data *= 0\n",
" self.sigmoid = torch.nn.Sigmoid()\n",
"\n",
" def forward(self, x):\n",
" import torch\n",
" with torch.no_grad():\n",
" rff = self.rff(x).cos() / self.sqrtrff\n",
" return self.final(rff)\n",
" \n",
" def predict(self, logits):\n",
" return self.sigmoid(logits)\n",
"\n",
"class IKAlgorithm(Algorithm):\n",
" import torch\n",
" \n",
" def __init__(self, *, hilo, sampler, lr):\n",
" from math import log\n",
" import itertools\n",
" import numpy\n",
" import torch\n",
" \n",
" super().__init__()\n",
" self.sampler = sampler\n",
" \n",
" util = Util()\n",
" \n",
" self.pi = util.RFFSoftmax(hilo[0], 10, 2000, 0.01)\n",
" doublehilo = numpy.concatenate((hilo[0], hilo[1]), axis=1)\n",
" self.decoder = util.RFFSoftmax(doublehilo, 10, 2000, 0.01)\n",
" self.alpha = 1/3\n",
" self.decoder.final.bias.data.fill_(log(self.alpha / (1 - self.alpha)))\n",
" self.opt = torch.optim.Adam(( p for p in itertools.chain(self.pi.parameters(), \n",
" self.decoder.parameters()) \n",
" if p.requires_grad ), \n",
" lr=lr)\n",
" \n",
" def sample(self, x):\n",
" import torch\n",
" \n",
" with torch.no_grad():\n",
" fhatlogit = self.pi(x)\n",
" fhat = self.pi.predict(fhatlogit)\n",
" sample, probs = self.sampler.sample(fhat, keepdim=True)\n",
" \n",
" return Util().Sample(x, sample, probs)\n",
"\n",
" def greedy(self, x):\n",
" import torch\n",
" \n",
" with torch.no_grad():\n",
" fhatlogit = self.pi(x)\n",
" _, pred = torch.max(fhatlogit, dim=1, keepdim=True)\n",
" _, anti = torch.min(fhatlogit, dim=1, keepdim=True)\n",
" \n",
" return pred, anti\n",
" \n",
" def update(self, sample, feedback):\n",
" import torch\n",
" \n",
" self.opt.zero_grad()\n",
" \n",
" fhatlogit = self.pi(sample.x)\n",
" \n",
" with torch.no_grad():\n",
" probs = sample.probs\n",
" d = torch.nn.functional.one_hot(sample.action.squeeze(1), num_classes=fhatlogit.shape[1]).float()\n",
" dprobs = self.alpha * d + ((1 - self.alpha) / (fhatlogit.shape[1] - 1)) * (1 - d)\n",
" dweights = (dprobs / probs) / torch.mean(dprobs / probs) \n",
" \n",
" dhatlogit = self.decoder(torch.cat((sample.x, feedback), dim=1))\n",
" dhat_log_loss = torch.nn.BCEWithLogitsLoss(weight=dweights)\n",
" dhat_loss = dhat_log_loss(dhatlogit, d)\n",
" \n",
" with torch.no_grad():\n",
" fakereward = self.decoder.sigmoid(dhatlogit)\n",
" loggedfakereward = torch.gather(input=fakereward, index=sample.action, dim=1)\n",
" \n",
" fhat_log_loss = torch.nn.BCEWithLogitsLoss()\n",
" loggedfhatlogit = torch.gather(input=fhatlogit, index=sample.action, dim=1)\n",
" fhat_loss = fhat_log_loss(loggedfhatlogit, loggedfakereward)\n",
"\n",
" loss = dhat_loss + fhat_loss\n",
" loss.backward()\n",
" \n",
" self.opt.step()\n",
" \n",
" return loss.item(), torch.mean(loggedfakereward)\n",
"\n",
" def __str__(self):\n",
" return f'IKAlgorithm(lr={self.opt.defaults[\"lr\"]} sampler={self.sampler})'\n",
"\n",
"class Sampler(object):\n",
" def __init__(self):\n",
" super().__init__()\n",
" \n",
" @abstractmethod\n",
" def sample(fhat, *, keepdim):\n",
" pass\n",
"\n",
"class EpsilonGreedy(Sampler):\n",
" def __init__(self, *, t0):\n",
" super().__init__()\n",
"\n",
" self.t0 = t0\n",
" self.t = 0\n",
"\n",
" def sample(self, fhat, *, keepdim=False):\n",
" import torch\n",
" N, K = fhat.shape\n",
" epsilon = (self.t0 / (self.t0 + self.t))**(1/3)\n",
" self.t += 1\n",
" fhatstar, ahatstar = torch.max(fhat, dim=1, keepdim=True)\n",
" rando = torch.randint(high=K, size=(N, 1), device=fhat.device)\n",
" unif = torch.rand(size=(N, 1), device=fhat.device)\n",
" shouldexplore = (unif <= epsilon).long()\n",
" actions = ahatstar + shouldexplore * (rando - ahatstar)\n",
" phatstar = (1 - epsilon) + epsilon / K\n",
" prando = epsilon / K\n",
" pactions = phatstar + shouldexplore * (prando - phatstar)\n",
" if not keepdim:\n",
" actions = actions.squeeze(1)\n",
" pactions = pactions.squeeze(1)\n",
" return actions, pactions\n",
" \n",
" def __str__(self):\n",
" return f'EpsilonGreedy(t0={self.t0})'\n",
"\n",
"class SquareCB(Sampler):\n",
" def __init__(self, *, t0, gamma0):\n",
" super().__init__()\n",
"\n",
" self.t0 = t0\n",
" self.gamma0 = gamma0\n",
" self.t = 0\n",
"\n",
" def sample(self, fhat, *, keepdim=False):\n",
" import torch\n",
" \n",
" self.t += 1\n",
" gamma = self.gamma0 * ((self.t0 + self.t) / self.t0)**(1/2)\n",
" \n",
" N, K = fhat.shape\n",
" fhatstar, ahatstar = torch.max(fhat, dim=1, keepdim=True)\n",
" probs = 1 / (K + gamma * (fhatstar - fhat))\n",
" #probs = (1 - fhat) / (K * (1 - fhat) + gamma * (fhatstar - fhat))\n",
" psum = torch.sum(probs, dim=1, keepdim=True)\n",
" phatstar = (1 - psum) + torch.gather(input=probs, dim=1, index=ahatstar)\n",
"\n",
" rando = torch.randint(high=K, size=(N, 1), device=fhat.device)\n",
" prando = torch.gather(input=probs, dim=1, index=rando)\n",
" unif = torch.rand(size=(N, 1), device=fhat.device)\n",
" shouldexplore = (unif <= K * prando).long()\n",
" actions = ahatstar + shouldexplore * (rando - ahatstar)\n",
" pactions = phatstar + shouldexplore * (prando - phatstar)\n",
" if not keepdim:\n",
" actions = actions.squeeze(1)\n",
" pactions = pactions.squeeze(1)\n",
" return actions, pactions\n",
" \n",
" def __str__(self):\n",
" return f'SquareCB(gamma0={self.gamma0} t0={self.t0})'\n",
"\n",
"def run_sim_helper(*, passes, simulator, algorithm):\n",
" import itertools\n",
" \n",
" class EasyAcc:\n",
" def __init__(self):\n",
" self.n = 0\n",
" self.sum = 0\n",
" self.sumsq = 0\n",
"\n",
" def __iadd__(self, other):\n",
" self.n += 1\n",
" self.sum += other\n",
" self.sumsq += other*other\n",
" return self\n",
"\n",
" def __isub__(self, other):\n",
" self.n += 1\n",
" self.sum -= other\n",
" self.sumsq += other*other\n",
" return self\n",
"\n",
" def mean(self):\n",
" return self.sum / max(self.n, 1)\n",
"\n",
" def var(self):\n",
" from math import sqrt\n",
" return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)\n",
"\n",
" def semean(self):\n",
" from math import sqrt\n",
" return self.var() / sqrt(max(self.n, 1))\n",
"\n",
" print('{:<5s}\\t{:<8s} {:<8s}\\t{:<8s} {:<8s}\\t{:<8s} {:<8s}\\t{:<8s} {:<8s}\\t{:<8s} {:<8s}'.format(\n",
" 'bno', \n",
" 'loss', 'since', \n",
" 'pred', 'since',\n",
" 'anti', 'since',\n",
" 'reward', 'since',\n",
" 'fake', 'since',\n",
" ),\n",
" flush=True)\n",
" \n",
" avloss, avreward, avfakereward, avpred, avanti = [ EasyAcc() for _ in range(5) ]\n",
" avlosssl, avrewardsl, avfakerewardsl, avpredsl, avantisl = [ EasyAcc() for _ in range(5) ]\n",
" \n",
" for bno, batch in enumerate(itertools.chain(*[ simulator.trainIterator() for _ in range(passes) ])):\n",
" x = batch.getContext()\n",
" sample = algorithm.sample(x)\n",
" feedback = batch.getFeedback(sample.getAction())\n",
" loss, fakereward = algorithm.update(sample, feedback)\n",
" \n",
" avloss += loss\n",
" avlosssl += loss\n",
"\n",
" avfakereward += fakereward\n",
" avfakerewardsl += fakereward\n",
" \n",
" reward = batch.getReward(sample.getAction())\n",
" pred, anti = algorithm.greedy(x)\n",
" predreward = batch.getReward(pred)\n",
" antireward = batch.getReward(anti)\n",
" \n",
" avreward += reward\n",
" avrewardsl += reward\n",
" avpred += predreward\n",
" avpredsl += predreward \n",
" avanti += antireward\n",
" avantisl += antireward\n",
" \n",
" if (bno & (bno - 1) == 0):\n",
" print('{:<5d}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}'.format(\n",
" bno, \n",
" avloss.mean(), avlosssl.mean(),\n",
" avpred.mean(), avpredsl.mean(),\n",
" avanti.mean(), avantisl.mean(),\n",
" avreward.mean(), avrewardsl.mean(),\n",
" avfakereward.mean(), avfakerewardsl.mean(),\n",
" ),\n",
" flush=True)\n",
" avlosssl, avrewardsl, avfakerewardsl, avpredsl, avantisl = [ EasyAcc() for _ in range(5) ]\n",
"\n",
" print('{:<5d}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}\\t{:<8.5f} {:<8.5f}'.format(\n",
" bno, \n",
" avloss.mean(), avlosssl.mean(),\n",
" avpred.mean(), avpredsl.mean(),\n",
" avanti.mean(), avantisl.mean(),\n",
" avreward.mean(), avrewardsl.mean(),\n",
" avfakereward.mean(), avfakerewardsl.mean(),\n",
" ),\n",
" flush=True)"
]
},
{
"cell_type": "markdown",
"id": "d1a9fb44",
"metadata": {},
"source": [
"# Mnist Action CI $(y_a \\perp a | r_a, x)$\n",
"Feedback is $(x + 1) \\mod 10$ if correct else $(x - 1) \\mod 10$."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "fd88f658",
"metadata": {
"code_folding": [
0
]
},
"outputs": [],
"source": [
"def run_mnist_action_ci(*, seed, alglambda):\n",
" import torch\n",
" print(f'***** seed = {seed} *****')\n",
" torch.manual_seed(seed)\n",
" hilo = MnistActionCI(batch_size=64, decodability=1).computeHilo()\n",
" \n",
" for decodability in (1, -1, 0.5, -0.5,):\n",
" import copy\n",
" torch.manual_seed(seed)\n",
" sim = MnistActionCI(batch_size=64, decodability=decodability)\n",
" print(f'***** decodability = {decodability} *****')\n",
" with torch.random.fork_rng():\n",
" alg = alglambda(hilo)\n",
" print(f'***** alg = {alg} *****')\n",
" with torch.random.fork_rng():\n",
" run_sim_helper(passes=1, simulator=copy.deepcopy(sim), algorithm=alg)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "5e596247",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"***** seed = 13 *****\n",
"***** decodability = 1 *****\n",
"***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n",
"bno \tloss since \tpred since \tanti since \treward since \tfake since \n",
"0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n",
"1 \t1.30444 1.27921 \t0.08594 0.06250 \t0.10938 0.12500 \t0.11719 0.12500 \t0.33463 0.33592 \n",
"2 \t1.30393 1.30293 \t0.06771 0.03125 \t0.11979 0.14062 \t0.08854 0.03125 \t0.33477 0.33506 \n",
"4 \t1.30399 1.30408 \t0.06562 0.06250 \t0.11250 0.10156 \t0.09687 0.10938 \t0.33652 0.33914 \n",
"8 \t1.29672 1.28764 \t0.08160 0.10156 \t0.09549 0.07422 \t0.10590 0.11719 \t0.34312 0.35137 \n",
"16 \t1.29588 1.29494 \t0.11121 0.14453 \t0.09099 0.08594 \t0.11029 0.11523 \t0.35378 0.36579 \n",
"32 \t1.29838 1.30103 \t0.15436 0.20020 \t0.08523 0.07910 \t0.13968 0.17090 \t0.38594 0.42010 \n",
"64 \t1.30807 1.31806 \t0.17596 0.19824 \t0.07067 0.05566 \t0.16226 0.18555 \t0.41429 0.44352 \n",
"128 \t1.30980 1.31156 \t0.20591 0.23633 \t0.06202 0.05322 \t0.18302 0.20410 \t0.43660 0.45927 \n",
"256 \t1.31787 1.32600 \t0.26204 0.31860 \t0.05052 0.03894 \t0.22781 0.27295 \t0.45920 0.48197 \n",
"512 \t1.32409 1.33034 \t0.30431 0.34674 \t0.04173 0.03290 \t0.26687 0.30609 \t0.49435 0.52965 \n",
"937 \t1.33276 1.34322 \t0.39582 0.50629 \t0.03077 0.01754 \t0.35419 0.45960 \t0.52439 0.56064 \n",
"***** decodability = -1 *****\n",
"***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n",
"bno \tloss since \tpred since \tanti since \treward since \tfake since \n",
"0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n",
"1 \t1.30487 1.28007 \t0.09375 0.07812 \t0.11719 0.14062 \t0.11719 0.12500 \t0.33465 0.33597 \n",
"2 \t1.30430 1.30317 \t0.07812 0.04688 \t0.11979 0.12500 \t0.08854 0.03125 \t0.33357 0.33142 \n",
"4 \t1.30306 1.30120 \t0.11250 0.16406 \t0.11250 0.10156 \t0.10000 0.11719 \t0.33489 0.33687 \n",
"8 \t1.29671 1.28877 \t0.10243 0.08984 \t0.09375 0.07031 \t0.11458 0.13281 \t0.34906 0.36677 \n",
"16 \t1.29965 1.30295 \t0.11581 0.13086 \t0.09007 0.08594 \t0.11857 0.12305 \t0.38112 0.41718 \n",
"32 \t1.30370 1.30800 \t0.14205 0.16992 \t0.08002 0.06934 \t0.13258 0.14746 \t0.39104 0.40158 \n",
"64 \t1.30846 1.31338 \t0.15986 0.17822 \t0.06731 0.05420 \t0.14543 0.15869 \t0.40492 0.41923 \n",
"128 \t1.31436 1.32034 \t0.19077 0.22217 \t0.05790 0.04834 \t0.16982 0.19458 \t0.43110 0.45769 \n",
"256 \t1.32377 1.33325 \t0.21948 0.24841 \t0.05235 0.04675 \t0.19303 0.21643 \t0.45743 0.48396 \n",
"512 \t1.33082 1.33791 \t0.29898 0.37878 \t0.04075 0.02911 \t0.26495 0.33716 \t0.48782 0.51834 \n",
"937 \t1.33090 1.33100 \t0.39592 0.51294 \t0.03305 0.02375 \t0.35553 0.46485 \t0.52269 0.56478 \n",
"***** decodability = 0.5 *****\n",
"***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n",
"bno \tloss since \tpred since \tanti since \treward since \tfake since \n",
"0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n",
"1 \t1.30485 1.28004 \t0.09375 0.07812 \t0.11719 0.14062 \t0.11719 0.12500 \t0.33435 0.33537 \n",
"2 \t1.30426 1.30308 \t0.07812 0.04688 \t0.13542 0.17188 \t0.08854 0.03125 \t0.33383 0.33279 \n",
"4 \t1.30349 1.30235 \t0.08438 0.09375 \t0.12188 0.10156 \t0.09687 0.10938 \t0.33555 0.33813 \n",
"8 \t1.29700 1.28888 \t0.09201 0.10156 \t0.09896 0.07031 \t0.10590 0.11719 \t0.34245 0.35107 \n",
"16 \t1.29599 1.29486 \t0.10662 0.12305 \t0.10386 0.10938 \t0.10938 0.11328 \t0.35196 0.36266 \n",
"32 \t1.30241 1.30924 \t0.13068 0.15625 \t0.09991 0.09570 \t0.12926 0.15039 \t0.37391 0.39723 \n",
"64 \t1.31644 1.33092 \t0.11827 0.10547 \t0.08702 0.07373 \t0.12043 0.11133 \t0.39610 0.41900 \n",
"128 \t1.32278 1.32921 \t0.12718 0.13623 \t0.08006 0.07300 \t0.12476 0.12915 \t0.41618 0.43656 \n",
"256 \t1.33804 1.35343 \t0.12239 0.11755 \t0.07551 0.07092 \t0.11977 0.11475 \t0.43168 0.44731 \n",
"512 \t1.35660 1.37523 \t0.12601 0.12964 \t0.07657 0.07764 \t0.12031 0.12085 \t0.45100 0.47039 \n",
"937 \t1.37982 1.40785 \t0.14880 0.17632 \t0.06935 0.06063 \t0.14169 0.16750 \t0.45475 0.45927 \n",
"***** decodability = -0.5 *****\n",
"***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n",
"bno \tloss since \tpred since \tanti since \treward since \tfake since \n",
"0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n",
"1 \t1.30532 1.28098 \t0.10156 0.09375 \t0.11719 0.14062 \t0.11719 0.12500 \t0.33389 0.33444 \n",
"2 \t1.30427 1.30217 \t0.08333 0.04688 \t0.11458 0.10938 \t0.08854 0.03125 \t0.33239 0.32940 \n",
"4 \t1.30410 1.30385 \t0.08438 0.08594 \t0.10938 0.10156 \t0.09375 0.10156 \t0.33464 0.33801 \n",
"8 \t1.29698 1.28807 \t0.09201 0.10156 \t0.09201 0.07031 \t0.10243 0.11328 \t0.34187 0.35091 \n",
"16 \t1.30001 1.30342 \t0.10662 0.12305 \t0.08272 0.07227 \t0.10754 0.11328 \t0.36231 0.38530 \n",
"32 \t1.30547 1.31126 \t0.13210 0.15918 \t0.07812 0.07324 \t0.12311 0.13965 \t0.37633 0.39124 \n",
"64 \t1.31463 1.32407 \t0.12476 0.11719 \t0.07548 0.07275 \t0.12308 0.12305 \t0.39423 0.41269 \n",
"128 \t1.32324 1.33199 \t0.12270 0.12061 \t0.08031 0.08521 \t0.11664 0.11011 \t0.41592 0.43795 \n",
"256 \t1.33952 1.35593 \t0.12044 0.11816 \t0.08208 0.08386 \t0.11278 0.10889 \t0.42968 0.44354 \n",
"512 \t1.35598 1.37249 \t0.11882 0.11719 \t0.07782 0.07355 \t0.11355 0.11432 \t0.44759 0.46556 \n",
"937 \t1.37896 1.40669 \t0.13551 0.15566 \t0.07348 0.06824 \t0.12883 0.14728 \t0.45464 0.46314 \n"
]
}
],
"source": [
"run_mnist_action_ci(seed=13, alglambda = lambda hilo: IKAlgorithm(hilo=hilo, sampler=EpsilonGreedy(t0=1), lr=5e-2))"
]
},
{
"cell_type": "markdown",
"id": "25aaf5cf",
"metadata": {},
"source": [
"# Mnist Full CI ($y_a \\perp x, a | r_a$)\n",
"Feedback is a \"1\" image if correct or a \"0\" image if incorrect."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "be02f211",
"metadata": {
"code_folding": [
0
]
},
"outputs": [],
"source": [
"def run_mnist_full_ci(*, seed, alglambda):\n",
" import torch\n",
" print(f'***** seed = {seed} *****')\n",
" torch.manual_seed(seed)\n",
" hilo = MnistFullCI(batch_size=64, decodability=1).computeHilo()\n",
" \n",
" for decodability in (1, -1, 0.5, -0.5,):\n",
" import copy\n",
" torch.manual_seed(seed)\n",
" sim = MnistFullCI(batch_size=64, decodability=decodability)\n",
" print(f'***** decodability = {decodability} *****')\n",
" with torch.random.fork_rng():\n",
" alg = alglambda(hilo)\n",
" print(f'***** alg = {alg} *****')\n",
" with torch.random.fork_rng():\n",
" run_sim_helper(passes=1, simulator=copy.deepcopy(sim), algorithm=alg)"
]
},
{
"cell_type": "markdown",
"id": "ce18fc54",
"metadata": {
"heading_collapsed": true
},
"source": [
"## IK, Epsilon-Greedy"
]
},
{
"cell_type": "code",
"execution_count": 120,
"id": "12c6fc3c",
"metadata": {
"hidden": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"***** seed = 13 *****\n",
"***** decodability = 1 *****\n",
"***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n",
"bno \tloss since \tpred since \tanti since \treward since \tfake since \n",
"0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n",
"1 \t1.30507 1.28049 \t0.09375 0.07812 \t0.12500 0.15625 \t0.11719 0.12500 \t0.33461 0.33589 \n",
"2 \t1.30434 1.30288 \t0.08333 0.06250 \t0.13021 0.14062 \t0.09375 0.04688 \t0.33435 0.33383 \n",
"4 \t1.30183 1.29805 \t0.08750 0.09375 \t0.11875 0.10156 \t0.10000 0.10938 \t0.33558 0.33743 \n",
"8 \t1.29784 1.29286 \t0.08854 0.08984 \t0.09722 0.07031 \t0.10590 0.11328 \t0.34345 0.35328 \n",
"16 \t1.30046 1.30342 \t0.13695 0.19141 \t0.09743 0.09766 \t0.12040 0.13672 \t0.35886 0.37619 \n",
"32 \t1.30231 1.30427 \t0.16998 0.20508 \t0.09470 0.09180 \t0.14867 0.17871 \t0.37876 0.39991 \n",
"64 \t1.31192 1.32184 \t0.21274 0.25684 \t0.07933 0.06348 \t0.18293 0.21826 \t0.40531 0.43269 \n",
"128 \t1.31666 1.32146 \t0.25787 0.30371 \t0.06831 0.05713 \t0.22081 0.25928 \t0.43104 0.45718 \n",
"256 \t1.32381 1.33101 \t0.29754 0.33752 \t0.05952 0.05066 \t0.25517 0.28979 \t0.45657 0.48229 \n",
"512 \t1.32461 1.32541 \t0.36528 0.43329 \t0.05172 0.04388 \t0.32002 0.38513 \t0.49113 0.52582 \n",
"937 \t1.32471 1.32484 \t0.45266 0.55813 \t0.03711 0.01949 \t0.40302 0.50320 \t0.52373 0.56309 \n",
"***** decodability = -1 *****\n",
"***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n",
"bno \tloss since \tpred since \tanti since \treward since \tfake since \n",
"0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n",
"1 \t1.30546 1.28126 \t0.07812 0.04688 \t0.11719 0.14062 \t0.11719 0.12500 \t0.33346 0.33360 \n",
"2 \t1.30742 1.31135 \t0.08854 0.10938 \t0.11979 0.12500 \t0.09375 0.04688 \t0.33497 0.33798 \n",
"4 \t1.30894 1.31121 \t0.07500 0.05469 \t0.10938 0.09375 \t0.09687 0.10156 \t0.34675 0.36442 \n",
"8 \t1.30966 1.31057 \t0.09549 0.12109 \t0.09375 0.07422 \t0.10417 0.11328 \t0.35984 0.37621 \n",
"16 \t1.30897 1.30820 \t0.10570 0.11719 \t0.09835 0.10352 \t0.11121 0.11914 \t0.35881 0.35765 \n",
"32 \t1.31093 1.31301 \t0.09991 0.09375 \t0.10701 0.11621 \t0.10511 0.09863 \t0.39271 0.42872 \n",
"64 \t1.32061 1.33060 \t0.10577 0.11182 \t0.12812 0.14990 \t0.11322 0.12158 \t0.41979 0.44773 \n",
"128 \t1.32243 1.32428 \t0.09460 0.08325 \t0.13227 0.13647 \t0.09629 0.07910 \t0.44220 0.46496 \n",
"256 \t1.33451 1.34668 \t0.10372 0.11292 \t0.12099 0.10962 \t0.10044 0.10461 \t0.46115 0.48024 \n",
"512 \t1.34569 1.35691 \t0.12850 0.15338 \t0.10304 0.08502 \t0.12141 0.14246 \t0.48409 0.50712 \n",
"937 \t1.36260 1.38302 \t0.18943 0.26298 \t0.09245 0.07967 \t0.17769 0.24563 \t0.50181 0.52320 \n",
"***** decodability = 0.5 *****\n",
"***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n",
"bno \tloss since \tpred since \tanti since \treward since \tfake since \n",
"0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n",
"1 \t1.30474 1.27983 \t0.09375 0.07812 \t0.12500 0.15625 \t0.11719 0.12500 \t0.33459 0.33585 \n",
"2 \t1.30399 1.30249 \t0.07292 0.03125 \t0.13542 0.15625 \t0.08854 0.03125 \t0.33390 0.33250 \n",
"4 \t1.30216 1.29940 \t0.08750 0.10938 \t0.12188 0.10156 \t0.10000 0.11719 \t0.33365 0.33329 \n",
"8 \t1.29573 1.28770 \t0.09722 0.10938 \t0.10069 0.07422 \t0.10243 0.10547 \t0.34649 0.36253 \n",
"16 \t1.30262 1.31038 \t0.12776 0.16211 \t0.09743 0.09375 \t0.12040 0.14062 \t0.36855 0.39337 \n",
"32 \t1.30842 1.31457 \t0.14015 0.15332 \t0.10275 0.10840 \t0.12879 0.13770 \t0.37469 0.38122 \n",
"64 \t1.31988 1.33170 \t0.13534 0.13037 \t0.08702 0.07080 \t0.13197 0.13525 \t0.39178 0.40940 \n",
"128 \t1.32899 1.33825 \t0.15540 0.17578 \t0.07982 0.07251 \t0.14462 0.15747 \t0.40806 0.42460 \n",
"256 \t1.34702 1.36519 \t0.14105 0.12659 \t0.08554 0.09131 \t0.13132 0.11792 \t0.42628 0.44465 \n",
"512 \t1.36833 1.38972 \t0.15278 0.16455 \t0.07392 0.06226 \t0.14215 0.15302 \t0.44862 0.47105 \n",
"937 \t1.39125 1.41892 \t0.17161 0.19434 \t0.06875 0.06250 \t0.16081 0.18335 \t0.45148 0.45493 \n",
"***** decodability = -0.5 *****\n",
"***** alg = IKAlgorithm(lr=0.05 sampler=EpsilonGreedy(t0=1)) *****\n",
"bno \tloss since \tpred since \tanti since \treward since \tfake since \n",
"0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n",
"1 \t1.30487 1.28009 \t0.09375 0.07812 \t0.10156 0.10938 \t0.11719 0.12500 \t0.33412 0.33491 \n",
"2 \t1.30595 1.30810 \t0.07812 0.04688 \t0.11458 0.14062 \t0.08854 0.03125 \t0.33473 0.33595 \n",
"4 \t1.30764 1.31018 \t0.06250 0.03906 \t0.10938 0.10156 \t0.09375 0.10156 \t0.33889 0.34512 \n",
"8 \t1.30273 1.29660 \t0.08160 0.10547 \t0.09375 0.07422 \t0.10243 0.11328 \t0.34147 0.34469 \n",
"16 \t1.30562 1.30887 \t0.08548 0.08984 \t0.08915 0.08398 \t0.10386 0.10547 \t0.34891 0.35729 \n",
"32 \t1.31029 1.31525 \t0.07955 0.07324 \t0.10417 0.12012 \t0.09659 0.08887 \t0.36639 0.38495 \n",
"64 \t1.32097 1.33199 \t0.07187 0.06396 \t0.11010 0.11621 \t0.08918 0.08154 \t0.40254 0.43981 \n",
"128 \t1.32663 1.33237 \t0.07001 0.06812 \t0.11810 0.12622 \t0.08043 0.07153 \t0.42110 0.43996 \n",
"256 \t1.34333 1.36015 \t0.06980 0.06958 \t0.11570 0.11328 \t0.07496 0.06946 \t0.44035 0.45975 \n",
"512 \t1.36080 1.37834 \t0.08114 0.09253 \t0.11507 0.11444 \t0.08373 0.09253 \t0.45726 0.47424 \n",
"937 \t1.38823 1.42134 \t0.09522 0.11221 \t0.10439 0.09151 \t0.09637 0.11162 \t0.46219 0.46815 \n"
]
}
],
"source": [
"run_mnist_full_ci(seed=13, alglambda = lambda hilo: IKAlgorithm(hilo=hilo, sampler=EpsilonGreedy(t0=1), lr=5e-2))"
]
},
{
"cell_type": "markdown",
"id": "5ffdb723",
"metadata": {
"heading_collapsed": true
},
"source": [
"## IK, SquareCB"
]
},
{
"cell_type": "code",
"execution_count": 101,
"id": "11219b00",
"metadata": {
"hidden": true,
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"***** seed = 13 *****\n",
"***** decodability = 1 *****\n",
"***** alg = IKAlgorithm(lr=0.05 sampler=SquareCB(gamma0=10 t0=10)) *****\n",
"bno \tloss since \tpred since \tanti since \treward since \tfake since \n",
"0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n",
"1 \t1.30494 1.28021 \t0.09375 0.07812 \t0.13281 0.17188 \t0.10938 0.10938 \t0.33776 0.34219 \n",
"2 \t1.30257 1.29784 \t0.08854 0.07812 \t0.14062 0.15625 \t0.08333 0.03125 \t0.33513 0.32985 \n",
"4 \t1.30280 1.30315 \t0.10312 0.12500 \t0.12500 0.10156 \t0.09687 0.11719 \t0.33583 0.33688 \n",
"8 \t1.29567 1.28675 \t0.12153 0.14453 \t0.10069 0.07031 \t0.11458 0.13672 \t0.33917 0.34336 \n",
"16 \t1.29371 1.29151 \t0.16085 0.20508 \t0.10938 0.11914 \t0.12592 0.13867 \t0.34092 0.34287 \n",
"32 \t1.29202 1.29022 \t0.21070 0.26367 \t0.10890 0.10840 \t0.15152 0.17871 \t0.35087 0.36145 \n",
"64 \t1.29390 1.29584 \t0.25697 0.30469 \t0.09207 0.07471 \t0.16779 0.18457 \t0.36396 0.37746 \n",
"128 \t1.29290 1.29189 \t0.34969 0.44385 \t0.06480 0.03711 \t0.21427 0.26147 \t0.38296 0.40226 \n",
"256 \t1.28826 1.28358 \t0.43020 0.51135 \t0.04596 0.02698 \t0.27243 0.33105 \t0.41878 0.45488 \n",
"512 \t1.28455 1.28083 \t0.50433 0.57874 \t0.03326 0.02051 \t0.34957 0.42700 \t0.46282 0.50702 \n",
"937 \t1.27793 1.26993 \t0.57862 0.66831 \t0.02434 0.01357 \t0.43420 0.53636 \t0.50870 0.56408 \n",
"***** decodability = -1 *****\n",
"***** alg = IKAlgorithm(lr=0.05 sampler=SquareCB(gamma0=10 t0=10)) *****\n",
"bno \tloss since \tpred since \tanti since \treward since \tfake since \n",
"0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n",
"1 \t1.30513 1.28060 \t0.08594 0.06250 \t0.10938 0.12500 \t0.10938 0.10938 \t0.33781 0.34229 \n",
"2 \t1.30370 1.30084 \t0.09375 0.10938 \t0.11458 0.12500 \t0.08333 0.03125 \t0.33464 0.32829 \n",
"4 \t1.30378 1.30391 \t0.08125 0.06250 \t0.10938 0.10156 \t0.09687 0.11719 \t0.33765 0.34217 \n",
"8 \t1.29949 1.29413 \t0.09028 0.10156 \t0.09028 0.06641 \t0.10590 0.11719 \t0.34235 0.34822 \n",
"16 \t1.29919 1.29884 \t0.10202 0.11523 \t0.10294 0.11719 \t0.11857 0.13281 \t0.34206 0.34174 \n",
"32 \t1.29828 1.29731 \t0.11932 0.13770 \t0.12074 0.13965 \t0.12358 0.12891 \t0.34872 0.35580 \n",
"64 \t1.29811 1.29794 \t0.12212 0.12500 \t0.13197 0.14355 \t0.11731 0.11084 \t0.36252 0.37675 \n",
"128 \t1.30163 1.30521 \t0.12875 0.13550 \t0.10938 0.08643 \t0.11810 0.11890 \t0.37397 0.38560 \n",
"256 \t1.30894 1.31631 \t0.16786 0.20728 \t0.08810 0.06665 \t0.14014 0.16235 \t0.39465 0.41550 \n",
"512 \t1.31513 1.32135 \t0.24050 0.31342 \t0.06768 0.04718 \t0.19219 0.24445 \t0.42662 0.45871 \n",
"937 \t1.31834 1.32222 \t0.31413 0.40301 \t0.05549 0.04077 \t0.25568 0.33232 \t0.46456 0.51035 \n",
"***** decodability = 0.5 *****\n",
"***** alg = IKAlgorithm(lr=0.05 sampler=SquareCB(gamma0=10 t0=10)) *****\n",
"bno \tloss since \tpred since \tanti since \treward since \tfake since \n",
"0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n",
"1 \t1.30494 1.28022 \t0.09375 0.07812 \t0.14062 0.18750 \t0.10938 0.10938 \t0.33699 0.34065 \n",
"2 \t1.30267 1.29813 \t0.09375 0.09375 \t0.14062 0.14062 \t0.08333 0.03125 \t0.33467 0.33003 \n",
"4 \t1.30326 1.30416 \t0.11250 0.14062 \t0.12500 0.10156 \t0.09687 0.11719 \t0.33453 0.33433 \n",
"8 \t1.29485 1.28433 \t0.12500 0.14062 \t0.10243 0.07422 \t0.11458 0.13672 \t0.34006 0.34697 \n",
"16 \t1.29546 1.29614 \t0.14522 0.16797 \t0.11213 0.12305 \t0.12132 0.12891 \t0.33862 0.33701 \n",
"32 \t1.29664 1.29790 \t0.16146 0.17871 \t0.09517 0.07715 \t0.13494 0.14941 \t0.34379 0.34927 \n",
"64 \t1.29866 1.30075 \t0.15986 0.15820 \t0.08053 0.06543 \t0.12933 0.12354 \t0.34898 0.35433 \n",
"128 \t1.30344 1.30830 \t0.19634 0.23340 \t0.06734 0.05396 \t0.13929 0.14941 \t0.35562 0.36236 \n",
"256 \t1.31659 1.32984 \t0.19163 0.18689 \t0.06463 0.06189 \t0.13831 0.13733 \t0.36600 0.37646 \n",
"512 \t1.33155 1.34657 \t0.21040 0.22925 \t0.05781 0.05096 \t0.15850 0.17877 \t0.37977 0.39359 \n",
"937 \t1.34935 1.37084 \t0.21973 0.23099 \t0.05574 0.05324 \t0.17377 0.19221 \t0.39498 0.41335 \n",
"***** decodability = -0.5 *****\n",
"***** alg = IKAlgorithm(lr=0.05 sampler=SquareCB(gamma0=10 t0=10)) *****\n",
"bno \tloss since \tpred since \tanti since \treward since \tfake since \n",
"0 \t1.32966 1.32966 \t0.10938 0.10938 \t0.09375 0.09375 \t0.10938 0.10938 \t0.33333 0.33333 \n",
"1 \t1.30492 1.28018 \t0.09375 0.07812 \t0.08594 0.07812 \t0.10938 0.10938 \t0.33714 0.34094 \n",
"2 \t1.30358 1.30090 \t0.08333 0.06250 \t0.09896 0.12500 \t0.08333 0.03125 \t0.33515 0.33118 \n",
"4 \t1.30400 1.30463 \t0.08438 0.08594 \t0.10000 0.10156 \t0.10000 0.12500 \t0.33447 0.33346 \n",
"8 \t1.29777 1.28998 \t0.09722 0.11328 \t0.08507 0.06641 \t0.10764 0.11719 \t0.33946 0.34568 \n",
"16 \t1.29687 1.29587 \t0.09743 0.09766 \t0.08548 0.08594 \t0.11489 0.12305 \t0.33790 0.33615 \n",
"32 \t1.29696 1.29706 \t0.11222 0.12793 \t0.09612 0.10742 \t0.12311 0.13184 \t0.34567 0.35394 \n",
"64 \t1.30000 1.30314 \t0.10481 0.09717 \t0.09832 0.10059 \t0.11034 0.09717 \t0.35347 0.36151 \n",
"128 \t1.30487 1.30981 \t0.09448 0.08398 \t0.09278 0.08716 \t0.10514 0.09985 \t0.35846 0.36353 \n",
"256 \t1.31660 1.32842 \t0.10171 0.10901 \t0.09168 0.09058 \t0.10445 0.10376 \t0.37122 0.38409 \n",
"512 \t1.33261 1.34868 \t0.12189 0.14215 \t0.08784 0.08398 \t0.11364 0.12286 \t0.38679 0.40243 \n",
"937 \t1.34934 1.36953 \t0.13471 0.15018 \t0.08136 0.07353 \t0.12328 0.13493 \t0.39942 0.41465 \n"
]
}
],
"source": [
"run_mnist_full_ci(seed=13, alglambda = lambda hilo: IKAlgorithm(hilo=hilo, sampler=SquareCB(gamma0=10, t0=10), lr=5e-2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8c587b93",
"metadata": {
"hidden": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment