Last active
January 19, 2020 16:50
-
-
Save mrbkdad/f40e5558b5ae23bba0fff38a508f2648 to your computer and use it in GitHub Desktop.
RNN Sample with LSTM block for long characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<torch._C.Generator at 0x7f7b3cd90be8>" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"from torch.autograd import Variable\n", | |
"torch.manual_seed(777)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"## sentence to RNN dataset\n", | |
"## sentence : target dataset\n", | |
"## seq_len : sequence number of RNN cell\n", | |
"def sent2rnntensor(sentence,seq_len):\n", | |
" lookup = {w:i for i,w in enumerate(list(set(sentence)))}\n", | |
" X = []\n", | |
" Y = []\n", | |
" for i in range(0, len(sentence) - seq_len):\n", | |
" x_str = sentence[i:i+seq_len]\n", | |
" y_str = sentence[i+1:i+seq_len+1]\n", | |
" #print(i, x_str, '->', y_str)\n", | |
" x = [lookup[c] for c in x_str]\n", | |
" y = [lookup[c] for c in y_str]\n", | |
" X.append(x)\n", | |
" Y.append(y)\n", | |
" return X,Y,lookup\n", | |
"\n", | |
"## Tensor Dataset to onehot encoding Dataset\n", | |
"def onehot2d(dataset):\n", | |
" idx = dataset.long()\n", | |
" onehot_len = torch.max(idx)+1\n", | |
" onehot_data = torch.zeros(idx.numel(),onehot_len)\n", | |
" onehot_data.scatter_(1,idx.view(-1,1),1)\n", | |
" return onehot_data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'The core of extensible programming is defining functions.Python allows mandatory and optional arguments, keyword arguments,and even arbitrary argument lists. More about defining functions in Python 3.'" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sentence = (\"The core of extensible programming is defining functions.\"\n", | |
" \"Python allows mandatory and optional arguments, keyword arguments,\"\n", | |
" \"and even arbitrary argument lists. More about defining functions in Python 3.\")\n", | |
"sentence" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## hyper parameters && transform sentence to dataset\n", | |
"learning_rate = 0.1\n", | |
"num_epoch = 500\n", | |
"sequence_length = 10 ## an arbitary number\n", | |
"\n", | |
"dataX,dataY,char2idx=sent2rnntensor(sentence,sequence_length)\n", | |
"idx2char = {i:c for c,i in char2idx.items()} \n", | |
"\n", | |
"input_size = len(char2idx) ## one hot encoding vector size\n", | |
"hidden_size = len(char2idx) ## RNN output size\n", | |
"num_classes = len(char2idx) ## output size (RNN, softmax etc)\n", | |
"num_layers = 2 ## RNN layers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"batch_size = len(dataX)\n", | |
"x_data = torch.Tensor(dataX)\n", | |
"y_data = torch.LongTensor(dataY)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x_one_hot = onehot2d(x_data)\n", | |
"input_data = Variable(x_one_hot.view(x_data.size()[0],\n", | |
" x_data.size()[1],num_classes))\n", | |
"labels = Variable(y_data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"## Define LSTM class\n", | |
"class LSTM(nn.Module):\n", | |
" def __init__(self,classes,input_size,hidden_size,num_layers):\n", | |
" super(LSTM,self).__init__()\n", | |
" self.classes = classes\n", | |
" self.input_size = input_size\n", | |
" self.hidden_size = hidden_size\n", | |
" self.num_layers = num_layers\n", | |
" ## LSTM\n", | |
" self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,\n", | |
" num_layers=num_layers,batch_first=True)\n", | |
" ## FC\n", | |
" self.fc = nn.Linear(hidden_size,classes)\n", | |
" def forward(self,x):\n", | |
" ## input : batch X sequence X input_size(if batch_first is True)\n", | |
" ## Initialize hidden and cell states\n", | |
" ## : num_layers X batch_size X hidden_size\n", | |
" h_0 = Variable(torch.zeros(self.num_layers,\n", | |
" x.size(0), self.hidden_size))\n", | |
" c_0 = Variable(torch.zeros(self.num_layers,\n", | |
" x.size(0), self.hidden_size))\n", | |
" out,_ = self.lstm(x,(h_0,c_0))\n", | |
" '''\n", | |
" Contiguous means that your tensor is not a single block\n", | |
" of memory, but a block with holes. view can be only used\n", | |
" with contiguous tensors, so if you need to use it here,\n", | |
" just call .contiguous()\n", | |
" '''\n", | |
" out = out.contiguous().view(-1, self.hidden_size)\n", | |
" out = self.fc(out)\n", | |
" return out" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<class 'torch.autograd.variable.Variable'> <class 'torch.autograd.variable.Variable'>\n", | |
"torch.Size([3, 10, 30]) torch.Size([3, 10, 30])\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"'ssssssssss'" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,\n", | |
" num_layers=num_layers,batch_first=True)\n", | |
"h_0 = Variable(torch.zeros(num_layers,3, hidden_size))\n", | |
"c_0 = Variable(torch.zeros(num_layers,3, hidden_size))\n", | |
"out,_=lstm(input_data[:3],(h_0,c_0))\n", | |
"print(type(out),type(out.contiguous()))\n", | |
"print(out.size(),input_data[:3].size())\n", | |
"out = out.contiguous().view(-1,hidden_size)\n", | |
"fc = nn.Linear(hidden_size,num_classes)\n", | |
"out = fc(out)\n", | |
"_,idx=out.max(1)\n", | |
"idx = idx.data.numpy()\n", | |
"idx = idx.reshape(-1,sequence_length)\n", | |
"''.join([idx2char[i] for i in idx[-1]])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"## Model\n", | |
"lstm = LSTM(num_classes,input_size,hidden_size,num_layers)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"## loss function, optimizer\n", | |
"loss_fn = nn.CrossEntropyLoss()\n", | |
"optimizer = torch.optim.Adam(lstm.parameters(),lr=learning_rate)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"start >>>\n", | |
"epoch: 1, loss: 3.403\n", | |
"predicted sent: yyyyyyyyyy\n", | |
"epoch: 2, loss: 3.271\n", | |
"predicted sent: sioooooooo\n", | |
"epoch: 3, loss: 3.355\n", | |
"predicted sent: rlmrrrrrrr\n", | |
"epoch: 4, loss: 3.163\n", | |
"predicted sent: rrrrrrrrrr\n", | |
"epoch: 5, loss: 3.197\n", | |
"predicted sent: \n", | |
"epoch: 6, loss: 3.072\n", | |
"predicted sent: \n", | |
"epoch: 7, loss: 3.007\n", | |
"predicted sent: \n", | |
"epoch: 8, loss: 3.010\n", | |
"predicted sent: nnnnnnnnnn\n", | |
"epoch: 9, loss: 2.994\n", | |
"predicted sent: nnnnnnnnnn\n", | |
"epoch: 10, loss: 2.960\n", | |
"predicted sent: nnnnnnnnnn\n", | |
"epoch: 11, loss: 2.929\n", | |
"predicted sent: nnn nn \n", | |
"epoch: 12, loss: 2.896\n", | |
"predicted sent: n o \n", | |
"epoch: 13, loss: 2.845\n", | |
"predicted sent: o o o a \n", | |
"epoch: 14, loss: 2.787\n", | |
"predicted sent: o ooo a \n", | |
"epoch: 15, loss: 2.728\n", | |
"predicted sent: o ooo o \n", | |
"epoch: 16, loss: 2.661\n", | |
"predicted sent: o ooo aae\n", | |
"epoch: 17, loss: 2.594\n", | |
"predicted sent: o eooo aaa\n", | |
"epoch: 18, loss: 2.526\n", | |
"predicted sent: o eoao aae\n", | |
"epoch: 19, loss: 2.444\n", | |
"predicted sent: orenno aaa\n", | |
"epoch: 20, loss: 2.351\n", | |
"predicted sent: orenno aaa\n", | |
"epoch: 21, loss: 2.258\n", | |
"predicted sent: orennonara\n", | |
"epoch: 22, loss: 2.174\n", | |
"predicted sent: orennonara\n", | |
"epoch: 23, loss: 2.067\n", | |
"predicted sent: srmnnonara\n", | |
"epoch: 24, loss: 1.995\n", | |
"predicted sent: srmnnonaaa\n", | |
"epoch: 25, loss: 1.903\n", | |
"predicted sent: srmntonara\n", | |
"epoch: 26, loss: 1.800\n", | |
"predicted sent: srmntonaaa\n", | |
"epoch: 27, loss: 1.713\n", | |
"predicted sent: srmntoneaa\n", | |
"epoch: 28, loss: 1.624\n", | |
"predicted sent: srentoneaa\n", | |
"epoch: 29, loss: 1.546\n", | |
"predicted sent: slentonal \n", | |
"epoch: 30, loss: 1.441\n", | |
"predicted sent: slenton al\n", | |
"epoch: 31, loss: 1.360\n", | |
"predicted sent: saenton al\n", | |
"epoch: 32, loss: 1.266\n", | |
"predicted sent: saenton al\n", | |
"epoch: 33, loss: 1.184\n", | |
"predicted sent: saechon al\n", | |
"epoch: 34, loss: 1.115\n", | |
"predicted sent: safthon al\n", | |
"epoch: 35, loss: 1.038\n", | |
"predicted sent: dafthon al\n", | |
"epoch: 36, loss: 0.971\n", | |
"predicted sent: dafthon al\n", | |
"epoch: 37, loss: 0.905\n", | |
"predicted sent: dafthon al\n", | |
"epoch: 38, loss: 0.847\n", | |
"predicted sent: dafthon af\n", | |
"epoch: 39, loss: 0.799\n", | |
"predicted sent: daython au\n", | |
"epoch: 40, loss: 0.748\n", | |
"predicted sent: saython a.\n", | |
"epoch: 41, loss: 0.703\n", | |
"predicted sent: saython a.\n", | |
"epoch: 42, loss: 0.660\n", | |
"predicted sent: saython a.\n", | |
"epoch: 43, loss: 0.627\n", | |
"predicted sent: saython a.\n", | |
"epoch: 44, loss: 0.593\n", | |
"predicted sent: saython a.\n", | |
"epoch: 45, loss: 0.560\n", | |
"predicted sent: saython a.\n", | |
"epoch: 46, loss: 0.530\n", | |
"predicted sent: saython a.\n", | |
"epoch: 47, loss: 0.506\n", | |
"predicted sent: saython a.\n", | |
"epoch: 48, loss: 0.481\n", | |
"predicted sent: saython a.\n", | |
"epoch: 49, loss: 0.458\n", | |
"predicted sent: saython a.\n", | |
"epoch: 50, loss: 0.445\n", | |
"predicted sent: saython a.\n", | |
"epoch: 51, loss: 0.422\n", | |
"predicted sent: saython a.\n", | |
"epoch: 52, loss: 0.405\n", | |
"predicted sent: saython a.\n", | |
"epoch: 53, loss: 0.390\n", | |
"predicted sent: saython a.\n", | |
"epoch: 54, loss: 0.377\n", | |
"predicted sent: saython a.\n", | |
"epoch: 55, loss: 0.365\n", | |
"predicted sent: saython a.\n", | |
"epoch: 56, loss: 0.353\n", | |
"predicted sent: taython a.\n", | |
"epoch: 57, loss: 0.343\n", | |
"predicted sent: gaython a.\n", | |
"epoch: 58, loss: 0.334\n", | |
"predicted sent: gaython a.\n", | |
"epoch: 59, loss: 0.327\n", | |
"predicted sent: gaython a.\n", | |
"epoch: 60, loss: 0.318\n", | |
"predicted sent: gaython a.\n", | |
"epoch: 61, loss: 0.311\n", | |
"predicted sent: saython a.\n", | |
"epoch: 62, loss: 0.304\n", | |
"predicted sent: saython a.\n", | |
"epoch: 63, loss: 0.299\n", | |
"predicted sent: saython a.\n", | |
"epoch: 64, loss: 0.294\n", | |
"predicted sent: saython a.\n", | |
"epoch: 65, loss: 0.288\n", | |
"predicted sent: saython a.\n", | |
"epoch: 66, loss: 0.284\n", | |
"predicted sent: saython a.\n", | |
"epoch: 67, loss: 0.280\n", | |
"predicted sent: saython a.\n", | |
"epoch: 68, loss: 0.277\n", | |
"predicted sent: saython a.\n", | |
"epoch: 69, loss: 0.273\n", | |
"predicted sent: saython a.\n", | |
"epoch: 70, loss: 0.270\n", | |
"predicted sent: taython a.\n", | |
"epoch: 71, loss: 0.267\n", | |
"predicted sent: taython a.\n", | |
"epoch: 72, loss: 0.264\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 73, loss: 0.261\n", | |
"predicted sent: daython 3.\n", | |
"epoch: 74, loss: 0.260\n", | |
"predicted sent: daython 3.\n", | |
"epoch: 75, loss: 0.258\n", | |
"predicted sent: daython 3.\n", | |
"epoch: 76, loss: 0.255\n", | |
"predicted sent: daython 3.\n", | |
"epoch: 77, loss: 0.254\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 78, loss: 0.251\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 79, loss: 0.250\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 80, loss: 0.249\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 81, loss: 0.247\n", | |
"predicted sent: aython 3.\n", | |
"epoch: 82, loss: 0.246\n", | |
"predicted sent: aython 3.\n", | |
"epoch: 83, loss: 0.245\n", | |
"predicted sent: aython 3.\n", | |
"epoch: 84, loss: 0.244\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 85, loss: 0.243\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 86, loss: 0.242\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 87, loss: 0.241\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 88, loss: 0.240\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 89, loss: 0.239\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 90, loss: 0.239\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 91, loss: 0.238\n", | |
"predicted sent: daython 3.\n", | |
"epoch: 92, loss: 0.237\n", | |
"predicted sent: daython 3.\n", | |
"epoch: 93, loss: 0.237\n", | |
"predicted sent: daython 3.\n", | |
"epoch: 94, loss: 0.236\n", | |
"predicted sent: daython 3.\n", | |
"epoch: 95, loss: 0.236\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 96, loss: 0.236\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 97, loss: 0.236\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 98, loss: 0.235\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 99, loss: 0.234\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 100, loss: 0.233\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 101, loss: 0.233\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 102, loss: 0.233\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 103, loss: 0.232\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 104, loss: 0.232\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 105, loss: 0.232\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 106, loss: 0.231\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 107, loss: 0.231\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 108, loss: 0.231\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 109, loss: 0.230\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 110, loss: 0.230\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 111, loss: 0.230\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 112, loss: 0.230\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 113, loss: 0.229\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 114, loss: 0.229\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 115, loss: 0.229\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 116, loss: 0.228\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 117, loss: 0.228\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 118, loss: 0.228\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 119, loss: 0.228\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 120, loss: 0.227\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 121, loss: 0.227\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 122, loss: 0.227\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 123, loss: 0.227\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 124, loss: 0.227\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 125, loss: 0.227\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 126, loss: 0.226\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 127, loss: 0.226\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 128, loss: 0.226\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 129, loss: 0.226\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 130, loss: 0.226\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 131, loss: 0.226\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 132, loss: 0.226\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 133, loss: 0.226\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 134, loss: 0.225\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 135, loss: 0.225\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 136, loss: 0.225\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 137, loss: 0.225\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 138, loss: 0.225\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 139, loss: 0.225\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 140, loss: 0.225\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 141, loss: 0.225\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 142, loss: 0.225\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 143, loss: 0.224\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 144, loss: 0.224\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 145, loss: 0.224\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 146, loss: 0.224\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 147, loss: 0.224\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 148, loss: 0.224\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 149, loss: 0.224\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 150, loss: 0.224\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 151, loss: 0.224\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 152, loss: 0.224\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 153, loss: 0.224\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 154, loss: 0.224\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 155, loss: 0.223\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 156, loss: 0.223\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 157, loss: 0.223\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 158, loss: 0.223\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 159, loss: 0.223\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 160, loss: 0.223\n", | |
"predicted sent: saython 3.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch: 161, loss: 0.223\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 162, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 163, loss: 0.223\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 164, loss: 0.223\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 165, loss: 0.223\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 166, loss: 0.223\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 167, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 168, loss: 0.222\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 169, loss: 0.223\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 170, loss: 0.223\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 171, loss: 0.223\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 172, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 173, loss: 0.222\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 174, loss: 0.222\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 175, loss: 0.222\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 176, loss: 0.222\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 177, loss: 0.222\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 178, loss: 0.222\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 179, loss: 0.222\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 180, loss: 0.222\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 181, loss: 0.222\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 182, loss: 0.222\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 183, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 184, loss: 0.223\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 185, loss: 0.223\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 186, loss: 0.223\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 187, loss: 0.222\n", | |
"predicted sent: daython 3.\n", | |
"epoch: 188, loss: 0.222\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 189, loss: 0.222\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 190, loss: 0.222\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 191, loss: 0.222\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 192, loss: 0.222\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 193, loss: 0.222\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 194, loss: 0.222\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 195, loss: 0.222\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 196, loss: 0.222\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 197, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 198, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 199, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 200, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 201, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 202, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 203, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 204, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 205, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 206, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 207, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 208, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 209, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 210, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 211, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 212, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 213, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 214, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 215, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 216, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 217, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 218, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 219, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 220, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 221, loss: 0.222\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 222, loss: 0.222\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 223, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 224, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 225, loss: 0.221\n", | |
"predicted sent: daython 3.\n", | |
"epoch: 226, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 227, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 228, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 229, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 230, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 231, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 232, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 233, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 234, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 235, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 236, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 237, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 238, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 239, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 240, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 241, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 242, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 243, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 244, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 245, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 246, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 247, loss: 0.222\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 248, loss: 0.222\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 249, loss: 0.222\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 250, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 251, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 252, loss: 0.221\n", | |
"predicted sent: daython 3.\n", | |
"epoch: 253, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 254, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 255, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 256, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 257, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 258, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 259, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 260, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 261, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 262, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 263, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 264, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 265, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 266, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 267, loss: 0.220\n", | |
"predicted sent: daython 3.\n", | |
"epoch: 268, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 269, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 270, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 271, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 272, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 273, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 274, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 275, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 276, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 277, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 278, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 279, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 280, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 281, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 282, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 283, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 284, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 285, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 286, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 287, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 288, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 289, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 290, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 291, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 292, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 293, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 294, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 295, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 296, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 297, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 298, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 299, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 300, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 301, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 302, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 303, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 304, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 305, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 306, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 307, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 308, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 309, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 310, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 311, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 312, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 313, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 314, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 315, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 316, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 317, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 318, loss: 0.220\n", | |
"predicted sent: gaython 3.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch: 319, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 320, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 321, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 322, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 323, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 324, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 325, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 326, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 327, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 328, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 329, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 330, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 331, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 332, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 333, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 334, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 335, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 336, loss: 0.221\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 337, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 338, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 339, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 340, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 341, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 342, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 343, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 344, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 345, loss: 0.220\n", | |
"predicted sent: aython 3.\n", | |
"epoch: 346, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 347, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 348, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 349, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 350, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 351, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 352, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 353, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 354, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 355, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 356, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 357, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 358, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 359, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 360, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 361, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 362, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 363, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 364, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 365, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 366, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 367, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 368, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 369, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 370, loss: 0.221\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 371, loss: 0.221\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 372, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 373, loss: 0.220\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 374, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 375, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 376, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 377, loss: 0.220\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 378, loss: 0.220\n", | |
"predicted sent: aython 3.\n", | |
"epoch: 379, loss: 0.220\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 380, loss: 0.223\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 381, loss: 0.286\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 382, loss: 0.991\n", | |
"predicted sent: taytoo 3.\n", | |
"epoch: 383, loss: 1.644\n", | |
"predicted sent: sanshonnl.\n", | |
"epoch: 384, loss: 1.770\n", | |
"predicted sent: gaywton au\n", | |
"epoch: 385, loss: 1.685\n", | |
"predicted sent: gaythoraa.\n", | |
"epoch: 386, loss: 1.445\n", | |
"predicted sent: geyt onaa.\n", | |
"epoch: 387, loss: 1.135\n", | |
"predicted sent: cayttonaa.\n", | |
"epoch: 388, loss: 0.887\n", | |
"predicted sent: taython a.\n", | |
"epoch: 389, loss: 0.855\n", | |
"predicted sent: taython a,\n", | |
"epoch: 390, loss: 0.773\n", | |
"predicted sent: taython a,\n", | |
"epoch: 391, loss: 0.711\n", | |
"predicted sent: taython a.\n", | |
"epoch: 392, loss: 0.639\n", | |
"predicted sent: taython a.\n", | |
"epoch: 393, loss: 0.586\n", | |
"predicted sent: taython k \n", | |
"epoch: 394, loss: 0.526\n", | |
"predicted sent: taython a.\n", | |
"epoch: 395, loss: 0.489\n", | |
"predicted sent: daython a.\n", | |
"epoch: 396, loss: 0.458\n", | |
"predicted sent: daython a.\n", | |
"epoch: 397, loss: 0.429\n", | |
"predicted sent: daython a.\n", | |
"epoch: 398, loss: 0.412\n", | |
"predicted sent: daython a.\n", | |
"epoch: 399, loss: 0.393\n", | |
"predicted sent: daython a.\n", | |
"epoch: 400, loss: 0.377\n", | |
"predicted sent: gaython a.\n", | |
"epoch: 401, loss: 0.361\n", | |
"predicted sent: gaython a.\n", | |
"epoch: 402, loss: 0.345\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 403, loss: 0.333\n", | |
"predicted sent: gaython a.\n", | |
"epoch: 404, loss: 0.323\n", | |
"predicted sent: gaython a.\n", | |
"epoch: 405, loss: 0.314\n", | |
"predicted sent: saython a.\n", | |
"epoch: 406, loss: 0.305\n", | |
"predicted sent: saython a.\n", | |
"epoch: 407, loss: 0.298\n", | |
"predicted sent: saython a.\n", | |
"epoch: 408, loss: 0.291\n", | |
"predicted sent: saython a.\n", | |
"epoch: 409, loss: 0.284\n", | |
"predicted sent: saython a.\n", | |
"epoch: 410, loss: 0.278\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 411, loss: 0.273\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 412, loss: 0.269\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 413, loss: 0.265\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 414, loss: 0.262\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 415, loss: 0.259\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 416, loss: 0.256\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 417, loss: 0.253\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 418, loss: 0.251\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 419, loss: 0.248\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 420, loss: 0.246\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 421, loss: 0.245\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 422, loss: 0.243\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 423, loss: 0.242\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 424, loss: 0.240\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 425, loss: 0.239\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 426, loss: 0.238\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 427, loss: 0.237\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 428, loss: 0.236\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 429, loss: 0.235\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 430, loss: 0.234\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 431, loss: 0.234\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 432, loss: 0.233\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 433, loss: 0.232\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 434, loss: 0.232\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 435, loss: 0.231\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 436, loss: 0.231\n", | |
"predicted sent: gaython 3.\n", | |
"epoch: 437, loss: 0.230\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 438, loss: 0.230\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 439, loss: 0.229\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 440, loss: 0.229\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 441, loss: 0.229\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 442, loss: 0.229\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 443, loss: 0.228\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 444, loss: 0.228\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 445, loss: 0.228\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 446, loss: 0.227\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 447, loss: 0.227\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 448, loss: 0.227\n", | |
"predicted sent: taython 3.\n", | |
"epoch: 449, loss: 0.227\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 450, loss: 0.227\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 451, loss: 0.226\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 452, loss: 0.226\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 453, loss: 0.226\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 454, loss: 0.226\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 455, loss: 0.226\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 456, loss: 0.225\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 457, loss: 0.225\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 458, loss: 0.225\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 459, loss: 0.225\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 460, loss: 0.225\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 461, loss: 0.225\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 462, loss: 0.225\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 463, loss: 0.225\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 464, loss: 0.224\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 465, loss: 0.224\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 466, loss: 0.224\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 467, loss: 0.224\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 468, loss: 0.224\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 469, loss: 0.224\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 470, loss: 0.224\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 471, loss: 0.224\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 472, loss: 0.224\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 473, loss: 0.224\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 474, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 475, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 476, loss: 0.223\n", | |
"predicted sent: saython 3.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch: 477, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 478, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 479, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 480, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 481, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 482, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 483, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 484, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 485, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 486, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 487, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 488, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 489, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 490, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 491, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 492, loss: 0.223\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 493, loss: 0.222\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 494, loss: 0.222\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 495, loss: 0.222\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 496, loss: 0.222\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 497, loss: 0.222\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 498, loss: 0.222\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 499, loss: 0.222\n", | |
"predicted sent: saython 3.\n", | |
"epoch: 500, loss: 0.222\n", | |
"predicted sent: saython 3.\n", | |
">>> end\n" | |
] | |
} | |
], | |
"source": [ | |
"# Train the model\n", | |
"print('start >>>')\n", | |
"for epoch in range(num_epoch):\n", | |
" outputs = lstm(input_data)\n", | |
" optimizer.zero_grad()\n", | |
" loss = loss_fn(outputs, labels.view(-1))\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" ## the predicted indices of the next character\n", | |
" _, idx = outputs.max(1)\n", | |
" idx = idx.data.numpy()\n", | |
" idx = idx.reshape(-1, sequence_length) # (170,10)\n", | |
" # display the last sequence\n", | |
" result_str = [idx2char[c] for c in idx[-1]]\n", | |
" print(\"epoch: {}, loss: {:1.3f}\".format(epoch + 1, loss.data[0]))\n", | |
" print(\"predicted sent: \", ''.join(result_str))\n", | |
"\n", | |
"print('>>> end')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment