Skip to content

Instantly share code, notes, and snippets.

@z-a-f
Created March 18, 2021 10:14
Show Gist options
  • Save z-a-f/d63074e6470557046453cc0f107e48b4 to your computer and use it in GitHub Desktop.
Save z-a-f/d63074e6470557046453cc0f107e48b4 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "edfbxDDh2AEs"
},
"source": [
"## Generate Shakespearian Scripts using PyTorch\n",
"\n",
"This tutorial is adapted from this github repository: [albertlai431](https://github.com/albertlai431/Machine-Learning/tree/master/Text%20Generation)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"from contextlib import contextmanager\n",
"import copy\n",
"import itertools\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import sys\n",
"\n",
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"import torch.quantization as tq\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "KRQ6Fjra3Ruq"
},
"source": [
"### Download data\n",
"\n",
"Download *The Complete Works of William Shakespeare* as a single text file from [Project Gutenberg](https://www.gutenberg.org/)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 208
},
"colab_type": "code",
"id": "j8sIXh1DEDDd",
"outputId": "79f9cbb8-98ae-4c5e-c8d4-2ed31eae4d93"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2021-03-18 02:50:58-- http://www.gutenberg.org/files/100/100-0.txt\n",
"Resolving www.gutenberg.org (www.gutenberg.org)... 2610:28:3090:3000:0:bad:cafe:47, 152.19.134.47\n",
"Connecting to www.gutenberg.org (www.gutenberg.org)|2610:28:3090:3000:0:bad:cafe:47|:80... connected.\n",
"HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable\n",
"\n",
" The file is already fully retrieved; nothing to do.\n",
"\n"
]
}
],
"source": [
"!wget --show-progress --continue -O ./shakespeare.txt http://www.gutenberg.org/files/100/100-0.txt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The Project Gutenberg eBook of The Complete Works of William Shakespeare, by William Shakespeare\n",
"\n",
"This eBook is for the use of anyone anywhere in the United States and\n",
"most other parts of the world at no cost and with almost no restrictions\n",
"whatsoev\n"
]
}
],
"source": [
"with open('shakespeare.txt', 'r') as f:\n",
" text = f.read()\n",
"\n",
"# Showing the first 250 characters\n",
"print(text[:250])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tokenize the data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 70 41 91 61 28 26 19 94 27 61 59 45 28 3 71 45 61 95\n",
" 32 61 19 100 28 61 89 94 94 20 28 94 72 28 41 91 61 28\n",
" 64 94 24 14 7 61 45 61 28 106 94 19 20 105 28 94 72 28\n",
" 106 77 7 7 77 52 24 28 97 91 52 20 61 105 14 61 52 19\n",
" 61 0 28 32 21 28 106 77 7 7 77 52 24 28 97 91 52 20\n",
" 61 105 14 61 52 19 61 50 50 41 91 77 105 28 61 89 94 94\n",
" 20 28 77 105 28 72 94 19 28 45 91 61 28 71 105 61 28 94\n",
" 72 28 52 95 21 94 95 61 28 52 95 21 22 91 61 19 61 28\n",
" 77 95 28 45 91 61 28 104 95 77 45 61 18 28 97 45 52 45\n",
" 61 105 28 52 95 18 50 24 94 105 45 28 94 45 91 61 19 28\n",
" 14 52 19 45 105 28 94 72 28 45 91 61 28 22 94 19 7 18\n",
" 28 52 45 28 95 94 28 59 94 105 45 28 52 95 18 28 22 77\n",
" 45 91 28 52 7 24 94 105 45 28 95 94 28 19 61 105 45 19\n",
" 77 59 45 77 94 95 105 50 22 91 52 45 105 94 61 98]\n"
]
}
],
"source": [
"# encoding the text and map each character to an integer and vice versa\n",
"\n",
"# We create two dictionaries:\n",
"# 1. int2char, which maps integers to characters\n",
"# 2. char2int, which maps characters to integers\n",
"chars = tuple(set(text))\n",
"int2char = dict(enumerate(chars))\n",
"char2int = {ch: ii for ii, ch in int2char.items()}\n",
"\n",
"# Encode the text\n",
"encoded = np.array([char2int[ch] for ch in text])\n",
"\n",
"# Showing the first 100 encoded characters\n",
"print(encoded[:250])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Utilities"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Defining method to encode one hot labels\n",
"def one_hot_encode(arr, n_labels):\n",
" \n",
" # Initialize the the encoded array\n",
" one_hot = np.zeros((np.multiply(*arr.shape), n_labels), dtype=np.float32)\n",
" \n",
" # Fill the appropriate elements with ones\n",
" one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.\n",
" \n",
" # Finally reshape it to get back to the original array\n",
" one_hot = one_hot.reshape((*arr.shape, n_labels))\n",
" \n",
" return one_hot\n",
" \n",
"# Defining method to make mini-batches for training\n",
"def get_batches(arr, batch_size, seq_length):\n",
" '''Create a generator that returns batches of size\n",
" batch_size x seq_length from arr.\n",
" \n",
" Arguments\n",
" ---------\n",
" arr: Array you want to make batches from\n",
" batch_size: Batch size, the number of sequences per batch\n",
" seq_length: Number of encoded chars in a sequence\n",
" '''\n",
" \n",
" batch_size_total = batch_size * seq_length\n",
" # total number of batches we can make\n",
" n_batches = len(arr)//batch_size_total\n",
" \n",
" # Keep only enough characters to make full batches\n",
" arr = arr[:n_batches * batch_size_total]\n",
" # Reshape into batch_size rows\n",
" arr = arr.reshape((batch_size, -1))\n",
" \n",
" # iterate through the array, one sequence at a time\n",
" for n in range(0, arr.shape[1], seq_length):\n",
" # The features\n",
" x = arr[:, n:n+seq_length]\n",
" # The targets, shifted by one\n",
" y = np.zeros_like(x)\n",
" try:\n",
" y[:, :-1], y[:, -1] = x[:, 1:], arr[:, n+seq_length]\n",
" except IndexError:\n",
" y[:, :-1], y[:, -1] = x[:, 1:], arr[:, 0]\n",
" yield x, y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define the Model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Declaring the model\n",
"class CharRNN(nn.Module):\n",
" \n",
" def __init__(self, tokens, n_hidden=256, n_layers=2, drop_prob=0.5):\n",
" super().__init__()\n",
" self.drop_prob = drop_prob\n",
" self.n_layers = n_layers\n",
" self.n_hidden = n_hidden\n",
" \n",
" # creating character dictionaries\n",
" self.chars = tokens\n",
" self.int2char = dict(enumerate(self.chars))\n",
" self.char2int = {ch: ii for ii, ch in self.int2char.items()}\n",
" \n",
" # Define quantizers\n",
" self.quant_x = tq.QuantStub()\n",
" self.quant_h = nn.ModuleList([tq.QuantStub(), tq.QuantStub()])\n",
" \n",
" # Define the LSTM\n",
" self.lstm = nn.LSTM(len(self.chars), n_hidden, n_layers, \n",
" dropout=drop_prob, batch_first=True)\n",
"\n",
" # Define a dropout layer\n",
" self.dropout = nn.Dropout(drop_prob)\n",
"\n",
" # Define the final, fully-connected output layer\n",
" self.fc = nn.Linear(n_hidden, len(self.chars))\n",
" \n",
" # Define dequantizers\n",
" self.dequant = tq.DeQuantStub()\n",
" \n",
" def forward(self, x, h):\n",
" ''' Forward pass through the network. \n",
" These inputs are x, and the hidden/cell state `hidden`. '''\n",
" # Quantize the inputs\n",
" qx = self.quant_x(x)\n",
" qh = [self.quant_h[0](h[0]), self.quant_h[1](h[1])]\n",
" \n",
" #get the outputs and the new hidden state from the lstm\n",
" r_output, h = self.lstm(qx, qh)\n",
" \n",
" #pass through a dropout layer\n",
" out = self.dropout(r_output)\n",
" \n",
" # Stack up LSTM outputs using view\n",
" out = out.contiguous().view(-1, self.n_hidden)\n",
" \n",
" #put x through the fully-connected layer\n",
" out = self.fc(out)\n",
" \n",
" out = self.dequant(out)\n",
" h = self.dequant(h[0]), self.dequant(h[1])\n",
" # return the final output and the hidden state\n",
" return out, h\n",
" \n",
"\n",
"def init_hidden(model, batch_size, device='cpu'):\n",
" ''' Initializes hidden state '''\n",
" # Create two new tensors with sizes n_layers x batch_size x n_hidden,\n",
" # initialized to zero, for hidden state and cell state of LSTM\n",
" hidden = (\n",
" torch.zeros(model.n_layers, batch_size, model.n_hidden, device=device),\n",
" torch.zeros(model.n_layers, batch_size, model.n_hidden, device=device),\n",
" )\n",
" return hidden"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CharRNN(\n",
" (quant_x): QuantStub()\n",
" (quant_h): ModuleList(\n",
" (0): QuantStub()\n",
" (1): QuantStub()\n",
" )\n",
" (lstm): LSTM(108, 512, num_layers=2, batch_first=True, dropout=0.5)\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
" (fc): Linear(in_features=512, out_features=108, bias=True)\n",
" (dequant): DeQuantStub()\n",
")\n"
]
}
],
"source": [
"# Define and print the net\n",
"n_hidden=512\n",
"n_layers=2\n",
"\n",
"net = CharRNN(chars, n_hidden, n_layers)\n",
"print(net)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Declaring the train method\n",
"def train(net, data, epochs=10, batch_size=10, seq_length=50,\n",
" lr=0.001, clip=5, val_frac=0.1, print_every=10, device='cpu'):\n",
" ''' Training a network \n",
" \n",
" Arguments\n",
" ---------\n",
" \n",
" net: CharRNN network\n",
" data: text data to train the network\n",
" epochs: Number of epochs to train\n",
" batch_size: Number of mini-sequences per mini-batch, aka batch size\n",
" seq_length: Number of character steps per mini-batch\n",
" lr: learning rate\n",
" clip: gradient clipping\n",
" val_frac: Fraction of data to hold out for validation\n",
" print_every: Number of steps for printing training and validation loss\n",
" device: Device to work on\n",
" \n",
" '''\n",
" spinner = itertools.cycle(['|', '/', '-', '\\\\']) # A \"busy\" spinner\n",
" \n",
" old_training = net.training\n",
" net.train()\n",
" \n",
" opt = torch.optim.Adam(net.parameters(), lr=lr)\n",
" criterion = nn.CrossEntropyLoss()\n",
" \n",
" # create training and validation data\n",
" val_idx = int(len(data)*(1-val_frac))\n",
" data, val_data = data[:val_idx], data[val_idx:]\n",
" \n",
" net.to(device)\n",
" \n",
" n_chars = len(net.chars)\n",
" n_batches = len(data) // (batch_size * seq_length)\n",
" dot_every = n_batches // 100 + 1\n",
" \n",
" best_model = {\n",
" 'epoch': -1,\n",
" 'train_loss': float('inf'),\n",
" 'val_loss': float('inf'),\n",
" 'model': None\n",
" }\n",
" \n",
" history = {\n",
" 'epoch': [],\n",
" 'train_loss': [],\n",
" 'val_loss': [],\n",
" }\n",
"\n",
" for e in range(epochs):\n",
" print(f'{e+1:>03}/{n_epochs:>03}', end='', flush=True)\n",
" history['epoch'].append(e)\n",
" # initialize hidden state\n",
" h = init_hidden(net, batch_size, device)\n",
" \n",
" # Train Phase\n",
" counter = 0\n",
" running_loss = 0.0\n",
" net.train()\n",
" for x, y in get_batches(data, batch_size, seq_length):\n",
" counter += 1\n",
" \n",
" # One-hot encode our data and make them Torch tensors\n",
" x = one_hot_encode(x, n_chars)\n",
" inputs, targets = torch.from_numpy(x), torch.from_numpy(y)\n",
" inputs, targets = inputs.to(device), targets.to(device)\n",
"\n",
" # Creating new variables for the hidden state, otherwise\n",
" # we'd backprop through the entire training history\n",
" h = tuple([each.data for each in h])\n",
"\n",
" # zero accumulated gradients\n",
" net.zero_grad()\n",
" \n",
" # get the output from the model\n",
" output, h = net(inputs, h)\n",
" \n",
" # calculate the loss and perform backprop\n",
" loss = criterion(output, targets.view(batch_size*seq_length).long())\n",
" loss.backward()\n",
" # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.\n",
" nn.utils.clip_grad_norm_(net.parameters(), clip)\n",
" opt.step()\n",
" \n",
" running_loss += loss.item()\n",
" \n",
" if counter % (10 * dot_every) == 0:\n",
" print(':', end='', flush=True)\n",
" elif counter % dot_every == 0:\n",
" print('.', end='', flush=True)\n",
" loss = running_loss / counter\n",
" history['train_loss'].append(loss)\n",
" \n",
" print() \n",
" # Eval Phase\n",
" val_h = init_hidden(net, batch_size, device)\n",
" val_losses = []\n",
" running_loss = 0.0\n",
" counter = 0\n",
" net.eval()\n",
" for x, y in get_batches(val_data, batch_size, seq_length):\n",
" sys.stdout.write('\\r\\033[K')\n",
" print(next(spinner) + ' Validating', end='', flush=True)\n",
" # One-hot encode our data and make them Torch tensors\n",
" x = one_hot_encode(x, n_chars)\n",
" x, y = torch.from_numpy(x), torch.from_numpy(y)\n",
"\n",
" # Creating new variables for the hidden state, otherwise\n",
" # we'd backprop through the entire training history\n",
" val_h = tuple([each.data for each in val_h])\n",
"\n",
" inputs, targets = x.to(device), y.to(device)\n",
"\n",
" output, val_h = net(inputs, val_h)\n",
" val_loss = criterion(output, targets.view(batch_size*seq_length).long())\n",
"\n",
" val_losses.append(val_loss.item())\n",
" val_loss = np.mean(val_losses)\n",
" history['val_loss'].append(val_loss)\n",
" \n",
" if val_loss <= best_model['val_loss']:\n",
" best_model.update({\n",
" 'epoch': e,\n",
" 'train_loss': loss,\n",
" 'val_loss': val_loss,\n",
" 'model': copy.deepcopy(net)\n",
" })\n",
"\n",
" sys.stdout.write('\\r\\033[K')\n",
" print(f\"\\r\\033[K\\tTrain Loss: {loss:.4f}; Val Loss: {val_loss:.4f}\")\n",
" \n",
" net.train(old_training)\n",
" return best_model, history"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"001/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 3.0595; Val Loss: 2.3688\n",
"002/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 2.1082; Val Loss: 1.9720\n",
"003/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.8323; Val Loss: 1.8137\n",
"004/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.6878; Val Loss: 1.7265\n",
"005/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.5914; Val Loss: 1.6699\n",
"006/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.5226; Val Loss: 1.6317\n",
"007/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.4717; Val Loss: 1.6036\n",
"008/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.4326; Val Loss: 1.5848\n",
"009/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.4015; Val Loss: 1.5675\n",
"010/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.3763; Val Loss: 1.5584\n",
"011/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.3558; Val Loss: 1.5483\n",
"012/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.3378; Val Loss: 1.5427\n",
"013/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.3220; Val Loss: 1.5299\n",
"014/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.3089; Val Loss: 1.5227\n",
"015/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.2966; Val Loss: 1.5197\n",
"016/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.2860; Val Loss: 1.5135\n",
"017/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.2760; Val Loss: 1.5106\n",
"018/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.2673; Val Loss: 1.5089\n",
"019/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.2591; Val Loss: 1.5074\n",
"020/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.2517; Val Loss: 1.5037\n",
"021/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.2446; Val Loss: 1.5053\n",
"022/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.2381; Val Loss: 1.5049\n",
"023/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.2324; Val Loss: 1.4982\n",
"024/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.2269; Val Loss: 1.5017\n",
"025/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.2220; Val Loss: 1.5013\n",
"026/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.2168; Val Loss: 1.4991\n",
"027/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.2120; Val Loss: 1.4999\n",
"028/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.2074; Val Loss: 1.4988\n",
"029/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.2030; Val Loss: 1.4979\n",
"030/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1995; Val Loss: 1.5026\n",
"031/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1955; Val Loss: 1.5009\n",
"032/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1922; Val Loss: 1.5033\n",
"033/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1889; Val Loss: 1.5028\n",
"034/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1853; Val Loss: 1.5030\n",
"035/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1820; Val Loss: 1.5055\n",
"036/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1791; Val Loss: 1.5001\n",
"037/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1764; Val Loss: 1.5048\n",
"038/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1733; Val Loss: 1.5088\n",
"039/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1704; Val Loss: 1.5067\n",
"040/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1678; Val Loss: 1.5100\n",
"041/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1653; Val Loss: 1.5169\n",
"042/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1628; Val Loss: 1.5155\n",
"043/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1605; Val Loss: 1.5194\n",
"044/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1581; Val Loss: 1.5144\n",
"045/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1555; Val Loss: 1.5217\n",
"046/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1535; Val Loss: 1.5199\n",
"047/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1514; Val Loss: 1.5230\n",
"048/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1490; Val Loss: 1.5241\n",
"049/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1469; Val Loss: 1.5237\n",
"050/050.........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"\u001b[K\tTrain Loss: 1.1451; Val Loss: 1.5242\n"
]
}
],
"source": [
"# Declaring the hyperparameters\n",
"batch_size = 128\n",
"seq_length = 100\n",
"n_epochs = 50 # start smaller if you are just testing initial behavior\n",
"\n",
"# train the model\n",
"best_model, history = train(net, encoded, epochs=n_epochs, batch_size=batch_size, seq_length=seq_length,\n",
" lr=0.001, print_every=50, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history['epoch'], history['train_loss'], label='Train')\n",
"plt.plot(history['epoch'], history['val_loss'], label='Val')\n",
"\n",
"min_loss_epoch = np.argmin(history['val_loss'])\n",
"min_loss = history['val_loss'][min_loss_epoch]\n",
"plt.scatter([min_loss_epoch], [min_loss], c='g', label='Min Loss')\n",
"plt.annotate(f'({min_loss_epoch}, {min_loss:.2f})', (min_loss_epoch, min_loss))\n",
"\n",
"plt.legend()\n",
"plt.grid()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# Saving the model\n",
"model_name = r'rnn_best_model.pt'\n",
"if isinstance(best_model['model'], nn.Module):\n",
" net = best_model['model']\n",
" best_model['model'] = best_model['model'].state_dict()\n",
"with open(model_name, 'wb') as f:\n",
" torch.save(best_model, f)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# Defining a method to generate the next character\n",
"def predict(net, char, h=None, top_k=None, device='cpu'):\n",
" ''' Given a character, predict the next character.\n",
" Returns the predicted character and the hidden state.\n",
" '''\n",
" # tensor inputs\n",
" x = np.array([[net.char2int[char]]])\n",
" x = one_hot_encode(x, len(net.chars))\n",
" inputs = torch.from_numpy(x)\n",
" \n",
" inputs = inputs.to(device)\n",
" \n",
" # detach hidden state from history\n",
" h = tuple([each.data for each in h])\n",
" # get the output of the model\n",
" out, h = net(inputs, h)\n",
"\n",
" # get the character probabilities\n",
" p = F.softmax(out, dim=1).data\n",
" p = p.cpu() # move to cpu\n",
" \n",
" # get top characters\n",
" if top_k is None:\n",
" top_ch = np.arange(len(net.chars))\n",
" else:\n",
" p, top_ch = p.topk(top_k)\n",
" top_ch = top_ch.numpy().squeeze()\n",
" \n",
" # select the likely next character with some element of randomness\n",
" p = p.numpy().squeeze()\n",
" char = np.random.choice(top_ch, p=p/p.sum())\n",
" \n",
" # return the encoded value of the predicted char and the hidden state\n",
" return net.int2char[char], h\n",
" \n",
"# Declaring a method to generate new text\n",
"def sample(net, size, prime='The', top_k=None, device='cpu'):\n",
" net = net.to(device)\n",
" \n",
" net.eval() # eval mode\n",
" \n",
" # First off, run through the prime characters\n",
" chars = [ch for ch in prime]\n",
" h = init_hidden(net, 1, device)\n",
" for ch in prime:\n",
" char, h = predict(net, ch, h, top_k=top_k, device=device)\n",
"\n",
" chars.append(char)\n",
" \n",
" # Now pass in the previous character and get a new one\n",
" for ii in range(size):\n",
" char, h = predict(net, chars[-1], h, top_k=top_k, device=device)\n",
" chars.append(char)\n",
"\n",
" return ''.join(chars)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testary\n",
"\n",
"A TOMON.\n",
"Then stays not where to make his honour beat\n",
"A trooping thoughts of season to the world,\n",
"And hath been he is brings against the womb\n",
"As thou wert letter; and thou wert to saw\n",
"To the sea such a second conscience,\n",
"To hear thee said; and thou stand strives on me.\n",
"\n",
" Enter a Servant.\n",
"\n",
"SIR TOBY.\n",
"He shall have the worse the world.\n",
"\n",
"HAMLET.\n",
"And will the streets be taken, she is not so? Who came to the\n",
"minds of a sun, who will share make a good best and time? Alas, this\n",
"whole head and all as we will be a commonwealth of the minds and all\n",
"hers. This stale is a star there a song where you speak truly; and we are\n",
"not he hath shall be his. I’ll see you here welcome. His more\n",
"welp in them, I shall have some state as the children of a cold.\n",
"\n",
"SIR TOBY.\n",
"Why, they shall have my blessing will not then, but I’ll shall set it to\n",
"him. This will send them to the moon and the morning. This shall have\n",
"a couparine of the world.\n",
"\n",
"HAMLET.\n",
"I have not dead, and see thou art a chain that holding thy day\n"
]
}
],
"source": [
"# Generating new text\n",
"print(sample(net, 1000, prime='Test', top_k=5, device=device))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Quantizing the model"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import torch.quantization as tq"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CharRNN(\n",
" (quant_x): QuantStub()\n",
" (quant_h): ModuleList(\n",
" (0): QuantStub()\n",
" (1): QuantStub()\n",
" )\n",
" (lstm): LSTM(108, 512, num_layers=2, batch_first=True, dropout=0.5)\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
" (fc): Linear(in_features=512, out_features=108, bias=True)\n",
" (dequant): DeQuantStub()\n",
")\n"
]
}
],
"source": [
"model_name = r'rnn_best_model.pt'\n",
"best_model = torch.load(model_name)\n",
"n_layers = 0\n",
"n_hidden = 0\n",
"for key, tensor in best_model['model'].items():\n",
" key = key.split('_')\n",
" layer_type, key_type = key[0].split('.')\n",
" if layer_type == 'fc':\n",
" continue\n",
" \n",
" layer_idx = int(key[2][1])\n",
" n_layers = max(n_layers, layer_idx)\n",
" \n",
" if key_type == 'weight' and key[1] == 'hh':\n",
" n_hidden = tensor.shape[1]\n",
"n_layers += 1\n",
"\n",
"net = CharRNN(chars, n_hidden, n_layers).cpu()\n",
"net.load_state_dict(best_model['model'])\n",
"print(net)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Because the LSTM is setup as a custom module, you need to setup a configuration."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"custom_module_config = {\n",
" 'float_to_observed_custom_module_class': {\n",
" # What to do when converting a floating point layer\n",
" torch.nn.LSTM: torch.nn.quantizable.LSTM,\n",
" },\n",
" 'observed_to_quantized_custom_module_class': {\n",
" # What to do after the observation is complete\n",
" torch.nn.quantizable.LSTM: torch.nn.quantizable.LSTM\n",
" }\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The utility below also helps setting up the quantization engine."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"qengine = 'fbgemm'\n",
"\n",
"@contextmanager\n",
"def override_quantized_engine(qengine):\n",
" previous = torch.backends.quantized.engine\n",
" torch.backends.quantized.engine = qengine\n",
" try:\n",
" if qengine == 'qnnpack':\n",
" torch._C._set_default_mobile_cpu_allocator()\n",
" yield\n",
" finally:\n",
" if qengine == 'qnnpack':\n",
" torch._C._unset_default_mobile_cpu_allocator()\n",
" torch.backends.quantized.engine = previous"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CharRNN(\n",
" (quant_x): QuantStub(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (quant_h): ModuleList(\n",
" (0): QuantStub(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (1): QuantStub(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" )\n",
" (lstm): QuantizableLSTM(\n",
" (layers): ModuleList(\n",
" (0): _LSTMLayer(\n",
" (layer_fw): _LSTMSingleLayer(\n",
" (cell): QuantizableLSTMCell(\n",
" (igates): Linear(\n",
" in_features=108, out_features=2048, bias=True\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (hgates): Linear(\n",
" in_features=512, out_features=2048, bias=True\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (gates): FloatFunctional(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (fgate_cx): FloatFunctional(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (igate_cgate): FloatFunctional(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (fgate_cx_igate_cgate): FloatFunctional(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (ogate_cy): FloatFunctional(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (1): _LSTMLayer(\n",
" (layer_fw): _LSTMSingleLayer(\n",
" (cell): QuantizableLSTMCell(\n",
" (igates): Linear(\n",
" in_features=512, out_features=2048, bias=True\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (hgates): Linear(\n",
" in_features=512, out_features=2048, bias=True\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (gates): FloatFunctional(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (fgate_cx): FloatFunctional(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (igate_cgate): FloatFunctional(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (fgate_cx_igate_cgate): FloatFunctional(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (ogate_cy): FloatFunctional(\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" )\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
" (fc): Linear(\n",
" in_features=512, out_features=108, bias=True\n",
" (activation_post_process): HistogramObserver()\n",
" )\n",
" (dequant): DeQuantStub()\n",
")\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/zafar/Git/pytorch-dev/pytorch/torch/quantization/observer.py:123: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.\n",
" reduce_range will be deprecated in a future release of PyTorch.\"\n",
"/home/zafar/Git/pytorch-dev/pytorch/torch/nn/quantizable/modules/rnn.py:320: UserWarning: dropout option for quantizable LSTM is ignored. If you are training, please, use nn.LSTM version followed by `prepare` step.\n",
" warnings.warn(\"dropout option for quantizable LSTM is ignored. \"\n"
]
}
],
"source": [
"net.qconfig = tq.get_default_qconfig(qengine)\n",
"net_prepared = tq.prepare(net, prepare_custom_config_dict=custom_module_config, inplace=False)\n",
"print(net_prepared)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Calibrate"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Calibrating .........:.........:.........:.........:.........:.........:.........:.........:.........:.......\n",
"Calibration Loss: 1.09\n"
]
}
],
"source": [
"device = torch.device('cuda')\n",
"net_calibrated = copy.deepcopy(net_prepared).eval().to(device)\n",
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"# Hyperparameters\n",
"batch_size = 128\n",
"seq_length = 100\n",
"n_chars = len(net_calibrated.chars)\n",
"h = init_hidden(net, batch_size, device)\n",
"\n",
"# Split the data, and calibrate only on the training\n",
"val_frac = 0.1\n",
"val_idx = int(len(encoded)*(1-val_frac))\n",
"train_encoded, val_encoded = encoded[:val_idx], encoded[val_idx:]\n",
"\n",
"print('Calibrating', end=' ', flush=True)\n",
"dot_every = 4 # This is for logging and printing\n",
"counter = 0\n",
"running_loss = 0.0\n",
"for x, y in get_batches(train_encoded, batch_size, seq_length):\n",
" counter += 1\n",
"\n",
" x = one_hot_encode(x, n_chars)\n",
" inputs, targets = torch.from_numpy(x), torch.from_numpy(y)\n",
" inputs, targets = inputs.to(device), targets.to(device)\n",
"\n",
" h = tuple([each.data.to(device) for each in h])\n",
" output, h = net_calibrated(inputs, h)\n",
"\n",
" loss = criterion(output, targets.view(batch_size*seq_length).long())\n",
" running_loss += loss.item()\n",
"\n",
" if counter % (10 * dot_every) == 0:\n",
" print(':', end='', flush=True)\n",
" elif counter % dot_every == 0:\n",
" print('.', end='', flush=True)\n",
"print()\n",
"print(f'Calibration Loss: {running_loss / counter:.2f}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Convert"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CharRNN(\n",
" (quant_x): Quantize(scale=tensor([0.0079]), zero_point=tensor([0]), dtype=torch.quint8)\n",
" (quant_h): ModuleList(\n",
" (0): Quantize(scale=tensor([0.0159]), zero_point=tensor([63]), dtype=torch.quint8)\n",
" (1): Quantize(scale=tensor([2.1360]), zero_point=tensor([85]), dtype=torch.quint8)\n",
" )\n",
" (lstm): QuantizedLSTM(\n",
" (layers): ModuleList(\n",
" (0): _LSTMLayer(\n",
" (layer_fw): _LSTMSingleLayer(\n",
" (cell): QuantizedLSTMCell(\n",
" (igates): QuantizedLinear(in_features=108, out_features=2048, scale=0.04310886934399605, zero_point=64, qscheme=torch.per_channel_affine)\n",
" (hgates): QuantizedLinear(in_features=512, out_features=2048, scale=0.10811269283294678, zero_point=64, qscheme=torch.per_channel_affine)\n",
" (gates): QFunctional(\n",
" scale=0.11146293580532074, zero_point=63\n",
" (activation_post_process): Identity()\n",
" )\n",
" (fgate_cx): QFunctional(\n",
" scale=0.03302617743611336, zero_point=60\n",
" (activation_post_process): Identity()\n",
" )\n",
" (igate_cgate): QFunctional(\n",
" scale=0.015441657043993473, zero_point=64\n",
" (activation_post_process): Identity()\n",
" )\n",
" (fgate_cx_igate_cgate): QFunctional(\n",
" scale=0.04679515212774277, zero_point=59\n",
" (activation_post_process): Identity()\n",
" )\n",
" (ogate_cy): QFunctional(\n",
" scale=0.014196028932929039, zero_point=70\n",
" (activation_post_process): Identity()\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (1): _LSTMLayer(\n",
" (layer_fw): _LSTMSingleLayer(\n",
" (cell): QuantizedLSTMCell(\n",
" (igates): QuantizedLinear(in_features=512, out_features=2048, scale=0.20001915097236633, zero_point=63, qscheme=torch.per_channel_affine)\n",
" (hgates): QuantizedLinear(in_features=512, out_features=2048, scale=0.23398037254810333, zero_point=69, qscheme=torch.per_channel_affine)\n",
" (gates): QFunctional(\n",
" scale=0.2735488712787628, zero_point=71\n",
" (activation_post_process): Identity()\n",
" )\n",
" (fgate_cx): QFunctional(\n",
" scale=2.2412078380584717, zero_point=84\n",
" (activation_post_process): Identity()\n",
" )\n",
" (igate_cgate): QFunctional(\n",
" scale=0.016878196969628334, zero_point=59\n",
" (activation_post_process): Identity()\n",
" )\n",
" (fgate_cx_igate_cgate): QFunctional(\n",
" scale=2.240091562271118, zero_point=85\n",
" (activation_post_process): Identity()\n",
" )\n",
" (ogate_cy): QFunctional(\n",
" scale=0.016556110233068466, zero_point=60\n",
" (activation_post_process): Identity()\n",
" )\n",
" )\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
" (fc): QuantizedLinear(in_features=512, out_features=108, scale=0.4080132246017456, zero_point=74, qscheme=torch.per_channel_affine)\n",
" (dequant): DeQuantize()\n",
")\n"
]
}
],
"source": [
"net_calibrated = net_calibrated.cpu()\n",
"net_quantized = tq.convert(net_calibrated, convert_custom_config_dict=custom_module_config, inplace=False)\n",
"print(net_quantized)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Validating .........:\n",
"Validation Loss: 4.30\n"
]
}
],
"source": [
"print('Validating', end=' ', flush=True)\n",
"dot_every = 4 # This is for logging and printing\n",
"counter = 0\n",
"running_loss = 0.0\n",
"h = init_hidden(net, batch_size, 'cpu')\n",
"for x, y in get_batches(val_encoded, batch_size, seq_length):\n",
" counter += 1\n",
"\n",
" x = one_hot_encode(x, n_chars)\n",
" inputs, targets = torch.from_numpy(x), torch.from_numpy(y)\n",
"\n",
" h = tuple([each.data for each in h])\n",
" output, h = net_quantized(inputs, h)\n",
"\n",
" loss = criterion(output, targets.view(batch_size*seq_length).long())\n",
" running_loss += loss.item()\n",
"\n",
" if counter % (10 * dot_every) == 0:\n",
" print(':', end='', flush=True)\n",
" elif counter % dot_every == 0:\n",
" print('.', end='', flush=True)\n",
"print()\n",
"print(f'Validation Loss: {running_loss / counter:.2f}')"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TestTtTt tttn tnTnTn tTaTnt nanaTTa att tt a Tat T tn n tnnant nTn t TtTntTt aaTa n aanTntaTtttT aT ta tttTnataaT t Tn nnt a nat a an tttat an nTTaTaaTtn T T tnn tna aa n TT nn atnn nnT n aa nn na tnt T Ttat nTnana tan nn ttttTTa TTantatn aTa tT nnTT TnTn aTat atnT tTa n taaa att aTTn n t T TTa nataanntTt a attt tnnn t Tat ntn a T ttnat aattnTt Tn nT Tat anTttan naa tatnTnttTnTt n TnnaTnaata nnn n tnatnna ttt Tn TaTTTann naaant taanatTTaaT Tt ttt taa TnTa TntaanTTnaaa naann atn nntTan Ttn nTnttanTn a a Tn t taaattaTTT na tnTTnna taa ta ntnaaTnn nttT nn Tn TanT tn ttn nnn naatnn a nantt ta ta taaT aTTa TTnT TanTanttT ntatT n TTa Ta nn aTnT t n aaaT t Tt aaT TTnaTttt tn nan tna a na nT nt nnat Tt att na T tTtntaTaat T nant anTaTatana ntaa tT aaT ata tTnnTnn TTat naaatnat TtatnTTanTntt t nT aT Ttnnannt nat n Ta T ttn anTntaaTna nn nnttt t na n a a aaanTt n a tTT t t tnnaTTTaa aa t TtaattTn Taan aTnn TTtTTn a \n"
]
}
],
"source": [
"# Generating new text\n",
"print(sample(net_quantized, 1000, prime='Test', top_k=5, device='cpu'))"
]
}
],
"metadata": {
"accelerator": "TPU",
"colab": {
"collapsed_sections": [
"N6ZDpd9XzFeN"
],
"name": "Predict Shakespeare with Cloud TPUs and Keras",
"provenance": [],
"version": "0.3.2"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment