Skip to content

Instantly share code, notes, and snippets.

@shitchell
Created September 5, 2021 07:49
Show Gist options
  • Save shitchell/1c595dd86c7e137468a0021ca98b992f to your computer and use it in GitHub Desktop.
Save shitchell/1c595dd86c7e137468a0021ca98b992f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"custom.ipynb","provenance":[],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.4"},"widgets":{"application/vnd.jupyter.widget-state+json":{"f3abf27eca92417ebee3be2b17a92886":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_view_name":"HBoxView","_dom_classes":[],"_model_name":"HBoxModel","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.5.0","box_style":"","layout":"IPY_MODEL_ff60fd902e264150851e2967aa621921","_model_module":"@jupyter-widgets/controls","children":["IPY_MODEL_208c7fc6000e448c99dcb8102c36e99b","IPY_MODEL_a8db00f55129417cba25f9a20f725ad1","IPY_MODEL_e690c237997a4165a53e3824a3b29c87"]}},"ff60fd902e264150851e2967aa621921":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"208c7fc6000e448c99dcb8102c36e99b":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_485f6855d9ee469784c6e53c7cc1931e","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":"Epoch: 0%","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_48118e589ae34fc5b7253721900c08e9"}},"a8db00f55129417cba25f9a20f725ad1":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_view_name":"ProgressView","style":"IPY_MODEL_213d6f0c38584284b872a93adb610f40","_dom_classes":[],"description":"","_model_name":"FloatProgressModel","bar_style":"danger","max":3,"_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":0,"_view_count":null,"_view_module_version":"1.5.0","orientation":"horizontal","min":0,"description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_53801379ebd341c195139e1aa34429ab"}},"e690c237997a4165a53e3824a3b29c87":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_7bee543070914122b5c19bf1627613cd","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":" 0/3 [02:37&lt;?, ?it/s]","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_cfe50464710945b891949ee40d982548"}},"485f6855d9ee469784c6e53c7cc1931e":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"48118e589ae34fc5b7253721900c08e9":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"213d6f0c38584284b872a93adb610f40":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"ProgressStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","bar_color":null,"_model_module":"@jupyter-widgets/controls"}},"53801379ebd341c195139e1aa34429ab":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"7bee543070914122b5c19bf1627613cd":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"cfe50464710945b891949ee40d982548":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"2922bf1c2a92405f844d9117996a0af4":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_view_name":"HBoxView","_dom_classes":[],"_model_name":"HBoxModel","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.5.0","box_style":"","layout":"IPY_MODEL_32162f89f7f94cf1ac13854c2ad57b7a","_model_module":"@jupyter-widgets/controls","children":["IPY_MODEL_1b46d9dff0d24d00b445a299e404793f","IPY_MODEL_8bcf4d5e77064772b894c596a615e6d4","IPY_MODEL_4571b318e9134972b7a85e3edfe99867"]}},"32162f89f7f94cf1ac13854c2ad57b7a":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"1b46d9dff0d24d00b445a299e404793f":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_6c2957d5708e457c8095c0dc4c5c600c","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":"Iteration: 12%","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_86fd76831f9b4221be39073eaa78b08c"}},"8bcf4d5e77064772b894c596a615e6d4":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_view_name":"ProgressView","style":"IPY_MODEL_7be09deda8fb4d059bd071b86b50a11a","_dom_classes":[],"description":"","_model_name":"FloatProgressModel","bar_style":"danger","max":1605,"_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":187,"_view_count":null,"_view_module_version":"1.5.0","orientation":"horizontal","min":0,"description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_fb7fbc2fffc34c548b6ed9593255cc08"}},"4571b318e9134972b7a85e3edfe99867":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_view_name":"HTMLView","style":"IPY_MODEL_98b7b71540094b3c8552be1f37205892","_dom_classes":[],"description":"","_model_name":"HTMLModel","placeholder":"​","_view_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","value":" 187/1605 [02:37&lt;21:20, 1.11it/s]","_view_count":null,"_view_module_version":"1.5.0","description_tooltip":null,"_model_module":"@jupyter-widgets/controls","layout":"IPY_MODEL_e5f7f980f2284389bec9216c22255e77"}},"6c2957d5708e457c8095c0dc4c5c600c":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"86fd76831f9b4221be39073eaa78b08c":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"7be09deda8fb4d059bd071b86b50a11a":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"ProgressStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","bar_color":null,"_model_module":"@jupyter-widgets/controls"}},"fb7fbc2fffc34c548b6ed9593255cc08":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}},"98b7b71540094b3c8552be1f37205892":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_view_name":"StyleView","_model_name":"DescriptionStyleModel","description_width":"","_view_module":"@jupyter-widgets/base","_model_module_version":"1.5.0","_view_count":null,"_view_module_version":"1.2.0","_model_module":"@jupyter-widgets/controls"}},"e5f7f980f2284389bec9216c22255e77":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_view_name":"LayoutView","grid_template_rows":null,"right":null,"justify_content":null,"_view_module":"@jupyter-widgets/base","overflow":null,"_model_module_version":"1.2.0","_view_count":null,"flex_flow":null,"width":null,"min_width":null,"border":null,"align_items":null,"bottom":null,"_model_module":"@jupyter-widgets/base","top":null,"grid_column":null,"overflow_y":null,"overflow_x":null,"grid_auto_flow":null,"grid_area":null,"grid_template_columns":null,"flex":null,"_model_name":"LayoutModel","justify_items":null,"grid_row":null,"max_height":null,"align_content":null,"visibility":null,"align_self":null,"height":null,"min_height":null,"padding":null,"grid_auto_rows":null,"grid_gap":null,"max_width":null,"order":null,"_view_module_version":"1.2.0","grid_template_areas":null,"object_position":null,"object_fit":null,"grid_auto_columns":null,"margin":null,"display":null,"left":null}}}}},"cells":[{"cell_type":"markdown","metadata":{"id":"VTze-VbeU1c0"},"source":["# Fine-tune a DialoGPT model\n","\n","Adapted from the notebook in [this Medium post](https://towardsdatascience.com/make-your-own-rick-sanchez-bot-with-transformers-and-dialogpt-fine-tuning-f85e6d1f4e30?gi=e4a72d1510f0)."]},{"cell_type":"markdown","metadata":{"id":"Y17kuzFNUSrZ"},"source":["## Setup"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"GBfltjGHT6KG","executionInfo":{"status":"ok","timestamp":1630823995900,"user_tz":0,"elapsed":600,"user":{"displayName":"Elite Gaming","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiozIJvY6jRtEga3f1mlOFUILdzQRH1YdPwNodh0ew=s64","userId":"14163393543101114291"}},"outputId":"9cd8fb26-d8e9-4184-f935-80bb81612456"},"source":["from google.colab import drive\n","\n","drive.mount('/content/drive/')"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount(\"/content/drive/\", force_remount=True).\n"]}]},{"cell_type":"code","metadata":{"id":"T8fgmjaqUErq"},"source":["!pip -q install transformers"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"EtCreyG8UG1s"},"source":["import os\n","os.chdir(\"/content/drive/My Drive\")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"dnv5kT-mLsB-"},"source":["# all the imports\n","\n","import glob\n","import logging\n","import os\n","import pickle\n","import random\n","import re\n","import shutil\n","from typing import Dict, List, Tuple\n","\n","import numpy as np\n","import pandas as pd\n","\n","from sklearn.model_selection import train_test_split\n","\n","from torch.nn.utils.rnn import pad_sequence\n","from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler\n","from torch.utils.data.distributed import DistributedSampler\n","from tqdm.notebook import tqdm, trange\n","\n","from pathlib import Path\n","import torch\n","torch.cuda.empty_cache()\n","from transformers import (\n"," MODEL_WITH_LM_HEAD_MAPPING,\n"," WEIGHTS_NAME,\n"," AdamW,\n"," AutoConfig,\n"," PreTrainedModel,\n"," PreTrainedTokenizer,\n"," get_linear_schedule_with_warmup,\n",")\n","\n","\n","try:\n"," from torch.utils.tensorboard import SummaryWriter\n","except ImportError:\n"," from tensorboardX import SummaryWriter"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"BmrbGB8aUmBm"},"source":["## Get Data from Kaggle"]},{"cell_type":"code","metadata":{"id":"RXdJTSVwWGHj"},"source":["data = pd.read_csv('mcu.csv')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"FaCBUa5j-LQI"},"source":["data.rename(columns={\"character\": \"name\"}, inplace=True)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"h6kGx-9eG7qA","executionInfo":{"status":"ok","timestamp":1630824002800,"user_tz":0,"elapsed":0,"user":{"displayName":"Elite Gaming","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiozIJvY6jRtEga3f1mlOFUILdzQRH1YdPwNodh0ew=s64","userId":"14163393543101114291"}},"outputId":"3f673a6c-56cd-42e6-adfa-e5c1f811f75f"},"source":["sum(data.name==\"TONY STARK\")"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["1788"]},"metadata":{},"execution_count":7}]},{"cell_type":"code","metadata":{"id":"PG8v6--qWUwj"},"source":["CHARACTER_NAME = 'TONY STARK'"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"GZUcEMd2WLDT"},"source":["contexted = []\n","\n","# context window of size 7\n","n = 7\n","\n","for i in data[data.name == CHARACTER_NAME].index:\n"," if i < n:\n"," continue\n"," row = []\n"," prev = i - 1 - n # we additionally substract 1, so row will contain current responce and 7 previous responces \n"," for j in range(i, prev, -1):\n"," row.append(data.line[j])\n"," contexted.append(row)\n","\n","columns = ['response', 'context'] \n","columns = columns + ['context/' + str(i) for i in range(n - 1)]\n","\n","df = pd.DataFrame.from_records(contexted, columns=columns)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":533},"id":"4T5OlNZHUxij","executionInfo":{"status":"ok","timestamp":1630824003900,"user_tz":0,"elapsed":200,"user":{"displayName":"Elite Gaming","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiozIJvY6jRtEga3f1mlOFUILdzQRH1YdPwNodh0ew=s64","userId":"14163393543101114291"}},"outputId":"dd6c4a5c-1c34-4498-e2ab-375930721c53"},"source":["df.sample(6)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["<div>\n","<style scoped>\n"," .dataframe tbody tr th:only-of-type {\n"," vertical-align: middle;\n"," }\n","\n"," .dataframe tbody tr th {\n"," vertical-align: top;\n"," }\n","\n"," .dataframe thead th {\n"," text-align: right;\n"," }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n"," <thead>\n"," <tr style=\"text-align: right;\">\n"," <th></th>\n"," <th>response</th>\n"," <th>context</th>\n"," <th>context/0</th>\n"," <th>context/1</th>\n"," <th>context/2</th>\n"," <th>context/3</th>\n"," <th>context/4</th>\n"," <th>context/5</th>\n"," </tr>\n"," </thead>\n"," <tbody>\n"," <tr>\n"," <th>522</th>\n"," <td>Tell me.</td>\n"," <td>Who?</td>\n"," <td>Where’s Vanko?</td>\n"," <td>What?</td>\n"," <td>What is he?</td>\n"," <td>Yeah.</td>\n"," <td>Vanko’s alive?</td>\n"," <td>Listen, I think he’s working with Vanko.</td>\n"," </tr>\n"," <tr>\n"," <th>559</th>\n"," <td>I got this. Rhodes. I got an idea. You want t...</td>\n"," <td>Yeah.</td>\n"," <td>With the what? Hammer tech?</td>\n"," <td>This ain’t gonna be good. I got something spe...</td>\n"," <td>Good to be back.</td>\n"," <td>Head up. You got one more drone incoming. This...</td>\n"," <td>Okay.</td>\n"," <td>No, I’m gonna stay until to park is clear.</td>\n"," </tr>\n"," <tr>\n"," <th>24</th>\n"," <td>Maybe. Tell me your plans.</td>\n"," <td>The MIT commencement. Yes or no?</td>\n"," <td>You’re rushing me. What, you have plans tonig...</td>\n"," <td>Focus. I need to leave on time today.</td>\n"," <td>Five? I’ll need a bit longer than that --</td>\n"," <td>Should I tell Mr. Stark you were satisfied wit...</td>\n"," <td>And a coffee, hon. Black. One Splenda.</td>\n"," <td>Cab’s waiting outside.</td>\n"," </tr>\n"," <tr>\n"," <th>430</th>\n"," <td>You could drink that water.</td>\n"," <td>It’s not sexy.</td>\n"," <td>I know. It has a filtration system.</td>\n"," <td>You just peed in the suit.</td>\n"," <td>Come on, you know you want to.</td>\n"," <td>You’re not going to be happy about this.</td>\n"," <td>Give me another smooch</td>\n"," <td>It’s time to go to bed. It’s time.</td>\n"," </tr>\n"," <tr>\n"," <th>914</th>\n"," <td>Including Chad Davis?</td>\n"," <td>Yeah.</td>\n"," <td>Six people died, right?</td>\n"," <td>I guess this guy named Chad Davis, used to liv...</td>\n"," <td>I don't know, later. Hey kid, give me a little...</td>\n"," <td>What about The Avengers, can you talk about them?</td>\n"," <td>Maybe never, relax about it.</td>\n"," <td>She's six! Anyway, it's limited edition. When ...</td>\n"," </tr>\n"," <tr>\n"," <th>1323</th>\n"," <td>Me too. Where's yours?</td>\n"," <td>I have an idea.</td>\n"," <td>No.</td>\n"," <td>No. You really think he'd be on our side?</td>\n"," <td>Oh, yeah. It'd be great if we had a Hulk right...</td>\n"," <td>We're seriously understaffed.</td>\n"," <td>Always. 36 hours, jeez.</td>\n"," <td>You alright?</td>\n"," </tr>\n"," </tbody>\n","</table>\n","</div>"],"text/plain":[" response ... context/5\n","522 Tell me. ... Listen, I think he’s working with Vanko.\n","559 I got this. Rhodes. I got an idea. You want t... ... No, I’m gonna stay until to park is clear.\n","24 Maybe. Tell me your plans. ... Cab’s waiting outside.\n","430 You could drink that water. ... It’s time to go to bed. It’s time.\n","914 Including Chad Davis? ... She's six! Anyway, it's limited edition. When ...\n","1323 Me too. Where's yours? ... You alright?\n","\n","[6 rows x 8 columns]"]},"metadata":{},"execution_count":10}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":467},"id":"NGy0MxMQVIAP","executionInfo":{"status":"ok","timestamp":1630824003900,"user_tz":0,"elapsed":200,"user":{"displayName":"Elite Gaming","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiozIJvY6jRtEga3f1mlOFUILdzQRH1YdPwNodh0ew=s64","userId":"14163393543101114291"}},"outputId":"2cd4a6f8-e15a-431e-b12b-e846f4b78977"},"source":["trn_df, val_df = train_test_split(df, test_size=0.1)\n","trn_df.head()"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["<div>\n","<style scoped>\n"," .dataframe tbody tr th:only-of-type {\n"," vertical-align: middle;\n"," }\n","\n"," .dataframe tbody tr th {\n"," vertical-align: top;\n"," }\n","\n"," .dataframe thead th {\n"," text-align: right;\n"," }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n"," <thead>\n"," <tr style=\"text-align: right;\">\n"," <th></th>\n"," <th>response</th>\n"," <th>context</th>\n"," <th>context/0</th>\n"," <th>context/1</th>\n"," <th>context/2</th>\n"," <th>context/3</th>\n"," <th>context/4</th>\n"," <th>context/5</th>\n"," </tr>\n"," </thead>\n"," <tbody>\n"," <tr>\n"," <th>547</th>\n"," <td>I’m the one who put you in this position. Forg...</td>\n"," <td>No. I should have trusted you more.</td>\n"," <td>Don’t be.</td>\n"," <td>Yeah, thanks. Tony, look, I’m sorry, okay?</td>\n"," <td>You okay?</td>\n"," <td>Oh, man. You can have your suit back.</td>\n"," <td>Rhodes? Snap out of it buddy. I need you. They...</td>\n"," <td>Oh please.</td>\n"," </tr>\n"," <tr>\n"," <th>1048</th>\n"," <td>The President. Now.</td>\n"," <td>Let us out!</td>\n"," <td>Is anyone there?</td>\n"," <td>Was that Rhodes?</td>\n"," <td>Image coming through now, sir.</td>\n"," <td>Get me eyes on it now.</td>\n"," <td>Sir, Air Force One has been compromised. Inter...</td>\n"," <td>Whoa! Cool your boots, sir. That's not how the...</td>\n"," </tr>\n"," <tr>\n"," <th>85</th>\n"," <td>You’ll see. C’mon --</td>\n"," <td>See this. Huh. Huh. Tony, thought we were me...</td>\n"," <td>No -- to the office. I’ve been in captivity...</td>\n"," <td>We’re due at the hospital.</td>\n"," <td>Wouldn’t dream of it, Sir. Where to, Mr. Stark?</td>\n"," <td>You do something new with your hair?</td>\n"," <td>Good to see you again, Sir.</td>\n"," <td>Tears of joy. I hate job hunting.</td>\n"," </tr>\n"," <tr>\n"," <th>1059</th>\n"," <td>Yeah, but I missed the president.</td>\n"," <td>Oh, thank God.</td>\n"," <td>I think they all made it.</td>\n"," <td>Give me some good news, man.</td>\n"," <td>Nice work, guys! Excellent. Good team effort a...</td>\n"," <td>We made it!</td>\n"," <td>He's a chunky monkey, let's get him. Hello.</td>\n"," <td>1, 000 feet. 400 feet. 200 feet, sir.</td>\n"," </tr>\n"," <tr>\n"," <th>1173</th>\n"," <td>Come on. Use your words, buddy.</td>\n"," <td>Woah, woah, woah! It's going around.</td>\n"," <td>No, Ultron could've assimilated Jarvis. This i...</td>\n"," <td>JARVIS was the first line of defense. He would...</td>\n"," <td>This is insane.</td>\n"," <td>Yes there was.</td>\n"," <td>But there wasn't anyone else in the building.</td>\n"," <td>He also said he killed somebody.</td>\n"," </tr>\n"," </tbody>\n","</table>\n","</div>"],"text/plain":[" response ... context/5\n","547 I’m the one who put you in this position. Forg... ... Oh please.\n","1048 The President. Now. ... Whoa! Cool your boots, sir. That's not how the...\n","85 You’ll see. C’mon -- ... Tears of joy. I hate job hunting.\n","1059 Yeah, but I missed the president. ... 1, 000 feet. 400 feet. 200 feet, sir.\n","1173 Come on. Use your words, buddy. ... He also said he killed somebody.\n","\n","[5 rows x 8 columns]"]},"metadata":{},"execution_count":11}]},{"cell_type":"code","metadata":{"id":"aEeJQlAKWtiJ"},"source":["# create dataset suitable for our model\n","def construct_conv(row, tokenizer, eos = True):\n"," flatten = lambda l: [item for sublist in l for item in sublist]\n"," conv = list(reversed([tokenizer.encode(x) + [tokenizer.eos_token_id] for x in row]))\n"," conv = flatten(conv)\n"," return conv\n","\n","class ConversationDataset(Dataset):\n"," def __init__(self, tokenizer: PreTrainedTokenizer, args, df, block_size=512):\n","\n"," block_size = block_size - (tokenizer.model_max_length - tokenizer.max_len_single_sentence)\n","\n"," directory = args.cache_dir\n"," cached_features_file = os.path.join(\n"," directory, args.model_type + \"_cached_lm_\" + str(block_size)\n"," )\n","\n"," if os.path.exists(cached_features_file) and not args.overwrite_cache:\n"," logger.info(\"Loading features from cached file %s\", cached_features_file)\n"," with open(cached_features_file, \"rb\") as handle:\n"," self.examples = pickle.load(handle)\n"," else:\n"," logger.info(\"Creating features from dataset file at %s\", directory)\n","\n"," self.examples = []\n"," for _, row in df.iterrows():\n"," conv = construct_conv(row, tokenizer)\n"," self.examples.append(conv)\n","\n"," logger.info(\"Saving features into cached file %s\", cached_features_file)\n"," with open(cached_features_file, \"wb\") as handle:\n"," pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)\n","\n"," def __len__(self):\n"," return len(self.examples)\n","\n"," def __getitem__(self, item):\n"," return torch.tensor(self.examples[item], dtype=torch.long)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"-3iHwoKlWyrs"},"source":["# Cacheing and storing of data/checkpoints\n","\n","def load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False):\n"," return ConversationDataset(tokenizer, args, df_val if evaluate else df_trn)\n","\n","\n","def set_seed(args):\n"," random.seed(args.seed)\n"," np.random.seed(args.seed)\n"," torch.manual_seed(args.seed)\n"," if args.n_gpu > 0:\n"," torch.cuda.manual_seed_all(args.seed)\n","\n","\n","def _sorted_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> List[str]:\n"," ordering_and_checkpoint_path = []\n","\n"," glob_checkpoints = glob.glob(os.path.join(args.output_dir, \"{}-*\".format(checkpoint_prefix)))\n","\n"," for path in glob_checkpoints:\n"," if use_mtime:\n"," ordering_and_checkpoint_path.append((os.path.getmtime(path), path))\n"," else:\n"," regex_match = re.match(\".*{}-([0-9]+)\".format(checkpoint_prefix), path)\n"," if regex_match and regex_match.groups():\n"," ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))\n","\n"," checkpoints_sorted = sorted(ordering_and_checkpoint_path)\n"," checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]\n"," return checkpoints_sorted\n","\n","\n","def _rotate_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> None:\n"," if not args.save_total_limit:\n"," return\n"," if args.save_total_limit <= 0:\n"," return\n","\n"," # Check if we should delete older checkpoint(s)\n"," checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)\n"," if len(checkpoints_sorted) <= args.save_total_limit:\n"," return\n","\n"," number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)\n"," checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]\n"," for checkpoint in checkpoints_to_be_deleted:\n"," logger.info(\"Deleting older checkpoint [{}] due to args.save_total_limit\".format(checkpoint))\n"," shutil.rmtree(checkpoint)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EEDdTJTqUwZJ"},"source":["## Build Model"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"r2cE0fY5UHpz","executionInfo":{"status":"ok","timestamp":1630824007100,"user_tz":0,"elapsed":3400,"user":{"displayName":"Elite Gaming","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiozIJvY6jRtEga3f1mlOFUILdzQRH1YdPwNodh0ew=s64","userId":"14163393543101114291"}},"outputId":"6bd07153-3f03-469f-beda-c18358165f9f"},"source":["from transformers import AutoModelWithLMHead, AutoModelForCausalLM, AutoTokenizer\n","import torch\n","\n","tokenizer = AutoTokenizer.from_pretrained(\"microsoft/DialoGPT-small\")\n","model = AutoModelWithLMHead.from_pretrained(\"microsoft/DialoGPT-small\")"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.7/dist-packages/transformers/models/auto/modeling_auto.py:592: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n"," FutureWarning,\n"]}]},{"cell_type":"code","metadata":{"id":"ra2vsRp-UMXo"},"source":["\"\"\"\n","Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).\n","GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned\n","using a masked language modeling (MLM) loss.\n","\"\"\"\n","\n","# Configs\n","logger = logging.getLogger(__name__)\n","\n","MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())\n","MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"2OnASqJjUNJa"},"source":["# Args to allow for easy convertion of python script to notebook\n","class Args():\n"," def __init__(self):\n"," self.output_dir = 'output-medium'\n"," self.model_type = 'gpt2'\n"," self.model_name_or_path = 'microsoft/DialoGPT-medium'\n"," self.config_name = 'microsoft/DialoGPT-medium'\n"," self.tokenizer_name = 'microsoft/DialoGPT-medium'\n"," self.cache_dir = 'cached'\n"," self.block_size = 512\n"," self.do_train = True\n"," self.do_eval = True\n"," self.evaluate_during_training = False\n"," self.per_gpu_train_batch_size = 4\n"," self.per_gpu_eval_batch_size = 4\n"," self.gradient_accumulation_steps = 1\n"," self.learning_rate = 5e-5\n"," self.weight_decay = 0.0\n"," self.adam_epsilon = 1e-8\n"," self.max_grad_norm = 1.0\n"," self.num_train_epochs = 8\n"," self.max_steps = -1\n"," self.warmup_steps = 0\n"," self.logging_steps = 1000\n"," self.save_steps = 3500\n"," self.save_total_limit = None\n"," self.eval_all_checkpoints = False\n"," self.no_cuda = False\n"," self.overwrite_output_dir = True\n"," self.overwrite_cache = True\n"," self.should_continue = False\n"," self.seed = 42\n"," self.local_rank = -1\n"," self.fp16 = False\n"," self.fp16_opt_level = 'O1'\n","\n","args = Args()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"9Q1dTFXxW9NE"},"source":["## Train and Evaluate"]},{"cell_type":"code","metadata":{"id":"PaarIDZrW81h"},"source":["def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:\n"," \"\"\" Train the model \"\"\"\n"," if args.local_rank in [-1, 0]:\n"," tb_writer = SummaryWriter()\n","\n"," args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)\n","\n"," def collate(examples: List[torch.Tensor]):\n"," if tokenizer._pad_token is None:\n"," return pad_sequence(examples, batch_first=True)\n"," return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n","\n"," train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)\n"," train_dataloader = DataLoader(\n"," train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate, drop_last = True\n"," )\n","\n"," if args.max_steps > 0:\n"," t_total = args.max_steps\n"," args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1\n"," else:\n"," t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs\n","\n"," model = model.module if hasattr(model, \"module\") else model # Take care of distributed/parallel training\n"," model.resize_token_embeddings(len(tokenizer))\n"," # add_special_tokens_(model, tokenizer)\n","\n","\n"," # Prepare optimizer and schedule (linear warmup and decay)\n"," no_decay = [\"bias\", \"LayerNorm.weight\"]\n"," optimizer_grouped_parameters = [\n"," {\n"," \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n"," \"weight_decay\": args.weight_decay,\n"," },\n"," {\"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], \"weight_decay\": 0.0},\n"," ]\n"," optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)\n"," scheduler = get_linear_schedule_with_warmup(\n"," optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total\n"," )\n","\n"," # Check if saved optimizer or scheduler states exist\n"," if (\n"," args.model_name_or_path\n"," and os.path.isfile(os.path.join(args.model_name_or_path, \"optimizer.pt\"))\n"," and os.path.isfile(os.path.join(args.model_name_or_path, \"scheduler.pt\"))\n"," ):\n"," # Load in optimizer and scheduler states\n"," optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"optimizer.pt\")))\n"," scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"scheduler.pt\")))\n","\n"," if args.fp16:\n"," try:\n"," from apex import amp\n"," except ImportError:\n"," raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.\")\n"," model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)\n","\n"," # multi-gpu training (should be after apex fp16 initialization)\n"," if args.n_gpu > 1:\n"," model = torch.nn.DataParallel(model)\n","\n"," # Distributed training (should be after apex fp16 initialization)\n"," if args.local_rank != -1:\n"," model = torch.nn.parallel.DistributedDataParallel(\n"," model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True\n"," )\n","\n"," # Train!\n"," logger.info(\"***** Running training *****\")\n"," logger.info(\" Num examples = %d\", len(train_dataset))\n"," logger.info(\" Num Epochs = %d\", args.num_train_epochs)\n"," logger.info(\" Instantaneous batch size per GPU = %d\", args.per_gpu_train_batch_size)\n"," logger.info(\n"," \" Total train batch size (w. parallel, distributed & accumulation) = %d\",\n"," args.train_batch_size\n"," * args.gradient_accumulation_steps\n"," * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),\n"," )\n"," logger.info(\" Gradient Accumulation steps = %d\", args.gradient_accumulation_steps)\n"," logger.info(\" Total optimization steps = %d\", t_total)\n","\n"," global_step = 0\n"," epochs_trained = 0\n"," steps_trained_in_current_epoch = 0\n"," # Check if continuing training from a checkpoint\n"," if args.model_name_or_path and os.path.exists(args.model_name_or_path):\n"," try:\n"," # set global_step to gobal_step of last saved checkpoint from model path\n"," checkpoint_suffix = args.model_name_or_path.split(\"-\")[-1].split(\"/\")[0]\n"," global_step = int(checkpoint_suffix)\n"," epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)\n"," steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)\n","\n"," logger.info(\" Continuing training from checkpoint, will skip to saved global_step\")\n"," logger.info(\" Continuing training from epoch %d\", epochs_trained)\n"," logger.info(\" Continuing training from global step %d\", global_step)\n"," logger.info(\" Will skip the first %d steps in the first epoch\", steps_trained_in_current_epoch)\n"," except ValueError:\n"," logger.info(\" Starting fine-tuning.\")\n","\n"," tr_loss, logging_loss = 0.0, 0.0\n","\n"," model.zero_grad()\n"," train_iterator = trange(\n"," epochs_trained, int(args.num_train_epochs), desc=\"Epoch\", disable=args.local_rank not in [-1, 0]\n"," )\n"," set_seed(args) # Added here for reproducibility\n"," for _ in train_iterator:\n"," epoch_iterator = tqdm(train_dataloader, desc=\"Iteration\", disable=args.local_rank not in [-1, 0])\n"," for step, batch in enumerate(epoch_iterator):\n","\n"," # Skip past any already trained steps if resuming training\n"," if steps_trained_in_current_epoch > 0:\n"," steps_trained_in_current_epoch -= 1\n"," continue\n","\n"," inputs, labels = (batch, batch)\n"," if inputs.shape[1] > 1024: continue\n"," inputs = inputs.to(args.device)\n"," labels = labels.to(args.device)\n"," model.train()\n"," outputs = model(inputs, labels=labels)\n"," loss = outputs[0] # model outputs are always tuple in transformers (see doc)\n","\n"," if args.n_gpu > 1:\n"," loss = loss.mean() # mean() to average on multi-gpu parallel training\n"," if args.gradient_accumulation_steps > 1:\n"," loss = loss / args.gradient_accumulation_steps\n","\n"," if args.fp16:\n"," with amp.scale_loss(loss, optimizer) as scaled_loss:\n"," scaled_loss.backward()\n"," torch.cuda.empty_cache()\n"," else:\n"," loss.backward()\n","\n"," tr_loss += loss.item()\n"," if (step + 1) % args.gradient_accumulation_steps == 0:\n"," if args.fp16:\n"," torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)\n"," torch.cuda.empty_cache()\n"," else:\n"," torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)\n"," torch.cuda.empty_cache()\n"," model.zero_grad()\n"," global_step += 1\n"," scheduler.step()\n"," if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:\n"," # Log metrics\n"," if (\n"," args.local_rank == -1 and args.evaluate_during_training\n"," ): # Only evaluate when single GPU otherwise metrics may not average well\n"," results = evaluate(args, model, tokenizer)\n"," for key, value in results.items():\n"," tb_writer.add_scalar(\"eval_{}\".format(key), value, global_step)\n"," tb_writer.add_scalar(\"lr\", scheduler.get_lr()[0], global_step)\n"," tb_writer.add_scalar(\"loss\", (tr_loss - logging_loss) / args.logging_steps, global_step)\n"," logging_loss = tr_loss\n","\n"," if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:\n"," checkpoint_prefix = \"checkpoint\"\n"," # Save model checkpoint\n"," output_dir = os.path.join(args.output_dir, \"{}-{}\".format(checkpoint_prefix, global_step))\n"," os.makedirs(output_dir, exist_ok=True)\n"," model_to_save = (\n"," model.module if hasattr(model, \"module\") else model\n"," ) # Take care of distributed/parallel training\n"," model_to_save.save_pretrained(output_dir)\n"," tokenizer.save_pretrained(output_dir)\n","\n"," torch.save(args, os.path.join(output_dir, \"training_args.bin\"))\n"," logger.info(\"Saving model checkpoint to %s\", output_dir)\n","\n"," _rotate_checkpoints(args, checkpoint_prefix)\n","\n"," torch.save(optimizer.state_dict(), os.path.join(output_dir, \"optimizer.pt\"))\n"," torch.save(scheduler.state_dict(), os.path.join(output_dir, \"scheduler.pt\"))\n"," logger.info(\"Saving optimizer and scheduler states to %s\", output_dir)\n","\n"," if args.max_steps > 0 and global_step > args.max_steps:\n"," epoch_iterator.close()\n"," break\n"," if args.max_steps > 0 and global_step > args.max_steps:\n"," train_iterator.close()\n"," break\n","\n"," if args.local_rank in [-1, 0]:\n"," tb_writer.close()\n","\n"," return global_step, tr_loss / global_step\n","\n","# Evaluation of some model\n","\n","def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, df_trn, df_val, prefix=\"\") -> Dict:\n"," # Loop to handle MNLI double evaluation (matched, mis-matched)\n"," eval_output_dir = args.output_dir\n","\n"," eval_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=True)\n"," os.makedirs(eval_output_dir, exist_ok=True)\n"," args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)\n"," # Note that DistributedSampler samples randomly\n","\n"," def collate(examples: List[torch.Tensor]):\n"," if tokenizer._pad_token is None:\n"," return pad_sequence(examples, batch_first=True)\n"," return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n","\n"," eval_sampler = SequentialSampler(eval_dataset)\n"," eval_dataloader = DataLoader(\n"," eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate, drop_last = True\n"," )\n","\n"," # multi-gpu evaluate\n"," if args.n_gpu > 1:\n"," model = torch.nn.DataParallel(model)\n","\n"," # Eval!\n"," logger.info(\"***** Running evaluation {} *****\".format(prefix))\n"," logger.info(\" Num examples = %d\", len(eval_dataset))\n"," logger.info(\" Batch size = %d\", args.eval_batch_size)\n"," eval_loss = 0.0\n"," nb_eval_steps = 0\n"," model.eval()\n","\n"," for batch in tqdm(eval_dataloader, desc=\"Evaluating\"):\n"," inputs, labels = (batch, batch)\n"," inputs = inputs.to(args.device)\n"," labels = labels.to(args.device)\n"," torch.cuda.empty_cache()\n"," with torch.no_grad():\n"," outputs = model(inputs, labels=labels)\n"," lm_loss = outputs[0]\n"," torch.cuda.empty_cache()\n"," eval_loss += lm_loss.mean().item()\n"," nb_eval_steps += 1\n","\n"," eval_loss = eval_loss / nb_eval_steps\n"," perplexity = torch.exp(torch.tensor(eval_loss))\n","\n"," result = {\"perplexity\": perplexity}\n","\n"," output_eval_file = os.path.join(eval_output_dir, prefix, \"eval_results.txt\")\n"," with open(output_eval_file, \"w\") as writer:\n"," logger.info(\"***** Eval results {} *****\".format(prefix))\n"," for key in sorted(result.keys()):\n"," logger.info(\" %s = %s\", key, str(result[key]))\n"," writer.write(\"%s = %s\\n\" % (key, str(result[key])))\n","\n"," return result"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"SCnGAJWbXD9C"},"source":["# Main runner\n","\n","def main(df_trn, df_val):\n"," args = Args()\n"," torch.cuda.empty_cache()\n"," if args.should_continue:\n"," sorted_checkpoints = _sorted_checkpoints(args)\n"," if len(sorted_checkpoints) == 0:\n"," raise ValueError(\"Used --should_continue but no checkpoint was found in --output_dir.\")\n"," else:\n"," args.model_name_or_path = sorted_checkpoints[-1]\n","\n"," if (\n"," os.path.exists(args.output_dir)\n"," and os.listdir(args.output_dir)\n"," and args.do_train\n"," and not args.overwrite_output_dir\n"," and not args.should_continue\n"," ):\n"," raise ValueError(\n"," \"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.\".format(\n"," args.output_dir\n"," )\n"," )\n","\n"," # Setup CUDA, GPU & distributed training\n"," device = torch.device(\"cuda\")\n"," args.n_gpu = torch.cuda.device_count()\n"," args.device = device\n"," torch.cuda.empty_cache()\n"," # Setup logging\n"," logging.basicConfig(\n"," format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n"," datefmt=\"%m/%d/%Y %H:%M:%S\",\n"," level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,\n"," )\n"," logger.warning(\n"," \"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s\",\n"," args.local_rank,\n"," device,\n"," args.n_gpu,\n"," bool(args.local_rank != -1),\n"," args.fp16,\n"," )\n","\n"," # Set seed\n"," set_seed(args)\n"," torch.cuda.empty_cache()\n"," config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)\n"," tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)\n"," model = AutoModelWithLMHead.from_pretrained(\n"," args.model_name_or_path,\n"," from_tf=False,\n"," config=config,\n"," cache_dir=args.cache_dir,\n"," )\n"," torch.cuda.empty_cache()\n"," model.to(args.device)\n"," torch.cuda.empty_cache()\n"," logger.info(\"Training/evaluation parameters %s\", args)\n","\n"," # Training\n"," if args.do_train:\n"," train_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False)\n"," \n"," global_step, tr_loss = train(args, train_dataset, model, tokenizer)\n"," logger.info(\" global_step = %s, average loss = %s\", global_step, tr_loss)\n"," torch.cuda.empty_cache()\n"," # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()\n"," if args.do_train:\n"," # Create output directory if needed\n"," os.makedirs(args.output_dir, exist_ok=True)\n","\n"," logger.info(\"Saving model checkpoint to %s\", args.output_dir)\n"," # Save a trained model, configuration and tokenizer using `save_pretrained()`.\n"," # They can then be reloaded using `from_pretrained()`\n"," model_to_save = (\n"," model.module if hasattr(model, \"module\") else model\n"," ) # Take care of distributed/parallel training\n"," model_to_save.save_pretrained(args.output_dir)\n"," tokenizer.save_pretrained(args.output_dir)\n"," torch.cuda.empty_cache()\n"," # Good practice: save your training arguments together with the trained model\n"," torch.save(args, os.path.join(args.output_dir, \"training_args.bin\"))\n"," \n"," # Load a trained model and vocabulary that you have fine-tuned\n"," model = AutoModelWithLMHead.from_pretrained(args.output_dir)\n"," tokenizer = AutoTokenizer.from_pretrained(args.output_dir)\n"," model.to(args.device)\n","\n"," # Evaluation\n"," results = {}\n"," if args.do_eval and args.local_rank in [-1, 0]:\n"," checkpoints = [args.output_dir]\n"," if args.eval_all_checkpoints:\n"," checkpoints = list(\n"," os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + \"/**/\" + WEIGHTS_NAME, recursive=True))\n"," )\n"," logging.getLogger(\"transformers.modeling_utils\").setLevel(logging.WARN) # Reduce logging\n"," logger.info(\"Evaluate the following checkpoints: %s\", checkpoints)\n"," for checkpoint in checkpoints:\n"," global_step = checkpoint.split(\"-\")[-1] if len(checkpoints) > 1 else \"\"\n"," prefix = checkpoint.split(\"/\")[-1] if checkpoint.find(\"checkpoint\") != -1 else \"\"\n","\n"," model = AutoModelWithLMHead.from_pretrained(checkpoint)\n"," model.to(args.device)\n"," result = evaluate(args, model, tokenizer, df_trn, df_val, prefix=prefix)\n"," result = dict((k + \"_{}\".format(global_step), v) for k, v in result.items())\n"," results.update(result)\n","\n"," return results"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7NWvkdR-XHeB"},"source":["## Run the Main Function"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":710,"referenced_widgets":["f3abf27eca92417ebee3be2b17a92886","ff60fd902e264150851e2967aa621921","208c7fc6000e448c99dcb8102c36e99b","a8db00f55129417cba25f9a20f725ad1","e690c237997a4165a53e3824a3b29c87","485f6855d9ee469784c6e53c7cc1931e","48118e589ae34fc5b7253721900c08e9","213d6f0c38584284b872a93adb610f40","53801379ebd341c195139e1aa34429ab","7bee543070914122b5c19bf1627613cd","cfe50464710945b891949ee40d982548","2922bf1c2a92405f844d9117996a0af4","32162f89f7f94cf1ac13854c2ad57b7a","1b46d9dff0d24d00b445a299e404793f","8bcf4d5e77064772b894c596a615e6d4","4571b318e9134972b7a85e3edfe99867","6c2957d5708e457c8095c0dc4c5c600c","86fd76831f9b4221be39073eaa78b08c","7be09deda8fb4d059bd071b86b50a11a","fb7fbc2fffc34c548b6ed9593255cc08","98b7b71540094b3c8552be1f37205892","e5f7f980f2284389bec9216c22255e77"]},"id":"e61zo2JtXGNX","executionInfo":{"status":"error","timestamp":1630824185200,"user_tz":0,"elapsed":177400,"user":{"displayName":"Elite Gaming","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiozIJvY6jRtEga3f1mlOFUILdzQRH1YdPwNodh0ew=s64","userId":"14163393543101114291"}},"outputId":"2a096ccd-36cd-4bef-cc4b-7ee595b0b6c4"},"source":["x torch.cuda.empty_cache()\n","main(trn_df, val_df)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["09/05/2021 06:40:07 - WARNING - __main__ - Process rank: -1, device: cuda, n_gpu: 1, distributed training: False, 16-bits training: False\n","/usr/local/lib/python3.7/dist-packages/transformers/models/auto/modeling_auto.py:592: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n"," FutureWarning,\n","09/05/2021 06:40:23 - INFO - __main__ - Training/evaluation parameters <__main__.Args object at 0x7f5fb7a2a290>\n","09/05/2021 06:40:23 - INFO - __main__ - Creating features from dataset file at cached\n","09/05/2021 06:40:26 - INFO - __main__ - Saving features into cached file cached/gpt2_cached_lm_512\n","09/05/2021 06:40:26 - INFO - __main__ - ***** Running training *****\n","09/05/2021 06:40:26 - INFO - __main__ - Num examples = 1605\n","09/05/2021 06:40:26 - INFO - __main__ - Num Epochs = 3\n","09/05/2021 06:40:26 - INFO - __main__ - Instantaneous batch size per GPU = 1\n","09/05/2021 06:40:26 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 1\n","09/05/2021 06:40:26 - INFO - __main__ - Gradient Accumulation steps = 1\n","09/05/2021 06:40:26 - INFO - __main__ - Total optimization steps = 4815\n"]},{"output_type":"display_data","data":{"application/vnd.jupyter.widget-view+json":{"model_id":"f3abf27eca92417ebee3be2b17a92886","version_minor":0,"version_major":2},"text/plain":["Epoch: 0%| | 0/3 [00:00<?, ?it/s]"]},"metadata":{}},{"output_type":"display_data","data":{"application/vnd.jupyter.widget-view+json":{"model_id":"2922bf1c2a92405f844d9117996a0af4","version_minor":0,"version_major":2},"text/plain":["Iteration: 0%| | 0/1605 [00:00<?, ?it/s]"]},"metadata":{}},{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.7/dist-packages/torch/optim/lr_scheduler.py:134: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n"," \"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\", UserWarning)\n"]},{"output_type":"error","ename":"RuntimeError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-19-0337229e2a48>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mempty_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mmain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrn_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;32m<ipython-input-18-4e273e7155ec>\u001b[0m in \u001b[0;36mmain\u001b[0;34m(df_trn, df_val)\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0mtrain_dataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_and_cache_examples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdf_trn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdf_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 66\u001b[0;31m \u001b[0mglobal_step\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 67\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\" global_step = %s, average loss = %s\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mglobal_step\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mempty_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m<ipython-input-17-c6170f9e4b03>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(args, train_dataset, model, tokenizer)\u001b[0m\n\u001b[1;32m 135\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mempty_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 137\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 138\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0mtr_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 254\u001b[0m inputs=inputs)\n\u001b[0;32m--> 255\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 256\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 147\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 148\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 149\u001b[0;31m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 150\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 246.00 MiB (GPU 0; 11.17 GiB total capacity; 10.20 GiB already allocated; 229.81 MiB free; 10.49 GiB reserved in total by PyTorch)"]}]},{"cell_type":"markdown","metadata":{"id":"YRpQ_n2zXQj-"},"source":["## Load the Trained Model"]},{"cell_type":"code","metadata":{"id":"HGw3qgfaXQHX"},"source":["tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-large')\n","model = AutoModelWithLMHead.from_pretrained('output-large')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"RGjTen7cfT_Q"},"source":[""],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"lAWsiAvNXbxd"},"source":["# Let's chat for 4 lines\n","for step in range(4):\n"," # encode the new user input, add the eos_token and return a tensor in Pytorch\n"," new_user_input_ids = tokenizer.encode(input(\">> User:\") + tokenizer.eos_token, return_tensors='pt')\n"," # print(new_user_input_ids)\n","\n"," # append the new user input tokens to the chat history\n"," bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids\n","\n"," # generated a response while limiting the total chat history to 1000 tokens, \n"," chat_history_ids = model.generate(\n"," bot_input_ids, max_length=200,\n"," pad_token_id=tokenizer.eos_token_id, \n"," no_repeat_ngram_size=3, \n"," do_sample=True, \n"," top_k=100, \n"," top_p=0.7,\n"," temperature=0.8\n"," )\n"," \n"," # pretty print last ouput tokens from bot\n"," print(\"JoshuaBot: {}\".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"D_XfXTCrZKmO"},"source":["## All Done!"]},{"cell_type":"code","metadata":{"id":"_tIwK7G8ZLrd"},"source":[""],"execution_count":null,"outputs":[]}]}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment