Skip to content

Instantly share code, notes, and snippets.

@shawntan
Created March 30, 2022 02:28
Show Gist options
  • Save shawntan/b6b28b3f16d54f5e5f1668026f70cabc to your computer and use it in GitHub Desktop.
Save shawntan/b6b28b3f16d54f5e5f1668026f70cabc to your computer and use it in GitHub Desktop.
Parity problem with PFSA
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "0e842850",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import random\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d0ae549a",
"metadata": {},
"outputs": [],
"source": [
"# Bit string with even no. of ones\n",
"def generate_parity(generate_even=True):\n",
" end_state = \"EVEN\" if generate_even else \"ODD\"\n",
" \n",
" rn = random.random()\n",
" if rn < 0.5:\n",
" result = [0]\n",
" state = \"EVEN\"\n",
" else:\n",
" result = [1]\n",
" state = \"ODD\"\n",
" \n",
" while True:\n",
" if state == end_state:\n",
" rn = random.random()\n",
" if rn < 0.1:\n",
" break\n",
" \n",
" rn = random.random()\n",
" if state == \"EVEN\":\n",
" if rn < 0.5:\n",
" result.append(1)\n",
" state = \"ODD\"\n",
" else:\n",
" result.append(0)\n",
" state = \"EVEN\"\n",
" else: # state == \"ODD\"\n",
" if rn < 0.5:\n",
" result.append(1)\n",
" state = \"EVEN\"\n",
" else: \n",
" result.append(0)\n",
" state = \"ODD\"\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "79d6c5dd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"58 34 0100010100001111111011110110111011100100101101011011010101\n",
"29 16 10010000111010110001111111001\n",
"1 0 0\n",
"8 6 11011011\n",
"4 2 1001\n"
]
}
],
"source": [
"for i in range(5):\n",
" string = generate_parity(generate_even=True)\n",
" print(len(string), sum(string), ''.join(str(n) for n in string))\n",
" \n",
"def create_batch(batch_size, device=torch.device('cpu'), generate_even=True, min_length=0, max_length=100000):\n",
" batch_list = []\n",
" while len(batch_list) < batch_size:\n",
" instance = generate_parity(generate_even) + [2]\n",
" if len(instance) >= min_length and len(instance) <= max_length:\n",
" batch_list.append(instance)\n",
" \n",
" max_length = max(len(s) for s in batch_list)\n",
" batch = np.full((max_length, batch_size), -1)\n",
" for i in range(batch_size):\n",
" batch[:, i][:len(batch_list[i])] = batch_list[i]\n",
" return torch.tensor(batch)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b43eba5c",
"metadata": {},
"outputs": [],
"source": [
"def restricted_mask():\n",
" transitions = np.zeros((4, 3, 4), dtype=np.bool_)\n",
"\n",
" transitions[0, 0, 2] = 1\n",
" transitions[0, 1, 1] = 1\n",
"\n",
" transitions[1, 0, 1] = 1\n",
" transitions[1, 1, 2] = 1\n",
"\n",
" transitions[2, 0, 2] = 1\n",
" transitions[2, 1, 1] = 1\n",
" transitions[2, 2, 3] = 1\n",
"\n",
" print(transitions.astype(np.int32))\n",
" return transitions"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "96dadd7f",
"metadata": {},
"outputs": [],
"source": [
"class PFSA(nn.Module):\n",
" \n",
" def __init__(self, n_states, n_symbols, end_symbol):\n",
" super(PFSA, self).__init__()\n",
" self.initial = nn.Parameter(torch.randn(n_states))\n",
" self.transition_logits = nn.Parameter(torch.randn(n_states, n_symbols, n_states))\n",
" self.end_symbol = end_symbol\n",
" \n",
" def _normalise(self):\n",
" log_init = torch.log_softmax(self.initial, dim=-1)\n",
" \n",
" z = torch.logsumexp(self.transition_logits, dim=(-2, -1), keepdim=True)\n",
" log_probs = self.transition_logits - z\n",
" return log_init, log_probs\n",
" \n",
" def log_mult(self, state, exp_T):\n",
" state_k, _ = torch.max(state, dim=-1, keepdim=True)\n",
" p_state = torch.exp(state - state_k)\n",
" n_state = torch.einsum('bi,bij->bj', p_state, exp_T)\n",
" log_n_state = torch.log(n_state) + state_k\n",
" return log_n_state\n",
" \n",
" def forward(self, x):\n",
" log_init, log_probs = self._normalise()\n",
" log_state = log_init[None, :].expand(x.size(1), -1)\n",
" transitions = torch.exp(log_probs[:, x, :]).permute(1, 2, 0, 3)\n",
" final = torch.zeros_like(x[0], dtype=torch.float)\n",
" for t in range(x.size(0)):\n",
" log_state = self.log_mult(log_state, transitions[t])\n",
" e = x[t] == self.end_symbol\n",
" final[e] = log_state[e, -1] \n",
" return final"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d2191a1e",
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"pfsa = PFSA(50, 3, end_symbol=2)\n",
"\n",
"id_test = create_batch(100)\n",
"id_length = torch.sum(id_test != -1)\n",
"pos_test = create_batch(100, min_length=21)\n",
"pos_length = torch.sum(pos_test != -1)\n",
"neg_test = create_batch(100, generate_even=False)\n",
"neg_length = torch.sum(neg_test != -1)\n",
"\n",
"optimizer = torch.optim.Adam(pfsa.parameters(), lr=1e-3)\n",
"log = []\n",
"for i in range(20000):\n",
" x = create_batch(100, max_length=20)\n",
" lengths = torch.sum(x != -1)\n",
" log_prob = pfsa(x)\n",
" loss = -log_prob.sum() / lengths\n",
" \n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
" if i % 200 == 0:\n",
" pfsa.eval()\n",
" id_test_loss = -pfsa(id_test).sum() / id_length\n",
" pos_test_loss = -pfsa(pos_test).sum() / pos_length\n",
" neg_test_loss = -pfsa(neg_test).sum() / neg_length\n",
" log.append((id_test_loss.item(), pos_test_loss.item(), neg_test_loss.item()))\n",
" pfsa.train()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a362d057",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"t = list(range(len(log)))\n",
"plt.figure(figsize=(10,5))\n",
"plt.plot(t, [x[0] for x in log], label='In-distribution ($\\leq$ 20)')\n",
"plt.plot(t, [x[1] for x in log], label='OOD ($>$ 20)')\n",
"plt.plot(t, [x[2] for x in log], label='Non-even bits')\n",
"plt.ylabel('Loss')\n",
"plt.xlabel('Iterations')\n",
"plt.legend()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@JacksonConvis
Copy link

WOW

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment