Skip to content

Instantly share code, notes, and snippets.

@NTT123
Created November 25, 2020 17:47
Show Gist options
  • Save NTT123/fc7bc21821bc0d73849e43c8b3296b71 to your computer and use it in GitHub Desktop.
Save NTT123/fc7bc21821bc0d73849e43c8b3296b71 to your computer and use it in GitHub Desktop.
Generate Novel Names.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "Generate Novel Names.ipynb",
"provenance": [],
"collapsed_sections": [
"Tf_Tk5mg7LxY",
"PdObT-3j9_gp"
],
"authorship_tag": "ABX9TyM1oeDRjwe13u3vM4Q7QxmH",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"6a1742e2c96347feb0011bdf636679b4": {
"model_module": "@jupyter-widgets/controls",
"model_name": "IntSliderModel",
"state": {
"_view_name": "IntSliderView",
"style": "IPY_MODEL_39d9edbfc62a49c7b6969e9d1c68cf20",
"_dom_classes": [],
"description": "nucleus prob.",
"step": 1,
"_model_name": "IntSliderModel",
"orientation": "horizontal",
"max": 100,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 80,
"_view_count": null,
"disabled": false,
"_view_module_version": "1.5.0",
"min": 10,
"continuous_update": true,
"readout_format": "d",
"description_tooltip": null,
"readout": true,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_195bcdd694be4315933f442ece2e5f75"
}
},
"39d9edbfc62a49c7b6969e9d1c68cf20": {
"model_module": "@jupyter-widgets/controls",
"model_name": "SliderStyleModel",
"state": {
"_view_name": "StyleView",
"handle_color": null,
"_model_name": "SliderStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"195bcdd694be4315933f442ece2e5f75": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"38dc6ef373c1489d9ba577a4cfdca079": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ButtonModel",
"state": {
"_view_name": "ButtonView",
"style": "IPY_MODEL_356a241447fd4d4c857eecf556ff4e40",
"_dom_classes": [],
"description": "Generate!",
"_model_name": "ButtonModel",
"button_style": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"tooltip": "",
"_view_count": null,
"disabled": false,
"_view_module_version": "1.5.0",
"layout": "IPY_MODEL_40c1c4e5cf9f4b2a9df93bceca159210",
"_model_module": "@jupyter-widgets/controls",
"icon": ""
}
},
"356a241447fd4d4c857eecf556ff4e40": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ButtonStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ButtonStyleModel",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"button_color": null,
"font_weight": "",
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"40c1c4e5cf9f4b2a9df93bceca159210": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"ea20d7a3db2b41338cdf788e8b8c8353": {
"model_module": "@jupyter-widgets/output",
"model_name": "OutputModel",
"state": {
"_view_name": "OutputView",
"msg_id": "",
"_dom_classes": [],
"_model_name": "OutputModel",
"outputs": [
{
"output_type": "stream",
"metadata": {
"tags": []
},
"text": "Nguyên Thủy Tiên Tôn\n",
"stream": "stdout"
},
{
"output_type": "stream",
"metadata": {
"tags": []
},
"text": "Đại Việt Chi Thế Giới Đại Đạo\n",
"stream": "stdout"
},
{
"output_type": "stream",
"metadata": {
"tags": []
},
"text": "Tuyệt Thế Thần Đế\n",
"stream": "stdout"
},
{
"output_type": "stream",
"metadata": {
"tags": []
},
"text": "Nhị Thứ Nguyên Chi Thiên Nhãn Thiên Đình\n",
"stream": "stdout"
}
],
"_view_module": "@jupyter-widgets/output",
"_model_module_version": "1.0.0",
"_view_count": null,
"_view_module_version": "1.0.0",
"layout": "IPY_MODEL_f10a02ce5c324eb6a7921bd2f3b34950",
"_model_module": "@jupyter-widgets/output"
}
},
"f10a02ce5c324eb6a7921bd2f3b34950": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/NTT123/fc7bc21821bc0d73849e43c8b3296b71/generate-novel-names.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NVwfamGOHFgp"
},
"source": [
"# Training"
]
},
{
"cell_type": "code",
"metadata": {
"id": "_5q1e_NPsAuP"
},
"source": [
"_=\"\"\" \n",
"\n",
"Tasks:\n",
"\n",
"- [x] Create a dataset of novel names.\n",
" + [x] download html pages.\n",
" + [x] parse html pages to get names.\n",
" + [x] clean up.\n",
" + [x] submit dataset to gist.\n",
"\n",
"- [x] Create RNN model to generate novel names.\n",
" + [x] data loader.\n",
" + [x] model.\n",
" + [x] training.\n",
"\n",
"- [x] Improve model.\n",
" + [x] use dropout.\n",
" + [x] use embedding layer.\n",
" + [x] use EMA.\n",
"\"\"\""
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pspCXo9aEboV",
"outputId": "496c8c02-0784-4618-b78c-27752601b16e"
},
"source": [
"!nvidia-smi | grep Tesla\n",
"!pip install -Uqq git+https://github.com/fadel/pytorch_ema\n",
"# !pip install -Uqq wandb\n",
"# !wandb login"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 |\n",
" Building wheel for torch-ema (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Tf_Tk5mg7LxY"
},
"source": [
"### Dataset of novel names"
]
},
{
"cell_type": "code",
"metadata": {
"id": "I_bINHgon9OD"
},
"source": [
"# \"\"\"\n",
"# Crawl data from truyencv.\n",
"# \"\"\"\n",
"\n",
"# import os\n",
"# import tqdm\n",
"\n",
"# for i in tqdm.trange(1, 1029, desc='downloading'):\n",
"# cmd = f'curl https://truyencv.com/danh-sach//trang-{i} -o data/trang-{i}.html'\n",
"# os.system(cmd)\n",
"\n",
"# !zip -q -r novel_names_html.zip data\n",
"# !ls data/*.html | wc -l # expect: 1028"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "dkJNDAaPp75C"
},
"source": [
"# \"\"\"\n",
"# Parse html files to get list of novel names. Use lxml library.\n",
"\n",
"# Normalize names.\n",
"# \"\"\"\n",
"\n",
"# import lxml\n",
"# from lxml.html import parse\n",
"# import unicodedata\n",
"# from pprint import pprint\n",
"\n",
"\n",
"# alphabet = \" !%&'()*+,-./0123456789:?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxyz{}~ÀÁÂÉÊÍÐÒÓÔÕÚÝ\\\n",
"# àáâãèéêìíòóôõùúýāĂăĐđĩŌōũūƠơƯưẠạẢảẤấẦầẨẩẫẬậẮắằẳặẹẻẽẾếỀềểỄễỆệỈỉỊịọỏỐốỒồỔổỗỘộỚớờỞởỡợụỦủỨứừửữỰựỲỳỵỶỷỹ\"\n",
"\n",
"# # normalize name\n",
"# def normalize_name(name):\n",
"# replace_list = [('\\xa0', ' '), ('〖', '['), ('〗', ']'), ('–', '-'), ('—', '-'), \n",
"# ( '∶', ':'), ( '♥', ' '), ('\\ufeff', ''), ('’', \"'\"), ('·', '.')]\n",
"\n",
"# name = unicodedata.normalize('NFC', name)\n",
"# for a, b in replace_list:\n",
"# name = name.replace(a, b)\n",
" \n",
"# name = ''.join([c for c in name if c in alphabet])\n",
"# return name\n",
"\n",
"\n",
"# def parse_names():\n",
"# names = []\n",
"# for i in range(1, 1029):\n",
"# root = parse(f'data/trang-{i}.html').getroot()\n",
"\n",
"# for h2 in root.iter('h2'):\n",
"# if 'class' in h2.attrib:\n",
"# if 'title' in h2.attrib['class']:\n",
"# a = next(h2.iter('a')).attrib\n",
"# if 'title' in a:\n",
"# names.append(a['title'])\n",
"\n",
"# names = [name.strip() for name in names]\n",
"# names = sorted(list(set(names)))\n",
"\n",
"# return names\n",
"\n",
"\n",
"# names = parse_names()\n",
"# pprint(names[:10], compact=True)\n",
"# pprint(len(names))\n",
"\n",
"# names = [normalize_name(name) for name in names]\n",
"\n",
"# # print alphabet after normalization\n",
"# text = '\\n'.join(names)\n",
"# pprint(sorted(list(set(text))), compact=True)\n",
"\n",
"# # save to file\n",
"# with open('novel_names.txt', 'w') as f:\n",
"# f.write('\\n'.join(names))"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3usCMaWJ6021",
"outputId": "1f4c1320-b654-424e-af21-c81b52747f70"
},
"source": [
"!git clone https://gist.github.com/NTT123/72d2341a77e0d407eb0dfca0fba52d59 corpus"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"fatal: destination path 'corpus' already exists and is not an empty directory.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PdObT-3j9_gp"
},
"source": [
"## Dataloader"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 0
},
"id": "cE74Ntt0-B3i",
"outputId": "36c9de9a-d879-446f-f121-a12a6c733fad"
},
"source": [
"\"\"\"\n",
"Load data in batch infinitely.\n",
"\"\"\"\n",
"\n",
"from torch.utils.data import IterableDataset, DataLoader\n",
"import numpy as np\n",
"import random\n",
"import matplotlib.pyplot as plt\n",
"\n",
"class TruyenCVDataset(IterableDataset):\n",
" def __init__(self, filepath: str, batch_size: int=32, mode='train'):\n",
" self.filepath = filepath\n",
" self.PAD = 0\n",
" self.rng = random.Random(42)\n",
" self.batch_size = batch_size\n",
" self.mode = mode\n",
" self.load_text()\n",
"\n",
" def __iter__(self):\n",
" self.rng.shuffle(self.encoded_lines)\n",
" L = len(self.encoded_lines) * 9 // 10\n",
" data = self.encoded_lines[:L] if self.mode == 'train' else self.encoded_lines[L:]\n",
"\n",
" while True:\n",
" self.rng.shuffle(data)\n",
" for i in range(0, len(data) - self.batch_size, self.batch_size):\n",
" batch = np.array(data[i: i+self.batch_size])\n",
" mask = batch != self.PAD\n",
" yield batch, mask\n",
"\n",
" def load_text(self):\n",
" text = open(self.filepath, 'r').read()\n",
" self.alphabet = ['[PAD]', '[BEGIN]', '[END]'] + sorted(list(set(text)))\n",
" self.lines = text.strip().split('\\n')\n",
" self.max_len = max([len(line) for line in self.lines]) + 2 # included [BEGIN] and [END]\n",
" self.encoded_lines = [self.encode_text(t, self.max_len) for t in self.lines]\n",
"\n",
" def encode_text(self, text, max_len):\n",
" text = ['[BEGIN]'] + list(text) + ['[END]']\n",
" pad_len = max(0, max_len - len(text))\n",
" text = text + ['[PAD]'] * pad_len\n",
" return [ self.alphabet.index(c) for c in text if c in self.alphabet ]\n",
"\n",
" def decode_text(self, tokens):\n",
" text = ''.join([self.alphabet[i] for i in tokens])\n",
" text = text.replace('[PAD]', '_')\n",
" return text\n",
"\n",
"def test_dataset():\n",
" dataset = TruyenCVDataset('corpus/novel_names.txt')\n",
" print(dataset.decode_text(dataset.encoded_lines[0]))\n",
" print(dataset.decode_text(dataset.encoded_lines[1]))\n",
" print(dataset.decode_text(dataset.encoded_lines[2])) \n",
"\n",
"\n",
" train_iter = iter(TruyenCVDataset('corpus/novel_names.txt', 32, 'train'))\n",
" test_iter = iter(TruyenCVDataset('corpus/novel_names.txt', 32, 'test'))\n",
" batch, mask = next(test_iter)\n",
" plt.subplot(1, 2, 1)\n",
" plt.matshow(batch, fignum=0)\n",
" plt.subplot(1, 2, 2)\n",
" plt.matshow(mask, fignum=0)\n",
" plt.show()\n",
"test_dataset()"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"[BEGIN](Fairy Tail) Nô lệ Quỷ Dữ[END]___________________________________________________\n",
"[BEGIN](Xuyên Không) Nam Chính À Tránh Xa Tôi Ra.[END]__________________________________\n",
"[BEGIN](ĐỒNG NHÂN THE HOBBIT) TA MUỐN VỀ NHÀ NGA[END]___________________________________\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAABhCAYAAADP5Pq1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAZL0lEQVR4nO2de3RV1bWHvxkSkgAhJBDCU0ClVLC8Fam0orYWtcVb21oftVRFrNVRbS2t2just7293mqxSh3KQxB8S0WKZeCDi9gWpVAUQRR5VawgECRE3nmu+8dcO8k5JOQ8k3PC/MY445y9z9pzr30ys/dvzz3XXOKcwzAMw0g/Mlq6A4ZhGEZs2AncMAwjTbETuGEYRppiJ3DDMIw0xU7ghmEYaYqdwA3DMNKUZjuBi8g4EdkoIltE5PYYtp8tIiUisr7eukIRWSIim/17QYS2eovIMhF5X0TeE5Fb4rSXIyKrRGStt/dffn0/EVnpj/k5EWkbxfG2EZE1IrIoAba2ici7IvKOiKyO81g7icjzIvKBiGwQkdFx2Brg+xS89ovIrbHaaynMt823G7DVPL7tnEv6C2gDbAVOBtoCa4GBUdr4MjAcWF9v3b3A7f7z7cDvIrTVHRjuP+cBm4CBcdgToIP/nAWsBM4C5gGX+/XTgBujON6fAk8Di/xyPLa2AV3C1sV6rHOBif5zW6BTrLYa8JFdQJ9E2Guul/m2+XZL+nZzOflo4JV6y3cAd8Rgp2+Yk28Eutdz3I0x9m8h8NVE2APaAW8Do4BPgcyGfoMmbPQClgLnAYv8P1FMtnz7hpw86mMF8oEPAYnXVgO2LwDeSOTftTle5tvm2xHYTppvN1cIpSfwcb3l7X5dvBQ753b6z7uA4mgNiEhfYBiqLGK2528L3wFKgCWoKitzzlX5JtEc8wPAz4Eav9w5DlsADnhVRN4SkUl+XSzH2g/YAzzmb4EfFZH2MdoK53LgmTj61lKYb5tvN0XSfLvVPMR0ekmLqi6AiHQA5gO3Ouf2x2PPOVftnBuKKowzgc9H05d6ffo6UOKceyuW7RthjHNuOHAhcJOIfLn+l1EcayZ6q/+Ic24YcAi9DYzFVi0+5jke+FP4d7HYa22Ybx+XE9q3m+sEvgPoXW+5l18XL7tFpDuAfy+JdEMRyUId/Cnn3Avx2gtwzpUBy9BbwU4ikum/ivSYzwbGi8g24Fn0VvPBGG0Ffdrh30uABeg/YSzHuh3Y7pxb6ZefR50+3t/tQuBt59xuvxz336EZMd823z4eSfXt5jqB/xPo7582t0VvKV5MgN0XgQn+8wQ03tckIiLALGCDc+7+BNgrEpFO/nMuGnPcgDr7t6Ox55y7wznXyznXF/2dXnPOXRWLLd+f9iKSF3xG43HrieFYnXO7gI9FZIBfdT7wfiy2wriCultMEmCvOTHfNt8+Hsn17XgC6FEG8i9Cn4hvBX4Zw/bPADuBSvRqeR0aP1sKbAb+DyiM0NYY9NZlHfCOf10Uh73BwBpvbz1wl19/MrAK2ILeQmVHecxjqXtSH5Mtv91a/3ov+O3jONahwGp/rH8GCmK15e21B/YC+fXWxWyvJV7m2+bbLeXb4o0ahmEYaUareYhpGIZxomEncMMwjDTFTuCGYRhpip3ADcMw0hQ7gRuGYaQpcZ3AJYYqbPWGu8ZNIm0l2p71reVtxYP5dvPYSrS9E6lvQOx54MRYhQ1YHW/ebTJsWd9Sw16i+xZjH8y3rW8p3zfn4itmdSawxTn3L+dcBTo09pI47BlGqmC+baQFmU03aZSGqrCNOt4GbSXb5dCODgW9HEB1jgDg9I3CggMA7NubB0BWyWH9vmMuANVZ2jCrsELf8wpoV9zbZe7RdvTPpKpMa8EH61ye37at37azblu5L7Qd7XPIzs6nx6BODuDghjYA1OTr9jWZun3X4jIASvZ00u18TbXMTw+FHGsO7egohQkbJZVIe621b0c5RIUrlwR0I2rf7lLYxp3UM5ORQ3IS8jsk0lai7G1a1w5ovf6TbHvJ8O14TuAR4eM+k0APYEzGxXw4/QsAVB7Wk+iAG98F4PBZQwA41FVPnlff8hIALw/XiovVQwYCUHDPvwHYNvNzABQ++7ba6306ZeeqzapcPdacUv29CuZpm+qBamP8gtcAeOLei+psCFR2Px2ApUtnATCun/7fHr4wtG/tCtT+qRdtBWDrolMA6DHF18KpqY7qdzLiZ6Vb2qz7q+/bJ/XM5MPVfZt1/y3F13oMbekunHA05tvxhFAiqsLmnJvhnBvpnBuZRXYcuzOMZiNq3y7q3KbZOmcYAfEo8NoqbKhzXw5cGcmGPR7XE/mBXrr7TVP0it5xk/4TdJuu5YJfKLkAgEMTdP2+wRqv6DhPC4Z1e1bb1VfHo69Xpb1ljO6r+ozTANjzg+EA3DP5UQDuuG8iAEVevQftyvqrgh+z7lIAOoT932aM3wvA9IHPAnDTAzcDkH3Easq0ImL2bcNoTmI+gTvnqkTkZuAV9Kn9bOfcewnrmWG0EObbRroQVwzcObcYWBztdln7K/WDr+E+bdxsAO54d2JIu3Y7jgDw8TdUseev9Q8ew9Su+MmYfvGzp5n9BZ0s5F9zVaVfddoqAFacqQ9G79nmS/H20bejf+kOwN4l+oCm51RV9dXva6x832Uaf3c+2FS1uAsAvx5/BgDFNW9GethGGhGrb7cmLNad+thITMMwjDQlrnrgfmqkA0A1UOWcG3m89h2l0I3K+ApV52o8uuxUVdTlPqOjxt8PtNVsQtoc1b4VzdU4dXgmyL5B+n3HrRm17cPblg7IDLH93cteB+D5J8cCdYo73HZ4LD38+8Lvbgdgz0J91hXE7Su+pFksn52ix1adLQ0fyzi1d+R6TUssHK/ZLJa9Ejsr3VL2u9JEpBFG7dsjh+S4Va/0Pl6TEx5T9LHTmG8nIo3wXOfcpwmwYxiphvm2kdIkPQ+8ITIqNZsk91I/z+cCjTMXzQlVz31/8QEAH5SrYg/yvdsP6Kft534IQMU5mld+28NPcnv+tUCdsoZQ5bzoY1XImWNKAdhVPkKXj4Qq5FfP0xh64QId6FO5sE1oH0vUbtZE/f/+6Y91/c+nhtorfig0Lzy438lduMq/N/ozGUZaYMq65Yg3Bu6AV0XkrVQpQmQYCcJ820h54lXgY5xzO0SkK7BERD5wzv2tfoPwkZgAH92karTyk0IA8n2c+NPvq9Lu8riq2X9XDAagupfa2jtfU0eqfJjYLQ1Vu38cNIRuZ6hiDvK+9w1StZ+3VffR/Vsf6bZDdBTnp0N12yd+OQWAnzx+bsgBjizSEdXLs7uGrM/ddRSAS/upkr5/kO6vyO+/dJAOwb9t0zqgLu+8y0xtb7HulCcq3z6pZ4vczKYEr3zyTsiyKfLmIy4F7pzb4d9LgAVoEaDwNjYS00g7ovVtG4lptAQxywYRaQ9kOOcO+M8XAL+OZNuej2mGxv4+WQD0vkozMLbNPyWkXY0vXpWp6eBU/aUzAMU+Dl0+tgMAh7uqnd3Xj6B4psa+23fRGDUuyBpRJb3LqWrv8dwWAIrWbgLghpJb1eZl2v7kCT4LZbAq9R/M0ZTgl6drvL4iX4/hiQcvBKB0qiaj57+nfQmyUqasuVz74UVJwd/zAdj4jOarB/2tuyPQY+o/YWNtm26Pal9qKnz+vKn3pBKPbxvHKvJoMPUeHfHc9xUDC0QksPO0c+7lhPTKMFoW820jLYgrDzxagjzwI5foKMamcqoPXDIMgLyFa4C60ZWFHYPMkCIgNHvlULHaPO+GfwDw0nOjAej1QMP53kFtE8vDTn8SmQceLemQB27qNn1pzLdtJKZhGEaa0mQIRURmA18HSpxzp/t1hcBzQF9gG3CZc25ftDuvzNMLyvhumqkxy1fwDM9GCUYt5i1T1Xw0W+PEbcLuHnI/OcKhrvrda9PPAqDXnFDl/VkfPeQDJ/vKhq9qbRPJ0uyU8rGq+rdP0Jh2+39o5kyQ6RL0actvtV2Hf2c0+H2wv7JTdH8HB2vWSv4/c4C62HdGri4fGdVf2/XUGPrhYqHPtA0AVH+2Xw/Q7g4SSjJ9OxWJJzadbOzuIDYiUeBzgHFh624Hljrn+gNL/bJhpBtzMN820piIYuAi0hdYVE+lbATGOud2ikh34HXn3ICm7EQbAw+PeR95RXOxu0/T7yvP1lGV265RNX3qhHWmUk9gYomBJ8q3myMGbir1xCXRMfBi59xO/3kX+tS+QURkkoisFpHVlZTHuDvDaDZi8u09e004GM1P3MPHnHNORBqV8c65GcAMoHZCT5ehF5J9X9QJhqu98g6v893vAX3fcY4q715eeZderjHyaj8uqNNyfZesTHZfr3NYZl9UAsDRl3Xb8GqAu24YEbI+79v6P7t3SQ+grpbKzhv9aM/D2q6xu4Xd12u7867R7JfXHtMYfBDr3jNhuD9G/7v4S2dlez9Z8jSbSzPViMa3EzkBcWOkcgw7EuwOIvHEqsB3+9tL/HtJ4rpkGC2K+baRNsSqwF8EJgD/69+jqqknNX6m+BU6mrF0mB/F+H5Yd3y7u657CoDZD+jIxIINB4G6UYulI3T7ojmQt10/f7yrk9r0pro8qjVInFe4xVPDZtKZoW892Kbt/OpuDzTcrpiPQ9p1fUjbrX/IL/NmyPddZqzASAvi8u2WxBTuiUeTClxEngFWAANEZLuIXIc691dFZDPwFb9sGGmF+baR7jSpwJ1zVzTy1fkx77WRaGFVbsPrH96mFQL33ajx6e6PaFy5UwfNQkFUye+6YURtTHvAjWt1V8M0geDCd3XE5cvD9ZlUeJ72X2/9PQBf+8/bgLra47UjN/0Iz9ETdf3fn9KYd7epFrtOV5Li2y1IsmLkpuxTl0gU+GwRKRGR9fXW3S0iO0TkHf+6KLndNIzEY75tpDuRxMDnAA8Bj4et/4Nz7vex7HTXFToqMcjvLqytaRKa/53fQdvdd8qLANxa80OgbqRm0UodIFf0hv7/Hb5wSG1uefVizT7Zs1Dj5IHyDqr+ZfxIn03JYlX1V5yqKr/DlzQz5pktywAYd6ffl6+3snmNbp/p64h//V2dkWfW9IuBY2fgMVKaOSTYt1sjVu87dWlSgfsi9qXN0BfDaFbMt410J5488JtF5PvAauC2aOpF9Ltf1evm72l97y4qtGtzpYNa3MFcl7ffemmDdtxGnRMzUOx5y9rUKuXyj3Tbl2beC8A10zWseaSHBtoPPa/KvGtQJ7xGY9qBgr689xcB6IRmj9SG7Ve9q332E+ssmlGg2xGWrWKkMzH7djpgCrr1EOsJ/BHgN+h57TfAFODahho2NKVaTbbu9u4L5gNw/5bLtDNHGn66uTssJTCc3jO0AFRZ/7oBNcEAmmu+dSMABxce0vfFmSH7yvjqJ9q+5uNGD9Y4oYjJt1N5SjU7YbdeYhrI45zb7Zyrds7VADNpYLqpem1tSjUjbYjVt21KNaMliEk2iEj3evUivgmsP177xrj3MVXeNXm6fP4EDV8szdVh6MEw9VOPDARgp0Y1qPBlaEeu1Iegi2Zn17YPHlIG4ZjyAj8t22LdSXjZVzdSbT87fzpwbBphU5NIHD1fJ14+0Et/ytxLdwNwcHG3kGMIb7fvdC3AVfCeXkNtsuPUIFG+nUqkyhB8uxNIPJHUA38GGAt0EZHtwK+AsSIyFL3N3AbckMQ+GkZSMN820p0WnVKtPE9vOwvmhQ6aqZikiQHhard2UM6svwF1qYHVZ5wGQOmgXO6Z/CgAd9w3MWTbii/50rPf0Hh5/mZVvuGTCl/3pKYszv7C50P61FjpW0sbTB1sSrXkYMq55bEp1QzDMFoZTSpwEemNDnQoRm8rZzjnHoxl6qlAgdeM0Sv6D2dpFspvHvoecOxEDn8YOQ+AX069NuT7cFWMvy6VDquuLYgVbuuCUzYCsHmUFrsyxdz6iFaBJ9K3U0mBm2JufcSjwKvQXNiBwFnATSIyEJt6ykh/zLeNtCaSYlY7gZ3+8wER2QD0BC5BHwABzAVeB34RyU6PFmnxqclvfgeAjn59MEQ+GMgzdZhmqVSdE7p97ic6K0IwgXHAgFvW1qrzT27WfPC8ZX4ih9MOAPD6ZO1yrwdUoVefqVkoJcN1gM+BU1SZd9yi6j5Q8ocuVlVz9FoVYlXVeu1zSwuBuuyWzrMsmyRdSIZvpwLJzjoxhZ86RBUD9/MHDgNWEsXUU4aR6phvG+lIxFkoItIB+CvwW+fcCyJS5pzrVO/7fc65gga2qz8Sc8SYjIvZ9mvN887fom2CnOtADW/9jir0QAVbpocRCbFmoSTCt0/qmTniw9V9Y+67qVrjeMSVhSIiWcB84Cnn3At+dURTT9lITCOVSZRv20hMoyWIZCCPALOADc65++t9FfPUU0VrdRTigZ7q9EG8uo0vatV5jY8nP+3j1F/UwlQ7z9ZaKu0/0e+PdNULUnapLte0hUwdMFk7AtPUu9EYyfDtWIknbm3q/cQlkqH0ZwNXA++KSOBld6LOPc9PQ/URcFlyumgYScN820hrIslCWU5tpvUxxDT1VMm3NYukqkQzP4JYd0ZloLxDR15W3lWmGy5VBR6M3Mz2GSdPTlXxNKn/+bVZKFW5anPH5FEAHOqn+d8DbvFTrfmRl8HEyLU1U7Rrx9RMYbC27/gHrV64YaH2refU0HzzPkWapbLzVc0J7nmf3QGkKsnw7ZagpWud2B1Ay2EjMQ3DMNKUSGLgjY1Wuxu4Htjjm97pnFscyU6L5qvy/uxkVckVvhrhkYGqUqtz/ITBPge7eppe4bOLVBWPXqU53c8/qd2f1F/F0qGLh7JjvCrtz92gSjvIbCktywnZ9oXHdKfBBMnhozvPmqTKe4ufbKs6W9dvfUKV+JTJMwG4Y7/WXOl3ZWj+d0+svniqkwzfbs2Y0k49IomBB6PV3haRPOAtEVniv7O5A410xnzbSGviGYkZM233q0rNKdUIzpK7pgB1EwsHanjPD3Rk5thJvk74XM0fX3Gmqueic3QC4pc/1O/H9YM+hzRjJagHvm+wZrx03KShzjeHqfrvVqNToAVZ8Ll/VgWd65c3zwjtc8ZyjTN2Wa7LU6YP0mU/5ZqRfiTDt1MBU8onDvGMxASdO3CdiMwWkWMGOhhGumC+baQj8YzELAY+pW7uwO7OuWPmDmxoJCZnaG3uIAMkyAOoyfK52zMbqVNyslfTW/W6E54pcuCSYQybvAaA5U+ExtHDY9xB1kn3CzRWHdT3DtqHZ6k8eafeJfxkgL9LGKf2br7vOeDY+uFBrfMDJ+l+et9j2SjJJoEjMaP27XhHYrYGTPknj4SPxIx07kAbiWmkMonybRuJabQEMY/EjGfuwCM9VFFX+rkta3wvggqBtWq5WP8p9p+mmSXhtb5LL9c4dxDvRmDLGP3Y+exy/c7H0UdPVJX+xtwRofvaqPtqU6xqfuODQxrc14833ATAmFWqpFecqVkuD8l3ta8TtK8/+skCAP74yKUAtCtpvhmPjOhIhm8bRnMSz0jMK2zuQCPNMd820pp4RmLGnBebUaGqVPzkOFl+9OPeq1QtV7XzytzflYar4fKxmmlS7SMyV9/yEqBzZIbHuufdeR8AN31Oc8U7jdXMlaD2+NHC0LuAxqhqp/YmFWi2ygo/UC+8Nvl///0b2me/nfO/nGTpDqrP0L5XddC5Odu+6kd6Wmy82UmGb5/INDYi1GLjycNGYhqGYaQpkcTAc4C/Adm+/fPOuV+JSD/gWaAz8BZwtXOuIpKdVrbX60Z1jl/2yShlvVWFBoq7+zRV3Ee/MhiA7bdq/DqoRlg0V9Xrw+d9GYALlm9k+ROhav2H798MwP7LNO5eke9rnhxWG0EsPMg6mfX8IwB8643JIX3OPKx9O3vBzwDoeENYJsyjqsy7zGhYSQeR8CCfvG2DrYzmJBm+3ZoxJZ16RKLAy4HznHNDgKHAOBE5C/gdOlrtVGAfcF3yumkYScF820hrIomBO+CgX8zyLwecB1zp188F7gYeiWSnOXs1+H20QK8fVbmqivPe0DhzwTxVxZVna7746zO17si4fqNC1gfZJ23b7gc09ztQxAFtylURT7pTs0P+NLQPUJdjHtRGWXHmJgCu/PFPtU9ddfvvvPMRAH98RNV/4fpg7kvL6053kuHbrQlT3KlPpHngbfxT+hJgCbAVKHPO+ceQbKeRIcgiMklEVovI6krKE9Fnw0gYifLtPXvtQm40P5GkEeKcqwaGikgnYAHw+Uh34JybAcwA6CiFDsBlquLO+g8t9laxTOVuUOc7PJPk5h2jQjt9MAhHaiQ5+xWd175ozlu1sezaWihj9KIx43++qftA91HeWbNADlaHDi7K3lsJwJGJWoM8UOxBXnnZqbrPTY8OAyD/HV22WenTk0T59sghOa0u4d+ySlKfqLJQnHNlwDJgNNBJRIILQC9gR4L7ZhjNhvm2kY40WQtFRIqASudcmYjkAq+iD3kmAPOdc8+KyDRgnXPu4SZs7QEOoXUmEkGXBNpKtD3rW/Pb6uOcK4q0sfl2SthKtL3W2reGfds5d9wXMBhYA6xDhxTf5defDKwCtgB/ArKbsuW3Wx1Ju+a2ZX1LDXuJ7lsT+zLftr6lbd+ccxFloaxDy2yGr/8XjRT5MYx0wHzbSHdsJKZhGEaa0hIn8BlNN2kRW4m2Z31reVvNTSr/Dta3lreXcN+OeEIHwzAMI7WwEIphGEaaYidwwzCMNMVO4IZhGGmKncANwzDSFDuBG4ZhpCn/D4+j+vBsiVdXAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xsAwdWc6-F3Q"
},
"source": [
"## Model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "dfOkKwClCFOo"
},
"source": [
"import torch\n",
"import torch.nn as nn\n",
"# from torch.optim.lr_scheduler import OneCycleLR\n",
"import tqdm\n",
"\n",
"class RNN(nn.Module):\n",
" def __init__(self, input_size, hidden_size, output_size, num_layers):\n",
" super(RNN, self).__init__()\n",
" self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True)\n",
" self.input_size = input_size\n",
" self.embed = nn.Embedding(input_size, input_size)\n",
" self.embed_dropout = nn.Dropout(0.5)\n",
" self.output = nn.Linear(hidden_size, output_size)\n",
" self.dropout = nn.Dropout(0.5)\n",
" self.softmax = nn.LogSoftmax(dim=-1)\n",
"\n",
" def forward(self, inputs):\n",
" inputs = self.embed(inputs)\n",
" inputs = self.embed_dropout(inputs)\n",
" x, _ = self.lstm(inputs)\n",
" x = self.dropout(x)\n",
" x = self.output(x)\n",
" x = self.softmax(x)\n",
" return x"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "cI0eE81q-HS4"
},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zGrxyu2NVMbZ"
},
"source": [
"from typing import Deque\n",
"# import wandb\n",
"from torch.nn.utils import clip_grad_norm_\n",
"from torch_ema import ExponentialMovingAverage\n",
"\n",
"\n",
"class MovingAverage:\n",
" def __init__(self, max_len): self.deque = Deque(maxlen=max_len)\n",
" def append(self, v): self.deque.append(v.detach())\n",
" def mean(self): return sum(self.deque).item() / len(self.deque)\n",
"\n",
"\n",
"def train(net, optimizer, lrs, train_iter, val_iter, num_steps, device):\n",
"\n",
" def prepare_batch(x):\n",
" data, mask = x\n",
" data = torch.from_numpy(data).long().to(device)\n",
" mask = torch.from_numpy(mask).long().to(device)\n",
" return data, mask\n",
"\n",
" def loss_fn(data_iter, mode='train'):\n",
" data, mask = prepare_batch(next(data_iter))\n",
" if mode != 'train': \n",
" net.eval()\n",
" with torch.no_grad():\n",
" logits = net(data[:, :-1])\n",
" net.train()\n",
" else:\n",
" logits = net(data[:, :-1])\n",
" \n",
" logits = torch.transpose(logits, 1, 2)\n",
" losses = torch.nn.functional.cross_entropy(logits, data[:, 1:], reduction='none')\n",
" mask = mask[:, 1:]\n",
" losses = losses * mask\n",
" return torch.sum(losses) / torch.sum(mask)\n",
"\n",
" train_loss = MovingAverage(1000)\n",
" val_loss = MovingAverage(500)\n",
"\n",
" trange =tqdm.notebook.trange(num_steps, desc='training')\n",
" ema = ExponentialMovingAverage(net.parameters(), decay=0.999)\n",
"\n",
" for i in trange:\n",
" loss = loss_fn(train_iter, 'train')\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" train_loss.append(loss)\n",
" gn = clip_grad_norm_(net.parameters(), 10.)\n",
" optimizer.step()\n",
" ema.update(net.parameters())\n",
" if lrs is not None: lrs.step()\n",
"\n",
" if i % 10 == 0:\n",
" ema_bk = ExponentialMovingAverage(net.parameters(), decay=0.995)\n",
" ema.copy_to(net.parameters())\n",
" loss = loss_fn(val_iter, 'val')\n",
" ema_bk.copy_to(net.parameters())\n",
" val_loss.append(loss)\n",
" \n",
" if i % 5000 == 0:\n",
" lr = lrs.get_last_lr()[0] if lrs is not None else 1.0\n",
" print(f'step {i} train loss {train_loss.mean():.3f} val loss {val_loss.mean():.3f} lr {lr:.5f} grad norm {gn:.3f}') \n",
" # wandb.log({'train loss': train_loss.mean(), 'val loss': val_loss.mean(), 'lr': lr })\n",
"\n",
" ema.copy_to(net.parameters())"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "5Fh6F3VIMsVw"
},
"source": [
"batch_size=32\n",
"device = 'cuda'\n",
"lr = 4e-3\n",
"wd = 1e-1\n",
"num_layers=1\n",
"num_steps = 100_000\n",
"max_len=100\n",
"hidden_size = 512\n",
"train_dataset = TruyenCVDataset('corpus/novel_names.txt', batch_size, 'train')\n",
"val_dataset = TruyenCVDataset('corpus/novel_names.txt', batch_size, 'val')\n",
"vocab_size = len(train_dataset.alphabet)\n",
"net = RNN(vocab_size, hidden_size, vocab_size, num_layers).to(device)\n",
"optimizer = torch.optim.AdamW(net.parameters(), lr=lr, weight_decay=wd)\n",
"lrs = None \n",
"# lrs = OneCycleLR(optimizer, lr, total_steps=num_steps)\n",
"train_iter = iter(train_dataset)\n",
"val_iter = iter(val_dataset)\n",
"# wandb.init(project='NovelGen')"
],
"execution_count": 9,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "kGTS4zmrURAQ"
},
"source": [
"train(net, optimizer, lrs, train_iter, val_iter, num_steps, device)\n",
"torch.save(net.state_dict(), 'novel_name_model.ckpt')"
],
"execution_count": 14,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "cX8Zt19aDp29"
},
"source": [
"# Generate samples"
]
},
{
"cell_type": "code",
"metadata": {
"id": "-Sv9qY98YMy3",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 149,
"referenced_widgets": [
"6a1742e2c96347feb0011bdf636679b4",
"39d9edbfc62a49c7b6969e9d1c68cf20",
"195bcdd694be4315933f442ece2e5f75",
"38dc6ef373c1489d9ba577a4cfdca079",
"356a241447fd4d4c857eecf556ff4e40",
"40c1c4e5cf9f4b2a9df93bceca159210",
"ea20d7a3db2b41338cdf788e8b8c8353",
"f10a02ce5c324eb6a7921bd2f3b34950"
]
},
"cellView": "form",
"outputId": "8bdb053a-5624-43bc-c68b-133d474a53d3"
},
"source": [
"#@title\n",
"from torch.distributions.categorical import Categorical\n",
"\n",
"def nucleus(probs, p=0.9):\n",
" sp, _ = torch.sort(probs, dim=-1)\n",
" csp = torch.cumsum(sp, dim=-1)\n",
" t, _ = torch.max(torch.where(csp <= (1. - p), sp, torch.zeros_like(sp)), dim=-1, keepdim=True)\n",
" return torch.where(probs > t, probs, torch.zeros_like(probs))\n",
"\n",
"\n",
"def sample(net, hidden_size, max_len, vocab_size, device, p=0.9, num_layers=1):\n",
" hx = torch.zeros(num_layers, 1, hidden_size, device=device)\n",
" cx = torch.zeros(num_layers, 1, hidden_size, device=device)\n",
" state = (hx, cx)\n",
" inp = torch.tensor([[1]], dtype=torch.long, device=device)\n",
" outs = [inp]\n",
" net.eval()\n",
" with torch.no_grad():\n",
" for i in range(max_len):\n",
" inp = net.embed(inp)\n",
" out, state = net.lstm(inp, state)\n",
" logits = net.softmax(net.output(out))\n",
" probs = torch.exp(logits)\n",
" probs = nucleus(probs, p)\n",
" token = Categorical(probs=probs).sample()\n",
" inp = token\n",
" outs.append(token)\n",
" if token.item() == 2: break\n",
" net.train()\n",
" outs = torch.cat(outs, dim=0)\n",
" return torch.squeeze(outs)\n",
"\n",
"\n",
"# !gdown -q --id 1I2VyTpXoC7aZz2t9DpPvSaOOeCrD07oz\n",
"device = 'cuda'\n",
"num_layers=1\n",
"hidden_size = 512\n",
"vocab_size = 199\n",
"net = RNN(vocab_size, hidden_size, vocab_size, num_layers).to(device)\n",
"net.load_state_dict(torch.load('/content/novel_name_model.ckpt')) \n",
"\n",
"import ipywidgets as widgets\n",
"slider = widgets.IntSlider(description='nucleus prob.', min=10, max=100, value=80)\n",
"display(slider)\n",
"button = widgets.Button(description=\"Generate!\")\n",
"output = widgets.Output()\n",
"\n",
"def on_button_clicked(b):\n",
" with output:\n",
" s = sample(net, hidden_size, max_len, vocab_size, device, slider.value/100., num_layers)\n",
" s = train_dataset.decode_text(s.tolist()[1:-1])\n",
" print(s)\n",
"\n",
"button.on_click(on_button_clicked)\n",
"display(button, output)"
],
"execution_count": 15,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6a1742e2c96347feb0011bdf636679b4",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"IntSlider(value=80, description='nucleus prob.', min=10)"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "38dc6ef373c1489d9ba577a4cfdca079",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Button(description='Generate!', style=ButtonStyle())"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ea20d7a3db2b41338cdf788e8b8c8353",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"Output()"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "rwENPbvIb0k2",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "ca46f4dd-e175-44e1-9ade-2abd69b3c8ca"
},
"source": [
"for i in range(10):\n",
" s = sample(net, hidden_size, max_len, vocab_size, device, 0.95, num_layers)\n",
" s = train_dataset.decode_text(s.tolist())\n",
" print(s)"
],
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"text": [
"[BEGIN]Đô Thị Chi Truyền Thuyết[END]\n",
"[BEGIN]Chư Thiên Thế Giới Đế Quốc[END]\n",
"[BEGIN]Tà Tuyến Trọng Sinh[END]\n",
"[BEGIN]Họa Thượng[END]\n",
"[BEGIN]Chư Thiên Vạn Giới[END]\n",
"[BEGIN]Bất Hủ Thần Vương[END]\n",
"[BEGIN]Vô Hạn Saukeo[END]\n",
"[BEGIN]Tu Chân Nữ Nhi Diễn Động[END]\n",
"[BEGIN]Thiên Sủng Nữ Phụ Là Ngươi Giả[END]\n",
"[BEGIN]Tiên Viên Đại Thần[END]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "C_3gwjidhBV0"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment