Skip to content

Instantly share code, notes, and snippets.

@iczero
Created March 10, 2024 00:21
Show Gist options
  • Save iczero/b778d0a77573d19d4f4e61b3b9a01b84 to your computer and use it in GitHub Desktop.
Save iczero/b778d0a77573d19d4f4e61b3b9a01b84 to your computer and use it in GitHub Desktop.
FizzBuzz in PyTorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "3cc970e3-a4e2-40ae-a4da-db15abcd550c",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import numpy as np\n",
"from tqdm.notebook import trange\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ea5f1ab4",
"metadata": {},
"outputs": [],
"source": [
"torch.set_default_device('cuda')\n",
"model_dtype = torch.float32\n",
"\n",
"FIZZ = 3\n",
"BUZZ = 5"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7b1f6362-de4b-4e09-84dc-7ccff11cd2be",
"metadata": {},
"outputs": [],
"source": [
"class Network(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.stack = nn.Sequential(\n",
" nn.Linear(32, 1024),\n",
" nn.Tanh(),\n",
" nn.Linear(1024, 256),\n",
" nn.Tanh(),\n",
" nn.Linear(256, 3),\n",
" nn.Sigmoid(),\n",
" )\n",
" \n",
" def forward(self, x):\n",
" return self.stack(x)\n",
"\n",
"# currently unused\n",
"def log_loss_old(out: torch.Tensor, expected: torch.Tensor):\n",
" eps = 1e-7\n",
" out = torch.clip(out, eps, 1 - eps)\n",
" errors = -expected * torch.log(out) - (1 - expected) * torch.log(1 - out)\n",
" # sum horizontally (each data row)\n",
" return torch.mean(torch.sum(errors, axis=1))\n",
"\n",
"loss_fn = torch.nn.BCELoss()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "cb24b1e4",
"metadata": {},
"outputs": [],
"source": [
"shift_by = torch.arange(32, dtype=torch.int64)\n",
"mask = 1\n",
"\n",
"def split_nums(nums: torch.Tensor):\n",
" \"split u32 into bits\"\n",
" return nums.unsqueeze(-1).bitwise_right_shift(shift_by).bitwise_and(mask).type(model_dtype) * 20 - 10\n",
"\n",
"def make_data_set(min: int, max: int, count: int | None) -> tuple[torch.Tensor, torch.Tensor]:\n",
" if count is not None:\n",
" nums = torch.randint(min, max, (count,), dtype=torch.int64)\n",
" else:\n",
" nums = torch.arange(min, max, dtype=torch.int64)\n",
" inputs = split_nums(nums)\n",
" out_fizz = nums.remainder(FIZZ).eq(0)\n",
" out_buzz = nums.remainder(BUZZ).eq(0)\n",
" out_none = torch.logical_and(out_fizz.logical_not(), out_buzz.logical_not())\n",
" outputs = torch.stack((out_fizz, out_buzz, out_none), axis=1).type(model_dtype)\n",
" return inputs, outputs"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f4f49008",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10.],\n",
" [ 10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10.],\n",
" [-10., 10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10.],\n",
" [ 10., 10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10.],\n",
" [-10., -10., 10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10.],\n",
" [ 10., -10., 10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10.],\n",
" [-10., 10., 10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10.],\n",
" [ 10., 10., 10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10.],\n",
" [-10., -10., -10., 10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10.],\n",
" [ 10., -10., -10., 10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,\n",
" -10., -10., -10., -10., -10., -10., -10., -10.]], device='cuda:0'),\n",
" tensor([[1., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [1., 0., 0.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.]], device='cuda:0'))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"make_data_set(0, 10, None)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f7b3ff70",
"metadata": {},
"outputs": [],
"source": [
"def make_step(optimizer, batch_size, min, max):\n",
" def step(model):\n",
" train_in, train_out = make_data_set(min, max, batch_size)\n",
" prediction = model(train_in)\n",
" loss = loss_fn(prediction, train_out)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" return loss\n",
" \n",
" return step"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e64deb5e",
"metadata": {},
"outputs": [],
"source": [
"def model_eval(model, input):\n",
" model.eval()\n",
" with torch.no_grad():\n",
" return model.forward(input)\n",
"\n",
"test_threshold = 0.75\n",
"def test_fizzbuzz(model, input, quiet=False):\n",
" out = model_eval(model, split_nums(input))\n",
" for i, val in zip(input, out):\n",
" fizz, buzz, neither = val\n",
" fizzmod = i % FIZZ\n",
" buzzmod = i % BUZZ\n",
" if not quiet:\n",
" print(\n",
" f'{i} fizz: {fizz:.4f}, fizzmod: {fizzmod}, ' +\n",
" f'buzz: {buzz:.4f}, buzzmod: {buzzmod}, ' +\n",
" f'neither: {neither:.4f}')\n",
" if fizzmod == 0 and fizz < test_threshold:\n",
" print(i, 'fizz wrong: expect fizz got', fizz)\n",
" if fizz >= test_threshold and fizzmod != 0:\n",
" print(i, 'fizz wrong: expect not fizz got', fizz)\n",
" if buzzmod == 0 and buzz < test_threshold:\n",
" print(i, 'buzz wrong: expect buzz got', buzz)\n",
" if buzz >= test_threshold and buzzmod != 0:\n",
" print(i, 'buzz wrong: expect not buzz got', buzz)\n",
" if fizzmod != 0 and buzzmod != 0 and neither < test_threshold:\n",
" print(i, 'neither wrong: expect neither got', neither)\n",
" if (fizzmod == 0 or buzzmod == 0) and neither >= test_threshold:\n",
" print(i, 'neither wrong: expect not neither got', neither)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "178e7313",
"metadata": {},
"outputs": [],
"source": [
"def train_stage(model, optimizer, batch_size, iterations, min_val, max_val):\n",
" model.train()\n",
" step = make_step(optimizer, batch_size, min_val, max_val)\n",
" progress = trange(iterations)\n",
" lossplot_v = np.zeros(len(progress))\n",
" for i in progress:\n",
" loss_v = step(model)\n",
" if i % 100 == 0:\n",
" progress.set_postfix({ 'loss': loss_v })\n",
" lossplot_v[i] = float(loss_v)\n",
" if loss_v < 0.001:\n",
" break\n",
"\n",
" lossplot_t = np.arange(progress.n)\n",
" plt.plot(lossplot_t, lossplot_v[:progress.n])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "dc6facba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Network(\n",
" (stack): Sequential(\n",
" (0): Linear(in_features=32, out_features=1024, bias=True)\n",
" (1): Tanh()\n",
" (2): Linear(in_features=1024, out_features=256, bias=True)\n",
" (3): Tanh()\n",
" (4): Linear(in_features=256, out_features=3, bias=True)\n",
" (5): Sigmoid()\n",
" )\n",
")"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = Network()\n",
"display(model)\n",
"def make_optimizer(model):\n",
" return torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.001)\n",
"batch_size = 1024"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "ac8334dd",
"metadata": {},
"outputs": [],
"source": [
"# stage 1: train to 16 bits\n",
"#optimizer = make_optimizer(model)\n",
"#train_stage(model, optimizer, batch_size,\n",
"# iterations=16384, min_val=0, max_val=1 << 16)\n",
"#test_fizzbuzz(model, torch.arange(25))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b109fb31",
"metadata": {},
"outputs": [],
"source": [
"# stage 2: train to 24 bits\n",
"#optimizer = make_optimizer(model)\n",
"#train_stage(model, optimizer, batch_size,\n",
"# iterations=32768, min_val=0, max_val=(1 << 24) - 1)\n",
"#test_fizzbuzz(model, torch.tensor([i for i in range(int(1e2), int(1e2) + 25)]))\n",
"#test_fizzbuzz(model, torch.tensor([i for i in range(int(1e6), int(1e6) + 25)]))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "63728fdc",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a8a8b4b487c74908af2f572170236611",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/65536 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"200 fizz: 0.0001, fizzmod: 2, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"201 fizz: 1.0000, fizzmod: 0, buzz: 0.0000, buzzmod: 1, neither: 0.0000\n",
"202 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 2, neither: 0.9999\n",
"203 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 3, neither: 0.9998\n",
"204 fizz: 1.0000, fizzmod: 0, buzz: 0.0001, buzzmod: 4, neither: 0.0000\n",
"205 fizz: 0.0000, fizzmod: 1, buzz: 1.0000, buzzmod: 0, neither: 0.0005\n",
"206 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 1, neither: 1.0000\n",
"207 fizz: 1.0000, fizzmod: 0, buzz: 0.0002, buzzmod: 2, neither: 0.0000\n",
"208 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 3, neither: 0.9999\n",
"209 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 4, neither: 1.0000\n",
"210 fizz: 1.0000, fizzmod: 0, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"211 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 1, neither: 1.0000\n",
"212 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 2, neither: 1.0000\n",
"213 fizz: 1.0000, fizzmod: 0, buzz: 0.0001, buzzmod: 3, neither: 0.0000\n",
"214 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 4, neither: 1.0000\n",
"215 fizz: 0.0000, fizzmod: 2, buzz: 1.0000, buzzmod: 0, neither: 0.0002\n",
"216 fizz: 1.0000, fizzmod: 0, buzz: 0.0000, buzzmod: 1, neither: 0.0000\n",
"217 fizz: 0.0000, fizzmod: 1, buzz: 0.0001, buzzmod: 2, neither: 1.0000\n",
"218 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 3, neither: 0.9999\n",
"219 fizz: 1.0000, fizzmod: 0, buzz: 0.0001, buzzmod: 4, neither: 0.0000\n",
"220 fizz: 0.0000, fizzmod: 1, buzz: 1.0000, buzzmod: 0, neither: 0.0004\n",
"221 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 1, neither: 0.9999\n",
"222 fizz: 1.0000, fizzmod: 0, buzz: 0.0002, buzzmod: 2, neither: 0.0000\n",
"223 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 3, neither: 1.0000\n",
"224 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 4, neither: 0.9999\n",
"1000000000 fizz: 0.0000, fizzmod: 1, buzz: 1.0000, buzzmod: 0, neither: 0.0003\n",
"1000000001 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 1, neither: 1.0000\n",
"1000000002 fizz: 1.0000, fizzmod: 0, buzz: 0.0000, buzzmod: 2, neither: 0.0000\n",
"1000000003 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 3, neither: 1.0000\n",
"1000000004 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 4, neither: 1.0000\n",
"1000000005 fizz: 1.0000, fizzmod: 0, buzz: 0.9999, buzzmod: 0, neither: 0.0000\n",
"1000000006 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 1, neither: 1.0000\n",
"1000000007 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 2, neither: 1.0000\n",
"1000000008 fizz: 1.0000, fizzmod: 0, buzz: 0.0001, buzzmod: 3, neither: 0.0000\n",
"1000000009 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 4, neither: 0.9999\n",
"1000000010 fizz: 0.0000, fizzmod: 2, buzz: 0.9999, buzzmod: 0, neither: 0.0004\n",
"1000000011 fizz: 1.0000, fizzmod: 0, buzz: 0.0000, buzzmod: 1, neither: 0.0000\n",
"1000000012 fizz: 0.0000, fizzmod: 1, buzz: 0.0001, buzzmod: 2, neither: 0.9991\n",
"1000000013 fizz: 0.0000, fizzmod: 2, buzz: 0.0002, buzzmod: 3, neither: 0.9992\n",
"1000000014 fizz: 1.0000, fizzmod: 0, buzz: 0.0000, buzzmod: 4, neither: 0.0000\n",
"1000000015 fizz: 0.0000, fizzmod: 1, buzz: 0.9998, buzzmod: 0, neither: 0.0023\n",
"1000000016 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 1, neither: 1.0000\n",
"1000000017 fizz: 1.0000, fizzmod: 0, buzz: 0.0000, buzzmod: 2, neither: 0.0001\n",
"1000000018 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 3, neither: 1.0000\n",
"1000000019 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 4, neither: 1.0000\n",
"1000000020 fizz: 1.0000, fizzmod: 0, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"1000000021 fizz: 0.0001, fizzmod: 1, buzz: 0.0001, buzzmod: 1, neither: 0.9996\n",
"1000000022 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 2, neither: 1.0000\n",
"1000000023 fizz: 1.0000, fizzmod: 0, buzz: 0.0001, buzzmod: 3, neither: 0.0001\n",
"1000000024 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 4, neither: 0.9999\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# stage 3: train to 32 bits\n",
"optimizer = make_optimizer(model)\n",
"train_stage(model, optimizer, batch_size,\n",
" iterations=65536, min_val=0, max_val=(1 << 32) - 1)\n",
"test_fizzbuzz(model, torch.tensor([i for i in range(int(2e2), int(2e2) + 25)]))\n",
"test_fizzbuzz(model, torch.tensor([i for i in range(int(1e9), int(1e9) + 25)]))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "cb69a399",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6275b40a60df4d25b321500c3f43bc7d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/16 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(28229346, device='cuda:0') fizz wrong: expect fizz got tensor(0.1587, device='cuda:0')\n",
"tensor(53023808, device='cuda:0') neither wrong: expect neither got tensor(0.7494, device='cuda:0')\n",
"tensor(53023810, device='cuda:0') buzz wrong: expect buzz got tensor(0.2398, device='cuda:0')\n",
"tensor(53023840, device='cuda:0') buzz wrong: expect buzz got tensor(0.2424, device='cuda:0')\n",
"tensor(53023850, device='cuda:0') buzz wrong: expect buzz got tensor(0.4570, device='cuda:0')\n",
"tensor(53023970, device='cuda:0') buzz wrong: expect buzz got tensor(0.1307, device='cuda:0')\n",
"tensor(53024320, device='cuda:0') buzz wrong: expect buzz got tensor(0.2518, device='cuda:0')\n",
"tensor(53024330, device='cuda:0') buzz wrong: expect buzz got tensor(0.4427, device='cuda:0')\n",
"tensor(53024354, device='cuda:0') neither wrong: expect neither got tensor(0.7040, device='cuda:0')\n",
"tensor(53024355, device='cuda:0') buzz wrong: expect buzz got tensor(0.2424, device='cuda:0')\n",
"tensor(53024360, device='cuda:0') buzz wrong: expect buzz got tensor(0.4293, device='cuda:0')\n",
"tensor(53024362, device='cuda:0') neither wrong: expect neither got tensor(0.4302, device='cuda:0')\n",
"tensor(53024370, device='cuda:0') buzz wrong: expect buzz got tensor(0.2436, device='cuda:0')\n",
"tensor(53024450, device='cuda:0') buzz wrong: expect buzz got tensor(0.1235, device='cuda:0')\n",
"tensor(53024458, device='cuda:0') neither wrong: expect neither got tensor(0.5055, device='cuda:0')\n",
"tensor(53024480, device='cuda:0') buzz wrong: expect buzz got tensor(0.1266, device='cuda:0')\n",
"tensor(53024482, device='cuda:0') neither wrong: expect neither got tensor(0.5953, device='cuda:0')\n",
"tensor(53024483, device='cuda:0') neither wrong: expect neither got tensor(0.6983, device='cuda:0')\n",
"tensor(53024488, device='cuda:0') neither wrong: expect neither got tensor(0.6188, device='cuda:0')\n",
"tensor(53024490, device='cuda:0') buzz wrong: expect buzz got tensor(0.2725, device='cuda:0')\n",
"tensor(53024498, device='cuda:0') neither wrong: expect neither got tensor(0.7045, device='cuda:0')\n",
"tensor(53024610, device='cuda:0') buzz wrong: expect buzz got tensor(0.2425, device='cuda:0')\n"
]
}
],
"source": [
"test_fizzbuzz(model, torch.arange(0, 101), quiet=True)\n",
"for i in trange(0, int(64e6), int(4e6)):\n",
" start = np.random.randint(i, i + 1048576 - 1024)\n",
" test_fizzbuzz(model, torch.arange(start, start + 1024), quiet=True)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "f4ddb9bd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(1, device='cuda:0')\n",
"tensor(2, device='cuda:0')\n",
"fizz\n",
"tensor(4, device='cuda:0')\n",
"buzz\n",
"fizz\n",
"tensor(7, device='cuda:0')\n",
"tensor(8, device='cuda:0')\n",
"fizz\n",
"buzz\n",
"tensor(11, device='cuda:0')\n",
"fizz\n",
"tensor(13, device='cuda:0')\n",
"tensor(14, device='cuda:0')\n",
"fizzbuzz\n",
"tensor(16, device='cuda:0')\n",
"tensor(17, device='cuda:0')\n",
"fizz\n",
"tensor(19, device='cuda:0')\n",
"buzz\n",
"fizz\n",
"tensor(22, device='cuda:0')\n",
"tensor(23, device='cuda:0')\n",
"fizz\n",
"buzz\n",
"tensor(26, device='cuda:0')\n",
"fizz\n",
"tensor(28, device='cuda:0')\n",
"tensor(29, device='cuda:0')\n",
"fizzbuzz\n",
"tensor(31, device='cuda:0')\n",
"tensor(32, device='cuda:0')\n",
"fizz\n",
"tensor(34, device='cuda:0')\n",
"buzz\n",
"fizz\n",
"tensor(37, device='cuda:0')\n",
"tensor(38, device='cuda:0')\n",
"fizz\n",
"buzz\n",
"tensor(41, device='cuda:0')\n",
"fizz\n",
"tensor(43, device='cuda:0')\n",
"tensor(44, device='cuda:0')\n",
"fizzbuzz\n",
"tensor(46, device='cuda:0')\n",
"tensor(47, device='cuda:0')\n",
"fizz\n",
"tensor(49, device='cuda:0')\n",
"buzz\n",
"fizz\n",
"tensor(52, device='cuda:0')\n",
"tensor(53, device='cuda:0')\n",
"fizz\n",
"buzz\n",
"tensor(56, device='cuda:0')\n",
"fizz\n",
"tensor(58, device='cuda:0')\n",
"tensor(59, device='cuda:0')\n",
"fizzbuzz\n",
"tensor(61, device='cuda:0')\n",
"tensor(62, device='cuda:0')\n",
"fizz\n",
"tensor(64, device='cuda:0')\n",
"buzz\n",
"fizz\n",
"tensor(67, device='cuda:0')\n",
"tensor(68, device='cuda:0')\n",
"fizz\n",
"buzz\n",
"tensor(71, device='cuda:0')\n",
"fizz\n",
"tensor(73, device='cuda:0')\n",
"tensor(74, device='cuda:0')\n",
"fizzbuzz\n",
"tensor(76, device='cuda:0')\n",
"tensor(77, device='cuda:0')\n",
"fizz\n",
"tensor(79, device='cuda:0')\n",
"buzz\n",
"fizz\n",
"tensor(82, device='cuda:0')\n",
"tensor(83, device='cuda:0')\n",
"fizz\n",
"buzz\n",
"tensor(86, device='cuda:0')\n",
"fizz\n",
"tensor(88, device='cuda:0')\n",
"tensor(89, device='cuda:0')\n",
"fizzbuzz\n",
"tensor(91, device='cuda:0')\n",
"tensor(92, device='cuda:0')\n",
"fizz\n",
"tensor(94, device='cuda:0')\n",
"buzz\n",
"fizz\n",
"tensor(97, device='cuda:0')\n",
"tensor(98, device='cuda:0')\n",
"fizz\n",
"buzz\n"
]
}
],
"source": [
"infer_threshold = 0.75\n",
"def fizzbuzz(model, min, max, quiet=False):\n",
" nums = torch.arange(min, max)\n",
" inputs = split_nums(nums)\n",
" outputs = model.forward(inputs)\n",
" collected = []\n",
" for num, (fizz, buzz, neither) in zip(nums, outputs):\n",
" out = ''\n",
" if fizz > infer_threshold:\n",
" out += 'fizz'\n",
" if buzz > infer_threshold:\n",
" out += 'buzz'\n",
" if neither > infer_threshold:\n",
" out += str(num)\n",
" collected.append(out)\n",
" if not quiet:\n",
" print(out)\n",
" \n",
" return collected\n",
"\n",
"fizzbuzz(model, 1, 101);"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment