Skip to content

Instantly share code, notes, and snippets.

@mrbkdad
Last active January 19, 2020 16:50
Show Gist options
  • Save mrbkdad/f40e5558b5ae23bba0fff38a508f2648 to your computer and use it in GitHub Desktop.
Save mrbkdad/f40e5558b5ae23bba0fff38a508f2648 to your computer and use it in GitHub Desktop.
RNN Sample with LSTM block for long characters
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f7b3cd90be8>"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.autograd import Variable\n",
"torch.manual_seed(777)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"## sentence to RNN dataset\n",
"## sentence : target dataset\n",
"## seq_len : sequence number of RNN cell\n",
"def sent2rnntensor(sentence,seq_len):\n",
" lookup = {w:i for i,w in enumerate(list(set(sentence)))}\n",
" X = []\n",
" Y = []\n",
" for i in range(0, len(sentence) - seq_len):\n",
" x_str = sentence[i:i+seq_len]\n",
" y_str = sentence[i+1:i+seq_len+1]\n",
" #print(i, x_str, '->', y_str)\n",
" x = [lookup[c] for c in x_str]\n",
" y = [lookup[c] for c in y_str]\n",
" X.append(x)\n",
" Y.append(y)\n",
" return X,Y,lookup\n",
"\n",
"## Tensor Dataset to onehot encoding Dataset\n",
"def onehot2d(dataset):\n",
" idx = dataset.long()\n",
" onehot_len = torch.max(idx)+1\n",
" onehot_data = torch.zeros(idx.numel(),onehot_len)\n",
" onehot_data.scatter_(1,idx.view(-1,1),1)\n",
" return onehot_data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'The core of extensible programming is defining functions.Python allows mandatory and optional arguments, keyword arguments,and even arbitrary argument lists. More about defining functions in Python 3.'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sentence = (\"The core of extensible programming is defining functions.\"\n",
" \"Python allows mandatory and optional arguments, keyword arguments,\"\n",
" \"and even arbitrary argument lists. More about defining functions in Python 3.\")\n",
"sentence"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"## hyper parameters && transform sentence to dataset\n",
"learning_rate = 0.1\n",
"num_epoch = 500\n",
"sequence_length = 10 ## an arbitary number\n",
"\n",
"dataX,dataY,char2idx=sent2rnntensor(sentence,sequence_length)\n",
"idx2char = {i:c for c,i in char2idx.items()} \n",
"\n",
"input_size = len(char2idx) ## one hot encoding vector size\n",
"hidden_size = len(char2idx) ## RNN output size\n",
"num_classes = len(char2idx) ## output size (RNN, softmax etc)\n",
"num_layers = 2 ## RNN layers"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"batch_size = len(dataX)\n",
"x_data = torch.Tensor(dataX)\n",
"y_data = torch.LongTensor(dataY)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"x_one_hot = onehot2d(x_data)\n",
"input_data = Variable(x_one_hot.view(x_data.size()[0],\n",
" x_data.size()[1],num_classes))\n",
"labels = Variable(y_data)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"## Define LSTM class\n",
"class LSTM(nn.Module):\n",
" def __init__(self,classes,input_size,hidden_size,num_layers):\n",
" super(LSTM,self).__init__()\n",
" self.classes = classes\n",
" self.input_size = input_size\n",
" self.hidden_size = hidden_size\n",
" self.num_layers = num_layers\n",
" ## LSTM\n",
" self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,\n",
" num_layers=num_layers,batch_first=True)\n",
" ## FC\n",
" self.fc = nn.Linear(hidden_size,classes)\n",
" def forward(self,x):\n",
" ## input : batch X sequence X input_size(if batch_first is True)\n",
" ## Initialize hidden and cell states\n",
" ## : num_layers X batch_size X hidden_size\n",
" h_0 = Variable(torch.zeros(self.num_layers,\n",
" x.size(0), self.hidden_size))\n",
" c_0 = Variable(torch.zeros(self.num_layers,\n",
" x.size(0), self.hidden_size))\n",
" out,_ = self.lstm(x,(h_0,c_0))\n",
" '''\n",
" Contiguous means that your tensor is not a single block\n",
" of memory, but a block with holes. view can be only used\n",
" with contiguous tensors, so if you need to use it here,\n",
" just call .contiguous()\n",
" '''\n",
" out = out.contiguous().view(-1, self.hidden_size)\n",
" out = self.fc(out)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'torch.autograd.variable.Variable'> <class 'torch.autograd.variable.Variable'>\n",
"torch.Size([3, 10, 30]) torch.Size([3, 10, 30])\n"
]
},
{
"data": {
"text/plain": [
"'ssssssssss'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,\n",
" num_layers=num_layers,batch_first=True)\n",
"h_0 = Variable(torch.zeros(num_layers,3, hidden_size))\n",
"c_0 = Variable(torch.zeros(num_layers,3, hidden_size))\n",
"out,_=lstm(input_data[:3],(h_0,c_0))\n",
"print(type(out),type(out.contiguous()))\n",
"print(out.size(),input_data[:3].size())\n",
"out = out.contiguous().view(-1,hidden_size)\n",
"fc = nn.Linear(hidden_size,num_classes)\n",
"out = fc(out)\n",
"_,idx=out.max(1)\n",
"idx = idx.data.numpy()\n",
"idx = idx.reshape(-1,sequence_length)\n",
"''.join([idx2char[i] for i in idx[-1]])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"## Model\n",
"lstm = LSTM(num_classes,input_size,hidden_size,num_layers)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"## loss function, optimizer\n",
"loss_fn = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(lstm.parameters(),lr=learning_rate)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"start >>>\n",
"epoch: 1, loss: 3.403\n",
"predicted sent: yyyyyyyyyy\n",
"epoch: 2, loss: 3.271\n",
"predicted sent: sioooooooo\n",
"epoch: 3, loss: 3.355\n",
"predicted sent: rlmrrrrrrr\n",
"epoch: 4, loss: 3.163\n",
"predicted sent: rrrrrrrrrr\n",
"epoch: 5, loss: 3.197\n",
"predicted sent: \n",
"epoch: 6, loss: 3.072\n",
"predicted sent: \n",
"epoch: 7, loss: 3.007\n",
"predicted sent: \n",
"epoch: 8, loss: 3.010\n",
"predicted sent: nnnnnnnnnn\n",
"epoch: 9, loss: 2.994\n",
"predicted sent: nnnnnnnnnn\n",
"epoch: 10, loss: 2.960\n",
"predicted sent: nnnnnnnnnn\n",
"epoch: 11, loss: 2.929\n",
"predicted sent: nnn nn \n",
"epoch: 12, loss: 2.896\n",
"predicted sent: n o \n",
"epoch: 13, loss: 2.845\n",
"predicted sent: o o o a \n",
"epoch: 14, loss: 2.787\n",
"predicted sent: o ooo a \n",
"epoch: 15, loss: 2.728\n",
"predicted sent: o ooo o \n",
"epoch: 16, loss: 2.661\n",
"predicted sent: o ooo aae\n",
"epoch: 17, loss: 2.594\n",
"predicted sent: o eooo aaa\n",
"epoch: 18, loss: 2.526\n",
"predicted sent: o eoao aae\n",
"epoch: 19, loss: 2.444\n",
"predicted sent: orenno aaa\n",
"epoch: 20, loss: 2.351\n",
"predicted sent: orenno aaa\n",
"epoch: 21, loss: 2.258\n",
"predicted sent: orennonara\n",
"epoch: 22, loss: 2.174\n",
"predicted sent: orennonara\n",
"epoch: 23, loss: 2.067\n",
"predicted sent: srmnnonara\n",
"epoch: 24, loss: 1.995\n",
"predicted sent: srmnnonaaa\n",
"epoch: 25, loss: 1.903\n",
"predicted sent: srmntonara\n",
"epoch: 26, loss: 1.800\n",
"predicted sent: srmntonaaa\n",
"epoch: 27, loss: 1.713\n",
"predicted sent: srmntoneaa\n",
"epoch: 28, loss: 1.624\n",
"predicted sent: srentoneaa\n",
"epoch: 29, loss: 1.546\n",
"predicted sent: slentonal \n",
"epoch: 30, loss: 1.441\n",
"predicted sent: slenton al\n",
"epoch: 31, loss: 1.360\n",
"predicted sent: saenton al\n",
"epoch: 32, loss: 1.266\n",
"predicted sent: saenton al\n",
"epoch: 33, loss: 1.184\n",
"predicted sent: saechon al\n",
"epoch: 34, loss: 1.115\n",
"predicted sent: safthon al\n",
"epoch: 35, loss: 1.038\n",
"predicted sent: dafthon al\n",
"epoch: 36, loss: 0.971\n",
"predicted sent: dafthon al\n",
"epoch: 37, loss: 0.905\n",
"predicted sent: dafthon al\n",
"epoch: 38, loss: 0.847\n",
"predicted sent: dafthon af\n",
"epoch: 39, loss: 0.799\n",
"predicted sent: daython au\n",
"epoch: 40, loss: 0.748\n",
"predicted sent: saython a.\n",
"epoch: 41, loss: 0.703\n",
"predicted sent: saython a.\n",
"epoch: 42, loss: 0.660\n",
"predicted sent: saython a.\n",
"epoch: 43, loss: 0.627\n",
"predicted sent: saython a.\n",
"epoch: 44, loss: 0.593\n",
"predicted sent: saython a.\n",
"epoch: 45, loss: 0.560\n",
"predicted sent: saython a.\n",
"epoch: 46, loss: 0.530\n",
"predicted sent: saython a.\n",
"epoch: 47, loss: 0.506\n",
"predicted sent: saython a.\n",
"epoch: 48, loss: 0.481\n",
"predicted sent: saython a.\n",
"epoch: 49, loss: 0.458\n",
"predicted sent: saython a.\n",
"epoch: 50, loss: 0.445\n",
"predicted sent: saython a.\n",
"epoch: 51, loss: 0.422\n",
"predicted sent: saython a.\n",
"epoch: 52, loss: 0.405\n",
"predicted sent: saython a.\n",
"epoch: 53, loss: 0.390\n",
"predicted sent: saython a.\n",
"epoch: 54, loss: 0.377\n",
"predicted sent: saython a.\n",
"epoch: 55, loss: 0.365\n",
"predicted sent: saython a.\n",
"epoch: 56, loss: 0.353\n",
"predicted sent: taython a.\n",
"epoch: 57, loss: 0.343\n",
"predicted sent: gaython a.\n",
"epoch: 58, loss: 0.334\n",
"predicted sent: gaython a.\n",
"epoch: 59, loss: 0.327\n",
"predicted sent: gaython a.\n",
"epoch: 60, loss: 0.318\n",
"predicted sent: gaython a.\n",
"epoch: 61, loss: 0.311\n",
"predicted sent: saython a.\n",
"epoch: 62, loss: 0.304\n",
"predicted sent: saython a.\n",
"epoch: 63, loss: 0.299\n",
"predicted sent: saython a.\n",
"epoch: 64, loss: 0.294\n",
"predicted sent: saython a.\n",
"epoch: 65, loss: 0.288\n",
"predicted sent: saython a.\n",
"epoch: 66, loss: 0.284\n",
"predicted sent: saython a.\n",
"epoch: 67, loss: 0.280\n",
"predicted sent: saython a.\n",
"epoch: 68, loss: 0.277\n",
"predicted sent: saython a.\n",
"epoch: 69, loss: 0.273\n",
"predicted sent: saython a.\n",
"epoch: 70, loss: 0.270\n",
"predicted sent: taython a.\n",
"epoch: 71, loss: 0.267\n",
"predicted sent: taython a.\n",
"epoch: 72, loss: 0.264\n",
"predicted sent: taython 3.\n",
"epoch: 73, loss: 0.261\n",
"predicted sent: daython 3.\n",
"epoch: 74, loss: 0.260\n",
"predicted sent: daython 3.\n",
"epoch: 75, loss: 0.258\n",
"predicted sent: daython 3.\n",
"epoch: 76, loss: 0.255\n",
"predicted sent: daython 3.\n",
"epoch: 77, loss: 0.254\n",
"predicted sent: saython 3.\n",
"epoch: 78, loss: 0.251\n",
"predicted sent: gaython 3.\n",
"epoch: 79, loss: 0.250\n",
"predicted sent: gaython 3.\n",
"epoch: 80, loss: 0.249\n",
"predicted sent: gaython 3.\n",
"epoch: 81, loss: 0.247\n",
"predicted sent: aython 3.\n",
"epoch: 82, loss: 0.246\n",
"predicted sent: aython 3.\n",
"epoch: 83, loss: 0.245\n",
"predicted sent: aython 3.\n",
"epoch: 84, loss: 0.244\n",
"predicted sent: gaython 3.\n",
"epoch: 85, loss: 0.243\n",
"predicted sent: gaython 3.\n",
"epoch: 86, loss: 0.242\n",
"predicted sent: gaython 3.\n",
"epoch: 87, loss: 0.241\n",
"predicted sent: gaython 3.\n",
"epoch: 88, loss: 0.240\n",
"predicted sent: taython 3.\n",
"epoch: 89, loss: 0.239\n",
"predicted sent: taython 3.\n",
"epoch: 90, loss: 0.239\n",
"predicted sent: gaython 3.\n",
"epoch: 91, loss: 0.238\n",
"predicted sent: daython 3.\n",
"epoch: 92, loss: 0.237\n",
"predicted sent: daython 3.\n",
"epoch: 93, loss: 0.237\n",
"predicted sent: daython 3.\n",
"epoch: 94, loss: 0.236\n",
"predicted sent: daython 3.\n",
"epoch: 95, loss: 0.236\n",
"predicted sent: saython 3.\n",
"epoch: 96, loss: 0.236\n",
"predicted sent: gaython 3.\n",
"epoch: 97, loss: 0.236\n",
"predicted sent: gaython 3.\n",
"epoch: 98, loss: 0.235\n",
"predicted sent: saython 3.\n",
"epoch: 99, loss: 0.234\n",
"predicted sent: saython 3.\n",
"epoch: 100, loss: 0.233\n",
"predicted sent: saython 3.\n",
"epoch: 101, loss: 0.233\n",
"predicted sent: saython 3.\n",
"epoch: 102, loss: 0.233\n",
"predicted sent: saython 3.\n",
"epoch: 103, loss: 0.232\n",
"predicted sent: gaython 3.\n",
"epoch: 104, loss: 0.232\n",
"predicted sent: saython 3.\n",
"epoch: 105, loss: 0.232\n",
"predicted sent: saython 3.\n",
"epoch: 106, loss: 0.231\n",
"predicted sent: gaython 3.\n",
"epoch: 107, loss: 0.231\n",
"predicted sent: taython 3.\n",
"epoch: 108, loss: 0.231\n",
"predicted sent: taython 3.\n",
"epoch: 109, loss: 0.230\n",
"predicted sent: taython 3.\n",
"epoch: 110, loss: 0.230\n",
"predicted sent: gaython 3.\n",
"epoch: 111, loss: 0.230\n",
"predicted sent: saython 3.\n",
"epoch: 112, loss: 0.230\n",
"predicted sent: saython 3.\n",
"epoch: 113, loss: 0.229\n",
"predicted sent: gaython 3.\n",
"epoch: 114, loss: 0.229\n",
"predicted sent: gaython 3.\n",
"epoch: 115, loss: 0.229\n",
"predicted sent: saython 3.\n",
"epoch: 116, loss: 0.228\n",
"predicted sent: gaython 3.\n",
"epoch: 117, loss: 0.228\n",
"predicted sent: gaython 3.\n",
"epoch: 118, loss: 0.228\n",
"predicted sent: saython 3.\n",
"epoch: 119, loss: 0.228\n",
"predicted sent: saython 3.\n",
"epoch: 120, loss: 0.227\n",
"predicted sent: taython 3.\n",
"epoch: 121, loss: 0.227\n",
"predicted sent: taython 3.\n",
"epoch: 122, loss: 0.227\n",
"predicted sent: taython 3.\n",
"epoch: 123, loss: 0.227\n",
"predicted sent: gaython 3.\n",
"epoch: 124, loss: 0.227\n",
"predicted sent: saython 3.\n",
"epoch: 125, loss: 0.227\n",
"predicted sent: saython 3.\n",
"epoch: 126, loss: 0.226\n",
"predicted sent: saython 3.\n",
"epoch: 127, loss: 0.226\n",
"predicted sent: saython 3.\n",
"epoch: 128, loss: 0.226\n",
"predicted sent: saython 3.\n",
"epoch: 129, loss: 0.226\n",
"predicted sent: taython 3.\n",
"epoch: 130, loss: 0.226\n",
"predicted sent: taython 3.\n",
"epoch: 131, loss: 0.226\n",
"predicted sent: saython 3.\n",
"epoch: 132, loss: 0.226\n",
"predicted sent: saython 3.\n",
"epoch: 133, loss: 0.226\n",
"predicted sent: taython 3.\n",
"epoch: 134, loss: 0.225\n",
"predicted sent: gaython 3.\n",
"epoch: 135, loss: 0.225\n",
"predicted sent: gaython 3.\n",
"epoch: 136, loss: 0.225\n",
"predicted sent: taython 3.\n",
"epoch: 137, loss: 0.225\n",
"predicted sent: taython 3.\n",
"epoch: 138, loss: 0.225\n",
"predicted sent: gaython 3.\n",
"epoch: 139, loss: 0.225\n",
"predicted sent: saython 3.\n",
"epoch: 140, loss: 0.225\n",
"predicted sent: taython 3.\n",
"epoch: 141, loss: 0.225\n",
"predicted sent: gaython 3.\n",
"epoch: 142, loss: 0.225\n",
"predicted sent: saython 3.\n",
"epoch: 143, loss: 0.224\n",
"predicted sent: saython 3.\n",
"epoch: 144, loss: 0.224\n",
"predicted sent: taython 3.\n",
"epoch: 145, loss: 0.224\n",
"predicted sent: gaython 3.\n",
"epoch: 146, loss: 0.224\n",
"predicted sent: gaython 3.\n",
"epoch: 147, loss: 0.224\n",
"predicted sent: saython 3.\n",
"epoch: 148, loss: 0.224\n",
"predicted sent: gaython 3.\n",
"epoch: 149, loss: 0.224\n",
"predicted sent: saython 3.\n",
"epoch: 150, loss: 0.224\n",
"predicted sent: saython 3.\n",
"epoch: 151, loss: 0.224\n",
"predicted sent: gaython 3.\n",
"epoch: 152, loss: 0.224\n",
"predicted sent: taython 3.\n",
"epoch: 153, loss: 0.224\n",
"predicted sent: saython 3.\n",
"epoch: 154, loss: 0.224\n",
"predicted sent: gaython 3.\n",
"epoch: 155, loss: 0.223\n",
"predicted sent: taython 3.\n",
"epoch: 156, loss: 0.223\n",
"predicted sent: gaython 3.\n",
"epoch: 157, loss: 0.223\n",
"predicted sent: gaython 3.\n",
"epoch: 158, loss: 0.223\n",
"predicted sent: taython 3.\n",
"epoch: 159, loss: 0.223\n",
"predicted sent: taython 3.\n",
"epoch: 160, loss: 0.223\n",
"predicted sent: saython 3.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 161, loss: 0.223\n",
"predicted sent: gaython 3.\n",
"epoch: 162, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 163, loss: 0.223\n",
"predicted sent: gaython 3.\n",
"epoch: 164, loss: 0.223\n",
"predicted sent: gaython 3.\n",
"epoch: 165, loss: 0.223\n",
"predicted sent: taython 3.\n",
"epoch: 166, loss: 0.223\n",
"predicted sent: gaython 3.\n",
"epoch: 167, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 168, loss: 0.222\n",
"predicted sent: gaython 3.\n",
"epoch: 169, loss: 0.223\n",
"predicted sent: gaython 3.\n",
"epoch: 170, loss: 0.223\n",
"predicted sent: taython 3.\n",
"epoch: 171, loss: 0.223\n",
"predicted sent: gaython 3.\n",
"epoch: 172, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 173, loss: 0.222\n",
"predicted sent: gaython 3.\n",
"epoch: 174, loss: 0.222\n",
"predicted sent: taython 3.\n",
"epoch: 175, loss: 0.222\n",
"predicted sent: gaython 3.\n",
"epoch: 176, loss: 0.222\n",
"predicted sent: gaython 3.\n",
"epoch: 177, loss: 0.222\n",
"predicted sent: taython 3.\n",
"epoch: 178, loss: 0.222\n",
"predicted sent: gaython 3.\n",
"epoch: 179, loss: 0.222\n",
"predicted sent: saython 3.\n",
"epoch: 180, loss: 0.222\n",
"predicted sent: gaython 3.\n",
"epoch: 181, loss: 0.222\n",
"predicted sent: saython 3.\n",
"epoch: 182, loss: 0.222\n",
"predicted sent: gaython 3.\n",
"epoch: 183, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 184, loss: 0.223\n",
"predicted sent: gaython 3.\n",
"epoch: 185, loss: 0.223\n",
"predicted sent: taython 3.\n",
"epoch: 186, loss: 0.223\n",
"predicted sent: gaython 3.\n",
"epoch: 187, loss: 0.222\n",
"predicted sent: daython 3.\n",
"epoch: 188, loss: 0.222\n",
"predicted sent: saython 3.\n",
"epoch: 189, loss: 0.222\n",
"predicted sent: gaython 3.\n",
"epoch: 190, loss: 0.222\n",
"predicted sent: taython 3.\n",
"epoch: 191, loss: 0.222\n",
"predicted sent: gaython 3.\n",
"epoch: 192, loss: 0.222\n",
"predicted sent: saython 3.\n",
"epoch: 193, loss: 0.222\n",
"predicted sent: saython 3.\n",
"epoch: 194, loss: 0.222\n",
"predicted sent: gaython 3.\n",
"epoch: 195, loss: 0.222\n",
"predicted sent: taython 3.\n",
"epoch: 196, loss: 0.222\n",
"predicted sent: gaython 3.\n",
"epoch: 197, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 198, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 199, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 200, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 201, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 202, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 203, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 204, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 205, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 206, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 207, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 208, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 209, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 210, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 211, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 212, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 213, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 214, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 215, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 216, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 217, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 218, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 219, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 220, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 221, loss: 0.222\n",
"predicted sent: gaython 3.\n",
"epoch: 222, loss: 0.222\n",
"predicted sent: taython 3.\n",
"epoch: 223, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 224, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 225, loss: 0.221\n",
"predicted sent: daython 3.\n",
"epoch: 226, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 227, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 228, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 229, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 230, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 231, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 232, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 233, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 234, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 235, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 236, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 237, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 238, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 239, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 240, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 241, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 242, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 243, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 244, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 245, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 246, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 247, loss: 0.222\n",
"predicted sent: gaython 3.\n",
"epoch: 248, loss: 0.222\n",
"predicted sent: saython 3.\n",
"epoch: 249, loss: 0.222\n",
"predicted sent: gaython 3.\n",
"epoch: 250, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 251, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 252, loss: 0.221\n",
"predicted sent: daython 3.\n",
"epoch: 253, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 254, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 255, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 256, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 257, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 258, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 259, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 260, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 261, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 262, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 263, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 264, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 265, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 266, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 267, loss: 0.220\n",
"predicted sent: daython 3.\n",
"epoch: 268, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 269, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 270, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 271, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 272, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 273, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 274, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 275, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 276, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 277, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 278, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 279, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 280, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 281, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 282, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 283, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 284, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 285, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 286, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 287, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 288, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 289, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 290, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 291, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 292, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 293, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 294, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 295, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 296, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 297, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 298, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 299, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 300, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 301, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 302, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 303, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 304, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 305, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 306, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 307, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 308, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 309, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 310, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 311, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 312, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 313, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 314, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 315, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 316, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 317, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 318, loss: 0.220\n",
"predicted sent: gaython 3.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 319, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 320, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 321, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 322, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 323, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 324, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 325, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 326, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 327, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 328, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 329, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 330, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 331, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 332, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 333, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 334, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 335, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 336, loss: 0.221\n",
"predicted sent: taython 3.\n",
"epoch: 337, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 338, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 339, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 340, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 341, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 342, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 343, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 344, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 345, loss: 0.220\n",
"predicted sent: aython 3.\n",
"epoch: 346, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 347, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 348, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 349, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 350, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 351, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 352, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 353, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 354, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 355, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 356, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 357, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 358, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 359, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 360, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 361, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 362, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 363, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 364, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 365, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 366, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 367, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 368, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 369, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 370, loss: 0.221\n",
"predicted sent: gaython 3.\n",
"epoch: 371, loss: 0.221\n",
"predicted sent: saython 3.\n",
"epoch: 372, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 373, loss: 0.220\n",
"predicted sent: taython 3.\n",
"epoch: 374, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 375, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 376, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 377, loss: 0.220\n",
"predicted sent: gaython 3.\n",
"epoch: 378, loss: 0.220\n",
"predicted sent: aython 3.\n",
"epoch: 379, loss: 0.220\n",
"predicted sent: saython 3.\n",
"epoch: 380, loss: 0.223\n",
"predicted sent: gaython 3.\n",
"epoch: 381, loss: 0.286\n",
"predicted sent: gaython 3.\n",
"epoch: 382, loss: 0.991\n",
"predicted sent: taytoo 3.\n",
"epoch: 383, loss: 1.644\n",
"predicted sent: sanshonnl.\n",
"epoch: 384, loss: 1.770\n",
"predicted sent: gaywton au\n",
"epoch: 385, loss: 1.685\n",
"predicted sent: gaythoraa.\n",
"epoch: 386, loss: 1.445\n",
"predicted sent: geyt onaa.\n",
"epoch: 387, loss: 1.135\n",
"predicted sent: cayttonaa.\n",
"epoch: 388, loss: 0.887\n",
"predicted sent: taython a.\n",
"epoch: 389, loss: 0.855\n",
"predicted sent: taython a,\n",
"epoch: 390, loss: 0.773\n",
"predicted sent: taython a,\n",
"epoch: 391, loss: 0.711\n",
"predicted sent: taython a.\n",
"epoch: 392, loss: 0.639\n",
"predicted sent: taython a.\n",
"epoch: 393, loss: 0.586\n",
"predicted sent: taython k \n",
"epoch: 394, loss: 0.526\n",
"predicted sent: taython a.\n",
"epoch: 395, loss: 0.489\n",
"predicted sent: daython a.\n",
"epoch: 396, loss: 0.458\n",
"predicted sent: daython a.\n",
"epoch: 397, loss: 0.429\n",
"predicted sent: daython a.\n",
"epoch: 398, loss: 0.412\n",
"predicted sent: daython a.\n",
"epoch: 399, loss: 0.393\n",
"predicted sent: daython a.\n",
"epoch: 400, loss: 0.377\n",
"predicted sent: gaython a.\n",
"epoch: 401, loss: 0.361\n",
"predicted sent: gaython a.\n",
"epoch: 402, loss: 0.345\n",
"predicted sent: gaython 3.\n",
"epoch: 403, loss: 0.333\n",
"predicted sent: gaython a.\n",
"epoch: 404, loss: 0.323\n",
"predicted sent: gaython a.\n",
"epoch: 405, loss: 0.314\n",
"predicted sent: saython a.\n",
"epoch: 406, loss: 0.305\n",
"predicted sent: saython a.\n",
"epoch: 407, loss: 0.298\n",
"predicted sent: saython a.\n",
"epoch: 408, loss: 0.291\n",
"predicted sent: saython a.\n",
"epoch: 409, loss: 0.284\n",
"predicted sent: saython a.\n",
"epoch: 410, loss: 0.278\n",
"predicted sent: saython 3.\n",
"epoch: 411, loss: 0.273\n",
"predicted sent: saython 3.\n",
"epoch: 412, loss: 0.269\n",
"predicted sent: saython 3.\n",
"epoch: 413, loss: 0.265\n",
"predicted sent: saython 3.\n",
"epoch: 414, loss: 0.262\n",
"predicted sent: saython 3.\n",
"epoch: 415, loss: 0.259\n",
"predicted sent: saython 3.\n",
"epoch: 416, loss: 0.256\n",
"predicted sent: saython 3.\n",
"epoch: 417, loss: 0.253\n",
"predicted sent: saython 3.\n",
"epoch: 418, loss: 0.251\n",
"predicted sent: saython 3.\n",
"epoch: 419, loss: 0.248\n",
"predicted sent: saython 3.\n",
"epoch: 420, loss: 0.246\n",
"predicted sent: saython 3.\n",
"epoch: 421, loss: 0.245\n",
"predicted sent: saython 3.\n",
"epoch: 422, loss: 0.243\n",
"predicted sent: saython 3.\n",
"epoch: 423, loss: 0.242\n",
"predicted sent: saython 3.\n",
"epoch: 424, loss: 0.240\n",
"predicted sent: saython 3.\n",
"epoch: 425, loss: 0.239\n",
"predicted sent: saython 3.\n",
"epoch: 426, loss: 0.238\n",
"predicted sent: saython 3.\n",
"epoch: 427, loss: 0.237\n",
"predicted sent: saython 3.\n",
"epoch: 428, loss: 0.236\n",
"predicted sent: saython 3.\n",
"epoch: 429, loss: 0.235\n",
"predicted sent: saython 3.\n",
"epoch: 430, loss: 0.234\n",
"predicted sent: saython 3.\n",
"epoch: 431, loss: 0.234\n",
"predicted sent: saython 3.\n",
"epoch: 432, loss: 0.233\n",
"predicted sent: saython 3.\n",
"epoch: 433, loss: 0.232\n",
"predicted sent: gaython 3.\n",
"epoch: 434, loss: 0.232\n",
"predicted sent: gaython 3.\n",
"epoch: 435, loss: 0.231\n",
"predicted sent: gaython 3.\n",
"epoch: 436, loss: 0.231\n",
"predicted sent: gaython 3.\n",
"epoch: 437, loss: 0.230\n",
"predicted sent: saython 3.\n",
"epoch: 438, loss: 0.230\n",
"predicted sent: saython 3.\n",
"epoch: 439, loss: 0.229\n",
"predicted sent: saython 3.\n",
"epoch: 440, loss: 0.229\n",
"predicted sent: saython 3.\n",
"epoch: 441, loss: 0.229\n",
"predicted sent: saython 3.\n",
"epoch: 442, loss: 0.229\n",
"predicted sent: saython 3.\n",
"epoch: 443, loss: 0.228\n",
"predicted sent: saython 3.\n",
"epoch: 444, loss: 0.228\n",
"predicted sent: saython 3.\n",
"epoch: 445, loss: 0.228\n",
"predicted sent: saython 3.\n",
"epoch: 446, loss: 0.227\n",
"predicted sent: taython 3.\n",
"epoch: 447, loss: 0.227\n",
"predicted sent: taython 3.\n",
"epoch: 448, loss: 0.227\n",
"predicted sent: taython 3.\n",
"epoch: 449, loss: 0.227\n",
"predicted sent: saython 3.\n",
"epoch: 450, loss: 0.227\n",
"predicted sent: saython 3.\n",
"epoch: 451, loss: 0.226\n",
"predicted sent: saython 3.\n",
"epoch: 452, loss: 0.226\n",
"predicted sent: saython 3.\n",
"epoch: 453, loss: 0.226\n",
"predicted sent: saython 3.\n",
"epoch: 454, loss: 0.226\n",
"predicted sent: saython 3.\n",
"epoch: 455, loss: 0.226\n",
"predicted sent: saython 3.\n",
"epoch: 456, loss: 0.225\n",
"predicted sent: saython 3.\n",
"epoch: 457, loss: 0.225\n",
"predicted sent: saython 3.\n",
"epoch: 458, loss: 0.225\n",
"predicted sent: saython 3.\n",
"epoch: 459, loss: 0.225\n",
"predicted sent: saython 3.\n",
"epoch: 460, loss: 0.225\n",
"predicted sent: saython 3.\n",
"epoch: 461, loss: 0.225\n",
"predicted sent: saython 3.\n",
"epoch: 462, loss: 0.225\n",
"predicted sent: saython 3.\n",
"epoch: 463, loss: 0.225\n",
"predicted sent: saython 3.\n",
"epoch: 464, loss: 0.224\n",
"predicted sent: saython 3.\n",
"epoch: 465, loss: 0.224\n",
"predicted sent: saython 3.\n",
"epoch: 466, loss: 0.224\n",
"predicted sent: saython 3.\n",
"epoch: 467, loss: 0.224\n",
"predicted sent: saython 3.\n",
"epoch: 468, loss: 0.224\n",
"predicted sent: saython 3.\n",
"epoch: 469, loss: 0.224\n",
"predicted sent: saython 3.\n",
"epoch: 470, loss: 0.224\n",
"predicted sent: saython 3.\n",
"epoch: 471, loss: 0.224\n",
"predicted sent: saython 3.\n",
"epoch: 472, loss: 0.224\n",
"predicted sent: saython 3.\n",
"epoch: 473, loss: 0.224\n",
"predicted sent: saython 3.\n",
"epoch: 474, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 475, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 476, loss: 0.223\n",
"predicted sent: saython 3.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 477, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 478, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 479, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 480, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 481, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 482, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 483, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 484, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 485, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 486, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 487, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 488, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 489, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 490, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 491, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 492, loss: 0.223\n",
"predicted sent: saython 3.\n",
"epoch: 493, loss: 0.222\n",
"predicted sent: saython 3.\n",
"epoch: 494, loss: 0.222\n",
"predicted sent: saython 3.\n",
"epoch: 495, loss: 0.222\n",
"predicted sent: saython 3.\n",
"epoch: 496, loss: 0.222\n",
"predicted sent: saython 3.\n",
"epoch: 497, loss: 0.222\n",
"predicted sent: saython 3.\n",
"epoch: 498, loss: 0.222\n",
"predicted sent: saython 3.\n",
"epoch: 499, loss: 0.222\n",
"predicted sent: saython 3.\n",
"epoch: 500, loss: 0.222\n",
"predicted sent: saython 3.\n",
">>> end\n"
]
}
],
"source": [
"# Train the model\n",
"print('start >>>')\n",
"for epoch in range(num_epoch):\n",
" outputs = lstm(input_data)\n",
" optimizer.zero_grad()\n",
" loss = loss_fn(outputs, labels.view(-1))\n",
" loss.backward()\n",
" optimizer.step()\n",
" ## the predicted indices of the next character\n",
" _, idx = outputs.max(1)\n",
" idx = idx.data.numpy()\n",
" idx = idx.reshape(-1, sequence_length) # (170,10)\n",
" # display the last sequence\n",
" result_str = [idx2char[c] for c in idx[-1]]\n",
" print(\"epoch: {}, loss: {:1.3f}\".format(epoch + 1, loss.data[0]))\n",
" print(\"predicted sent: \", ''.join(result_str))\n",
"\n",
"print('>>> end')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment