Skip to content

Instantly share code, notes, and snippets.

@malkin1729
Created December 22, 2022 01:41
Show Gist options
  • Save malkin1729/88227a1e451596e1ea1fc7d4e0a7ae09 to your computer and use it in GitHub Desktop.
Save malkin1729/88227a1e451596e1ea1fc7d4e0a7ae09 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch as T\n",
"import numpy as np\n",
"from matplotlib import pyplot as pt\n",
"import tqdm\n",
"\n",
"device = T.device('cuda')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class Policy(T.nn.Module):\n",
" \n",
" def __init__(self, n_classes, n_mixture_components, n_hidden_layers, hidden_dim, context_dim, uniform_pb=False): \n",
" super(Policy, self).__init__()\n",
" \n",
" self.n_classes = n_classes\n",
" self.n_mixture_components = n_mixture_components\n",
" self.context_dim = context_dim\n",
" self.uniform_pb = uniform_pb\n",
" \n",
" layers = [ T.nn.Linear(2*n_classes+1+context_dim, hidden_dim), T.nn.ELU() ]\n",
" for _ in range(n_hidden_layers-1):\n",
" layers += [ T.nn.Linear(hidden_dim, hidden_dim), T.nn.ELU() ]\n",
" layers += [ T.nn.Linear(hidden_dim, n_classes*(2+3*n_mixture_components)+1) ]\n",
" \n",
" self.model = T.nn.Sequential(*layers)\n",
" \n",
" def forward(self, theta, set_mask, remaining, context):\n",
" # theta: b * n_cl\n",
" # set_mask: b * n_cl\n",
" # remaining: b\n",
" # context: b * context_dim\n",
" \n",
" x = T.cat([theta, set_mask, remaining.unsqueeze(1), context], 1)\n",
" x = self.model(x)\n",
" \n",
" pb_logits = (x[:, :self.n_classes] * (0 if self.uniform_pb else 1) - 1e9*(1-set_mask)).log_softmax(1)\n",
" pf_class_logits = (x[:, self.n_classes:2*self.n_classes] - 1e9*set_mask).log_softmax(1)\n",
" pf_beta_parameters = x[:, 2*self.n_classes:self.n_classes*(2+3*self.n_mixture_components)]\n",
" log_flow = x[:, -1]\n",
"\n",
" return pb_logits, \\\n",
" pf_class_logits, \\\n",
" pf_beta_parameters.view(-1, self.n_classes, self.n_mixture_components, 3), \\\n",
" log_flow \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def sample_forward(context, temperature=1., uniform_prob=0., max_steps=None):\n",
" bs = context.shape[0]\n",
" assert context.shape==(bs,policy.context_dim)\n",
" \n",
" if max_steps == None: \n",
" max_steps = policy.n_classes\n",
"\n",
" remaining = T.ones((bs,)).to(device)\n",
" theta = T.zeros((bs, policy.n_classes)).to(device)\n",
" set_mask = T.zeros_like(theta)\n",
" logPF = T.zeros_like(remaining)\n",
" logPB = T.zeros_like(remaining)\n",
" \n",
" for step in range(max_steps+1):\n",
" pb_logits, pf_class_logits, pf_beta_parameters, log_flow = policy(theta, set_mask, remaining, context)\n",
" \n",
" if step==0:\n",
" logZ = log_flow\n",
" else:\n",
" logPB += pb_logits.gather(1, positions).squeeze(1)\n",
" \n",
" if step < max_steps:\n",
" sampling_probs = (1-uniform_prob) * (pf_class_logits/temperature).softmax(1) + uniform_prob * (1-set_mask) / (1-set_mask).sum(1).unsqueeze(1)\n",
" positions = T.multinomial(sampling_probs, 1)\n",
" \n",
" logPF += pf_class_logits.gather(1, positions).squeeze(1)\n",
" \n",
" if step < policy.n_classes - 1:\n",
" pos_beta_mixture_parameters = pf_beta_parameters.gather(1, positions[...,None,None].repeat(*((1,1)+pf_beta_parameters.shape[-2:]))).squeeze(1)\n",
" mixture_components_logits = pos_beta_mixture_parameters[...,0].log_softmax(1)\n",
" mixture_components = T.multinomial(mixture_components_logits.exp(), 1)\n",
" logPF += mixture_components_logits.gather(1, mixture_components).squeeze(1)\n",
" pos_beta_parameters = pos_beta_mixture_parameters.gather(1, mixture_components[...,None].repeat(*((1,1)+pf_beta_parameters.shape[-1:]))).squeeze(1)\n",
" betas = T.distributions.Beta(pos_beta_parameters[:, 1].exp()+1e-2, pos_beta_parameters[:, 2].exp()+1e-2)\n",
" samples = betas.sample()\n",
" \n",
" all_betas = T.distributions.Beta(pos_beta_mixture_parameters[...,1].exp()+1e-2, pos_beta_mixture_parameters[...,2].exp()+1e-2)\n",
" all_log_probs = all_betas.log_prob(samples.unsqueeze(1))\n",
" logPF += (all_log_probs + mixture_components_logits).logsumexp(1) - remaining.detach().log()\n",
" else:\n",
" samples = T.ones_like(samples)\n",
" \n",
" theta.scatter_(1, positions, (samples * remaining).unsqueeze(1))\n",
" remaining = remaining - samples * remaining\n",
" set_mask.scatter_(1, positions, T.ones(bs,1).to(device))\n",
" \n",
" return theta, logZ, logPF, logPB"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"n_topics = 3\n",
"n_vocab = 100\n",
"doc_size = 16"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"prior = T.distributions.Dirichlet(T.full((n_topics,), 0.1).to(device))\n",
"def log_reward(theta,context,tw):\n",
" theta_soft = theta*0.99 + 0.01/n_topics\n",
" lp = prior.log_prob(theta_soft)\n",
" lpw = theta_soft @ tw.softmax(1)\n",
" ll = (lpw.log() * context).sum(1)\n",
" return lp + ll"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"# generate a dataset\n",
"word_prior = T.distributions.Dirichlet(T.full((n_vocab,), 0.1).to(device))\n",
"topic_word = word_prior.sample((n_topics,))\n",
"\n",
"bs = 256\n",
"topics = prior.sample((bs,))\n",
"dists = topics @ topic_word\n",
"docs = T.distributions.Multinomial(probs=dists, total_count=doc_size).sample()\n",
"gt_topic_word = topic_word.log().clone()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"topic_word = T.nn.Parameter(word_prior.sample((n_topics,)).log()*0.01)\n",
"# topic_word.data = gt_topic_word\n",
"policy = Policy(n_classes=n_topics, \n",
" n_mixture_components=4, \n",
" n_hidden_layers=3, \n",
" hidden_dim=32, \n",
" context_dim=n_vocab).to(device)\n",
"optimizer = T.optim.Adam(policy.parameters(), 0.001)\n",
"optimizer_gen = T.optim.Adam([topic_word], 0.01)\n",
"losses, log_rewards, updates = [], [], []"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5339.19677734375 -75.00799560546875 0.0 -0.17898723483085632\n",
"tensor([0.3113, 0.3744, 0.3144], device='cuda:0')\n",
"3546.5011989450454 -73.70572082519531 0.0 -88.59672546386719\n",
"tensor([9.7947e-06, 6.7550e-01, 3.2449e-01], device='cuda:0')\n",
"11.263638963699341 -72.77050445556641 0.0 -77.75553894042969\n",
"tensor([0.0338, 0.6186, 0.3476], device='cuda:0')\n",
"1.232149149775505 -73.63213912963867 0.12 -73.81747436523438\n",
"tensor([0.2774, 0.3218, 0.4008], device='cuda:0')\n",
"1.095009415745735 -70.41562278747558 0.24 -70.29661560058594\n",
"tensor([0.3188, 0.2862, 0.3950], device='cuda:0')\n",
"1.1437681084871292 -67.94061233520507 0.14 -68.41439819335938\n",
"tensor([0.3710, 0.2747, 0.3543], device='cuda:0')\n",
"1.1356153529882431 -66.29721214294433 0.15 -66.96688842773438\n",
"tensor([0.3281, 0.2894, 0.3825], device='cuda:0')\n",
"1.1351548039913177 -64.84039974212646 0.18 -65.48225402832031\n",
"tensor([0.3407, 0.2649, 0.3944], device='cuda:0')\n",
"1.1267199742794036 -63.53774150848389 0.2 -64.29792022705078\n",
"tensor([0.3482, 0.2941, 0.3577], device='cuda:0')\n",
"1.1324942630529404 -62.38467308044434 0.25 -63.13263702392578\n",
"tensor([0.3262, 0.3246, 0.3492], device='cuda:0')\n",
"1.143118695616722 -61.31421634674072 0.27 -62.254608154296875\n",
"tensor([0.3030, 0.3183, 0.3786], device='cuda:0')\n",
"1.1150964844226836 -60.55814685821533 0.42 -61.5101318359375\n",
"tensor([0.3323, 0.2919, 0.3758], device='cuda:0')\n",
"1.023587293624878 -59.797543373107914 0.6 -60.83245086669922\n",
"tensor([0.3067, 0.3009, 0.3924], device='cuda:0')\n",
"0.9061188751459122 -59.317057075500486 0.72 -60.38682556152344\n",
"tensor([0.3014, 0.3380, 0.3606], device='cuda:0')\n",
"0.7718061292171479 -59.03282218933106 0.92 -60.226985931396484\n",
"tensor([0.2953, 0.3278, 0.3769], device='cuda:0')\n",
"0.6978682881593704 -58.97709060668945 0.94 -60.081329345703125\n",
"tensor([0.2654, 0.3753, 0.3594], device='cuda:0')\n",
"0.6522869244217873 -58.85392658233643 0.96 -60.108116149902344\n",
"tensor([0.3207, 0.3704, 0.3089], device='cuda:0')\n",
"0.6289197561144829 -58.743394165039064 0.97 -59.9983024597168\n",
"tensor([0.2902, 0.2981, 0.4117], device='cuda:0')\n",
"0.5816634127497673 -58.73764385223389 0.97 -60.01136779785156\n",
"tensor([0.3489, 0.3280, 0.3232], device='cuda:0')\n",
"0.5562542042136193 -58.72943668365478 0.99 -60.056365966796875\n",
"tensor([0.3079, 0.3579, 0.3342], device='cuda:0')\n",
"0.5859266641736031 -58.76437438964844 0.97 -60.106407165527344\n",
"tensor([0.2555, 0.3772, 0.3673], device='cuda:0')\n",
"0.6630006548762322 -58.77452781677246 0.95 -60.04123306274414\n",
"tensor([0.2941, 0.2906, 0.4153], device='cuda:0')\n",
"0.6043819969892502 -58.74245792388916 0.96 -60.04582595825195\n",
"tensor([0.3065, 0.3022, 0.3914], device='cuda:0')\n",
"0.6941645961999893 -58.71961921691894 0.96 -59.98759460449219\n",
"tensor([0.2827, 0.3007, 0.4166], device='cuda:0')\n",
"0.6233581843972206 -58.746786727905274 0.97 -59.98252868652344\n",
"tensor([0.2997, 0.3276, 0.3726], device='cuda:0')\n",
"0.6827544385194778 -58.758174934387206 0.98 -60.07716369628906\n",
"tensor([0.2958, 0.3441, 0.3601], device='cuda:0')\n",
"0.8040225741267204 -58.74992374420166 0.91 -59.965362548828125\n",
"tensor([0.2929, 0.3405, 0.3666], device='cuda:0')\n",
"0.9204131281375885 -58.74249607086182 0.72 -59.993736267089844\n",
"tensor([0.3245, 0.3548, 0.3207], device='cuda:0')\n",
"1.170855165719986 -58.77177585601807 0.11 -59.88347625732422\n",
"tensor([0.2422, 0.3131, 0.4447], device='cuda:0')\n",
"1.1648672425746918 -58.70951473236084 0.15 -59.92580032348633\n",
"tensor([0.2931, 0.3494, 0.3575], device='cuda:0')\n",
"1.187340236902237 -58.70164463043213 0.15 -59.771148681640625\n",
"tensor([0.2706, 0.3577, 0.3717], device='cuda:0')\n",
"1.2327655619382858 -58.72432807922363 0.09 -59.748844146728516\n",
"tensor([0.2577, 0.3286, 0.4137], device='cuda:0')\n",
"1.2280465269088745 -58.66496212005615 0.04 -59.663368225097656\n",
"tensor([0.2090, 0.4177, 0.3733], device='cuda:0')\n",
"1.2034409683942795 -58.570076065063475 0.04 -59.65102767944336\n",
"tensor([0.2559, 0.3828, 0.3613], device='cuda:0')\n",
"1.227398519515991 -58.53832061767578 0.07 -59.479557037353516\n",
"tensor([0.2413, 0.3840, 0.3747], device='cuda:0')\n",
"1.255662345290184 -58.47613731384277 0.05 -59.378055572509766\n",
"tensor([0.2325, 0.3726, 0.3949], device='cuda:0')\n",
"1.2991685396432877 -58.424017639160155 0.04 -59.340728759765625\n",
"tensor([0.2000, 0.3796, 0.4204], device='cuda:0')\n",
"1.2509387838840484 -58.38878238677979 0.03 -59.24190139770508\n",
"tensor([0.2263, 0.3132, 0.4606], device='cuda:0')\n",
"1.2423586374521256 -58.33663829803467 0.05 -59.16193389892578\n",
"tensor([0.2503, 0.3429, 0.4068], device='cuda:0')\n",
"1.3024506282806396 -58.06736526489258 0.06 -59.02488708496094\n",
"tensor([0.1837, 0.3821, 0.4342], device='cuda:0')\n",
"1.3143541789054871 -57.42053646087646 0.06 -58.57398986816406\n",
"tensor([0.1645, 0.3260, 0.5095], device='cuda:0')\n",
"1.2668221169710159 -56.48176433563233 0.07 -58.14472198486328\n",
"tensor([0.1534, 0.2907, 0.5559], device='cuda:0')\n",
"1.303649778366089 -55.632459526062014 0.08 -57.7420539855957\n",
"tensor([0.1045, 0.3489, 0.5466], device='cuda:0')\n",
"1.246590762734413 -54.75730804443359 0.12 -57.4212646484375\n",
"tensor([0.1030, 0.3304, 0.5666], device='cuda:0')\n",
"1.186555580496788 -54.08460330963135 0.36 -56.502723693847656\n",
"tensor([0.0891, 0.3491, 0.5618], device='cuda:0')\n",
"0.9469049417972565 -53.72010948181153 0.78 -56.472625732421875\n",
"tensor([0.1164, 0.3223, 0.5613], device='cuda:0')\n",
"0.6925678265094757 -53.612727661132816 0.96 -56.296485900878906\n",
"tensor([0.1294, 0.3005, 0.5701], device='cuda:0')\n",
"0.6352209240198136 -53.604561462402344 0.99 -56.410926818847656\n",
"tensor([0.1317, 0.2872, 0.5811], device='cuda:0')\n",
"0.7304302099347114 -53.61720794677734 0.98 -56.252532958984375\n",
"tensor([0.1220, 0.2601, 0.6179], device='cuda:0')\n",
"0.6374397739768028 -53.64296215057373 0.99 -56.23942565917969\n",
"tensor([0.1195, 0.2807, 0.5999], device='cuda:0')\n",
"0.6619596907496452 -53.63779563903809 0.95 -56.35626983642578\n",
"tensor([0.1156, 0.2950, 0.5895], device='cuda:0')\n",
"0.6234841993451119 -53.64442481994629 0.92 -56.36970138549805\n",
"tensor([0.1137, 0.2877, 0.5987], device='cuda:0')\n",
"0.5973330575227738 -53.64245052337647 0.98 -56.4075813293457\n",
"tensor([0.0905, 0.2973, 0.6122], device='cuda:0')\n",
"0.6158758586645127 -53.624692039489744 0.94 -56.541683197021484\n",
"tensor([0.0865, 0.3106, 0.6029], device='cuda:0')\n",
"0.5967774161696434 -53.61479591369629 0.99 -56.396446228027344\n",
"tensor([0.0901, 0.3083, 0.6016], device='cuda:0')\n",
"0.6080861777067185 -53.558166389465335 1.0 -56.43267059326172\n",
"tensor([0.0923, 0.3227, 0.5849], device='cuda:0')\n",
"0.7368822130560875 -53.54587490081787 0.89 -56.5095100402832\n",
"tensor([0.0728, 0.3296, 0.5977], device='cuda:0')\n",
"0.7293363624811172 -53.54059223175049 0.93 -56.49578857421875\n",
"tensor([0.0789, 0.3251, 0.5960], device='cuda:0')\n",
"0.7664073100686073 -53.53928524017334 0.93 -56.550113677978516\n",
"tensor([0.0654, 0.3441, 0.5905], device='cuda:0')\n",
"0.7412702116370201 -53.51655475616455 0.95 -56.53882598876953\n",
"tensor([0.0720, 0.3339, 0.5940], device='cuda:0')\n",
"0.7214934992790222 -53.52108829498291 0.96 -56.602134704589844\n",
"tensor([0.0603, 0.3622, 0.5775], device='cuda:0')\n",
"0.707744922041893 -53.53437286376953 0.95 -56.60513687133789\n",
"tensor([0.0540, 0.3462, 0.5998], device='cuda:0')\n",
"0.6582868728041649 -53.52054286956787 0.98 -56.70066833496094\n",
"tensor([0.0469, 0.3683, 0.5849], device='cuda:0')\n",
"0.8879157817363739 -53.52068420410156 0.84 -56.774166107177734\n",
"tensor([0.0524, 0.3551, 0.5924], device='cuda:0')\n",
"0.7335088667273522 -53.52858638763428 0.87 -56.66447448730469\n",
"tensor([0.0441, 0.3739, 0.5820], device='cuda:0')\n",
"0.4448272404074669 -53.54760578155518 1.0 -56.678627014160156\n",
"tensor([0.0637, 0.3510, 0.5853], device='cuda:0')\n",
"0.4459055018424988 -53.51938419342041 1.0 -56.800987243652344\n",
"tensor([0.0612, 0.3370, 0.6018], device='cuda:0')\n",
"0.374909031689167 -53.50309471130371 1.0 -56.85752868652344\n",
"tensor([0.0523, 0.3462, 0.6015], device='cuda:0')\n",
"0.33665544584393503 -53.45789794921875 1.0 -56.89958953857422\n",
"tensor([0.0584, 0.3590, 0.5826], device='cuda:0')\n",
"0.30595080524683 -53.47039882659912 1.0 -56.85301208496094\n",
"tensor([0.0536, 0.3424, 0.6040], device='cuda:0')\n",
"0.2681262767314911 -53.47688762664795 1.0 -57.00203323364258\n",
"tensor([0.0665, 0.3597, 0.5738], device='cuda:0')\n",
"0.2664834652841091 -53.497044258117675 1.0 -57.07732391357422\n",
"tensor([0.0537, 0.3591, 0.5872], device='cuda:0')\n",
"0.24812621906399726 -53.517969627380374 1.0 -57.02836990356445\n",
"tensor([0.0503, 0.3629, 0.5868], device='cuda:0')\n",
"0.25389285638928416 -53.499766464233396 1.0 -56.9744987487793\n",
"tensor([0.0542, 0.3453, 0.6005], device='cuda:0')\n",
"0.2338979324698448 -53.49336658477783 1.0 -57.014556884765625\n",
"tensor([0.0645, 0.3461, 0.5894], device='cuda:0')\n",
"0.23552647292613982 -53.5185506439209 1.0 -57.15602111816406\n",
"tensor([0.0587, 0.3539, 0.5875], device='cuda:0')\n",
"0.21728055641055108 -53.5498291015625 1.0 -57.239437103271484\n",
"tensor([0.0558, 0.3449, 0.5994], device='cuda:0')\n",
"0.20451952904462814 -53.5641007232666 1.0 -57.16626739501953\n",
"tensor([0.0536, 0.3534, 0.5930], device='cuda:0')\n",
"0.1960418738424778 -53.57102745056152 1.0 -57.36680603027344\n",
"tensor([0.0486, 0.3715, 0.5798], device='cuda:0')\n",
"0.20397837460041046 -53.57062450408935 1.0 -57.1917724609375\n",
"tensor([0.0527, 0.3559, 0.5914], device='cuda:0')\n",
"0.1794803101569414 -53.62214382171631 1.0 -57.260433197021484\n",
"tensor([0.0576, 0.3441, 0.5982], device='cuda:0')\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-33-12bae130d821>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mit\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m100000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mtheta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlogZ\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlogPF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlogPB\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msample_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdocs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mlogR\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlog_reward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtheta\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdocs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtopic_word\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-3-145ba1a263bc>\u001b[0m in \u001b[0;36msample_forward\u001b[0;34m(context, temperature, uniform_prob, max_steps)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_steps\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mpb_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpf_class_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpf_beta_parameters\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_flow\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpolicy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtheta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mset_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mremaining\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/data/anaconda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1108\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1111\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1112\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-2-858de4281a63>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, theta, set_mask, remaining, context)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtheta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mset_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mremaining\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mpb_logits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_classes\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muniform_pb\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1e9\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mset_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/data/anaconda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1108\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1111\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1112\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/data/anaconda/lib/python3.8/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 141\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 142\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/data/anaconda/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1108\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1111\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1112\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/data/anaconda/lib/python3.8/site-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"threshold = 1\n",
"for it in range(100000):\n",
"\n",
" theta, logZ, logPF, logPB = sample_forward(docs)\n",
" logR = log_reward(theta,docs,topic_word)\n",
" \n",
" optimizer.zero_grad()\n",
" \n",
" loss = ((logZ + logPF - logR - logPB)**2).mean()\n",
" loss.backward()\n",
" \n",
" optimizer.step()\n",
" \n",
" losses.append(loss.item())\n",
" log_rewards.append(logR.mean().item())\n",
" \n",
" update = loss.item() < threshold\n",
" if update:\n",
" optimizer_gen.zero_grad()\n",
" (-log_reward(theta,docs,topic_word).mean() - word_prior.log_prob(topic_word.softmax(1)).sum() / bs).backward()\n",
" topic_word.grad.nan_to_num_()\n",
" optimizer_gen.step()\n",
" updates.append(1 if update else 0)\n",
" \n",
" if it%100==0: \n",
" print(np.array(losses[-100:]).mean(), \n",
" np.array(log_rewards[-100:]).mean(), \n",
" np.array(updates[-100:]).mean(), \n",
" logZ.mean().item())\n",
" print(theta.mean(0))"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x216 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x216 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x216 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAADSCAYAAACxZoAXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3de5wU9Znv8c8zw3C/ugMrCgLGQQXjapyIe3SNG8mCrLfdXLwkUVc0N03Yk6wbXCNrMFnN8XWSkFfMSYwxxmy8EHNiMIdIokkOJxpHZlYEuQiIDhdBBh2ug8PM8Jw/qrqtabp7emaqp5vi+3695jXdVdXVT/c89cyvfvWrKnN3REQkWSpKHYCIiMRPxV1EJIFU3EVEEkjFXUQkgVTcRUQSSMVdRCSBVNyLwMweNLOv9fVrS8nM9pnZiaWOQyQXM1tlZheUOo6+ouJ+lDOziWbmZtavN+tx96HuvjGuuCQeZva6mU0vdRy9FeboSb1Zh7tPdfc/xhRS2VNxl17p7T+F3r5eeudo+P6P1hxNXHE3sy+b2VYz22tmr5jZheH0s83sz2a2y8y2mdl3zax/5HVuZp8zs/Xha+80s/eY2XNmtsfMFqaWN7MLzGyLmf2bme0MW0cfzxPTxWa2PHzv58zs9Mi8M83sv8L3fAwYmGc915nZs2Hsu81sberzhfOPM7NFZva2mW0wsxsj8842s/rws7xpZt8MZy0Nf+8Ku1b+Olz+ejNbY2bNZrbEzCZkfFc3mdl6YH1k2knh4xFm9pCZNZlZo5l9xcwqMj7Dt8zsLeCO/H9R6Skz+ylwAvBk+Lf918ie2mwz2wT8PpXPGa9Nt/jNrMLM5prZq2b2VrgtHJPjPfNuG13kxklm9n/D3N4Zbg+YWSpHXwo/xxXh9Hzb1ethLVgB7DezfhmfaYCZfdvM3gh/vm1mAzI+w5fNbDvw4zj+Hn3O3RPzA5wMbAaOC59PBN4TPj4LOAfoF05fA/xz5LUO/AoYDkwFWoFngBOBEcBq4Npw2QuAduCbwADgA8B+4ORw/oPA18LHZwI7gGlAJXAt8Hr4uv5AI/DfgSrgI0Bb6rVZPt914fumlr8C2A0cE85fCnyP4B/EGUAT8MFw3p+BT4aPhwLnRL4jB/pF3ucyYANwavh9fQV4LuO7+h1wDDAoMu2k8PFD4Xc5LFz/OmB2xmf4fLjuQaXOmyT/hLk2PfI89fd+CBgCDArzeUuu1wFzgOeBcWHe/gB4JMf7dbVt5MuNR4DbCBqdA4HzMnLupMjznNtVJP7lwPhIjkY/0/zwM40BRgPPAXdmfIZvhJ/hiMzRkgcQcyKfFP7BpwNVXSz7z8AvM5Ln3MjzBuDLkef/E/h2xh9/SGT+QuD28PGDvFvc/1cqaSLLvhIm/fnAG4BF5j1H/uKeufwLwCfDJO4AhkXm3QU8GD5eCnwVqM5Y50QOL+6/SW1w4fMKoAWYEPmuPpixHg+//0rgIDAlMu/TwB8jn2FTqXPlaPkhd3E/MTLtAvIX9zXAhZF5YwkaIf2yvF/ObaOA3HgIuA8Yl2W9mcU953YVif/6PJ/pVWBWZN4M4PXIZzgIDCz13683P4nqlnH3DQRF+w5gh5k9ambHAZjZZDP7tZltN7M9wH8A1RmreDPy+ECW50Mjz5vdfX/keSNwXJawJgBfCncdd5nZLoJCfFz4s9XDjIqsJ59sy6fW9ba7782Yd3z4eDYwGVhrZsvM7OI87zEBWBCJ923AIuuCYA8pm2qCvYro54jGke+10ne68zeYAPwykg9rCBoSf5lj+VzbRle58a8EefaCBSNbru8iplzbVUq+z3hcljiir21y93fyvL7sJaq4A7j7w+5+HsEf3wl2rSD4T78WqHH34cC/ESRST40ysyGR5ycQtKozbQa+7u4jIz+D3f0RYBtwvJlZxnryybb8G+HPMWY2LGPeVgB3X+/uVxHshn4DeDyMP9tlQTcDn86IeZC7PxdZJtflRHcStOomRKal4+jitRK/XN91dPp+YHDqiZlVEnRVpGwGLsrIh4HuHv2bRuXaNvLmhrtvd/cb3f04ghb99yz3CJl821VXn50wnsw4otvvEZ+jiSruZnaymX0wPDDyDkFr+1A4exiwB9hnZqcAn43hLb9qZv3N7G+Ai4GfZ1nmh8BnzGyaBYaY2d+HRfjPBLuwXzCzKjP7R+DsLt5zTGT5jxL0iy92980EXTp3mdnA8ODSbOA/AczsE2Y22t0PAbvCdR0i6Jc/RHBsIeX7wK1mNjV87Yjwvbrk7h0Eu+FfN7NhFhyI/WIqDulzb9L5b5vNOmBgmJdVBMdYBkTmf5/g7zkBwMxGm9llXazzsG2jq9wws4+a2bjw9c0EBTa1/WZ+jnzbVSEeAb4SfpZqYB4Jy9FEFXeChLyboIWwnaAQ3hrO+xfgamAvQWI81sv32k6QgG8APwM+4+5rMxdy93rgRuC74fIbCPqdcfeDwD+Gz98mOED6v7t43zqghuAzfh34iLu/Fc67iqBP9Q3gl8C/u/vT4byZwCoz2wcsAK509wPu3hKu59lw9/Ycd/8lQev+0bAL62XgooK/meBg6X5gI/An4GHggW68XuJzF0ER22Vm/5JtAXffDXwOuJ+gFb0fiI6eWQAsAn5rZnsJDkROy/Oe+baNfLnxfqAuzNFFwBx/99yJO4CfhJ/jY/m2qwJ9DagHVgArgf8KpyWGde6+lUJYcJbbf7r7uK6Wjfl9rwNuCLudRMpOqbYNOVzSWu4iIoKKu4hIIqlbRkQkgdRyFxFJIBV3EZEEKtnVzqqrq33ixImlentJuIaGhp3uPrrrJeOn3JZiKjS3S1bcJ06cSH19faneXhLOzLq6jEPRKLelmArNbXXLiIgkkIq7iEgCqbiLiCSQinu5W7EQvnUa3DEy+L1iYakjEomHcruojsh7Ax41ViyEJ78AbQeC57s3B88BTv9Y6eIS6S3ldtGp5V7Onpn/bvKntB0IposcyZTbRafiXs52b+nedJEjhXK76FTcy9mIHFdNzTVd5Eih3C46FfdyduE8DtC/06SOykFw4bwSBSQSkwvn0eKdc5sq5XacVNzLWMOID/Hlgzew5VA1h9zYcqiabw+6SQec5Ih3+dLjmNvWObefPfXfldsx0miZMjb38ZdYf+g8Fh1898ZL4zsG8aUSxiTSWw2NzSzfspvldM7timXw87OaOWvCqBJGlxxquZex199uOWzaG7sO0NDYXIJoROKx4Ol1WacfcrjzyVV9HE1yqbiXqYfrNtHWcfiNVDocrnugTgVejljHDOmfc97+gx19GEmyqbiXqdt/tTLnvL2tHTlbPyLl7onlb+Sc92rTPh6u29SH0SRXQcXdzGaa2StmtsHM5maZf4KZ/cHMXjSzFWY2K/5Qjx4Njc10HMo+z4Azxo1gzvTJfRpTUim3+9Y/P/pi3vmHHO5ZsraPokm2Lou7mVUC9wIXAVOAq8xsSsZiXwEWuvuZwJXA9+IO9Ggy9xcrcs677IzjeOLm83TQKQbK7b6Xr9UOUFVpVA8doG7HGBTScj8b2ODuG939IPAocFnGMg4MDx+PAPL/BSWnhsZm1u/Yd9h0C38/sfwNzvvG75X88VBu96F8rfZUfnd0OOt37MvbwJHCFFLcjwc2R55vCadF3QF8wsy2AIuBz8cS3VEo12iB0cPePQi1pfkAN/xkmQp878WW22b2KTOrN7P6pqamYsR6RLt78Zq8rXYH+lUYVf2CkrR994Gcy0ph4jqgehXwoLuPA2YBPzWzw9atDSC/hsZmVm3bk3Xejr0HGT9qEJUGAyoraG5p00HVvlFQbrv7fe5e6+61o0eX5NatZe0HSzceNq3SOj+ef9lpjBk2gEqDj0+b0IfRJVMhxX0rMD7yfFw4LWo2sBDA3f8MDASqM1ekDSC/O59clXX4I8CowVV89oKTeO/xIxgzfADDBlQy87SxfRxh4sSW25JfZlafMX4kN/7NiQyuqmBwVSV3Xv5ennp5G5ubD9Dh8Fj9Zu2Z9lIhxX0ZUGNmk8ysP8FBpUUZy2wCLgQws1MJNgA1zbvLLOes5pY2nnp5G8u37GZz8wH2tnbw1Mvb+jC4RFJul8j+1nZWb9tDS9shJv/lUBbWb2bbrgPUjBnK4Kpgz1QnNPVOl8Xd3duBm4ElwBqCkQOrzGy+mV0aLvYl4EYzewl4BLjO3bM3QSWrhsZm9r/TRlWlYcDgqs5/msvPOI450ydTM3oIg6sqqRk9RMMhe0m53XfOGDei0/ONTfuYedpYzq+pBjOWb97F+qb9jB0xkMnHBsevX23ar9Z7LxR0bRl3X0xwMCk6bV7k8Wrg3HhDO7rM/cUK1jftTz8/GHbPDK6qYEBVJWdP+gsAdu4/SEtbBzv3HyxJnEmj3C6uuxev4Yf/byPHDO1PhQXj2CE40/rHz77G9t0HGDm4PzVjhjKkf9DV+ONnX6PSYG9rOwueXsdDs6eV9kMcoXSGahl4uG7TYcMfKyssbL0bzS1t3P7ESq574AWaW9roV2GdDqg2NDZzzY90SQIpPz9YupEOh6a9B9OFHYJjSNt3B92Lm5sPsH33AW6/ZCpPvbyN9Tv20eEwbECl9k57QcW9xBoam7ntl50vNVBVYbS2H2JAVSVfuXgKlRa0dPa2tjNsQCWTqod0Okv1zl+vZun6ndz569Wl+AgiOWXrvxo5uB+tbR0M6t8vXYD2tnZww0+WMfO0sQwbEHQovGfMMJ2s1wsq7iV255OrOm0A/SqM2edNYtTgKm6ZcQonHzuMwf0rgWC42LEjBrF+xz6GD6pKJ/7+1vZOv0XKwcN1m6jKUmF2tbTT0naIHXtbOX38SM4YP5JhAyrTgwYevP5szq+p5vaLM08Wlu5QcS+xtzL6ztsPOc+seZMX5/0dAFf84M/sbQ2ulNfhMKR/JefXVDNn+uR0d0zKkPCfgEipNTQ285UnVtJ26N2zTysj1abSoGbMUD5WO57hA/tx66wp6bw+a8IoHpo9Ta32XtLNOkpsx553Dpu2fU8rEFxAqT3SUTlqcBW3XzIVCK6JvedAG8u37OaMcSPSG4ZIOZj7+Eud+tgBOg4F49v3v9PG9j2t/NO5k3jq5W0sXb+TlVt3c/+171dBj5Fa7iVWWXH4n+DCU8cAcMuMU9KtngGVlk7+BU+vY+n6nWAW7L5eMlUtHSkr0RvNfPr8Ezlj3AjOGD+S2y+ews79B9nb2s49S9YyZ/rkdJfMdQ+8kHVQgAYM9Ixa7iXU0NjM8SMHdhoCCfDMmjdpaGxmYf3mdH/8wQ7nk/fXAc41fz0RIL0LK1JOojeaMWD1tj3cfsnUdK7eMuMU7lq8hurwph3vGT2U5Vt25xz6mG7MgIZFdoOKewkteHod65v2dxr/C8FB0wVPr2P55l2dlm9pC/ref1a3iZVfndGXoYoULPN67EvX72TPgTaeuPk8Ghqbeerlbbxn9BCWb9nNgqfXcfslU4OzUc3SXYsP123iniVruaJ2PHsOtHHG+JHqduwmFfcSebhuEy9u2sXgqgpa2oI7cwyuqmDyscM5Z9Ix/KyukXGjBgWtm/AMvpRjhw8oVdgiXQpa5qvZ19qR3vPcuf8g1/yoLn2cqGbMUEYNrmLmaWM5a8Ionrj5vE7ruGfJWppb2rj/T6/Rfsg5v6Zae6ndpD73EmhobGber15mb2s7lRUV6f7In95wDk/cdC6P1W9mb2sH+1vbeeLm87j94mAkwX/8w3s5v6aauz/yV1n7IdU3KeXg6mkncOYJozoN8X1738FOx4lwp7mljYXLDr+lXkNjM9VD+jNsQD9uOG+SBgv0kFruJXDnk6toP+RUGtw661SunnYCDY3NLHh6HXOmT+aWGadwz5K13DLjFID00DAINhyAa35Ul97dHT6oijnTJ6tvUsrGnOmTeXHTLvaG514cP3IgQ8KTk+ZMn5w+4e6t/Qc5c/5vuWXGKencTnVXnl9TzdxZp5bmAySAWu59rKGxmVfDA6iD+1dy8rHDgHcPGs19/KV0YT/52GGdWuLRlvmc6ZPTF11aun5n+h+DWjlSDs6aMIoHrz87vVd690f+iuGDqli+eVfQzx7uje5qOUhzS1u6n76hsZk977Snz8B+uG4TZ87/rW6a3QMq7n1s7i9WsLe1nQoLTrlOXR8mVZi372lNJ3uq4KeWiT5PteZTG0n05A9A3TNScqm+9CduOpezJoxiytjh9Kswpowdns7VW2dNSZ+NDaQHEqTOwE71veum2d2n4t7HUrcPG9ivslMrO5XsH592Av0qjCtqx6cL/szTxnLNj+rSl0jNPDs1c4x75j8FkWJK5eLDdZvyNip+VtdI+yHnZ3WN6WlXTzuBF+f9XbpLJnPv85YZp3Qq/lI49bn3sVtnTUl3u6QSOmr1tj20H3JWb9vDh6Yey54DbXzt16tpaetgzzvtPHFTcPXZVJ87HN6/ntow1D0jfSHVmHju1bfSZ1RnO+Zz7IhB7N2xj2NHDMq5rujeJwTFP3VM6pof1encjm5Qce9jqWTNJVWQZ542lht+sozmlrZ3Z0buEZGvgGduICLFFBw8bWZva0fey/Te/eHT08eGukuDBbpP3TJlJlWYn3p5G80tbQwbUEnNmKHBqdvhdWWiy+VqxWhYpPSVV7bvZV94cbuRg/vnzMmeHBNK5XG0S1IKo5Z7H4sOecy3exltmedbLtf61NKRvnLPkrXpMe27Wg52meO5cjP6utRyqZOeMpeVrqm494Fo0s79xQrW79jHtt3vdNpNzdwICu1aybWhqN9d+sLDdZtoCceyD+hXwa2zphyWk5nFPltuNjQ2H9YNuXT9Ts4YP7LToAL1uRdOxb0PRJM9NVpma3NLp2TuaaskVxFXv7v0hbsWr6Y1vEjYtEnHcPW0E9LnbkRb4NFiny03Fzy9juaWNkYNruqUy6linm8AgWSn4t4HogX4le17uWfJWqqHDmD9jn2HJXN3qYhLKaVGwFRYMAgADs/JXC31XK35VMtce6K9Y+7Z7nJYfLW1tV5fX1+S9y4Hqave5RoSKb1jZg3uXluK9z6acjvanXJ+TXWXDY1UUU/1pRfymmyvP5q7ZwrNbY2WKbJco1ZSo2GeenlbUd9HpFhShfaWGad0GsmSLxczbzSTryWebT06Qa9w6pYpsmh/Y+riXnOmTy5oV7Wn75PtIJZI3HIdzM83UqvQUWC51qPumcKpuBdRQ2NzpxsNpJI1eiVHID0KoDfDFzOTPrUu3ZtSiiVXoY3rBLts69ExpsIVVNzNbCawAKgE7nf3u7Ms8zHgDsCBl9z96hjjPKJk61eMHjTa8057uogDnVr20d/dke0g1sqtu2luact66zJRXvdWrkIbVwFWIe+dLou7mVUC9wIfArYAy8xskbuvjixTA9wKnOvuzWY2plgBHwlSrebo3WaynaCRbchXXMl81oRR3H/t+3t8unfSKa/Lk7oT41NIy/1sYIO7bwQws0eBy4DVkWVuBO5192YAd98Rd6BHknQL/UAb68PLlU44ZnCnM+2iRbxYrRO1fPJSXpchnVkdn0JGyxwPbI483xJOi5oMTDazZ83s+XB39zBm9ikzqzez+qampp5FfARIX2v9kqmMGlwVnKhUwOiAfDQaJnax5TUcPbndld7mqW44E5+4hkL2A2qAC4CrgB+a2cjMhdz9Pnevdffa0aNHx/TW5SvVNXJ+TTW3Xzwl74W+uqIhYCVRUF7D0ZXbhQx17GmednVBPClcId0yW4HxkefjwmlRW4A6d28DXjOzdQQbxbJYojyCxdU1oiFgsVNe91ChQx2ltAppuS8Dasxskpn1B64EFmUs8wRB6wYzqybYnd0YY5xHhOgdaS6/91ku/+6fYutGUYsmdsrrHsrsOom25HWrx/LRZXF393bgZmAJsAZY6O6rzGy+mV0aLrYEeMvMVgN/AG5x97eKFXS5SrVo7lmyluWbd7F8y+5edaOon714lNfdk62Apxoa2bpi1I1YegWNc3f3xcDijGnzIo8d+GL4c9SaedpYVm7dzRW143n+tbfBvVe7pxo5UFzK68J1tytG3TOlpzNUY5S6XszqbXvS9zrtDW0gUi66e9ZpoceaNK69eFTcYxR3MdY4dSkXxcpF7Z0Wj64KGaNCD3qqL12OdHHlsMa1F4+KewnoYJMc6eLK4d6OAlNDKTcV9xJQa0WONJlFtJAc7ovCq4ZSbupzLwH1pcuRppD7oHb1mmLQoIPcVNxFpEs9KaJ9UXjVUMpNxV1EutSTIqrCW1rqcxcRSSAVdxFJFI2gCai4i0iiaARNQH3uIpIoGkETUHEXkUTRgdyAumXKgPoIRSRuKu5lQH2EIhI3dcuUAfURikjcVNzLgPoIRSRu6pYREUkgFXcRkQRScRcRSSAVdxGRBFJxFxFJIBV3EZEEUnEXEUmggoq7mc00s1fMbIOZzc2z3IfNzM2sNr4QRYpHuS1J1WVxN7NK4F7gImAKcJWZTcmy3DBgDlAXd5AixaDcliQrpOV+NrDB3Te6+0HgUeCyLMvdCXwDeCfG+ESKSbktiVVIcT8e2Bx5viWclmZm7wPGu/v/iTE2kWJTbkti9fqAqplVAN8EvlTAsp8ys3ozq29qaurtW4sUlXJbjmSFFPetwPjI83HhtJRhwGnAH83sdeAcYFG2A0/ufp+717p77ejRo3setUg8lNuSWIUU92VAjZlNMrP+wJXAotRMd9/t7tXuPtHdJwLPA5e6e31RIhaJj3JbEqvL4u7u7cDNwBJgDbDQ3VeZ2Xwzu7TYAYoUi3Jbkqyg67m7+2Jgcca0eTmWvaD3YYn0DeW2JJXOUBURSSAVdxGRBFJxFxFJIBV3EZEEUnEXEUkgFXcRkQRScRcRSSAVdxGRBFJxFxFJIBV3EZEEUnEXEUkgFXcRkQRScRcRSSAVdxGRBFJxFxFJIBV3EZEEUnEXEUkgFXcRkQRScRcRSSAVdxGRBFJxFxFJIBV3EZEEUnEXEUkgFXcRkQRScRcRSaCCiruZzTSzV8xsg5nNzTL/i2a22sxWmNkzZjYh/lBF4qW8liTrsribWSVwL3ARMAW4ysymZCz2IlDr7qcDjwP/I+5AReKkvJakK6Tlfjawwd03uvtB4FHgsugC7v4Hd28Jnz4PjIs3TJHYKa8l0Qop7scDmyPPt4TTcpkN/CbbDDP7lJnVm1l9U1NT4VGKxC+2vAbltpSfWA+omtkngFrgnmzz3f0+d69199rRo0fH+dYiRdNVXoNyW8pPvwKW2QqMjzwfF07rxMymA7cBH3D31njCEyka5bUkWiEt92VAjZlNMrP+wJXAougCZnYm8APgUnffEX+YIrFTXkuidVnc3b0duBlYAqwBFrr7KjObb2aXhovdAwwFfm5my81sUY7ViZQF5bUkXSHdMrj7YmBxxrR5kcfTY45LpOiU15JkOkNVRCSBVNxFRBJIxV1EJIFU3EVEEkjFXUQkgVTcRUQSSMVdRCSBVNxFRBJIxV1EJIFU3EVEEkjFXUQkgVTcRUQSSMVdRCSBVNxFRBJIxV1EJIFU3EVEEkjFXUQkgVTcRUQSSMVdRCSBVNxFRBJIxV1EJIFU3EVEEkjFXUQkgVTcRUQSSMVdRCSBCiruZjbTzF4xsw1mNjfL/AFm9lg4v87MJvY4ohUL4VunwR0jg98rFvZ4VZJQMeaIclvKRsz50WVxN7NK4F7gImAKcJWZTclYbDbQ7O4nAd8CvtGjaFYshCe/ALs3Ax78fvIL2gjkXTHmiHJbykYR8qOQlvvZwAZ33+juB4FHgcsylrkM+En4+HHgQjOzbkfzzHxoO9B5WtuBYLoIxJ0jym0pD0XIj0KK+/HA5sjzLeG0rMu4ezuwG/iLzBWZ2afMrN7M6puamg5/p91bskeQa7ocfeLNEeW2lIci5EefHlB19/vcvdbda0ePHn34AiPGZX9hruly9CnTHFFuS68UIT8KKe5bgfGR5+PCaVmXMbN+wAjgrW5Hc+E8qBrUeVrVoGC6CMSdI8ptKQ9FyI9CivsyoMbMJplZf+BKYFHGMouAa8PHHwF+7+7e7WhO/xhc8h0YMR6w4Pcl3wmmi0DcOaLclvJQhPywQvLUzGYB3wYqgQfc/etmNh+od/dFZjYQ+ClwJvA2cKW7b8y3ztraWq+vr+9x4CL5mFmDu9cWsJxyW44oheZ2v0JW5u6LgcUZ0+ZFHr8DfLS7QYqUmnJbkkpnqIqIJJCKu4hIAqm4i4gkUEEHVIvyxmZNQGOeRaqBnX0UTj6Ko7NyiQPyxzLB3bMMOC8+5Xa3KY7OuoqjoNwuWXHvipnVF3JEWHEcnXFAecXSHeUSt+JIdhzqlhERSSAVdxGRBCrn4n5fqQMIKY7OyiUOKK9YuqNc4lYcnSUqjrLtcxcRkZ4r55a7iIj0UJ8X997c1szMbg2nv2JmM4ocxxfNbLWZrTCzZ8xsQmReh5ktD38yLzRVjFiuM7OmyHveEJl3rZmtD3+uzXxtzHF8KxLDOjPbFZkX23diZg+Y2Q4zeznHfDOz74RxrjCz90XmxfZ9dJdyu9txKK87z483r929z34ILs70KnAi0B94CZiSsczngO+Hj68EHgsfTwmXHwBMCtdTWcQ4/hYYHD7+bCqO8Pm+Pv5OrgO+m+W1xwAbw9+jwsejihVHxvKfJ7jQVjG+k/OB9wEv55g/C/gNYMA5QF3c34dyW3l9pOd1X7fce3Nbs8uAR9291d1fAzaE6ytKHO7+B3dvCZ8+T3Ct72Io5DvJZQbwO3d/292bgd8BM/sojquAR3r4Xnm5+1KCKzDmchnwkAeeB0aa2Vji/T66S7ndzTjyUF7HkNd9Xdx7c1uzQl4bZxxRswn+o6YMtOCWas+b2eU9jKG7sXw43FV73MxSN5goyXcS7sZPAn4fmRznd9KVXLHG+X3EFVPWZY6C3FZed1+seV3QJX+PZmb2CaAW+EBk8gR332pmJwK/N7OV7v5qEcN4EnjE3VvN7NMErb8PFvH9unIl8Li7d0Sm9fV3Ir1UBrmtvC6ivm659+a2ZoW8Ns44MLPpwG3Ape7empru7lvD3xuBPxLcyKGnuozF3d+KvP/9wFnd+RxxxRFxJRm7rjF/J2VP2V0AAAFGSURBVF3JFWuc30dcMWVd5ijIbeV198Wb13EdLCjwgEI/goMBk3j34MbUjGVuovNBp4Xh46l0Pui0kZ4fdCokjjMJDsTUZEwfBQwIH1cD68lzgCamWMZGHv8D8Ly/e6DltTCmUeHjY4oVR7jcKcDrhOdIFOM7CdczkdwHnv6ezgeeXoj7+1Bu9+7vqLwufV4XPemzfIBZwLowuW4Lp80naEEADAR+TnBQ6QXgxMhrbwtf9wpwUZHjeBp4E1ge/iwKp/83YGWYJCuB2X3wndwFrArf8w/AKZHXXh9+VxuAfypmHOHzO4C7M14X63dC0HraBrQR9C/OBj4DfCacb8C9YZwrgdpifB/K7V7/HZXXJcxrnaEqIpJAOkNVRCSBVNxFRBJIxV1EJIFU3EVEEkjFXUQkgVTcRUQSSMVdRCSBVNxFRBLo/wPW8jM4mLRUVwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x216 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x216 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x216 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x216 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x216 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x216 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x216 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"with T.no_grad():\n",
" for _ in range(10):\n",
"\n",
" i = np.random.randint(bs)\n",
" doc_ = docs[i:i+1].repeat(bs,1)\n",
"\n",
" pt.figure(figsize=(6,3))\n",
" pt.subplot(121)\n",
" pt.title('sampled posterior')\n",
"\n",
" theta, logZ, logPF, logPB = sample_forward(doc_)\n",
" logR = log_reward(theta,doc_,topic_word)\n",
"\n",
" pt.scatter((theta[:,0]+theta[:,1]/2).cpu(), (theta[:,1]*np.sqrt(3)/2).cpu(),s=2)\n",
" pt.scatter([0,0.5,1],[0,np.sqrt(3)/2,0])\n",
"\n",
" pt.subplot(122)\n",
" pt.title('true posterior')\n",
"\n",
" # importance sampling\n",
" rands = T.distributions.Dirichlet(T.full((n_topics,), 1.).to(device)).sample((bs*100,))\n",
" logR = log_reward(rands,doc_.repeat(100,1),topic_word)\n",
" rands = rands[T.distributions.Categorical(logits=logR).sample((bs,))]\n",
" pt.scatter((rands[:,0]+rands[:,1]/2).cpu(), (rands[:,1]*np.sqrt(3)/2).cpu(),s=2)\n",
"\n",
" pt.scatter([0,0.5,1],[0,np.sqrt(3)/2,0])\n",
" pt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1080x216 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"61 43 33 29 96 40 2 6 99 77\n",
"1\n",
"33 2 29 6 99 96 43 47 88 77\n",
"2\n",
"1 58 57 65 5 77 86 50 32 52\n",
"\n",
"0\n",
"2 33 43 6 47 99 29 61 96 88\n",
"1\n",
"65 5 50 52 19 43 57 70 61 66\n",
"2\n",
"1 58 77 86 57 32 56 35 20 78\n",
"\n"
]
}
],
"source": [
"with T.no_grad():\n",
" pt.figure(figsize=(15,3), facecolor='white')\n",
"\n",
" pt.subplot(211)\n",
" pt.title('Inferred with GFN')\n",
" pt.imshow(topic_word.log_softmax(1).cpu().numpy(),vmax=-1,vmin=-5,cmap='Blues')\n",
" pt.xticks([]);pt.yticks([])\n",
" pt.xlabel('vocab');pt.ylabel('topic')\n",
"\n",
" pt.subplot(212)\n",
" pt.title('Ground truth')\n",
" pt.imshow(gt_topic_word.log_softmax(1).cpu().numpy(),vmax=-1,vmin=-5,cmap='Blues')\n",
" pt.xticks([]);pt.yticks([])\n",
" pt.xlabel('vocab');pt.ylabel('topic')\n",
" pt.show()\n",
" \n",
" for mat in [ topic_word.softmax(1), gt_topic_word.softmax(1) ]:\n",
" for i in range(n_topics):\n",
" print(i)\n",
" print(*mat[i].topk(10).indices.cpu().numpy())\n",
"# for v,n in zip(*mat[i].topk(10)):\n",
"# print(n.item(),v.item())\n",
"\n",
" print()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment