Skip to content

Instantly share code, notes, and snippets.

@NTT123
Created November 25, 2020 17:47
Show Gist options
  • Save NTT123/fc7bc21821bc0d73849e43c8b3296b71 to your computer and use it in GitHub Desktop.
Save NTT123/fc7bc21821bc0d73849e43c8b3296b71 to your computer and use it in GitHub Desktop.
Generate Novel Names.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "Generate Novel Names.ipynb",
"provenance": [],
"collapsed_sections": [
"Tf_Tk5mg7LxY",
"PdObT-3j9_gp"
],
"authorship_tag": "ABX9TyM1oeDRjwe13u3vM4Q7QxmH",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"6a1742e2c96347feb0011bdf636679b4": {
"model_module": "@jupyter-widgets/controls",
"model_name": "IntSliderModel",
"state": {
"_view_name": "IntSliderView",
"style": "IPY_MODEL_39d9edbfc62a49c7b6969e9d1c68cf20",
"_dom_classes": [],
"description": "nucleus prob.",
"step": 1,
"_model_name": "IntSliderModel",
"orientation": "horizontal",
"max": 100,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 80,
"_view_count": null,
"disabled": false,
"_view_module_version": "1.5.0",
"min": 10,
"continuous_update": true,
"readout_format": "d",
"description_tooltip": null,
"readout": true,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_195bcdd694be4315933f442ece2e5f75"
}
},
"39d9edbfc62a49c7b6969e9d1c68cf20": {
"model_module": "@jupyter-widgets/controls",
"model_name": "SliderStyleModel",
"state": {
"_view_name": "StyleView",
"handle_color": null,
"_model_name": "SliderStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"195bcdd694be4315933f442ece2e5f75": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"38dc6ef373c1489d9ba577a4cfdca079": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ButtonModel",
"state": {
"_view_name": "ButtonView",
"style": "IPY_MODEL_356a241447fd4d4c857eecf556ff4e40",
"_dom_classes": [],
"description": "Generate!",
"_model_name": "ButtonModel",
"button_style": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"tooltip": "",
"_view_count": null,
"disabled": false,
"_view_module_version": "1.5.0",
"layout": "IPY_MODEL_40c1c4e5cf9f4b2a9df93bceca159210",
"_model_module": "@jupyter-widgets/controls",
"icon": ""
}
},
"356a241447fd4d4c857eecf556ff4e40": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ButtonStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ButtonStyleModel",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"button_color": null,
"font_weight": "",
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"40c1c4e5cf9f4b2a9df93bceca159210": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"ea20d7a3db2b41338cdf788e8b8c8353": {
"model_module": "@jupyter-widgets/output",
"model_name": "OutputModel",
"state": {
"_view_name": "OutputView",
"msg_id": "",
"_dom_classes": [],
"_model_name": "OutputModel",
"outputs": [
{
"output_type": "stream",
"metadata": {
"tags": []
},
"text": "Nguyên Thủy Tiên Tôn\n",
"stream": "stdout"
},
{
"output_type": "stream",
"metadata": {
"tags": []
},
"text": "Đại Việt Chi Thế Giới Đại Đạo\n",
"stream": "stdout"
},
{
"output_type": "stream",
"metadata": {
"tags": []
},
"text": "Tuyệt Thế Thần Đế\n",
"stream": "stdout"
},
{
"output_type": "stream",
"metadata": {
"tags": []
},
"text": "Nhị Thứ Nguyên Chi Thiên Nhãn Thiên Đình\n",
"stream": "stdout"
}
],
"_view_module": "@jupyter-widgets/output",
"_model_module_version": "1.0.0",
"_view_count": null,
"_view_module_version": "1.0.0",
"layout": "IPY_MODEL_f10a02ce5c324eb6a7921bd2f3b34950",
"_model_module": "@jupyter-widgets/output"
}
},
"f10a02ce5c324eb6a7921bd2f3b34950": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/NTT123/fc7bc21821bc0d73849e43c8b3296b71/generate-novel-names.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NVwfamGOHFgp"
},
"source": [
"# Training"
]
},
{
"cell_type": "code",
"metadata": {
"id": "_5q1e_NPsAuP"
},
"source": [
"_=\"\"\" \n",
"\n",
"Tasks:\n",
"\n",
"- [x] Create a dataset of novel names.\n",
" + [x] download html pages.\n",
" + [x] parse html pages to get names.\n",
" + [x] clean up.\n",
" + [x] submit dataset to gist.\n",
"\n",
"- [x] Create RNN model to generate novel names.\n",
" + [x] data loader.\n",
" + [x] model.\n",
" + [x] training.\n",
"\n",
"- [x] Improve model.\n",
" + [x] use dropout.\n",
" + [x] use embedding layer.\n",
" + [x] use EMA.\n",
"\"\"\""
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pspCXo9aEboV",
"outputId": "496c8c02-0784-4618-b78c-27752601b16e"
},
"source": [
"!nvidia-smi | grep Tesla\n",
"!pip install -Uqq git+https://github.com/fadel/pytorch_ema\n",
"# !pip install -Uqq wandb\n",
"# !wandb login"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 |\n",
" Building wheel for torch-ema (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Tf_Tk5mg7LxY"
},
"source": [
"### Dataset of novel names"
]
},
{
"cell_type": "code",
"metadata": {
"id": "I_bINHgon9OD"
},
"source": [
"# \"\"\"\n",
"# Crawl data from truyencv.\n",
"# \"\"\"\n",
"\n",
"# import os\n",
"# import tqdm\n",
"\n",
"# for i in tqdm.trange(1, 1029, desc='downloading'):\n",
"# cmd = f'curl https://truyencv.com/danh-sach//trang-{i} -o data/trang-{i}.html'\n",
"# os.system(cmd)\n",
"\n",
"# !zip -q -r novel_names_html.zip data\n",
"# !ls data/*.html | wc -l # expect: 1028"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "dkJNDAaPp75C"
},
"source": [
"# \"\"\"\n",
"# Parse html files to get list of novel names. Use lxml library.\n",
"\n",
"# Normalize names.\n",
"# \"\"\"\n",
"\n",
"# import lxml\n",
"# from lxml.html import parse\n",
"# import unicodedata\n",
"# from pprint import pprint\n",
"\n",
"\n",
"# alphabet = \" !%&'()*+,-./0123456789:?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxyz{}~ÀÁÂÉÊÍÐÒÓÔÕÚÝ\\\n",
"# àáâãèéêìíòóôõùúýāĂăĐđĩŌōũūƠơƯưẠạẢảẤấẦầẨẩẫẬậẮắằẳặẹẻẽẾếỀềểỄễỆệỈỉỊịọỏỐốỒồỔổỗỘộỚớờỞởỡợụỦủỨứừửữỰựỲỳỵỶỷỹ\"\n",
"\n",
"# # normalize name\n",
"# def normalize_name(name):\n",
"# replace_list = [('\\xa0', ' '), ('〖', '['), ('〗', ']'), ('–', '-'), ('—', '-'), \n",
"# ( '∶', ':'), ( '♥', ' '), ('\\ufeff', ''), ('’', \"'\"), ('·', '.')]\n",
"\n",
"# name = unicodedata.normalize('NFC', name)\n",
"# for a, b in replace_list:\n",
"# name = name.replace(a, b)\n",
" \n",
"# name = ''.join([c for c in name if c in alphabet])\n",
"# return name\n",
"\n",
"\n",
"# def parse_names():\n",
"# names = []\n",
"# for i in range(1, 1029):\n",
"# root = parse(f'data/trang-{i}.html').getroot()\n",
"\n",
"# for h2 in root.iter('h2'):\n",
"# if 'class' in h2.attrib:\n",
"# if 'title' in h2.attrib['class']:\n",
"# a = next(h2.iter('a')).attrib\n",
"# if 'title' in a:\n",
"# names.append(a['title'])\n",
"\n",
"# names = [name.strip() for name in names]\n",
"# names = sorted(list(set(names)))\n",
"\n",
"# return names\n",
"\n",
"\n",
"# names = parse_names()\n",
"# pprint(names[:10], compact=True)\n",
"# pprint(len(names))\n",
"\n",
"# names = [normalize_name(name) for name in names]\n",
"\n",
"# # print alphabet after normalization\n",
"# text = '\\n'.join(names)\n",
"# pprint(sorted(list(set(text))), compact=True)\n",
"\n",
"# # save to file\n",
"# with open('novel_names.txt', 'w') as f:\n",
"# f.write('\\n'.join(names))"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3usCMaWJ6021",
"outputId": "1f4c1320-b654-424e-af21-c81b52747f70"
},
"source": [
"!git clone https://gist.github.com/NTT123/72d2341a77e0d407eb0dfca0fba52d59 corpus"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"fatal: destination path 'corpus' already exists and is not an empty directory.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PdObT-3j9_gp"
},
"source": [
"## Dataloader"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 0
},
"id": "cE74Ntt0-B3i",
"outputId": "36c9de9a-d879-446f-f121-a12a6c733fad"
},
"source": [
"\"\"\"\n",
"Load data in batch infinitely.\n",
"\"\"\"\n",
"\n",
"from torch.utils.data import IterableDataset, DataLoader\n",
"import numpy as np\n",
"import random\n",
"import matplotlib.pyplot as plt\n",
"\n",
"class TruyenCVDataset(IterableDataset):\n",
" def __init__(self, filepath: str, batch_size: int=32, mode='train'):\n",
" self.filepath = filepath\n",
" self.PAD = 0\n",
" self.rng = random.Random(42)\n",
" self.batch_size = batch_size\n",
" self.mode = mode\n",
" self.load_text()\n",
"\n",
" def __iter__(self):\n",
" self.rng.shuffle(self.encoded_lines)\n",
" L = len(self.encoded_lines) * 9 // 10\n",
" data = self.encoded_lines[:L] if self.mode == 'train' else self.encoded_lines[L:]\n",
"\n",
" while True:\n",
" self.rng.shuffle(data)\n",
" for i in range(0, len(data) - self.batch_size, self.batch_size):\n",
" batch = np.array(data[i: i+self.batch_size])\n",
" mask = batch != self.PAD\n",
" yield batch, mask\n",
"\n",
" def load_text(self):\n",
" text = open(self.filepath, 'r').read()\n",
" self.alphabet = ['[PAD]', '[BEGIN]', '[END]'] + sorted(list(set(text)))\n",
" self.lines = text.strip().split('\\n')\n",
" self.max_len = max([len(line) for line in self.lines]) + 2 # included [BEGIN] and [END]\n",
" self.encoded_lines = [self.encode_text(t, self.max_len) for t in self.lines]\n",
"\n",
" def encode_text(self, text, max_len):\n",
" text = ['[BEGIN]'] + list(text) + ['[END]']\n",
" pad_len = max(0, max_len - len(text))\n",
" text = text + ['[PAD]'] * pad_len\n",
" return [ self.alphabet.index(c) for c in text if c in self.alphabet ]\n",
"\n",
" def decode_text(self, tokens):\n",
" text = ''.join([self.alphabet[i] for i in tokens])\n",
" text = text.replace('[PAD]', '_')\n",
" return text\n",
"\n",
"def test_dataset():\n",
" dataset = TruyenCVDataset('corpus/novel_names.txt')\n",
" print(dataset.decode_text(dataset.encoded_lines[0]))\n",
" print(dataset.decode_text(dataset.encoded_lines[1]))\n",
" print(dataset.decode_text(dataset.encoded_lines[2])) \n",
"\n",
"\n",
" train_iter = iter(TruyenCVDataset('corpus/novel_names.txt', 32, 'train'))\n",
" test_iter = iter(TruyenCVDataset('corpus/novel_names.txt', 32, 'test'))\n",
" batch, mask = next(test_iter)\n",
" plt.subplot(1, 2, 1)\n",
" plt.matshow(batch, fignum=0)\n",
" plt.subplot(1, 2, 2)\n",
" plt.matshow(mask, fignum=0)\n",
" plt.show()\n",
"test_dataset()"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"[BEGIN](Fairy Tail) Nô lệ Quỷ Dữ[END]___________________________________________________\n",
"[BEGIN](Xuyên Không) Nam Chính À Tránh Xa Tôi Ra.[END]__________________________________\n",
"[BEGIN](ĐỒNG NHÂN THE HOBBIT) TA MUỐN VỀ NHÀ NGA[END]___________________________________\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xsAwdWc6-F3Q"
},
"source": [
"## Model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "dfOkKwClCFOo"
},
"source": [
"import torch\n",
"import torch.nn as nn\n",
"# from torch.optim.lr_scheduler import OneCycleLR\n",
"import tqdm\n",
"\n",
"class RNN(nn.Module):\n",
" def __init__(self, input_size, hidden_size, output_size, num_layers):\n",
" super(RNN, self).__init__()\n",
" self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True)\n",
" self.input_size = input_size\n",
" self.embed = nn.Embedding(input_size, input_size)\n",
" self.embed_dropout = nn.Dropout(0.5)\n",
" self.output = nn.Linear(hidden_size, output_size)\n",
" self.dropout = nn.Dropout(0.5)\n",
" self.softmax = nn.LogSoftmax(dim=-1)\n",
"\n",
" def forward(self, inputs):\n",
" inputs = self.embed(inputs)\n",
" inputs = self.embed_dropout(inputs)\n",
" x, _ = self.lstm(inputs)\n",
" x = self.dropout(x)\n",
" x = self.output(x)\n",
" x = self.softmax(x)\n",
" return x"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "cI0eE81q-HS4"
},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zGrxyu2NVMbZ"
},
"source": [
"from typing import Deque\n",
"# import wandb\n",
"from torch.nn.utils import clip_grad_norm_\n",
"from torch_ema import ExponentialMovingAverage\n",
"\n",
"\n",
"class MovingAverage:\n",
" def __init__(self, max_len): self.deque = Deque(maxlen=max_len)\n",
" def append(self, v): self.deque.append(v.detach())\n",
" def mean(self): return sum(self.deque).item() / len(self.deque)\n",
"\n",
"\n",
"def train(net, optimizer, lrs, train_iter, val_iter, num_steps, device):\n",
"\n",
" def prepare_batch(x):\n",
" data, mask = x\n",
" data = torch.from_numpy(data).long().to(device)\n",
" mask = torch.from_numpy(mask).long().to(device)\n",
" return data, mask\n",
"\n",
" def loss_fn(data_iter, mode='train'):\n",
" data, mask = prepare_batch(next(data_iter))\n",
" if mode != 'train': \n",
" net.eval()\n",
" with torch.no_grad():\n",
" logits = net(data[:, :-1])\n",
" net.train()\n",
" else:\n",
" logits = net(data[:, :-1])\n",
" \n",
" logits = torch.transpose(logits, 1, 2)\n",
" losses = torch.nn.functional.cross_entropy(logits, data[:, 1:], reduction='none')\n",
" mask = mask[:, 1:]\n",
" losses = losses * mask\n",
" return torch.sum(losses) / torch.sum(mask)\n",
"\n",
" train_loss = MovingAverage(1000)\n",
" val_loss = MovingAverage(500)\n",
"\n",
" trange =tqdm.notebook.trange(num_steps, desc='training')\n",
" ema = ExponentialMovingAverage(net.parameters(), decay=0.999)\n",
"\n",
" for i in trange:\n",
" loss = loss_fn(train_iter, 'train')\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" train_loss.append(loss)\n",
" gn = clip_grad_norm_(net.parameters(), 10.)\n",
" optimizer.step()\n",
" ema.update(net.parameters())\n",
" if lrs is not None: lrs.step()\n",
"\n",
" if i % 10 == 0:\n",
" ema_bk = ExponentialMovingAverage(net.parameters(), decay=0.995)\n",
" ema.copy_to(net.parameters())\n",
" loss = loss_fn(val_iter, 'val')\n",
" ema_bk.copy_to(net.parameters())\n",
" val_loss.append(loss)\n",
" \n",
" if i % 5000 == 0:\n",
" lr = lrs.get_last_lr()[0] if lrs is not None else 1.0\n",
" print(f'step {i} train loss {train_loss.mean():.3f} val loss {val_loss.mean():.3f} lr {lr:.5f} grad norm {gn:.3f}') \n",
" # wandb.log({'train loss': train_loss.mean(), 'val loss': val_loss.mean(), 'lr': lr })\n",
"\n",
" ema.copy_to(net.parameters())"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "5Fh6F3VIMsVw"
},
"source": [
"batch_size=32\n",
"device = 'cuda'\n",
"lr = 4e-3\n",
"wd = 1e-1\n",
"num_layers=1\n",
"num_steps = 100_000\n",
"max_len=100\n",
"hidden_size = 512\n",
"train_dataset = TruyenCVDataset('corpus/novel_names.txt', batch_size, 'train')\n",
"val_dataset = TruyenCVDataset('corpus/novel_names.txt', batch_size, 'val')\n",
"vocab_size = len(train_dataset.alphabet)\n",
"net = RNN(vocab_size, hidden_size, vocab_size, num_layers).to(device)\n",
"optimizer = torch.optim.AdamW(net.parameters(), lr=lr, weight_decay=wd)\n",
"lrs = None \n",
"# lrs = OneCycleLR(optimizer, lr, total_steps=num_steps)\n",
"train_iter = iter(train_dataset)\n",
"val_iter = iter(val_dataset)\n",
"# wandb.init(project='NovelGen')"
],
"execution_count": 9,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "kGTS4zmrURAQ"
},
"source": [
"train(net, optimizer, lrs, train_iter, val_iter, num_steps, device)\n",
"torch.save(net.state_dict(), 'novel_name_model.ckpt')"
],
"execution_count": 14,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "cX8Zt19aDp29"
},
"source": [
"# Generate samples"
]
},
{
"cell_type": "code",
"metadata": {
"id": "-Sv9qY98YMy3",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 149,
"referenced_widgets": [
"6a1742e2c96347feb0011bdf636679b4",
"39d9edbfc62a49c7b6969e9d1c68cf20",
"195bcdd694be4315933f442ece2e5f75",
"38dc6ef373c1489d9ba577a4cfdca079",
"356a241447fd4d4c857eecf556ff4e40",
"40c1c4e5cf9f4b2a9df93bceca159210",
"ea20d7a3db2b41338cdf788e8b8c8353",
"f10a02ce5c324eb6a7921bd2f3b34950"
]
},
"cellView": "form",
"outputId": "8bdb053a-5624-43bc-c68b-133d474a53d3"
},
"source": [
"#@title\n",
"from torch.distributions.categorical import Categorical\n",
"\n",
"def nucleus(probs, p=0.9):\n",
" sp, _ = torch.sort(probs, dim=-1)\n",
" csp = torch.cumsum(sp, dim=-1)\n",
" t, _ = torch.max(torch.where(csp <= (1. - p), sp, torch.zeros_like(sp)), dim=-1, keepdim=True)\n",
" return torch.where(probs > t, probs, torch.zeros_like(probs))\n",
"\n",
"\n",
"def sample(net, hidden_size, max_len, vocab_size, device, p=0.9, num_layers=1):\n",
" hx = torch.zeros(num_layers, 1, hidden_size, device=device)\n",
" cx = torch.zeros(num_layers, 1, hidden_size, device=device)\n",
" state = (hx, cx)\n",
" inp = torch.tensor([[1]], dtype=torch.long, device=device)\n",
" outs = [inp]\n",
" net.eval()\n",
" with torch.no_grad():\n",
" for i in range(max_len):\n",
" inp = net.embed(inp)\n",
" out, state = net.lstm(inp, state)\n",
" logits = net.softmax(net.output(out))\n",
" probs = torch.exp(logits)\n",
" probs = nucleus(probs, p)\n",
" token = Categorical(probs=probs).sample()\n",
" inp = token\n",
" outs.append(token)\n",
" if token.item() == 2: break\n",
" net.train()\n",
" outs = torch.cat(outs, dim=0)\n",
" return torch.squeeze(outs)\n",
"\n",
"\n",
"# !gdown -q --id 1I2VyTpXoC7aZz2t9DpPvSaOOeCrD07oz\n",
"device = 'cuda'\n",
"num_layers=1\n",
"hidden_size = 512\n",
"vocab_size = 199\n",
"net = RNN(vocab_size, hidden_size, vocab_size, num_layers).to(device)\n",
"net.load_state_dict(torch.load('/content/novel_name_model.ckpt')) \n",
"\n",
"import ipywidgets as widgets\n",
"slider = widgets.IntSlider(description='nucleus prob.', min=10, max=100, value=80)\n",
"display(slider)\n",
"button = widgets.Button(description=\"Generate!\")\n",
"output = widgets.Output()\n",
"\n",
"def on_button_clicked(b):\n",
" with output:\n",
" s = sample(net, hidden_size, max_len, vocab_size, device, slider.value/100., num_layers)\n",
" s = train_dataset.decode_text(s.tolist()[1:-1])\n",
" print(s)\n",
"\n",
"button.on_click(on_button_clicked)\n",
"display(button, output)"
],
"execution_count": 15,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6a1742e2c96347feb0011bdf636679b4",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"IntSlider(value=80, description='nucleus prob.', min=10)"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "38dc6ef373c1489d9ba577a4cfdca079",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Button(description='Generate!', style=ButtonStyle())"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ea20d7a3db2b41338cdf788e8b8c8353",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Output()"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "rwENPbvIb0k2",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "ca46f4dd-e175-44e1-9ade-2abd69b3c8ca"
},
"source": [
"for i in range(10):\n",
" s = sample(net, hidden_size, max_len, vocab_size, device, 0.95, num_layers)\n",
" s = train_dataset.decode_text(s.tolist())\n",
" print(s)"
],
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"text": [
"[BEGIN]Đô Thị Chi Truyền Thuyết[END]\n",
"[BEGIN]Chư Thiên Thế Giới Đế Quốc[END]\n",
"[BEGIN]Tà Tuyến Trọng Sinh[END]\n",
"[BEGIN]Họa Thượng[END]\n",
"[BEGIN]Chư Thiên Vạn Giới[END]\n",
"[BEGIN]Bất Hủ Thần Vương[END]\n",
"[BEGIN]Vô Hạn Saukeo[END]\n",
"[BEGIN]Tu Chân Nữ Nhi Diễn Động[END]\n",
"[BEGIN]Thiên Sủng Nữ Phụ Là Ngươi Giả[END]\n",
"[BEGIN]Tiên Viên Đại Thần[END]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "C_3gwjidhBV0"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment