Skip to content

Instantly share code, notes, and snippets.

@iczero
Last active September 25, 2023 02:56
Show Gist options
  • Save iczero/e9328cecb3ffb7fc7a169c5d93387e9a to your computer and use it in GitHub Desktop.
Save iczero/e9328cecb3ffb7fc7a169c5d93387e9a to your computer and use it in GitHub Desktop.
fizzbuzz in jax
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "3cc970e3-a4e2-40ae-a4da-db15abcd550c",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import jax\n",
"from jax import numpy as jnp\n",
"from jax.tree_util import register_pytree_node_class\n",
"import optax\n",
"from tqdm.notebook import trange\n",
"import matplotlib.pyplot as plt\n",
"\n",
"jax.config.update('jax_debug_nans', True)\n",
"\n",
"model_dtype = jnp.float32"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7b1f6362-de4b-4e09-84dc-7ccff11cd2be",
"metadata": {},
"outputs": [],
"source": [
"# a layer\n",
"@register_pytree_node_class\n",
"class Layer:\n",
" weights: jax.Array\n",
" biases: jax.Array\n",
"\n",
" def __init__(self, weights: jax.Array, biases: jax.Array):\n",
" # weights array: columns are nodes, rows are weights\n",
" self.weights = weights\n",
" self.biases = biases\n",
"\n",
" @staticmethod\n",
" def init_size(randkey, inputs: int, nodes: int):\n",
" weights = jnp.sqrt(2 / inputs) * jax.random.normal(randkey, shape=(inputs, nodes))\n",
" biases = jnp.zeros(nodes)\n",
" return Layer(jnp.asarray(weights, dtype=model_dtype), jnp.asarray(biases, dtype=model_dtype))\n",
" \n",
" def forward(self, input: jax.Array) -> jax.Array:\n",
" return input @ self.weights + self.biases\n",
"\n",
" def __repr__(self):\n",
" return f'<Layer weights={repr(self.weights)} biases={repr(self.biases)}>'\n",
" \n",
" def tree_flatten(self):\n",
" return ((self.weights, self.biases), None)\n",
" \n",
" @classmethod\n",
" def tree_unflatten(cls, aux_data, children):\n",
" return cls(children[0], children[1])\n",
" \n",
"@register_pytree_node_class\n",
"class Network:\n",
" layers: list[Layer]\n",
"\n",
" def __init__(self, layers):\n",
" self.layers = layers\n",
"\n",
" @staticmethod\n",
" def new(randkey):\n",
" key1, key2, key3 = jax.random.split(randkey, 3)\n",
" return Network([\n",
" Layer.init_size(key1, 32, 1024),\n",
" Layer.init_size(key2, 1024, 512),\n",
" Layer.init_size(key3, 512, 3),\n",
" ])\n",
"\n",
" def forward(self, value: jax.Array) -> jax.Array:\n",
" layers = iter(self.layers)\n",
" value = jax.nn.tanh(next(layers).forward(value))\n",
" value = jax.nn.tanh(next(layers).forward(value))\n",
" value = jax.nn.sigmoid(next(layers).forward(value))\n",
"\n",
" return value\n",
" \n",
" def tree_flatten(self):\n",
" return (self.layers, None)\n",
" \n",
" @classmethod\n",
" def tree_unflatten(cls, aux_data, children):\n",
" return cls(children)\n",
"\n",
"def log_loss(out: jax.Array, expected: jax.Array):\n",
" eps = 1e-7\n",
" out = jnp.clip(out, eps, 1 - eps)\n",
" errors = -expected * jnp.log(out) - (1 - expected) * jnp.log(1 - out)\n",
" # sum horizontally (each data row)\n",
" return jnp.mean(jnp.sum(errors, axis=1))\n",
"\n",
"def loss(model: Network, x, y):\n",
" y_hat = model.forward(x)\n",
" return log_loss(y_hat, y)\n",
"\n",
"loss_grads = jax.value_and_grad(loss)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cb24b1e4",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"I0000 00:00:1695609639.709761 125941 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n"
]
}
],
"source": [
"bit_width = 1\n",
"shift_by = jnp.array([i for i in range(32 // bit_width)], dtype=jnp.uint32)\n",
"mask = (1 << bit_width) - 1\n",
"\n",
"def split_nums(nums: jax.Array):\n",
" \"split u32 into bits\"\n",
" inputs = jnp.reshape(nums, (-1, 1))\n",
" return (jnp.bitwise_and(jnp.right_shift(inputs, shift_by), mask)).astype(model_dtype)\n",
"\n",
"def make_data_set(randkey: jax.random.PRNGKey, min: int, max: int, count: int | None) -> (jax.Array, jax.Array):\n",
" minval = jnp.array(min, dtype=jnp.uint32)\n",
" maxval = jnp.array(max, dtype=jnp.uint32)\n",
" if count is not None:\n",
" nums = jax.random.randint(randkey, (count,), minval, maxval, dtype=jnp.uint32)\n",
" else:\n",
" nums = jnp.arange(min, max, dtype=jnp.uint32)\n",
" inputs = split_nums(nums)\n",
" out_fizz = jnp.equal(jnp.mod(nums, 3), 0)\n",
" out_buzz = jnp.equal(jnp.mod(nums, 5), 0)\n",
" out_none = jnp.logical_and(jnp.logical_not(out_fizz), jnp.logical_not(out_buzz))\n",
" outputs = jnp.stack((out_fizz, out_buzz, out_none), axis=1).astype(model_dtype)\n",
" return inputs, outputs\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f7b3ff70",
"metadata": {},
"outputs": [],
"source": [
"def make_step(optimizer, batch_size, min, max):\n",
" @jax.jit\n",
" def step(model, randkey, opt_state):\n",
" key, rand_split = jax.random.split(randkey)\n",
" train_in, train_out = make_data_set(key, min, max, batch_size)\n",
" loss_v, grads = loss_grads(model, train_in, train_out)\n",
" updates, opt_state = optimizer.update(grads, opt_state, model)\n",
" #updates = jax.tree_util.tree_map(lambda x: x * -0.1, grads)\n",
" model = optax.apply_updates(model, updates)\n",
" return rand_split, model, opt_state, loss_v\n",
" \n",
" return step"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e64deb5e",
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def model_eval(model, input):\n",
" return model.forward(input)\n",
"\n",
"test_threshold = 0.75\n",
"def test_fizzbuzz(model, input, quiet=False):\n",
" out = model_eval(model, split_nums(jnp.array(input)))\n",
" for i, val in zip(input, out):\n",
" fizz, buzz, neither = val\n",
" fizzmod = i % 3\n",
" buzzmod = i % 5\n",
" if not quiet:\n",
" print(\n",
" f'{i} fizz: {fizz:.4f}, fizzmod: {fizzmod}, ' +\n",
" f'buzz: {buzz:.4f}, buzzmod: {buzzmod}, ' +\n",
" f'neither: {neither:.4f}')\n",
" if fizzmod == 0 and fizz < test_threshold:\n",
" print(i, 'fizz wrong: expect fizz got', fizz)\n",
" if fizz >= test_threshold and fizzmod != 0:\n",
" print(i, 'fizz wrong: expect not fizz got', fizz)\n",
" if buzzmod == 0 and buzz < test_threshold:\n",
" print(i, 'buzz wrong: expect buzz got', buzz)\n",
" if buzz >= test_threshold and buzzmod != 0:\n",
" print(i, 'buzz wrong: expect not buzz got', buzz)\n",
" if fizzmod != 0 and buzzmod != 0 and neither < test_threshold:\n",
" print(i, 'neither wrong: expect neither got', neither)\n",
" if (fizzmod == 0 or buzzmod == 0) and neither >= test_threshold:\n",
" print(i, 'neither wrong: expect not neither got', neither)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "178e7313",
"metadata": {},
"outputs": [],
"source": [
"def train_stage(randkey, model, optimizer, batch_size, iterations, min_val, max_val):\n",
" step = make_step(optimizer, batch_size, min_val, max_val)\n",
" progress = trange(iterations)\n",
" lossplot_v = np.zeros(len(progress), dtype=model_dtype)\n",
" opt_state = optimizer.init(model)\n",
" for i in progress:\n",
" randkey, model, opt_state, loss_v = step(model, randkey, opt_state)\n",
" progress.set_postfix({ 'loss': loss_v })\n",
" lossplot_v[i] = loss_v\n",
" if loss_v < 0.005:\n",
" break\n",
"\n",
" lossplot_t = np.arange(progress.n)\n",
" plt.plot(lossplot_t, lossplot_v[:progress.n])\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b2df9625",
"metadata": {},
"outputs": [],
"source": [
"model = Network.new(jax.random.PRNGKey(42))\n",
"optimizer = optax.adam(learning_rate=0.0005)\n",
"batch_size = 16384\n",
"randkey = jax.random.PRNGKey(42)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ac8334dd",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c51f20058c174440ae27fc3e778fac62",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/16384 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 fizz: 0.9997, fizzmod: 0, buzz: 0.9998, buzzmod: 0, neither: 0.0000\n",
"1 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 1, neither: 1.0000\n",
"2 fizz: 0.0000, fizzmod: 2, buzz: 0.0001, buzzmod: 2, neither: 1.0000\n",
"3 fizz: 0.9999, fizzmod: 0, buzz: 0.0002, buzzmod: 3, neither: 0.0000\n",
"4 fizz: 0.0001, fizzmod: 1, buzz: 0.0000, buzzmod: 4, neither: 1.0000\n",
"5 fizz: 0.0090, fizzmod: 2, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"6 fizz: 0.9999, fizzmod: 0, buzz: 0.0002, buzzmod: 1, neither: 0.0000\n",
"7 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 2, neither: 1.0000\n",
"8 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 3, neither: 1.0000\n",
"9 fizz: 0.9999, fizzmod: 0, buzz: 0.0001, buzzmod: 4, neither: 0.0000\n",
"10 fizz: 0.0056, fizzmod: 1, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"11 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 1, neither: 1.0000\n",
"12 fizz: 0.9999, fizzmod: 0, buzz: 0.0001, buzzmod: 2, neither: 0.0001\n",
"13 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 3, neither: 1.0000\n",
"14 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 4, neither: 1.0000\n",
"15 fizz: 0.9999, fizzmod: 0, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"16 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 1, neither: 1.0000\n",
"17 fizz: 0.0090, fizzmod: 2, buzz: 0.0004, buzzmod: 2, neither: 0.9880\n",
"18 fizz: 0.9999, fizzmod: 0, buzz: 0.0003, buzzmod: 3, neither: 0.0000\n",
"19 fizz: 0.0001, fizzmod: 1, buzz: 0.0016, buzzmod: 4, neither: 0.9986\n",
"20 fizz: 0.0106, fizzmod: 2, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"21 fizz: 0.9729, fizzmod: 0, buzz: 0.0000, buzzmod: 1, neither: 0.0226\n",
"22 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 2, neither: 0.9999\n",
"23 fizz: 0.0006, fizzmod: 2, buzz: 0.0000, buzzmod: 3, neither: 0.9997\n",
"24 fizz: 0.9999, fizzmod: 0, buzz: 0.0001, buzzmod: 4, neither: 0.0000\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# stage 1: train to 16 bits\n",
"randkey, trainkey = jax.random.split(randkey)\n",
"model = train_stage(trainkey, model, optimizer, batch_size,\n",
" iterations=16384, min_val=0, max_val=1 << 16)\n",
"test_fizzbuzz(model, [i for i in range(0, 25)])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "b109fb31",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b1961c00900d489fb4fc3e64dab44aed",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/16384 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"100 fizz: 0.0000, fizzmod: 1, buzz: 1.0000, buzzmod: 0, neither: 0.0001\n",
"101 fizz: 0.0006, fizzmod: 2, buzz: 0.0000, buzzmod: 1, neither: 0.9998\n",
"102 fizz: 1.0000, fizzmod: 0, buzz: 0.0008, buzzmod: 2, neither: 0.0000\n",
"103 fizz: 0.0000, fizzmod: 1, buzz: 0.0001, buzzmod: 3, neither: 1.0000\n",
"104 fizz: 0.0002, fizzmod: 2, buzz: 0.0001, buzzmod: 4, neither: 0.9990\n",
"105 fizz: 0.9999, fizzmod: 0, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"106 fizz: 0.0003, fizzmod: 1, buzz: 0.0001, buzzmod: 1, neither: 0.9992\n",
"107 fizz: 0.0001, fizzmod: 2, buzz: 0.0000, buzzmod: 2, neither: 0.9998\n",
"108 fizz: 1.0000, fizzmod: 0, buzz: 0.0006, buzzmod: 3, neither: 0.0000\n",
"109 fizz: 0.0001, fizzmod: 1, buzz: 0.0001, buzzmod: 4, neither: 1.0000\n",
"110 fizz: 0.0000, fizzmod: 2, buzz: 1.0000, buzzmod: 0, neither: 0.0001\n",
"111 fizz: 1.0000, fizzmod: 0, buzz: 0.0000, buzzmod: 1, neither: 0.0000\n",
"112 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 2, neither: 1.0000\n",
"113 fizz: 0.0013, fizzmod: 2, buzz: 0.0000, buzzmod: 3, neither: 0.9996\n",
"114 fizz: 1.0000, fizzmod: 0, buzz: 0.0009, buzzmod: 4, neither: 0.0000\n",
"115 fizz: 0.0001, fizzmod: 1, buzz: 0.9998, buzzmod: 0, neither: 0.0008\n",
"116 fizz: 0.0004, fizzmod: 2, buzz: 0.0001, buzzmod: 1, neither: 0.9997\n",
"117 fizz: 0.9954, fizzmod: 0, buzz: 0.0000, buzzmod: 2, neither: 0.0259\n",
"118 fizz: 0.0000, fizzmod: 1, buzz: 0.0001, buzzmod: 3, neither: 0.9999\n",
"119 fizz: 0.0004, fizzmod: 2, buzz: 0.0001, buzzmod: 4, neither: 0.9991\n",
"120 fizz: 0.9999, fizzmod: 0, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"121 fizz: 0.0001, fizzmod: 1, buzz: 0.0000, buzzmod: 1, neither: 0.9997\n",
"122 fizz: 0.0001, fizzmod: 2, buzz: 0.0001, buzzmod: 2, neither: 0.9990\n",
"123 fizz: 0.9999, fizzmod: 0, buzz: 0.0000, buzzmod: 3, neither: 0.0003\n",
"124 fizz: 0.0000, fizzmod: 1, buzz: 0.0001, buzzmod: 4, neither: 1.0000\n",
"1000000 fizz: 0.0000, fizzmod: 1, buzz: 1.0000, buzzmod: 0, neither: 0.0001\n",
"1000001 fizz: 0.0003, fizzmod: 2, buzz: 0.0000, buzzmod: 1, neither: 0.9999\n",
"1000002 fizz: 1.0000, fizzmod: 0, buzz: 0.0003, buzzmod: 2, neither: 0.0000\n",
"1000003 fizz: 0.0000, fizzmod: 1, buzz: 0.0001, buzzmod: 3, neither: 0.9999\n",
"1000004 fizz: 0.0003, fizzmod: 2, buzz: 0.0006, buzzmod: 4, neither: 0.9988\n",
"1000005 fizz: 0.9983, fizzmod: 0, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"1000006 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 1, neither: 0.9998\n",
"1000007 fizz: 0.0001, fizzmod: 2, buzz: 0.0003, buzzmod: 2, neither: 0.9984\n",
"1000008 fizz: 1.0000, fizzmod: 0, buzz: 0.0002, buzzmod: 3, neither: 0.0000\n",
"1000009 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 4, neither: 1.0000\n",
"1000010 fizz: 0.0004, fizzmod: 2, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"1000011 fizz: 1.0000, fizzmod: 0, buzz: 0.0000, buzzmod: 1, neither: 0.0000\n",
"1000012 fizz: 0.0000, fizzmod: 1, buzz: 0.0009, buzzmod: 2, neither: 0.9999\n",
"1000013 fizz: 0.0001, fizzmod: 2, buzz: 0.0001, buzzmod: 3, neither: 0.9999\n",
"1000014 fizz: 1.0000, fizzmod: 0, buzz: 0.0006, buzzmod: 4, neither: 0.0000\n",
"1000015 fizz: 0.0000, fizzmod: 1, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"1000016 fizz: 0.0002, fizzmod: 2, buzz: 0.0001, buzzmod: 1, neither: 0.9999\n",
"1000017 fizz: 0.9994, fizzmod: 0, buzz: 0.0002, buzzmod: 2, neither: 0.0005\n",
"1000018 fizz: 0.0001, fizzmod: 1, buzz: 0.0003, buzzmod: 3, neither: 0.9998\n",
"1000019 fizz: 0.0001, fizzmod: 2, buzz: 0.0003, buzzmod: 4, neither: 0.9993\n",
"1000020 fizz: 0.9996, fizzmod: 0, buzz: 0.9999, buzzmod: 0, neither: 0.0000\n",
"1000021 fizz: 0.0027, fizzmod: 1, buzz: 0.0001, buzzmod: 1, neither: 0.9936\n",
"1000022 fizz: 0.0001, fizzmod: 2, buzz: 0.0002, buzzmod: 2, neither: 0.9994\n",
"1000023 fizz: 0.9996, fizzmod: 0, buzz: 0.0002, buzzmod: 3, neither: 0.0004\n",
"1000024 fizz: 0.0001, fizzmod: 1, buzz: 0.0001, buzzmod: 4, neither: 0.9999\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# stage 2: train to 24 bits\n",
"randkey, trainkey = jax.random.split(randkey)\n",
"model = train_stage(trainkey, model, optimizer, batch_size,\n",
" iterations=16384, min_val=0, max_val=(1 << 24) - 1)\n",
"test_fizzbuzz(model, [i for i in range(int(1e2), int(1e2) + 25)])\n",
"test_fizzbuzz(model, [i for i in range(int(1e6), int(1e6) + 25)])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "63728fdc",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b67500c28f10414090e446f488dc3394",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/32768 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"200 fizz: 0.0000, fizzmod: 2, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"201 fizz: 1.0000, fizzmod: 0, buzz: 0.0000, buzzmod: 1, neither: 0.0000\n",
"202 fizz: 0.0000, fizzmod: 1, buzz: 0.0013, buzzmod: 2, neither: 0.9997\n",
"203 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 3, neither: 1.0000\n",
"204 fizz: 1.0000, fizzmod: 0, buzz: 0.0022, buzzmod: 4, neither: 0.0000\n",
"205 fizz: 0.0000, fizzmod: 1, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"206 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 1, neither: 1.0000\n",
"207 fizz: 1.0000, fizzmod: 0, buzz: 0.0000, buzzmod: 2, neither: 0.0000\n",
"208 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 3, neither: 1.0000\n",
"209 fizz: 0.0000, fizzmod: 2, buzz: 0.0007, buzzmod: 4, neither: 0.9951\n",
"210 fizz: 1.0000, fizzmod: 0, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"211 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 1, neither: 0.9997\n",
"212 fizz: 0.0000, fizzmod: 2, buzz: 0.0008, buzzmod: 2, neither: 0.9842\n",
"213 fizz: 1.0000, fizzmod: 0, buzz: 0.0000, buzzmod: 3, neither: 0.0000\n",
"214 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 4, neither: 0.9999\n",
"215 fizz: 0.0000, fizzmod: 2, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"216 fizz: 1.0000, fizzmod: 0, buzz: 0.0000, buzzmod: 1, neither: 0.0000\n",
"217 fizz: 0.0000, fizzmod: 1, buzz: 0.0182, buzzmod: 2, neither: 0.9998\n",
"218 fizz: 0.0001, fizzmod: 2, buzz: 0.0000, buzzmod: 3, neither: 1.0000\n",
"219 fizz: 1.0000, fizzmod: 0, buzz: 0.0001, buzzmod: 4, neither: 0.0000\n",
"220 fizz: 0.0000, fizzmod: 1, buzz: 1.0000, buzzmod: 0, neither: 0.0000\n",
"221 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 1, neither: 1.0000\n",
"222 fizz: 1.0000, fizzmod: 0, buzz: 0.0000, buzzmod: 2, neither: 0.0000\n",
"223 fizz: 0.0000, fizzmod: 1, buzz: 0.0000, buzzmod: 3, neither: 1.0000\n",
"224 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 4, neither: 0.9999\n",
"1000000000 fizz: 0.0002, fizzmod: 1, buzz: 0.9998, buzzmod: 0, neither: 0.0016\n",
"1000000001 fizz: 0.0001, fizzmod: 2, buzz: 0.0029, buzzmod: 1, neither: 0.9992\n",
"1000000002 fizz: 0.9942, fizzmod: 0, buzz: 0.0000, buzzmod: 2, neither: 0.0374\n",
"1000000003 fizz: 0.0001, fizzmod: 1, buzz: 0.0001, buzzmod: 3, neither: 0.9997\n",
"1000000004 fizz: 0.0001, fizzmod: 2, buzz: 0.0003, buzzmod: 4, neither: 0.9969\n",
"1000000005 fizz: 1.0000, fizzmod: 0, buzz: 0.9991, buzzmod: 0, neither: 0.0000\n",
"1000000006 fizz: 0.0003, fizzmod: 1, buzz: 0.0190, buzzmod: 1, neither: 0.9696\n",
"1000000007 fizz: 0.0000, fizzmod: 2, buzz: 0.0000, buzzmod: 2, neither: 1.0000\n",
"1000000008 fizz: 0.9991, fizzmod: 0, buzz: 0.0342, buzzmod: 3, neither: 0.0000\n",
"1000000009 fizz: 0.0002, fizzmod: 1, buzz: 0.0003, buzzmod: 4, neither: 0.9999\n",
"1000000010 fizz: 0.0145, fizzmod: 2, buzz: 0.9907, buzzmod: 0, neither: 0.0010\n",
"1000000011 fizz: 0.9938, fizzmod: 0, buzz: 0.0046, buzzmod: 1, neither: 0.0015\n",
"1000000012 fizz: 0.0002, fizzmod: 1, buzz: 0.0001, buzzmod: 2, neither: 1.0000\n",
"1000000013 fizz: 0.0001, fizzmod: 2, buzz: 0.0238, buzzmod: 3, neither: 0.9997\n",
"1000000014 fizz: 0.9980, fizzmod: 0, buzz: 0.0003, buzzmod: 4, neither: 0.0040\n",
"1000000015 fizz: 0.0003, fizzmod: 1, buzz: 0.9948, buzzmod: 0, neither: 0.0747\n",
"1000000016 fizz: 0.0002, fizzmod: 2, buzz: 0.0012, buzzmod: 1, neither: 0.9977\n",
"1000000017 fizz: 1.0000, fizzmod: 0, buzz: 0.0000, buzzmod: 2, neither: 0.0009\n",
"1000000018 fizz: 0.0002, fizzmod: 1, buzz: 0.0002, buzzmod: 3, neither: 0.9987\n",
"1000000019 fizz: 0.0001, fizzmod: 2, buzz: 0.0000, buzzmod: 4, neither: 1.0000\n",
"1000000020 fizz: 1.0000, fizzmod: 0, buzz: 0.9998, buzzmod: 0, neither: 0.0000\n",
"1000000021 fizz: 0.0001, fizzmod: 1, buzz: 0.0017, buzzmod: 1, neither: 0.9983\n",
"1000000022 fizz: 0.0001, fizzmod: 2, buzz: 0.0000, buzzmod: 2, neither: 1.0000\n",
"1000000023 fizz: 1.0000, fizzmod: 0, buzz: 0.0002, buzzmod: 3, neither: 0.0000\n",
"1000000024 fizz: 0.0004, fizzmod: 1, buzz: 0.0001, buzzmod: 4, neither: 0.9999\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# stage 3: train to 32 bits\n",
"randkey, trainkey = jax.random.split(randkey)\n",
"model = train_stage(trainkey, model, optimizer, batch_size,\n",
" iterations=32768, min_val=0, max_val=(1 << 32) - 1)\n",
"test_fizzbuzz(model, [i for i in range(int(2e2), int(2e2) + 25)])\n",
"test_fizzbuzz(model, [i for i in range(int(1e9), int(1e9) + 25)])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "cb69a399",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "72c15724a9494c519bbef619216c5567",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/16 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"48949678 neither wrong: expect neither got 0.64756584\n",
"48949738 neither wrong: expect neither got 0.63570035\n",
"48949802 neither wrong: expect neither got 0.23076619\n",
"48949818 fizz wrong: expect fizz got 0.728774\n",
"48949918 neither wrong: expect neither got 0.48839158\n",
"48949922 neither wrong: expect neither got 0.31886557\n",
"48949931 neither wrong: expect neither got 0.38822156\n",
"48949938 fizz wrong: expect fizz got 0.7156811\n",
"48949946 neither wrong: expect neither got 0.2861112\n",
"48949947 fizz wrong: expect fizz got 0.722468\n",
"48949950 fizz wrong: expect fizz got 0.523748\n",
"48949978 neither wrong: expect neither got 0.7377738\n",
"48949994 neither wrong: expect neither got 0.56062484\n",
"48950008 neither wrong: expect neither got 0.6009756\n",
"48950010 fizz wrong: expect fizz got 0.68679327\n",
"48950158 neither wrong: expect neither got 0.45459938\n",
"48950186 neither wrong: expect neither got 0.34532258\n",
"48950202 fizz wrong: expect fizz got 0.6460029\n",
"48950248 neither wrong: expect neither got 0.602318\n",
"52235340 buzz wrong: expect buzz got 0.6042997\n"
]
}
],
"source": [
"test_fizzbuzz(model, jnp.arange(0, 101), quiet=True)\n",
"for i in trange(0, int(64e6), int(4e6)):\n",
" start = np.random.randint(i, i + 1048576 - 1024)\n",
" test_fizzbuzz(model, jnp.arange(start, start + 1024), quiet=True)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "f4ddb9bd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"2\n",
"fizz\n",
"4\n",
"buzz\n",
"fizz\n",
"7\n",
"8\n",
"fizz\n",
"buzz\n",
"11\n",
"fizz\n",
"13\n",
"14\n",
"fizzbuzz\n",
"16\n",
"17\n",
"fizz\n",
"19\n",
"buzz\n",
"fizz\n",
"22\n",
"23\n",
"fizz\n",
"buzz\n",
"26\n",
"fizz\n",
"28\n",
"29\n",
"fizzbuzz\n",
"31\n",
"32\n",
"fizz\n",
"34\n",
"buzz\n",
"fizz\n",
"37\n",
"38\n",
"fizz\n",
"buzz\n",
"41\n",
"fizz\n",
"43\n",
"44\n",
"fizzbuzz\n",
"46\n",
"47\n",
"fizz\n",
"49\n",
"buzz\n",
"fizz\n",
"52\n",
"53\n",
"fizz\n",
"buzz\n",
"56\n",
"fizz\n",
"58\n",
"59\n",
"fizzbuzz\n",
"61\n",
"62\n",
"fizz\n",
"64\n",
"buzz\n",
"fizz\n",
"67\n",
"68\n",
"fizz\n",
"buzz\n",
"71\n",
"fizz\n",
"73\n",
"74\n",
"fizzbuzz\n",
"76\n",
"77\n",
"fizz\n",
"79\n",
"buzz\n",
"fizz\n",
"82\n",
"83\n",
"fizz\n",
"buzz\n",
"86\n",
"fizz\n",
"88\n",
"89\n",
"fizzbuzz\n",
"91\n",
"92\n",
"fizz\n",
"94\n",
"buzz\n",
"fizz\n",
"97\n",
"98\n",
"fizz\n",
"buzz\n"
]
}
],
"source": [
"infer_threshold = 0.75\n",
"def fizzbuzz(model, min, max):\n",
" nums = jnp.arange(min, max)\n",
" inputs = split_nums(nums)\n",
" outputs = model.forward(inputs)\n",
" for num, (fizz, buzz, neither) in zip(nums, outputs):\n",
" if fizz > infer_threshold:\n",
" print('fizz', end='')\n",
" if buzz > infer_threshold:\n",
" print('buzz', end='')\n",
" if neither > infer_threshold:\n",
" print(num, end='')\n",
" print('')\n",
"\n",
"fizzbuzz(model, 1, 101)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "38423617",
"metadata": {},
"outputs": [],
"source": [
"# save/load\n",
"import pickle\n",
"from uuid import uuid4\n",
"\n",
"def save_model(model, name=None):\n",
" out_name = name or f'model.{uuid4()}.pickle'\n",
" with open(out_name, 'wb') as file:\n",
" pickle.dump(model, file, protocol=pickle.HIGHEST_PROTOCOL)\n",
" \n",
" return out_name\n",
"\n",
"def load_model(name):\n",
" with open(name, 'rb') as file:\n",
" return pickle.load(file)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment