Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save crazysal/67a31661865a619e334409f5a2af5c27 to your computer and use it in GitHub Desktop.
Save crazysal/67a31661865a619e334409f5a2af5c27 to your computer and use it in GitHub Desktop.
Variational Recurrent Network tiny Shakespeare
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Variational Recurrent Network (VRNN)\n",
"\n",
"Implementation based on Chung's *A Recurrent Latent Variable Model for Sequential Data* [arXiv:1506.02216v6].\n",
"\n",
"### 1. Network design\n",
"\n",
"There are three types of layers: input (x), hidden(h) and latent(z). We can compare VRNN sided by side with RNN to see how it works in generation phase.\n",
"\n",
"- RNN: $h_o + x_o -> h_1 + x_1 -> h_2 + x_2 -> ...$\n",
"- VRNN: with $ h_o \\left\\{\n",
"\\begin{array}{ll}\n",
" h_o -> z_1 \\\\\n",
" z_1 + h_o -> x_1\\\\\n",
" z_1 + x_1 + h_o -> h_1 \\\\\n",
"\\end{array} \n",
"\\right .$ \n",
"with $ h_1 \\left\\{\n",
"\\begin{array}{ll}\n",
" h_1 -> z_2 \\\\\n",
" z_2 + h_1 -> x_2\\\\\n",
" z_2 + x_2 + h_1 -> h_2 \\\\\n",
"\\end{array} \n",
"\\right .$\n",
"\n",
"It is clearer to see how it works in the code blocks below. This loop is used to generate new text when the network is properly trained. x is wanted output, h is deterministic hidden state, and z is latent state (stochastic hidden state). Both h and z are changing with repect to time.\n",
"\n",
"### 2. Training\n",
"\n",
"The VRNN above contains three components, a latent layer genreator $h_o -> z_1$, a decoder net to get $x_1$, and a recurrent net to get $h_1$ for the next cycle.\n",
"\n",
"The training objective is to make sure $x_0$ is realistic. To do that, an encoder layer is added to transform $x_1 + h_0 -> z_1$. Then the decoder should transform $z_1 + h_o -> x_1$ correctly. This implies a cross-entropy loss in the \"tiny shakespear\" or MSE in image reconstruction.\n",
"\n",
"Another loose end is $h_o -> z_1$. Statistically, $x_1 + h_0 -> z_1$ should be the same as $h_o -> z_1$, if $x_1$ is sampled randomly. This constraint is formularize as a KL divergence between the two.\n",
"\n",
">#### KL Divergence of Multivariate Normal Distribution\n",
">![](https://wikimedia.org/api/rest_v1/media/math/render/svg/8dad333d8c5fc46358036ced5ab8e5d22bae708c)\n",
"\n",
"Now putting everything together for one training cycle.\n",
"\n",
"$\\left\\{\n",
"\\begin{array}{ll}\n",
" h_o -> z_{1,prior} \\\\\n",
" x_1 + h_o -> z_{1,infer}\\\\\n",
" z_1 <- sampling N(z_{1,infer})\\\\\n",
" z_1 + h_o -> x_{1,reconstruct}\\\\\n",
" z_1 + x_1 + h_o -> h_1 \\\\\n",
"\\end{array} \n",
"\\right . $\n",
"=>\n",
"$\n",
"\\left\\{\n",
"\\begin{array}{ll}\n",
" loss\\_latent = DL(z_{1,infer} | z_{1,prior}) \\\\\n",
" loss\\_reconstruct = x_1 - x_{1,reconstruct} \\\\\n",
"\\end{array} \n",
"\\right .\n",
"$\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"\"\\x02$;P\\x17u\\\\F<{e\\x0f\\x03G\\x1a`\\x16Be~-D\\rV\\x121\\x00\\x10\\x1a5\\x1c!\\x10\\x0f'kmv!p`n=.\\x0e_?\\x01C\\x08r\\x0cM3E]d\\x05\\x1aD*qt\\x08\\x13?xJ7\\x1e\\x0bN\\x121\\x01&F& #CJ\\x08/GOq\\x03\\x1bVQy+~\\x128O9vf\""
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"from torch import nn, optim\n",
"from torch.autograd import Variable\n",
"\n",
"class VRNNCell(nn.Module):\n",
" def __init__(self):\n",
" super(VRNNCell,self).__init__()\n",
" self.phi_x = nn.Sequential(nn.Embedding(128,64), nn.Linear(64,64), nn.ELU())\n",
" self.encoder = nn.Linear(128,64*2) # output hyperparameters\n",
" self.phi_z = nn.Sequential(nn.Linear(64,64), nn.ELU())\n",
" self.decoder = nn.Linear(128,128) # logits\n",
" self.prior = nn.Linear(64,64*2) # output hyperparameters\n",
" self.rnn = nn.GRUCell(128,64)\n",
" def forward(self, x, hidden):\n",
" x = self.phi_x(x)\n",
" # 1. h => z\n",
" z_prior = self.prior(hidden)\n",
" # 2. x + h => z\n",
" z_infer = self.encoder(torch.cat([x,hidden], dim=1))\n",
" # sampling\n",
" z = Variable(torch.randn(x.size(0),64))*z_infer[:,64:].exp()+z_infer[:,:64]\n",
" z = self.phi_z(z)\n",
" # 3. h + z => x\n",
" x_out = self.decoder(torch.cat([hidden, z], dim=1))\n",
" # 4. x + z => h\n",
" hidden_next = self.rnn(torch.cat([x,z], dim=1),hidden)\n",
" return x_out, hidden_next, z_prior, z_infer\n",
" def calculate_loss(self, x, hidden):\n",
" x_out, hidden_next, z_prior, z_infer = self.forward(x, hidden)\n",
" # 1. logistic regression loss\n",
" loss1 = nn.functional.cross_entropy(x_out, x) \n",
" # 2. KL Divergence between Multivariate Gaussian\n",
" mu_infer, log_sigma_infer = z_infer[:,:64], z_infer[:,64:]\n",
" mu_prior, log_sigma_prior = z_prior[:,:64], z_prior[:,64:]\n",
" loss2 = (2*(log_sigma_infer-log_sigma_prior)).exp() \\\n",
" + ((mu_infer-mu_prior)/log_sigma_prior.exp())**2 \\\n",
" - 2*(log_sigma_infer-log_sigma_prior) - 1\n",
" loss2 = 0.5*loss2.sum(dim=1).mean()\n",
" return loss1, loss2, hidden_next\n",
" def generate(self, hidden=None, temperature=None):\n",
" if hidden is None:\n",
" hidden=Variable(torch.zeros(1,64))\n",
" if temperature is None:\n",
" temperature = 0.8\n",
" # 1. h => z\n",
" z_prior = self.prior(hidden)\n",
" # sampling\n",
" z = Variable(torch.randn(z_prior.size(0),64))*z_prior[:,64:].exp()+z_prior[:,:64]\n",
" z = self.phi_z(z)\n",
" # 2. h + z => x\n",
" x_out = self.decoder(torch.cat([hidden, z], dim=1))\n",
" # sampling\n",
" x_sample = x = x_out.div(temperature).exp().multinomial(1).squeeze()\n",
" x = self.phi_x(x)\n",
" # 3. x + z => h\n",
" hidden_next = self.rnn(torch.cat([x,z], dim=1),hidden)\n",
" return x_sample, hidden_next\n",
" def generate_text(self, hidden=None,temperature=None, n=100):\n",
" res = []\n",
" hidden = None\n",
" for _ in range(n):\n",
" x_sample, hidden = self.generate(hidden,temperature)\n",
" res.append(chr(x_sample.data[0]))\n",
" return \"\".join(res)\n",
" \n",
"# Test\n",
"net = VRNNCell()\n",
"x = Variable(torch.LongTensor([12,13,14]))\n",
"hidden = Variable(torch.rand(3,64))\n",
"output, hidden_next, z_infer, z_prior = net(x, hidden)\n",
"loss1, loss2, _ = net.calculate_loss(x, hidden)\n",
"loss1, loss2\n",
"hidden = Variable(torch.zeros(1,64))\n",
"net.generate_text()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Download tiny shakspear text"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-----SAMPLE----\n",
"\n",
"First Citizen:\n",
"Before we proceed any further, hear me speak.\n",
"\n",
"All:\n",
"Speak, speak.\n",
"\n",
"First Citizen:\n",
"You\n"
]
}
],
"source": [
"from six.moves.urllib import request\n",
"url = \"https://raw.githubusercontent.com/jcjohnson/torch-rnn/master/data/tiny-shakespeare.txt\"\n",
"text = request.urlopen(url).read().decode()\n",
"\n",
"print('-----SAMPLE----\\n')\n",
"print(text[:100])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### A convinient function to sample text"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"def batch_generator(seq_size=300, batch_size=64):\n",
" cap = len(text) - seq_size*batch_size\n",
" while True:\n",
" idx = np.random.randint(0, cap, batch_size)\n",
" res = []\n",
" for _ in range(seq_size):\n",
" batch = torch.LongTensor([ord(text[i]) for i in idx])\n",
" res.append(batch)\n",
" idx += 1\n",
" yield res\n",
"\n",
"g = batch_generator()\n",
"batch = next(g)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Training"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
">> epoch 0, loss 3769.2000, decoder loss 1464.9570, latent loss 2304.2424\n",
"]]\u0011ZA}~.m\\#\u000b",
"\u00061Z5\u0011r>UK!{nxm\t/\u0000d9TB/A \u0018\u0011wst\u0007Z\"y;FB9Ev\u000b",
"h;\u0011\n",
"\n",
">> epoch 100, loss 1007.5905, decoder loss 984.4189, latent loss 23.1717\n",
" r\f",
"t cos\n",
" u Btutrs,lh treotiu ri rihrsrathtgc apr kr heoeeewoset nal'e niai uochatyoe dec te.es\n",
"\n",
">> epoch 200, loss 930.7681, decoder loss 911.5832, latent loss 19.1846\n",
"scd.ORSo plm py ld e moencu arsh iiae eyuio li nntnir twr,tt h le cewee un nmo lhcri r we utocee e\n",
"\n",
">> epoch 300, loss 795.6132, decoder loss 778.8964, latent loss 16.7170\n",
"zUxitrel so bo sat om hho he tos otans thtos\n",
"Oere, I\n",
"AI unnt fey.C Yon arl nilil mithe,\n",
"\n",
"BNL:\n",
"TYor, \n",
"\n",
">> epoch 400, loss 739.1819, decoder loss 724.5921, latent loss 14.5900\n",
":\n",
"\n",
"ASESSE:\n",
"Tive theaveress aw thal he care hees the anl ouues spid at erucerte ho ere es the wat sde\n",
"\n",
">> epoch 500, loss 705.3881, decoder loss 693.6972, latent loss 11.6912\n",
"pzit wivy thes ciun keit usmend sere tout, the farod, ser theure hece dicle om it ay, pamer sald dy \n",
"\n",
">> epoch 600, loss 674.5148, decoder loss 663.9881, latent loss 10.5273\n",
"}INNY KPARASNET:\n",
"Yes ar,\n",
"wrele enge at thout, hor omos siw me leds.\n",
"\n",
"IOG ISS:\n",
"Thalet wes?\n",
"\n",
"RELa:\n",
"And\n",
"\n",
">> epoch 700, loss 657.9904, decoder loss 648.1220, latent loss 9.8681\n",
"unss prvish cay. Gow ches peant shy ander in.\n",
"Aift a vuture I the mand hone thasp will se, gand for'\n",
"\n",
">> epoch 800, loss 648.0872, decoder loss 639.5717, latent loss 8.5153\n",
"anwat hom the prawy I tha, you proll wime unges and sen be and gout and thar the sith:\n",
"And huck'd or\n",
"\n",
">> epoch 900, loss 636.0385, decoder loss 629.0430, latent loss 6.9951\n",
"UFY LESMUSY:\n",
"What your I his are,\n",
"On borlatee in whave ir hin cand a you and and, ereture all by me \n",
"\n",
">> epoch 1000, loss 633.3281, decoder loss 626.9877, latent loss 6.3404\n",
"\f",
"T:\n",
"Dho fill ow tath wim it dever not hat nom to I that hatith eardant then I tinipt, darro!\n",
"\n",
"Cir ce\n",
"\n",
">> epoch 1100, loss 612.0241, decoder loss 606.6413, latent loss 5.3829\n",
"\u0015s, in to fardee\n",
"Any breans lestie, haw, your doth to'ln or youse cowiss in\n",
"An' for the counds uthin\n",
"\n",
">> epoch 1200, loss 605.0844, decoder loss 599.7728, latent loss 5.3119\n",
"[;?\n",
"\n",
"COLITRI:\n",
"But frountous of you the the yet, misk's: shout as cood that cany prood\n",
"The must updin\n",
"\n",
">> epoch 1300, loss 603.7875, decoder loss 599.5039, latent loss 4.2838\n",
"\u0017L'ly crowntre af at to of me the hers and the be love\n",
"A oun lite, but ther he ing, the sintel,\n",
"And \n",
"\n",
">> epoch 1400, loss 585.8950, decoder loss 582.1081, latent loss 3.7871\n",
"S\u001d",
"ol fauries to princoundst, with blad,\n",
"I't thrnowt be to fey thear accure pracumed from thou they s\n",
"\n",
">> epoch 1500, loss 588.9115, decoder loss 585.4120, latent loss 3.4996\n",
"rtsllef thele for a det and this, there heep but not ware, prorbattres?\n",
"\n",
"YENGGUR:\n",
"Of! ghom will ceak\n",
"\n",
">> epoch 1600, loss 580.6710, decoder loss 577.3580, latent loss 3.3128\n",
"UTZ:\n",
"\n",
"MARKEO:\n",
"By gake, that not and a have shalting,\n",
"For love suld, Tore that winhor she come bet be\n",
"\n",
">> epoch 1700, loss 581.2310, decoder loss 577.8647, latent loss 3.3665\n",
"2or If more and the shall this thou here,\n",
"Then she weel, us you the own sit swertelt of the amerers,\n",
"\n",
">> epoch 1800, loss 582.2238, decoder loss 576.5284, latent loss 5.6954\n",
" XMIVV:\n",
"Not, sear so preant on of stithis a veak: be my, of his pull our the deppent,\n",
"Tene to the de\n",
"\n",
">> epoch 1900, loss 570.4400, decoder loss 567.7066, latent loss 2.7332\n",
"Y VINIO:\n",
"Kance thight a bave to the enther a manty comismor you stall to 'tat dausher but his of nea\n",
"\n"
]
}
],
"source": [
"net = VRNNCell()\n",
"max_epoch = 2000\n",
"optimizer = optim.Adam(net.parameters(), lr=0.001)\n",
"g = batch_generator()\n",
"\n",
"hidden = Variable(torch.zeros(64,64)) #batch_size x hidden_size\n",
"for epoch in range(max_epoch):\n",
" batch = next(g)\n",
" loss_seq = 0\n",
" loss1_seq, loss2_seq = 0, 0\n",
" optimizer.zero_grad()\n",
" for x in batch:\n",
" loss1, loss2, hidden = net.calculate_loss(Variable(x),hidden)\n",
" loss1_seq += loss1.data[0]\n",
" loss2_seq += loss2.data[0]\n",
" loss_seq = loss_seq + loss1+loss2\n",
" loss_seq.backward()\n",
" optimizer.step()\n",
" hidden.detach_()\n",
" if epoch%100==0:\n",
" print('>> epoch {}, loss {:12.4f}, decoder loss {:12.4f}, latent loss {:12.4f}'.format(epoch, loss_seq.data[0], loss1_seq, loss2_seq))\n",
" print(net.generate_text())\n",
" print()\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"`omak, wha lating\n",
"To thing matheds now:\n",
"Your, fich's mad pother you with thouss the deedh! goust I, hest, seably the were thee co, preatt goor his mat start pean the poose not 'ere, as and for that I great a cring wer.\n",
"\n",
"KINO KINGBRAV:\n",
"Bese retuble not whirs,\n",
"With my heake! who at his yeoth.\n",
"\n",
"Sist starl'd sullancen'd and bece breour there things.\n",
"Sconte to ctret.\n",
"\n",
"PRINGER:\n",
"OL RUMERTE RIRI IP LARIENIZ:\n",
"Beiolt, you to Mripching a will inting,\n",
"And the me thou read onaidion\n",
"And king a's for old somee thee for speak eim'p calf\n",
"The live eavert stish\n",
"Tis conhal of my wairggred most swexferous frome.\n",
"\n",
"VINGER:\n",
"Not you lay my disge,\n",
"We not: the rueselly with it hightens my, will an my foochorr me\n",
"but hash proied our nir is how, woul malay with lethantolt and is inge:\n",
"Had thy monk-tich hap,\n",
"Thimbrisuegetreve, like tous accounce; the were on and trust thoy if peeccon.\n",
"\n",
"COMEON:\n",
"Yet a peave. Preathed that in soned; what shave nongle.\n",
"\n",
"RICHENRIUS:\n",
"Forther,\n",
"And that the be thy chill with wogen thighter\n"
]
}
],
"source": [
"sample = net.generate_text(n=1000, temperature=1)\n",
"print(sample)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Comments\n",
"\n",
"- Denifinitely train longer to get better results. \n",
"- Keep in mind the rnn kernel only has 1 layer, with 64 neurons.\n",
"- Seems no need to tune temperature here. temperature = 0.8 generates a lot of obscure spelling. temperature = 1 works fine."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [conda env:tensorflow]",
"language": "python",
"name": "conda-env-tensorflow-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment