Created
November 25, 2020 17:47
-
-
Save NTT123/fc7bc21821bc0d73849e43c8b3296b71 to your computer and use it in GitHub Desktop.
Generate Novel Names.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"name": "Generate Novel Names.ipynb", | |
"provenance": [], | |
"collapsed_sections": [ | |
"Tf_Tk5mg7LxY", | |
"PdObT-3j9_gp" | |
], | |
"authorship_tag": "ABX9TyM1oeDRjwe13u3vM4Q7QxmH", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"6a1742e2c96347feb0011bdf636679b4": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "IntSliderModel", | |
"state": { | |
"_view_name": "IntSliderView", | |
"style": "IPY_MODEL_39d9edbfc62a49c7b6969e9d1c68cf20", | |
"_dom_classes": [], | |
"description": "nucleus prob.", | |
"step": 1, | |
"_model_name": "IntSliderModel", | |
"orientation": "horizontal", | |
"max": 100, | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": 80, | |
"_view_count": null, | |
"disabled": false, | |
"_view_module_version": "1.5.0", | |
"min": 10, | |
"continuous_update": true, | |
"readout_format": "d", | |
"description_tooltip": null, | |
"readout": true, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_195bcdd694be4315933f442ece2e5f75" | |
} | |
}, | |
"39d9edbfc62a49c7b6969e9d1c68cf20": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "SliderStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"handle_color": null, | |
"_model_name": "SliderStyleModel", | |
"description_width": "", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"195bcdd694be4315933f442ece2e5f75": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"38dc6ef373c1489d9ba577a4cfdca079": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ButtonModel", | |
"state": { | |
"_view_name": "ButtonView", | |
"style": "IPY_MODEL_356a241447fd4d4c857eecf556ff4e40", | |
"_dom_classes": [], | |
"description": "Generate!", | |
"_model_name": "ButtonModel", | |
"button_style": "", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"tooltip": "", | |
"_view_count": null, | |
"disabled": false, | |
"_view_module_version": "1.5.0", | |
"layout": "IPY_MODEL_40c1c4e5cf9f4b2a9df93bceca159210", | |
"_model_module": "@jupyter-widgets/controls", | |
"icon": "" | |
} | |
}, | |
"356a241447fd4d4c857eecf556ff4e40": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ButtonStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "ButtonStyleModel", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"button_color": null, | |
"font_weight": "", | |
"_view_module_version": "1.2.0", | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"40c1c4e5cf9f4b2a9df93bceca159210": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"ea20d7a3db2b41338cdf788e8b8c8353": { | |
"model_module": "@jupyter-widgets/output", | |
"model_name": "OutputModel", | |
"state": { | |
"_view_name": "OutputView", | |
"msg_id": "", | |
"_dom_classes": [], | |
"_model_name": "OutputModel", | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"metadata": { | |
"tags": [] | |
}, | |
"text": "Nguyên Thủy Tiên Tôn\n", | |
"stream": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"metadata": { | |
"tags": [] | |
}, | |
"text": "Đại Việt Chi Thế Giới Đại Đạo\n", | |
"stream": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"metadata": { | |
"tags": [] | |
}, | |
"text": "Tuyệt Thế Thần Đế\n", | |
"stream": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"metadata": { | |
"tags": [] | |
}, | |
"text": "Nhị Thứ Nguyên Chi Thiên Nhãn Thiên Đình\n", | |
"stream": "stdout" | |
} | |
], | |
"_view_module": "@jupyter-widgets/output", | |
"_model_module_version": "1.0.0", | |
"_view_count": null, | |
"_view_module_version": "1.0.0", | |
"layout": "IPY_MODEL_f10a02ce5c324eb6a7921bd2f3b34950", | |
"_model_module": "@jupyter-widgets/output" | |
} | |
}, | |
"f10a02ce5c324eb6a7921bd2f3b34950": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
} | |
} | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/NTT123/fc7bc21821bc0d73849e43c8b3296b71/generate-novel-names.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "NVwfamGOHFgp" | |
}, | |
"source": [ | |
"# Training" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "_5q1e_NPsAuP" | |
}, | |
"source": [ | |
"_=\"\"\" \n", | |
"\n", | |
"Tasks:\n", | |
"\n", | |
"- [x] Create a dataset of novel names.\n", | |
" + [x] download html pages.\n", | |
" + [x] parse html pages to get names.\n", | |
" + [x] clean up.\n", | |
" + [x] submit dataset to gist.\n", | |
"\n", | |
"- [x] Create RNN model to generate novel names.\n", | |
" + [x] data loader.\n", | |
" + [x] model.\n", | |
" + [x] training.\n", | |
"\n", | |
"- [x] Improve model.\n", | |
" + [x] use dropout.\n", | |
" + [x] use embedding layer.\n", | |
" + [x] use EMA.\n", | |
"\"\"\"" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "pspCXo9aEboV", | |
"outputId": "496c8c02-0784-4618-b78c-27752601b16e" | |
}, | |
"source": [ | |
"!nvidia-smi | grep Tesla\n", | |
"!pip install -Uqq git+https://github.com/fadel/pytorch_ema\n", | |
"# !pip install -Uqq wandb\n", | |
"# !wandb login" | |
], | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 |\n", | |
" Building wheel for torch-ema (setup.py) ... \u001b[?25l\u001b[?25hdone\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Tf_Tk5mg7LxY" | |
}, | |
"source": [ | |
"### Dataset of novel names" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "I_bINHgon9OD" | |
}, | |
"source": [ | |
"# \"\"\"\n", | |
"# Crawl data from truyencv.\n", | |
"# \"\"\"\n", | |
"\n", | |
"# import os\n", | |
"# import tqdm\n", | |
"\n", | |
"# for i in tqdm.trange(1, 1029, desc='downloading'):\n", | |
"# cmd = f'curl https://truyencv.com/danh-sach//trang-{i} -o data/trang-{i}.html'\n", | |
"# os.system(cmd)\n", | |
"\n", | |
"# !zip -q -r novel_names_html.zip data\n", | |
"# !ls data/*.html | wc -l # expect: 1028" | |
], | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dkJNDAaPp75C" | |
}, | |
"source": [ | |
"# \"\"\"\n", | |
"# Parse html files to get list of novel names. Use lxml library.\n", | |
"\n", | |
"# Normalize names.\n", | |
"# \"\"\"\n", | |
"\n", | |
"# import lxml\n", | |
"# from lxml.html import parse\n", | |
"# import unicodedata\n", | |
"# from pprint import pprint\n", | |
"\n", | |
"\n", | |
"# alphabet = \" !%&'()*+,-./0123456789:?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxyz{}~ÀÁÂÉÊÍÐÒÓÔÕÚÝ\\\n", | |
"# àáâãèéêìíòóôõùúýāĂăĐđĩŌōũūƠơƯưẠạẢảẤấẦầẨẩẫẬậẮắằẳặẹẻẽẾếỀềểỄễỆệỈỉỊịọỏỐốỒồỔổỗỘộỚớờỞởỡợụỦủỨứừửữỰựỲỳỵỶỷỹ\"\n", | |
"\n", | |
"# # normalize name\n", | |
"# def normalize_name(name):\n", | |
"# replace_list = [('\\xa0', ' '), ('〖', '['), ('〗', ']'), ('–', '-'), ('—', '-'), \n", | |
"# ( '∶', ':'), ( '♥', ' '), ('\\ufeff', ''), ('’', \"'\"), ('·', '.')]\n", | |
"\n", | |
"# name = unicodedata.normalize('NFC', name)\n", | |
"# for a, b in replace_list:\n", | |
"# name = name.replace(a, b)\n", | |
" \n", | |
"# name = ''.join([c for c in name if c in alphabet])\n", | |
"# return name\n", | |
"\n", | |
"\n", | |
"# def parse_names():\n", | |
"# names = []\n", | |
"# for i in range(1, 1029):\n", | |
"# root = parse(f'data/trang-{i}.html').getroot()\n", | |
"\n", | |
"# for h2 in root.iter('h2'):\n", | |
"# if 'class' in h2.attrib:\n", | |
"# if 'title' in h2.attrib['class']:\n", | |
"# a = next(h2.iter('a')).attrib\n", | |
"# if 'title' in a:\n", | |
"# names.append(a['title'])\n", | |
"\n", | |
"# names = [name.strip() for name in names]\n", | |
"# names = sorted(list(set(names)))\n", | |
"\n", | |
"# return names\n", | |
"\n", | |
"\n", | |
"# names = parse_names()\n", | |
"# pprint(names[:10], compact=True)\n", | |
"# pprint(len(names))\n", | |
"\n", | |
"# names = [normalize_name(name) for name in names]\n", | |
"\n", | |
"# # print alphabet after normalization\n", | |
"# text = '\\n'.join(names)\n", | |
"# pprint(sorted(list(set(text))), compact=True)\n", | |
"\n", | |
"# # save to file\n", | |
"# with open('novel_names.txt', 'w') as f:\n", | |
"# f.write('\\n'.join(names))" | |
], | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "3usCMaWJ6021", | |
"outputId": "1f4c1320-b654-424e-af21-c81b52747f70" | |
}, | |
"source": [ | |
"!git clone https://gist.github.com/NTT123/72d2341a77e0d407eb0dfca0fba52d59 corpus" | |
], | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"fatal: destination path 'corpus' already exists and is not an empty directory.\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "PdObT-3j9_gp" | |
}, | |
"source": [ | |
"## Dataloader" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 0 | |
}, | |
"id": "cE74Ntt0-B3i", | |
"outputId": "36c9de9a-d879-446f-f121-a12a6c733fad" | |
}, | |
"source": [ | |
"\"\"\"\n", | |
"Load data in batch infinitely.\n", | |
"\"\"\"\n", | |
"\n", | |
"from torch.utils.data import IterableDataset, DataLoader\n", | |
"import numpy as np\n", | |
"import random\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"class TruyenCVDataset(IterableDataset):\n", | |
" def __init__(self, filepath: str, batch_size: int=32, mode='train'):\n", | |
" self.filepath = filepath\n", | |
" self.PAD = 0\n", | |
" self.rng = random.Random(42)\n", | |
" self.batch_size = batch_size\n", | |
" self.mode = mode\n", | |
" self.load_text()\n", | |
"\n", | |
" def __iter__(self):\n", | |
" self.rng.shuffle(self.encoded_lines)\n", | |
" L = len(self.encoded_lines) * 9 // 10\n", | |
" data = self.encoded_lines[:L] if self.mode == 'train' else self.encoded_lines[L:]\n", | |
"\n", | |
" while True:\n", | |
" self.rng.shuffle(data)\n", | |
" for i in range(0, len(data) - self.batch_size, self.batch_size):\n", | |
" batch = np.array(data[i: i+self.batch_size])\n", | |
" mask = batch != self.PAD\n", | |
" yield batch, mask\n", | |
"\n", | |
" def load_text(self):\n", | |
" text = open(self.filepath, 'r').read()\n", | |
" self.alphabet = ['[PAD]', '[BEGIN]', '[END]'] + sorted(list(set(text)))\n", | |
" self.lines = text.strip().split('\\n')\n", | |
" self.max_len = max([len(line) for line in self.lines]) + 2 # included [BEGIN] and [END]\n", | |
" self.encoded_lines = [self.encode_text(t, self.max_len) for t in self.lines]\n", | |
"\n", | |
" def encode_text(self, text, max_len):\n", | |
" text = ['[BEGIN]'] + list(text) + ['[END]']\n", | |
" pad_len = max(0, max_len - len(text))\n", | |
" text = text + ['[PAD]'] * pad_len\n", | |
" return [ self.alphabet.index(c) for c in text if c in self.alphabet ]\n", | |
"\n", | |
" def decode_text(self, tokens):\n", | |
" text = ''.join([self.alphabet[i] for i in tokens])\n", | |
" text = text.replace('[PAD]', '_')\n", | |
" return text\n", | |
"\n", | |
"def test_dataset():\n", | |
" dataset = TruyenCVDataset('corpus/novel_names.txt')\n", | |
" print(dataset.decode_text(dataset.encoded_lines[0]))\n", | |
" print(dataset.decode_text(dataset.encoded_lines[1]))\n", | |
" print(dataset.decode_text(dataset.encoded_lines[2])) \n", | |
"\n", | |
"\n", | |
" train_iter = iter(TruyenCVDataset('corpus/novel_names.txt', 32, 'train'))\n", | |
" test_iter = iter(TruyenCVDataset('corpus/novel_names.txt', 32, 'test'))\n", | |
" batch, mask = next(test_iter)\n", | |
" plt.subplot(1, 2, 1)\n", | |
" plt.matshow(batch, fignum=0)\n", | |
" plt.subplot(1, 2, 2)\n", | |
" plt.matshow(mask, fignum=0)\n", | |
" plt.show()\n", | |
"test_dataset()" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[BEGIN](Fairy Tail) Nô lệ Quỷ Dữ[END]___________________________________________________\n", | |
"[BEGIN](Xuyên Không) Nam Chính À Tránh Xa Tôi Ra.[END]__________________________________\n", | |
"[BEGIN](ĐỒNG NHÂN THE HOBBIT) TA MUỐN VỀ NHÀ NGA[END]___________________________________\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAABhCAYAAADP5Pq1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAZL0lEQVR4nO2de3RV1bWHvxkSkgAhJBDCU0ClVLC8Fam0orYWtcVb21oftVRFrNVRbS2t2just7293mqxSh3KQxB8S0WKZeCDi9gWpVAUQRR5VawgECRE3nmu+8dcO8k5JOQ8k3PC/MY445y9z9pzr30ys/dvzz3XXOKcwzAMw0g/Mlq6A4ZhGEZs2AncMAwjTbETuGEYRppiJ3DDMIw0xU7ghmEYaYqdwA3DMNKUZjuBi8g4EdkoIltE5PYYtp8tIiUisr7eukIRWSIim/17QYS2eovIMhF5X0TeE5Fb4rSXIyKrRGStt/dffn0/EVnpj/k5EWkbxfG2EZE1IrIoAba2ici7IvKOiKyO81g7icjzIvKBiGwQkdFx2Brg+xS89ovIrbHaaynMt823G7DVPL7tnEv6C2gDbAVOBtoCa4GBUdr4MjAcWF9v3b3A7f7z7cDvIrTVHRjuP+cBm4CBcdgToIP/nAWsBM4C5gGX+/XTgBujON6fAk8Di/xyPLa2AV3C1sV6rHOBif5zW6BTrLYa8JFdQJ9E2Guul/m2+XZL+nZzOflo4JV6y3cAd8Rgp2+Yk28Eutdz3I0x9m8h8NVE2APaAW8Do4BPgcyGfoMmbPQClgLnAYv8P1FMtnz7hpw86mMF8oEPAYnXVgO2LwDeSOTftTle5tvm2xHYTppvN1cIpSfwcb3l7X5dvBQ753b6z7uA4mgNiEhfYBiqLGK2528L3wFKgCWoKitzzlX5JtEc8wPAz4Eav9w5DlsADnhVRN4SkUl+XSzH2g/YAzzmb4EfFZH2MdoK53LgmTj61lKYb5tvN0XSfLvVPMR0ekmLqi6AiHQA5gO3Ouf2x2PPOVftnBuKKowzgc9H05d6ffo6UOKceyuW7RthjHNuOHAhcJOIfLn+l1EcayZ6q/+Ic24YcAi9DYzFVi0+5jke+FP4d7HYa22Ybx+XE9q3m+sEvgPoXW+5l18XL7tFpDuAfy+JdEMRyUId/Cnn3Avx2gtwzpUBy9BbwU4ikum/ivSYzwbGi8g24Fn0VvPBGG0Ffdrh30uABeg/YSzHuh3Y7pxb6ZefR50+3t/tQuBt59xuvxz336EZMd823z4eSfXt5jqB/xPo7582t0VvKV5MgN0XgQn+8wQ03tckIiLALGCDc+7+BNgrEpFO/nMuGnPcgDr7t6Ox55y7wznXyznXF/2dXnPOXRWLLd+f9iKSF3xG43HrieFYnXO7gI9FZIBfdT7wfiy2wriCultMEmCvOTHfNt8+Hsn17XgC6FEG8i9Cn4hvBX4Zw/bPADuBSvRqeR0aP1sKbAb+DyiM0NYY9NZlHfCOf10Uh73BwBpvbz1wl19/MrAK2ILeQmVHecxjqXtSH5Mtv91a/3ov+O3jONahwGp/rH8GCmK15e21B/YC+fXWxWyvJV7m2+bbLeXb4o0ahmEYaUareYhpGIZxomEncMMwjDTFTuCGYRhpip3ADcMw0hQ7gRuGYaQpcZ3AJYYqbPWGu8ZNIm0l2p71reVtxYP5dvPYSrS9E6lvQOx54MRYhQ1YHW/ebTJsWd9Sw16i+xZjH8y3rW8p3zfn4itmdSawxTn3L+dcBTo09pI47BlGqmC+baQFmU03aZSGqrCNOt4GbSXb5dCODgW9HEB1jgDg9I3CggMA7NubB0BWyWH9vmMuANVZ2jCrsELf8wpoV9zbZe7RdvTPpKpMa8EH61ye37at37azblu5L7Qd7XPIzs6nx6BODuDghjYA1OTr9jWZun3X4jIASvZ00u18TbXMTw+FHGsO7egohQkbJZVIe621b0c5RIUrlwR0I2rf7lLYxp3UM5ORQ3IS8jsk0lai7G1a1w5ovf6TbHvJ8O14TuAR4eM+k0APYEzGxXw4/QsAVB7Wk+iAG98F4PBZQwA41FVPnlff8hIALw/XiovVQwYCUHDPvwHYNvNzABQ++7ba6306ZeeqzapcPdacUv29CuZpm+qBamP8gtcAeOLei+psCFR2Px2ApUtnATCun/7fHr4wtG/tCtT+qRdtBWDrolMA6DHF18KpqY7qdzLiZ6Vb2qz7q+/bJ/XM5MPVfZt1/y3F13oMbekunHA05tvxhFAiqsLmnJvhnBvpnBuZRXYcuzOMZiNq3y7q3KbZOmcYAfEo8NoqbKhzXw5cGcmGPR7XE/mBXrr7TVP0it5xk/4TdJuu5YJfKLkAgEMTdP2+wRqv6DhPC4Z1e1bb1VfHo69Xpb1ljO6r+ozTANjzg+EA3DP5UQDuuG8iAEVevQftyvqrgh+z7lIAOoT932aM3wvA9IHPAnDTAzcDkH3Easq0ImL2bcNoTmI+gTvnqkTkZuAV9Kn9bOfcewnrmWG0EObbRroQVwzcObcYWBztdln7K/WDr+E+bdxsAO54d2JIu3Y7jgDw8TdUseev9Q8ew9Su+MmYfvGzp5n9BZ0s5F9zVaVfddoqAFacqQ9G79nmS/H20bejf+kOwN4l+oCm51RV9dXva6x832Uaf3c+2FS1uAsAvx5/BgDFNW9GethGGhGrb7cmLNad+thITMMwjDQlrnrgfmqkA0A1UOWcG3m89h2l0I3K+ApV52o8uuxUVdTlPqOjxt8PtNVsQtoc1b4VzdU4dXgmyL5B+n3HrRm17cPblg7IDLH93cteB+D5J8cCdYo73HZ4LD38+8Lvbgdgz0J91hXE7Su+pFksn52ix1adLQ0fyzi1d+R6TUssHK/ZLJa9Ejsr3VL2u9JEpBFG7dsjh+S4Va/0Pl6TEx5T9LHTmG8nIo3wXOfcpwmwYxiphvm2kdIkPQ+8ITIqNZsk91I/z+cCjTMXzQlVz31/8QEAH5SrYg/yvdsP6Kft534IQMU5mld+28NPcnv+tUCdsoZQ5bzoY1XImWNKAdhVPkKXj4Qq5FfP0xh64QId6FO5sE1oH0vUbtZE/f/+6Y91/c+nhtorfig0Lzy438lduMq/N/ozGUZaYMq65Yg3Bu6AV0XkrVQpQmQYCcJ820h54lXgY5xzO0SkK7BERD5wzv2tfoPwkZgAH92karTyk0IA8n2c+NPvq9Lu8riq2X9XDAagupfa2jtfU0eqfJjYLQ1Vu38cNIRuZ6hiDvK+9w1StZ+3VffR/Vsf6bZDdBTnp0N12yd+OQWAnzx+bsgBjizSEdXLs7uGrM/ddRSAS/upkr5/kO6vyO+/dJAOwb9t0zqgLu+8y0xtb7HulCcq3z6pZ4vczKYEr3zyTsiyKfLmIy4F7pzb4d9LgAVoEaDwNjYS00g7ovVtG4lptAQxywYRaQ9kOOcO+M8XAL+OZNuej2mGxv4+WQD0vkozMLbNPyWkXY0vXpWp6eBU/aUzAMU+Dl0+tgMAh7uqnd3Xj6B4psa+23fRGDUuyBpRJb3LqWrv8dwWAIrWbgLghpJb1eZl2v7kCT4LZbAq9R/M0ZTgl6drvL4iX4/hiQcvBKB0qiaj57+nfQmyUqasuVz74UVJwd/zAdj4jOarB/2tuyPQY+o/YWNtm26Pal9qKnz+vKn3pBKPbxvHKvJoMPUeHfHc9xUDC0QksPO0c+7lhPTKMFoW820jLYgrDzxagjzwI5foKMamcqoPXDIMgLyFa4C60ZWFHYPMkCIgNHvlULHaPO+GfwDw0nOjAej1QMP53kFtE8vDTn8SmQceLemQB27qNn1pzLdtJKZhGEaa0mQIRURmA18HSpxzp/t1hcBzQF9gG3CZc25ftDuvzNMLyvhumqkxy1fwDM9GCUYt5i1T1Xw0W+PEbcLuHnI/OcKhrvrda9PPAqDXnFDl/VkfPeQDJ/vKhq9qbRPJ0uyU8rGq+rdP0Jh2+39o5kyQ6RL0actvtV2Hf2c0+H2wv7JTdH8HB2vWSv4/c4C62HdGri4fGdVf2/XUGPrhYqHPtA0AVH+2Xw/Q7g4SSjJ9OxWJJzadbOzuIDYiUeBzgHFh624Hljrn+gNL/bJhpBtzMN820piIYuAi0hdYVE+lbATGOud2ikh34HXn3ICm7EQbAw+PeR95RXOxu0/T7yvP1lGV265RNX3qhHWmUk9gYomBJ8q3myMGbir1xCXRMfBi59xO/3kX+tS+QURkkoisFpHVlZTHuDvDaDZi8u09e004GM1P3MPHnHNORBqV8c65GcAMoHZCT5ehF5J9X9QJhqu98g6v893vAX3fcY4q715eeZderjHyaj8uqNNyfZesTHZfr3NYZl9UAsDRl3Xb8GqAu24YEbI+79v6P7t3SQ+grpbKzhv9aM/D2q6xu4Xd12u7867R7JfXHtMYfBDr3jNhuD9G/7v4S2dlez9Z8jSbSzPViMa3EzkBcWOkcgw7EuwOIvHEqsB3+9tL/HtJ4rpkGC2K+baRNsSqwF8EJgD/69+jqqknNX6m+BU6mrF0mB/F+H5Yd3y7u657CoDZD+jIxIINB4G6UYulI3T7ojmQt10/f7yrk9r0pro8qjVInFe4xVPDZtKZoW892Kbt/OpuDzTcrpiPQ9p1fUjbrX/IL/NmyPddZqzASAvi8u2WxBTuiUeTClxEngFWAANEZLuIXIc691dFZDPwFb9sGGmF+baR7jSpwJ1zVzTy1fkx77WRaGFVbsPrH96mFQL33ajx6e6PaFy5UwfNQkFUye+6YURtTHvAjWt1V8M0geDCd3XE5cvD9ZlUeJ72X2/9PQBf+8/bgLra47UjN/0Iz9ETdf3fn9KYd7epFrtOV5Li2y1IsmLkpuxTl0gU+GwRKRGR9fXW3S0iO0TkHf+6KLndNIzEY75tpDuRxMDnAA8Bj4et/4Nz7vex7HTXFToqMcjvLqytaRKa/53fQdvdd8qLANxa80OgbqRm0UodIFf0hv7/Hb5wSG1uefVizT7Zs1Dj5IHyDqr+ZfxIn03JYlX1V5yqKr/DlzQz5pktywAYd6ffl6+3snmNbp/p64h//V2dkWfW9IuBY2fgMVKaOSTYt1sjVu87dWlSgfsi9qXN0BfDaFbMt410J5488JtF5PvAauC2aOpF9Ltf1evm72l97y4qtGtzpYNa3MFcl7ffemmDdtxGnRMzUOx5y9rUKuXyj3Tbl2beC8A10zWseaSHBtoPPa/KvGtQJ7xGY9qBgr689xcB6IRmj9SG7Ve9q332E+ssmlGg2xGWrWKkMzH7djpgCrr1EOsJ/BHgN+h57TfAFODahho2NKVaTbbu9u4L5gNw/5bLtDNHGn66uTssJTCc3jO0AFRZ/7oBNcEAmmu+dSMABxce0vfFmSH7yvjqJ9q+5uNGD9Y4oYjJt1N5SjU7YbdeYhrI45zb7Zyrds7VADNpYLqpem1tSjUjbYjVt21KNaMliEk2iEj3evUivgmsP177xrj3MVXeNXm6fP4EDV8szdVh6MEw9VOPDARgp0Y1qPBlaEeu1Iegi2Zn17YPHlIG4ZjyAj8t22LdSXjZVzdSbT87fzpwbBphU5NIHD1fJ14+0Et/ytxLdwNwcHG3kGMIb7fvdC3AVfCeXkNtsuPUIFG+nUqkyhB8uxNIPJHUA38GGAt0EZHtwK+AsSIyFL3N3AbckMQ+GkZSMN820p0WnVKtPE9vOwvmhQ6aqZikiQHhard2UM6svwF1qYHVZ5wGQOmgXO6Z/CgAd9w3MWTbii/50rPf0Hh5/mZVvuGTCl/3pKYszv7C50P61FjpW0sbTB1sSrXkYMq55bEp1QzDMFoZTSpwEemNDnQoRm8rZzjnHoxl6qlAgdeM0Sv6D2dpFspvHvoecOxEDn8YOQ+AX069NuT7cFWMvy6VDquuLYgVbuuCUzYCsHmUFrsyxdz6iFaBJ9K3U0mBm2JufcSjwKvQXNiBwFnATSIyEJt6ykh/zLeNtCaSYlY7gZ3+8wER2QD0BC5BHwABzAVeB34RyU6PFmnxqclvfgeAjn59MEQ+GMgzdZhmqVSdE7p97ic6K0IwgXHAgFvW1qrzT27WfPC8ZX4ih9MOAPD6ZO1yrwdUoVefqVkoJcN1gM+BU1SZd9yi6j5Q8ocuVlVz9FoVYlXVeu1zSwuBuuyWzrMsmyRdSIZvpwLJzjoxhZ86RBUD9/MHDgNWEsXUU4aR6phvG+lIxFkoItIB+CvwW+fcCyJS5pzrVO/7fc65gga2qz8Sc8SYjIvZ9mvN887fom2CnOtADW/9jir0QAVbpocRCbFmoSTCt0/qmTniw9V9Y+67qVrjeMSVhSIiWcB84Cnn3At+dURTT9lITCOVSZRv20hMoyWIZCCPALOADc65++t9FfPUU0VrdRTigZ7q9EG8uo0vatV5jY8nP+3j1F/UwlQ7z9ZaKu0/0e+PdNULUnapLte0hUwdMFk7AtPUu9EYyfDtWIknbm3q/cQlkqH0ZwNXA++KSOBld6LOPc9PQ/URcFlyumgYScN820hrIslCWU5tpvUxxDT1VMm3NYukqkQzP4JYd0ZloLxDR15W3lWmGy5VBR6M3Mz2GSdPTlXxNKn/+bVZKFW5anPH5FEAHOqn+d8DbvFTrfmRl8HEyLU1U7Rrx9RMYbC27/gHrV64YaH2refU0HzzPkWapbLzVc0J7nmf3QGkKsnw7ZagpWud2B1Ay2EjMQ3DMNKUSGLgjY1Wuxu4Htjjm97pnFscyU6L5qvy/uxkVckVvhrhkYGqUqtz/ITBPge7eppe4bOLVBWPXqU53c8/qd2f1F/F0qGLh7JjvCrtz92gSjvIbCktywnZ9oXHdKfBBMnhozvPmqTKe4ufbKs6W9dvfUKV+JTJMwG4Y7/WXOl3ZWj+d0+svniqkwzfbs2Y0k49IomBB6PV3haRPOAtEVniv7O5A410xnzbSGviGYkZM233q0rNKdUIzpK7pgB1EwsHanjPD3Rk5thJvk74XM0fX3Gmqueic3QC4pc/1O/H9YM+hzRjJagHvm+wZrx03KShzjeHqfrvVqNToAVZ8Ll/VgWd65c3zwjtc8ZyjTN2Wa7LU6YP0mU/5ZqRfiTDt1MBU8onDvGMxASdO3CdiMwWkWMGOhhGumC+baQj8YzELAY+pW7uwO7OuWPmDmxoJCZnaG3uIAMkyAOoyfK52zMbqVNyslfTW/W6E54pcuCSYQybvAaA5U+ExtHDY9xB1kn3CzRWHdT3DtqHZ6k8eafeJfxkgL9LGKf2br7vOeDY+uFBrfMDJ+l+et9j2SjJJoEjMaP27XhHYrYGTPknj4SPxIx07kAbiWmkMonybRuJabQEMY/EjGfuwCM9VFFX+rkta3wvggqBtWq5WP8p9p+mmSXhtb5LL9c4dxDvRmDLGP3Y+exy/c7H0UdPVJX+xtwRofvaqPtqU6xqfuODQxrc14833ATAmFWqpFecqVkuD8l3ta8TtK8/+skCAP74yKUAtCtpvhmPjOhIhm8bRnMSz0jMK2zuQCPNMd820pp4RmLGnBebUaGqVPzkOFl+9OPeq1QtV7XzytzflYar4fKxmmlS7SMyV9/yEqBzZIbHuufdeR8AN31Oc8U7jdXMlaD2+NHC0LuAxqhqp/YmFWi2ygo/UC+8Nvl///0b2me/nfO/nGTpDqrP0L5XddC5Odu+6kd6Wmy82UmGb5/INDYi1GLjycNGYhqGYaQpkcTAc4C/Adm+/fPOuV+JSD/gWaAz8BZwtXOuIpKdVrbX60Z1jl/2yShlvVWFBoq7+zRV3Ee/MhiA7bdq/DqoRlg0V9Xrw+d9GYALlm9k+ROhav2H798MwP7LNO5eke9rnhxWG0EsPMg6mfX8IwB8643JIX3OPKx9O3vBzwDoeENYJsyjqsy7zGhYSQeR8CCfvG2DrYzmJBm+3ZoxJZ16RKLAy4HznHNDgKHAOBE5C/gdOlrtVGAfcF3yumkYScF820hrIomBO+CgX8zyLwecB1zp188F7gYeiWSnOXs1+H20QK8fVbmqivPe0DhzwTxVxZVna7746zO17si4fqNC1gfZJ23b7gc09ztQxAFtylURT7pTs0P+NLQPUJdjHtRGWXHmJgCu/PFPtU9ddfvvvPMRAH98RNV/4fpg7kvL6053kuHbrQlT3KlPpHngbfxT+hJgCbAVKHPO+ceQbKeRIcgiMklEVovI6krKE9Fnw0gYifLtPXvtQm40P5GkEeKcqwaGikgnYAHw+Uh34JybAcwA6CiFDsBlquLO+g8t9laxTOVuUOc7PJPk5h2jQjt9MAhHaiQ5+xWd175ozlu1sezaWihj9KIx43++qftA91HeWbNADlaHDi7K3lsJwJGJWoM8UOxBXnnZqbrPTY8OAyD/HV22WenTk0T59sghOa0u4d+ySlKfqLJQnHNlwDJgNNBJRIILQC9gR4L7ZhjNhvm2kY40WQtFRIqASudcmYjkAq+iD3kmAPOdc8+KyDRgnXPu4SZs7QEOoXUmEkGXBNpKtD3rW/Pb6uOcK4q0sfl2SthKtL3W2reGfds5d9wXMBhYA6xDhxTf5defDKwCtgB/ArKbsuW3Wx1Ju+a2ZX1LDXuJ7lsT+zLftr6lbd+ccxFloaxDy2yGr/8XjRT5MYx0wHzbSHdsJKZhGEaa0hIn8BlNN2kRW4m2Z31reVvNTSr/Dta3lreXcN+OeEIHwzAMI7WwEIphGEaaYidwwzCMNMVO4IZhGGmKncANwzDSFDuBG4ZhpCn/D4+j+vBsiVdXAAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 2 Axes>" | |
] | |
}, | |
"metadata": { | |
"tags": [], | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "xsAwdWc6-F3Q" | |
}, | |
"source": [ | |
"## Model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dfOkKwClCFOo" | |
}, | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"# from torch.optim.lr_scheduler import OneCycleLR\n", | |
"import tqdm\n", | |
"\n", | |
"class RNN(nn.Module):\n", | |
" def __init__(self, input_size, hidden_size, output_size, num_layers):\n", | |
" super(RNN, self).__init__()\n", | |
" self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True)\n", | |
" self.input_size = input_size\n", | |
" self.embed = nn.Embedding(input_size, input_size)\n", | |
" self.embed_dropout = nn.Dropout(0.5)\n", | |
" self.output = nn.Linear(hidden_size, output_size)\n", | |
" self.dropout = nn.Dropout(0.5)\n", | |
" self.softmax = nn.LogSoftmax(dim=-1)\n", | |
"\n", | |
" def forward(self, inputs):\n", | |
" inputs = self.embed(inputs)\n", | |
" inputs = self.embed_dropout(inputs)\n", | |
" x, _ = self.lstm(inputs)\n", | |
" x = self.dropout(x)\n", | |
" x = self.output(x)\n", | |
" x = self.softmax(x)\n", | |
" return x" | |
], | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "cI0eE81q-HS4" | |
}, | |
"source": [ | |
"## Training" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "zGrxyu2NVMbZ" | |
}, | |
"source": [ | |
"from typing import Deque\n", | |
"# import wandb\n", | |
"from torch.nn.utils import clip_grad_norm_\n", | |
"from torch_ema import ExponentialMovingAverage\n", | |
"\n", | |
"\n", | |
"class MovingAverage:\n", | |
" def __init__(self, max_len): self.deque = Deque(maxlen=max_len)\n", | |
" def append(self, v): self.deque.append(v.detach())\n", | |
" def mean(self): return sum(self.deque).item() / len(self.deque)\n", | |
"\n", | |
"\n", | |
"def train(net, optimizer, lrs, train_iter, val_iter, num_steps, device):\n", | |
"\n", | |
" def prepare_batch(x):\n", | |
" data, mask = x\n", | |
" data = torch.from_numpy(data).long().to(device)\n", | |
" mask = torch.from_numpy(mask).long().to(device)\n", | |
" return data, mask\n", | |
"\n", | |
" def loss_fn(data_iter, mode='train'):\n", | |
" data, mask = prepare_batch(next(data_iter))\n", | |
" if mode != 'train': \n", | |
" net.eval()\n", | |
" with torch.no_grad():\n", | |
" logits = net(data[:, :-1])\n", | |
" net.train()\n", | |
" else:\n", | |
" logits = net(data[:, :-1])\n", | |
" \n", | |
" logits = torch.transpose(logits, 1, 2)\n", | |
" losses = torch.nn.functional.cross_entropy(logits, data[:, 1:], reduction='none')\n", | |
" mask = mask[:, 1:]\n", | |
" losses = losses * mask\n", | |
" return torch.sum(losses) / torch.sum(mask)\n", | |
"\n", | |
" train_loss = MovingAverage(1000)\n", | |
" val_loss = MovingAverage(500)\n", | |
"\n", | |
" trange =tqdm.notebook.trange(num_steps, desc='training')\n", | |
" ema = ExponentialMovingAverage(net.parameters(), decay=0.999)\n", | |
"\n", | |
" for i in trange:\n", | |
" loss = loss_fn(train_iter, 'train')\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" train_loss.append(loss)\n", | |
" gn = clip_grad_norm_(net.parameters(), 10.)\n", | |
" optimizer.step()\n", | |
" ema.update(net.parameters())\n", | |
" if lrs is not None: lrs.step()\n", | |
"\n", | |
" if i % 10 == 0:\n", | |
" ema_bk = ExponentialMovingAverage(net.parameters(), decay=0.995)\n", | |
" ema.copy_to(net.parameters())\n", | |
" loss = loss_fn(val_iter, 'val')\n", | |
" ema_bk.copy_to(net.parameters())\n", | |
" val_loss.append(loss)\n", | |
" \n", | |
" if i % 5000 == 0:\n", | |
" lr = lrs.get_last_lr()[0] if lrs is not None else 1.0\n", | |
" print(f'step {i} train loss {train_loss.mean():.3f} val loss {val_loss.mean():.3f} lr {lr:.5f} grad norm {gn:.3f}') \n", | |
" # wandb.log({'train loss': train_loss.mean(), 'val loss': val_loss.mean(), 'lr': lr })\n", | |
"\n", | |
" ema.copy_to(net.parameters())" | |
], | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "5Fh6F3VIMsVw" | |
}, | |
"source": [ | |
"batch_size=32\n", | |
"device = 'cuda'\n", | |
"lr = 4e-3\n", | |
"wd = 1e-1\n", | |
"num_layers=1\n", | |
"num_steps = 100_000\n", | |
"max_len=100\n", | |
"hidden_size = 512\n", | |
"train_dataset = TruyenCVDataset('corpus/novel_names.txt', batch_size, 'train')\n", | |
"val_dataset = TruyenCVDataset('corpus/novel_names.txt', batch_size, 'val')\n", | |
"vocab_size = len(train_dataset.alphabet)\n", | |
"net = RNN(vocab_size, hidden_size, vocab_size, num_layers).to(device)\n", | |
"optimizer = torch.optim.AdamW(net.parameters(), lr=lr, weight_decay=wd)\n", | |
"lrs = None \n", | |
"# lrs = OneCycleLR(optimizer, lr, total_steps=num_steps)\n", | |
"train_iter = iter(train_dataset)\n", | |
"val_iter = iter(val_dataset)\n", | |
"# wandb.init(project='NovelGen')" | |
], | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "kGTS4zmrURAQ" | |
}, | |
"source": [ | |
"train(net, optimizer, lrs, train_iter, val_iter, num_steps, device)\n", | |
"torch.save(net.state_dict(), 'novel_name_model.ckpt')" | |
], | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "cX8Zt19aDp29" | |
}, | |
"source": [ | |
"# Generate samples" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-Sv9qY98YMy3", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 149, | |
"referenced_widgets": [ | |
"6a1742e2c96347feb0011bdf636679b4", | |
"39d9edbfc62a49c7b6969e9d1c68cf20", | |
"195bcdd694be4315933f442ece2e5f75", | |
"38dc6ef373c1489d9ba577a4cfdca079", | |
"356a241447fd4d4c857eecf556ff4e40", | |
"40c1c4e5cf9f4b2a9df93bceca159210", | |
"ea20d7a3db2b41338cdf788e8b8c8353", | |
"f10a02ce5c324eb6a7921bd2f3b34950" | |
] | |
}, | |
"cellView": "form", | |
"outputId": "8bdb053a-5624-43bc-c68b-133d474a53d3" | |
}, | |
"source": [ | |
"#@title\n", | |
"from torch.distributions.categorical import Categorical\n", | |
"\n", | |
"def nucleus(probs, p=0.9):\n", | |
" sp, _ = torch.sort(probs, dim=-1)\n", | |
" csp = torch.cumsum(sp, dim=-1)\n", | |
" t, _ = torch.max(torch.where(csp <= (1. - p), sp, torch.zeros_like(sp)), dim=-1, keepdim=True)\n", | |
" return torch.where(probs > t, probs, torch.zeros_like(probs))\n", | |
"\n", | |
"\n", | |
"def sample(net, hidden_size, max_len, vocab_size, device, p=0.9, num_layers=1):\n", | |
" hx = torch.zeros(num_layers, 1, hidden_size, device=device)\n", | |
" cx = torch.zeros(num_layers, 1, hidden_size, device=device)\n", | |
" state = (hx, cx)\n", | |
" inp = torch.tensor([[1]], dtype=torch.long, device=device)\n", | |
" outs = [inp]\n", | |
" net.eval()\n", | |
" with torch.no_grad():\n", | |
" for i in range(max_len):\n", | |
" inp = net.embed(inp)\n", | |
" out, state = net.lstm(inp, state)\n", | |
" logits = net.softmax(net.output(out))\n", | |
" probs = torch.exp(logits)\n", | |
" probs = nucleus(probs, p)\n", | |
" token = Categorical(probs=probs).sample()\n", | |
" inp = token\n", | |
" outs.append(token)\n", | |
" if token.item() == 2: break\n", | |
" net.train()\n", | |
" outs = torch.cat(outs, dim=0)\n", | |
" return torch.squeeze(outs)\n", | |
"\n", | |
"\n", | |
"# !gdown -q --id 1I2VyTpXoC7aZz2t9DpPvSaOOeCrD07oz\n", | |
"device = 'cuda'\n", | |
"num_layers=1\n", | |
"hidden_size = 512\n", | |
"vocab_size = 199\n", | |
"net = RNN(vocab_size, hidden_size, vocab_size, num_layers).to(device)\n", | |
"net.load_state_dict(torch.load('/content/novel_name_model.ckpt')) \n", | |
"\n", | |
"import ipywidgets as widgets\n", | |
"slider = widgets.IntSlider(description='nucleus prob.', min=10, max=100, value=80)\n", | |
"display(slider)\n", | |
"button = widgets.Button(description=\"Generate!\")\n", | |
"output = widgets.Output()\n", | |
"\n", | |
"def on_button_clicked(b):\n", | |
" with output:\n", | |
" s = sample(net, hidden_size, max_len, vocab_size, device, slider.value/100., num_layers)\n", | |
" s = train_dataset.decode_text(s.tolist()[1:-1])\n", | |
" print(s)\n", | |
"\n", | |
"button.on_click(on_button_clicked)\n", | |
"display(button, output)" | |
], | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "6a1742e2c96347feb0011bdf636679b4", | |
"version_minor": 0, | |
"version_major": 2 | |
}, | |
"text/plain": [ | |
"IntSlider(value=80, description='nucleus prob.', min=10)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "38dc6ef373c1489d9ba577a4cfdca079", | |
"version_minor": 0, | |
"version_major": 2 | |
}, | |
"text/plain": [ | |
"Button(description='Generate!', style=ButtonStyle())" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "ea20d7a3db2b41338cdf788e8b8c8353", | |
"version_minor": 0, | |
"version_major": 2 | |
}, | |
"text/plain": [ | |
"Output()" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "rwENPbvIb0k2", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "ca46f4dd-e175-44e1-9ade-2abd69b3c8ca" | |
}, | |
"source": [ | |
"for i in range(10):\n", | |
" s = sample(net, hidden_size, max_len, vocab_size, device, 0.95, num_layers)\n", | |
" s = train_dataset.decode_text(s.tolist())\n", | |
" print(s)" | |
], | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[BEGIN]Đô Thị Chi Truyền Thuyết[END]\n", | |
"[BEGIN]Chư Thiên Thế Giới Đế Quốc[END]\n", | |
"[BEGIN]Tà Tuyến Trọng Sinh[END]\n", | |
"[BEGIN]Họa Thượng[END]\n", | |
"[BEGIN]Chư Thiên Vạn Giới[END]\n", | |
"[BEGIN]Bất Hủ Thần Vương[END]\n", | |
"[BEGIN]Vô Hạn Saukeo[END]\n", | |
"[BEGIN]Tu Chân Nữ Nhi Diễn Động[END]\n", | |
"[BEGIN]Thiên Sủng Nữ Phụ Là Ngươi Giả[END]\n", | |
"[BEGIN]Tiên Viên Đại Thần[END]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "C_3gwjidhBV0" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment