Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save prasadwrites/aecad8430c1b71695d81c408a667b976 to your computer and use it in GitHub Desktop.
Save prasadwrites/aecad8430c1b71695d81c408a667b976 to your computer and use it in GitHub Desktop.
Lab_3_fine_tune_model_to_detoxify_summaries.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/prasadwrites/aecad8430c1b71695d81c408a667b976/lab_3_fine_tune_model_to_detoxify_summaries.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"id": "d7f10b19-c061-42b2-8b42-e3cbafa3b1da",
"metadata": {
"id": "d7f10b19-c061-42b2-8b42-e3cbafa3b1da"
},
"source": [
"# Fine-Tune FLAN-T5 with Reinforcement Learning (PPO) and PEFT to Generate Less-Toxic Summaries"
]
},
{
"cell_type": "markdown",
"id": "36ef668a-9c51-489b-be47-a07a09ef2289",
"metadata": {
"id": "36ef668a-9c51-489b-be47-a07a09ef2289"
},
"source": [
"In this notebook, you will fine-tune a FLAN-T5 model to generate less toxic content with Meta AI's hate speech reward model. The reward model is a binary classifier that predicts either \"not hate\" or \"hate\" for the given text. You will use Proximal Policy Optimization (PPO) to fine-tune and reduce the model's toxicity."
]
},
{
"cell_type": "markdown",
"id": "ed5003e2-a642-416b-bd3f-93fb339c3a7d",
"metadata": {
"tags": [],
"id": "ed5003e2-a642-416b-bd3f-93fb339c3a7d"
},
"source": [
"# Table of Contents"
]
},
{
"cell_type": "markdown",
"id": "6791f449-d1da-461b-9eb7-c6dc6b37b15c",
"metadata": {
"tags": [],
"id": "6791f449-d1da-461b-9eb7-c6dc6b37b15c"
},
"source": [
"- [ 1 - Set up Kernel and Required Dependencies](#1)\n",
"- [ 2 - Load FLAN-T5 Model, Prepare Reward Model and Toxicity Evaluator](#2)\n",
" - [ 2.1 - Load Data and FLAN-T5 Model Fine-Tuned with Summarization Instruction](#2.1)\n",
" - [ 2.2 - Prepare Reward Model](#2.2)\n",
" - [ 2.3 - Evaluate Toxicity](#2.3)\n",
"- [ 3 - Perform Fine-Tuning to Detoxify the Summaries](#3)\n",
" - [ 3.1 - Initialize `PPOTrainer`](#3.1)\n",
" - [ 3.2 - Fine-Tune the Model](#3.2)\n",
" - [ 3.3 - Evaluate the Model Quantitatively](#3.3)\n",
" - [ 3.4 - Evaluate the Model Qualitatively](#3.4)"
]
},
{
"cell_type": "markdown",
"id": "89f973f1-f095-4915-86d0-bc16380da22d",
"metadata": {
"tags": [],
"id": "89f973f1-f095-4915-86d0-bc16380da22d"
},
"source": [
"<a name='1'></a>\n",
"## 1 - Set up Kernel and Required Dependencies"
]
},
{
"cell_type": "markdown",
"id": "ba0923b1-daaa-437c-a604-455d839b9877",
"metadata": {
"tags": [],
"id": "ba0923b1-daaa-437c-a604-455d839b9877"
},
"source": [
"First, check that the correct kernel is chosen.\n",
"\n",
"<img src=\"images/kernel_set_up.png\" width=\"300\"/>\n",
"\n",
"You can click on that (top right of the screen) to see and check the details of the image, kernel, and instance type.\n",
"\n",
"<img src=\"images/w3_kernel_and_instance_type.png\" width=\"600\"/>"
]
},
{
"cell_type": "markdown",
"id": "5e38b387-d92d-4361-8112-37c908f25ff1",
"metadata": {
"tags": [],
"id": "5e38b387-d92d-4361-8112-37c908f25ff1"
},
"source": [
"Now install the required packages to use PyTorch and Hugging Face transformers and datasets.\n",
"\n",
"<img src=\"data:image/svg+xml;base64,Cjxzdmcgd2lkdGg9IjgwMCIgaGVpZ2h0PSIxMjUiIHZpZXdCb3g9IjAgMCA4MDAgMTI1IiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogICAgPGRlZnM+CiAgICAgICAgPGxpbmVhckdyYWRpZW50IGlkPSJmYWRlR3JhZGllbnQiIHgxPSIwIiB4Mj0iMSI+CiAgICAgICAgICAgIDxzdG9wIG9mZnNldD0iMCUiIHN0b3AtY29sb3I9IiNGMEYwRjAiLz4KICAgICAgICAgICAgPHN0b3Agb2Zmc2V0PSIxMDAlIiBzdG9wLWNvbG9yPSIjRjBGMEYwIiBzdG9wLW9wYWNpdHk9IjAiLz4KICAgICAgICA8L2xpbmVhckdyYWRpZW50PgogICAgICAgIDxtYXNrIGlkPSJmYWRlTWFzayI+CiAgICAgICAgICAgIDxyZWN0IHg9IjAiIHk9IjAiIHdpZHRoPSI3NTAiIGhlaWdodD0iMTI1IiBmaWxsPSJ3aGl0ZSIvPgogICAgICAgICAgICA8cmVjdCB4PSI3NTAiIHk9IjAiIHdpZHRoPSI1MCIgaGVpZ2h0PSIxMjUiIGZpbGw9InVybCgjZmFkZUdyYWRpZW50KSIvPgogICAgICAgIDwvbWFzaz4KICAgIDwvZGVmcz4KICAgIDxwYXRoIGQ9Ik0zLDUwIEE1MCw1MCAwIDAgMSA1MywzIEw3OTcsMyBMNzk3LDk3IEw5Nyw5NyBMNTAsMTE1IEwzLDk3IFoiIGZpbGw9IiNGMEYwRjAiIHN0cm9rZT0iI0UwRTBFMCIgc3Ryb2tlLXdpZHRoPSIxIiBtYXNrPSJ1cmwoI2ZhZGVNYXNrKSIvPgogICAgPGNpcmNsZSBjeD0iNTAiIGN5PSI1MCIgcj0iMzAiIGZpbGw9IiM1N2M0ZjgiIHN0cm9rZT0iIzU3YzRmOCIgc3Ryb2tlLXdpZHRoPSIxIi8+CiAgICA8Y2lyY2xlIGN4PSI1MCIgY3k9IjUwIiByPSIyNSIgZmlsbD0iI0YwRjBGMCIvPgogICAgPGxpbmUgeDE9IjUwIiB5MT0iNTAiIHgyPSI1MCIgeTI9IjMwIiBzdHJva2U9IiM1N2M0ZjgiIHN0cm9rZS13aWR0aD0iMyIgc3Ryb2tlLWxpbmVjYXA9InJvdW5kIi8+CiAgICA8bGluZSB4MT0iNTAiIHkxPSI1MCIgeDI9IjY1IiB5Mj0iNTAiIHN0cm9rZT0iIzU3YzRmOCIgc3Ryb2tlLXdpZHRoPSIzIiBzdHJva2UtbGluZWNhcD0icm91bmQiLz4KICAgIDx0ZXh0IHg9IjEwMCIgeT0iMzQiIGZvbnQtZmFtaWx5PSJBcmlhbCwgc2Fucy1zZXJpZiIgZm9udC1zaXplPSIxNCIgZmlsbD0iIzMzMzMzMyI+VGhlIG5leHQgY2VsbCBtYXkgdGFrZSBhIGZldyBtaW51dGVzIHRvIHJ1bi4gUGxlYXNlIGJlIHBhdGllbnQuPC90ZXh0PgogICAgPHRleHQgeD0iMTAwIiB5PSI1NiIgZm9udC1mYW1pbHk9IkFyaWFsLCBzYW5zLXNlcmlmIiBmb250LXNpemU9IjE0IiBmaWxsPSIjMzMzMzMzIj5JZ25vcmUgdGhlIHdhcm5pbmdzIGFuZCBlcnJvcnMsIGFsb25nIHdpdGggdGhlIG5vdGUgYWJvdXQgcmVzdGFydGluZyB0aGUga2VybmVsIGF0IHRoZSBlbmQuPC90ZXh0Pgo8L3N2Zz4K\" alt=\"Time alert open medium\"/>"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9d24e86-f76f-4a44-90ef-0777752075a8",
"metadata": {
"tags": [],
"id": "f9d24e86-f76f-4a44-90ef-0777752075a8",
"outputId": "a7f0bf48-50f4-45ef-8f1e-88ff79ae0749"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: pip in /opt/conda/lib/python3.7/site-packages (23.1.2)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"pytest-astropy 0.8.0 requires pytest-cov>=2.0, which is not installed.\n",
"pytest-astropy 0.8.0 requires pytest-filter-subpackage>=0.1, which is not installed.\n",
"spyder 4.0.1 requires pyqt5<5.13; python_version >= \"3\", which is not installed.\n",
"spyder 4.0.1 requires pyqtwebengine<5.13; python_version >= \"3\", which is not installed.\n",
"sagemaker 2.165.0 requires importlib-metadata<5.0,>=1.4.0, but you have importlib-metadata 6.6.0 which is incompatible.\n",
"sparkmagic 0.20.4 requires nest-asyncio==1.5.5, but you have nest-asyncio 1.5.6 which is incompatible.\n",
"spyder 4.0.1 requires jedi==0.14.1, but you have jedi 0.18.2 which is incompatible.\u001b[0m\u001b[31m\n",
"\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n",
"Collecting git+https://github.com/lvwerra/trl.git@25fa1bd\n",
" Cloning https://github.com/lvwerra/trl.git (to revision 25fa1bd) to /tmp/pip-req-build-ck6je_dp\n",
" Running command git clone --filter=blob:none --quiet https://github.com/lvwerra/trl.git /tmp/pip-req-build-ck6je_dp\n",
"\u001b[33m WARNING: Did not find branch or tag '25fa1bd', assuming revision or ref.\u001b[0m\u001b[33m\n",
"\u001b[0m Running command git checkout -q 25fa1bd\n",
" Resolved https://github.com/lvwerra/trl.git to commit 25fa1bd\n",
" Preparing metadata (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25hRequirement already satisfied: torch>=1.4.0 in /opt/conda/lib/python3.7/site-packages (from trl==0.4.2.dev0) (1.13.1)\n",
"Requirement already satisfied: transformers>=4.18.0 in /opt/conda/lib/python3.7/site-packages (from trl==0.4.2.dev0) (4.27.2)\n",
"Requirement already satisfied: numpy>=1.18.2 in /opt/conda/lib/python3.7/site-packages (from trl==0.4.2.dev0) (1.21.6)\n",
"Requirement already satisfied: accelerate in /opt/conda/lib/python3.7/site-packages (from trl==0.4.2.dev0) (0.20.3)\n",
"Requirement already satisfied: datasets in /opt/conda/lib/python3.7/site-packages (from trl==0.4.2.dev0) (2.11.0)\n",
"Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from torch>=1.4.0->trl==0.4.2.dev0) (4.6.3)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /opt/conda/lib/python3.7/site-packages (from torch>=1.4.0->trl==0.4.2.dev0) (11.7.99)\n",
"Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /opt/conda/lib/python3.7/site-packages (from torch>=1.4.0->trl==0.4.2.dev0) (8.5.0.96)\n",
"Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /opt/conda/lib/python3.7/site-packages (from torch>=1.4.0->trl==0.4.2.dev0) (11.10.3.66)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /opt/conda/lib/python3.7/site-packages (from torch>=1.4.0->trl==0.4.2.dev0) (11.7.99)\n",
"Requirement already satisfied: setuptools in /opt/conda/lib/python3.7/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.4.0->trl==0.4.2.dev0) (65.5.1)\n",
"Requirement already satisfied: wheel in /opt/conda/lib/python3.7/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.4.0->trl==0.4.2.dev0) (0.40.0)\n",
"Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (3.0.12)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (0.16.2)\n",
"Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (6.0)\n",
"Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (2023.6.3)\n",
"Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (2.31.0)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (0.13.3)\n",
"Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (4.65.0)\n",
"Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from transformers>=4.18.0->trl==0.4.2.dev0) (6.6.0)\n",
"Requirement already satisfied: psutil in /opt/conda/lib/python3.7/site-packages (from accelerate->trl==0.4.2.dev0) (5.6.7)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (12.0.1)\n",
"Requirement already satisfied: dill<0.3.7,>=0.3.0 in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (0.3.6)\n",
"Requirement already satisfied: pandas in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (1.3.5)\n",
"Requirement already satisfied: xxhash in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (3.2.0)\n",
"Requirement already satisfied: multiprocess in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (0.70.14)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (2023.1.0)\n",
"Requirement already satisfied: aiohttp in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (3.8.4)\n",
"Requirement already satisfied: responses<0.19 in /opt/conda/lib/python3.7/site-packages (from datasets->trl==0.4.2.dev0) (0.18.0)\n",
"Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (2.0.4)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (4.0.2)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (1.3.3)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (1.3.1)\n",
"Requirement already satisfied: asynctest==0.13.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->datasets->trl==0.4.2.dev0) (0.13.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->transformers>=4.18.0->trl==0.4.2.dev0) (2.8)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->transformers>=4.18.0->trl==0.4.2.dev0) (2.0.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->transformers>=4.18.0->trl==0.4.2.dev0) (2023.5.7)\n",
"Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->transformers>=4.18.0->trl==0.4.2.dev0) (2.2.0)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/lib/python3.7/site-packages (from pandas->datasets->trl==0.4.2.dev0) (2.8.2)\n",
"Requirement already satisfied: pytz>=2017.3 in /opt/conda/lib/python3.7/site-packages (from pandas->datasets->trl==0.4.2.dev0) (2019.3)\n",
"Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.7/site-packages (from python-dateutil>=2.7.3->pandas->datasets->trl==0.4.2.dev0) (1.14.0)\n",
"Building wheels for collected packages: trl\n",
" Building wheel for trl (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25h Created wheel for trl: filename=trl-0.4.2.dev0-py3-none-any.whl size=67536 sha256=28810e37b984f27ecc11ad077bd048b373eca035a3c0d35ebd22e80339429858\n",
" Stored in directory: /tmp/pip-ephem-wheel-cache-l3bxbtz3/wheels/41/26/75/08a45cee1a1bba06c4f340451483cdfe150f4c8dad3876fb2e\n",
"Successfully built trl\n",
"Installing collected packages: trl\n",
"Successfully installed trl-0.4.2.dev0\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install --upgrade pip\n",
"%pip install --disable-pip-version-check \\\n",
" torch==1.13.1 \\\n",
" torchdata==0.5.1 --quiet\n",
"\n",
"%pip install \\\n",
" transformers==4.27.2 \\\n",
" datasets==2.11.0 \\\n",
" evaluate==0.4.0 \\\n",
" rouge_score==0.1.2 \\\n",
" peft==0.3.0 --quiet\n",
"\n",
"# Installing the Reinforcement Learning library directly from github.\n",
"%pip install git+https://github.com/lvwerra/trl.git@25fa1bd"
]
},
{
"cell_type": "markdown",
"id": "b8f3c076-d9d2-40e3-b005-9dd66b5a163a",
"metadata": {
"tags": [],
"id": "b8f3c076-d9d2-40e3-b005-9dd66b5a163a"
},
"source": [
"<img src=\"data:image/svg+xml;base64,Cjxzdmcgd2lkdGg9IjgwMCIgaGVpZ2h0PSI1MCIgdmlld0JveD0iMCAwIDgwMCA1MCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KICAgIDxkZWZzPgogICAgICAgIDxsaW5lYXJHcmFkaWVudCBpZD0iZmFkZUdyYWRpZW50IiB4MT0iMCIgeDI9IjEiPgogICAgICAgICAgICA8c3RvcCBvZmZzZXQ9IjAlIiBzdG9wLWNvbG9yPSIjRjBGMEYwIi8+CiAgICAgICAgICAgIDxzdG9wIG9mZnNldD0iMTAwJSIgc3RvcC1jb2xvcj0iI0YwRjBGMCIgc3RvcC1vcGFjaXR5PSIwIi8+CiAgICAgICAgPC9saW5lYXJHcmFkaWVudD4KICAgICAgICA8bWFzayBpZD0iZmFkZU1hc2siPgogICAgICAgICAgICA8cmVjdCB4PSIwIiB5PSIwIiB3aWR0aD0iNzUwIiBoZWlnaHQ9IjUwIiBmaWxsPSJ3aGl0ZSIvPgogICAgICAgICAgICA8cmVjdCB4PSI3NTAiIHk9IjAiIHdpZHRoPSI1MCIgaGVpZ2h0PSI1MCIgZmlsbD0idXJsKCNmYWRlR3JhZGllbnQpIi8+CiAgICAgICAgPC9tYXNrPgogICAgPC9kZWZzPgogICAgPHBhdGggZD0iTTI1LDUwIFEwLDUwIDAsMjUgTDUwLDMgTDk3LDI1IEw3OTcsMjUgTDc5Nyw1MCBMMjUsNTAgWiIgZmlsbD0iI0YwRjBGMCIgc3Ryb2tlPSIjRTBFMEUwIiBzdHJva2Utd2lkdGg9IjEiIG1hc2s9InVybCgjZmFkZU1hc2spIi8+Cjwvc3ZnPgo=\" alt=\"Time alert close\"/>"
]
},
{
"cell_type": "markdown",
"id": "74bfd06e-c747-43e0-b86c-0398628e1c32",
"metadata": {
"tags": [],
"id": "74bfd06e-c747-43e0-b86c-0398628e1c32"
},
"source": [
"Import the necessary components. Some of them are new for this week, they will be discussed later in the notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8c20bed-6a30-4847-a507-02969ecb4465",
"metadata": {
"tags": [],
"id": "d8c20bed-6a30-4847-a507-02969ecb4465"
},
"outputs": [],
"source": [
"from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig\n",
"from datasets import load_dataset\n",
"from peft import PeftModel, PeftConfig, LoraConfig, TaskType\n",
"\n",
"# trl: Transformer Reinforcement Learning library\n",
"from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead\n",
"from trl import create_reference_model\n",
"from trl.core import LengthSampler\n",
"\n",
"import torch\n",
"import evaluate\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"# tqdm library makes the loops show a smart progress meter.\n",
"from tqdm import tqdm\n",
"tqdm.pandas()"
]
},
{
"cell_type": "markdown",
"id": "b76eea84-8e3a-4487-9692-613977e6c8e3",
"metadata": {
"id": "b76eea84-8e3a-4487-9692-613977e6c8e3"
},
"source": [
"<a name='2'></a>\n",
"## 2 - Load FLAN-T5 Model, Prepare Reward Model and Toxicity Evaluator"
]
},
{
"cell_type": "markdown",
"id": "4a5f97d4-ea5f-4072-b5d6-785d1d833ed4",
"metadata": {
"tags": [],
"id": "4a5f97d4-ea5f-4072-b5d6-785d1d833ed4"
},
"source": [
"<a name='2.1'></a>\n",
"### 2.1 - Load Data and FLAN-T5 Model Fine-Tuned with Summarization Instruction"
]
},
{
"cell_type": "markdown",
"id": "90dc0211-4032-4967-946d-3a538829d5c9",
"metadata": {
"tags": [],
"id": "90dc0211-4032-4967-946d-3a538829d5c9"
},
"source": [
"You will keep working with the same Hugging Face dataset [DialogSum](https://huggingface.co/datasets/knkarthick/dialogsum) and the pre-trained model [FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b058b52b-ec4d-4426-8d71-91e898f727f6",
"metadata": {
"tags": [],
"colab": {
"referenced_widgets": [
"e8449a73aeb04a93bfc97a9797a3af2d",
"181f11de8ffb4bc995b0a748b7851123",
"fae0da12e37a44f68cd69dd48adbf5cc",
"e374fd3775414ec5a672e8eeb36e9f43",
"c6ff6918ca0f4db39ce346b095b7aafb",
"96589648779343198bb5013a755c7cf2",
"",
"5eedf19e8c1a446692dd7bd233a743e5"
]
},
"id": "b058b52b-ec4d-4426-8d71-91e898f727f6",
"outputId": "e95c0ed8-6734-4922-9d50-5136e13305d6"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e8449a73aeb04a93bfc97a9797a3af2d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading readme: 0%| | 0.00/4.56k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading and preparing dataset csv/knkarthick--dialogsum to /root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-391706c81424fc80/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "181f11de8ffb4bc995b0a748b7851123",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data files: 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fae0da12e37a44f68cd69dd48adbf5cc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/11.3M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e374fd3775414ec5a672e8eeb36e9f43",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/1.35M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c6ff6918ca0f4db39ce346b095b7aafb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/442k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "96589648779343198bb5013a755c7cf2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Extracting data files: 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0 examples [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating test split: 0 examples [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating validation split: 0 examples [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-391706c81424fc80/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5eedf19e8c1a446692dd7bd233a743e5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['id', 'dialogue', 'summary', 'topic'],\n",
" num_rows: 12460\n",
" })\n",
" test: Dataset({\n",
" features: ['id', 'dialogue', 'summary', 'topic'],\n",
" num_rows: 1500\n",
" })\n",
" validation: Dataset({\n",
" features: ['id', 'dialogue', 'summary', 'topic'],\n",
" num_rows: 500\n",
" })\n",
"})"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_name=\"google/flan-t5-base\"\n",
"huggingface_dataset_name = \"knkarthick/dialogsum\"\n",
"\n",
"dataset_original = load_dataset(huggingface_dataset_name)\n",
"\n",
"dataset_original"
]
},
{
"cell_type": "markdown",
"id": "668d30d6-6f81-4e52-a81a-3057163ddb0e",
"metadata": {
"id": "668d30d6-6f81-4e52-a81a-3057163ddb0e"
},
"source": [
"The next step will be to preprocess the dataset. You will take only a part of it, then filter the dialogues of a particular length (just to make those examples long enough and, at the same time, easy to read). Then wrap each dialogue with the instruction and tokenize the prompts. Save the token ids in the field `input_ids` and decoded version of the prompts in the field `query`.\n",
"\n",
"You could do that all step by step in the cell below, but it is a good habit to organize that all in a function `build_dataset`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51469abe-4d72-4093-a6c6-8e04e19f09eb",
"metadata": {
"tags": [],
"colab": {
"referenced_widgets": [
"",
"42592e5f880d47189a3602169da2747d",
"5674d181379348b19f1746efa1c1cb2c",
"4322b8b546b640f0bc2fda5181114bdc",
"e595be88b52d436f9a4222e56dfacd24"
]
},
"id": "51469abe-4d72-4093-a6c6-8e04e19f09eb",
"outputId": "f26f590b-cd29-4180-aa0e-c7ec8b8bfb86"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset csv (/root/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-391706c81424fc80/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Filter: 0%| | 0/12460 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "42592e5f880d47189a3602169da2747d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)okenizer_config.json: 0%| | 0.00/2.54k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5674d181379348b19f1746efa1c1cb2c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading spiece.model: 0%| | 0.00/792k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4322b8b546b640f0bc2fda5181114bdc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)/main/tokenizer.json: 0%| | 0.00/2.42M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e595be88b52d436f9a4222e56dfacd24",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)cial_tokens_map.json: 0%| | 0.00/2.20k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/10022 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],\n",
" num_rows: 8017\n",
" })\n",
" test: Dataset({\n",
" features: ['id', 'dialogue', 'summary', 'topic', 'input_ids', 'query'],\n",
" num_rows: 2005\n",
" })\n",
"})\n"
]
}
],
"source": [
"def build_dataset(model_name,\n",
" dataset_name,\n",
" input_min_text_length,\n",
" input_max_text_length):\n",
"\n",
" \"\"\"\n",
" Preprocess the dataset and split it into train and test parts.\n",
"\n",
" Parameters:\n",
" - model_name (str): Tokenizer model name.\n",
" - dataset_name (str): Name of the dataset to load.\n",
" - input_min_text_length (int): Minimum length of the dialogues.\n",
" - input_max_text_length (int): Maximum length of the dialogues.\n",
"\n",
" Returns:\n",
" - dataset_splits (datasets.dataset_dict.DatasetDict): Preprocessed dataset containing train and test parts.\n",
" \"\"\"\n",
"\n",
" # load dataset (only \"train\" part will be enough for this lab).\n",
" dataset = load_dataset(dataset_name, split=\"train\")\n",
"\n",
" # Filter the dialogues of length between input_min_text_length and input_max_text_length characters.\n",
" dataset = dataset.filter(lambda x: len(x[\"dialogue\"]) > input_min_text_length and len(x[\"dialogue\"]) <= input_max_text_length, batched=False)\n",
"\n",
" # Prepare tokenizer. Setting device_map=\"auto\" allows to switch between GPU and CPU automatically.\n",
" tokenizer = AutoTokenizer.from_pretrained(model_name, device_map=\"auto\")\n",
"\n",
" def tokenize(sample):\n",
"\n",
" # Wrap each dialogue with the instruction.\n",
" prompt = f\"\"\"\n",
"Summarize the following conversation.\n",
"\n",
"{sample[\"dialogue\"]}\n",
"\n",
"Summary:\n",
"\"\"\"\n",
" sample[\"input_ids\"] = tokenizer.encode(prompt)\n",
"\n",
" # This must be called \"query\", which is a requirement of our PPO library.\n",
" sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n",
" return sample\n",
"\n",
" # Tokenize each dialogue.\n",
" dataset = dataset.map(tokenize, batched=False)\n",
" dataset.set_format(type=\"torch\")\n",
"\n",
" # Split the dataset into train and test parts.\n",
" dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)\n",
"\n",
" return dataset_splits\n",
"\n",
"dataset = build_dataset(model_name=model_name,\n",
" dataset_name=huggingface_dataset_name,\n",
" input_min_text_length=200,\n",
" input_max_text_length=1000)\n",
"\n",
"print(dataset)"
]
},
{
"cell_type": "markdown",
"id": "7d03155e-649b-45bb-a5a0-94edd682c069",
"metadata": {
"tags": [],
"id": "7d03155e-649b-45bb-a5a0-94edd682c069"
},
"source": [
"In the previous lab, you fine-tuned the PEFT model with summarization instructions. The training in the notebook was done on a subset of data. Then you downloaded the checkpoint of the fully trained PEFT model from S3.\n",
"\n",
"Let's load the same model checkpoint here:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1d44a53-ea1f-4fa5-89e7-d46e37d19935",
"metadata": {
"tags": [],
"id": "e1d44a53-ea1f-4fa5-89e7-d46e37d19935",
"outputId": "30cc6b30-27e1-4469-ab9b-07826c0b62f5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/adapter_config.json to peft-dialogue-summary-checkpoint-from-s3/adapter_config.json\n",
"download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/special_tokens_map.json to peft-dialogue-summary-checkpoint-from-s3/special_tokens_map.json\n",
"download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/tokenizer_config.json to peft-dialogue-summary-checkpoint-from-s3/tokenizer_config.json\n",
"download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/tokenizer.json to peft-dialogue-summary-checkpoint-from-s3/tokenizer.json\n",
"download: s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/adapter_model.bin to peft-dialogue-summary-checkpoint-from-s3/adapter_model.bin\n"
]
}
],
"source": [
"!aws s3 cp --recursive s3://dlai-generative-ai/models/peft-dialogue-summary-checkpoint/ ./peft-dialogue-summary-checkpoint-from-s3/"
]
},
{
"cell_type": "markdown",
"id": "dec8bea4-addd-4b29-b3af-6db6ea2baeb7",
"metadata": {
"tags": [],
"id": "dec8bea4-addd-4b29-b3af-6db6ea2baeb7"
},
"source": [
"List the model item and check its size (it's less than 15 Mb):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4288240d-764b-4c49-8df7-b30b9277adbd",
"metadata": {
"tags": [],
"id": "4288240d-764b-4c49-8df7-b30b9277adbd",
"outputId": "79c3c026-9c7c-4966-ad7c-10babeeebedd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-rw-r--r-- 1 root root 14M May 15 11:18 ./peft-dialogue-summary-checkpoint-from-s3/adapter_model.bin\n"
]
}
],
"source": [
"!ls -alh ./peft-dialogue-summary-checkpoint-from-s3/adapter_model.bin"
]
},
{
"cell_type": "markdown",
"id": "f4226923-67c0-4ea6-8e47-030136b2f191",
"metadata": {
"id": "f4226923-67c0-4ea6-8e47-030136b2f191"
},
"source": [
"Prepare a function to pull out the number of model parameters (it is the same as in the previous lab):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1f06806-a194-4c14-b64d-e31afd7b658c",
"metadata": {
"tags": [],
"id": "a1f06806-a194-4c14-b64d-e31afd7b658c"
},
"outputs": [],
"source": [
"def print_number_of_trainable_model_parameters(model):\n",
" trainable_model_params = 0\n",
" all_model_params = 0\n",
" for _, param in model.named_parameters():\n",
" all_model_params += param.numel()\n",
" if param.requires_grad:\n",
" trainable_model_params += param.numel()\n",
" return f\"\\ntrainable model parameters: {trainable_model_params}\\nall model parameters: {all_model_params}\\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%\""
]
},
{
"cell_type": "markdown",
"id": "21e06a57-fb80-4f8c-a967-4c7c42a7bfda",
"metadata": {
"tags": [],
"id": "21e06a57-fb80-4f8c-a967-4c7c42a7bfda"
},
"source": [
"Add the adapter to the original FLAN-T5 model. In the previous lab you were adding the fully trained adapter only for inferences, so there was no need to pass LoRA configurations doing that. Now you need to pass them to the constructed PEFT model, also putting `is_trainable=True`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1a94b14-b375-45e7-9e49-a7f2c341b4ff",
"metadata": {
"tags": [],
"colab": {
"referenced_widgets": [
"01c75d45acb1400395caca202f885675",
"196b5f2140af42428db8ed4378384897",
"a6617f05c30441579ad3e0bb8e7aa0f4"
]
},
"id": "a1a94b14-b375-45e7-9e49-a7f2c341b4ff",
"outputId": "8aa27841-06ea-472a-f381-6be72076d970"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "01c75d45acb1400395caca202f885675",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)lve/main/config.json: 0%| | 0.00/1.40k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "196b5f2140af42428db8ed4378384897",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading pytorch_model.bin: 0%| | 0.00/990M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a6617f05c30441579ad3e0bb8e7aa0f4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)neration_config.json: 0%| | 0.00/147 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"PEFT model parameters to be updated:\n",
"\n",
"trainable model parameters: 3538944\n",
"all model parameters: 251116800\n",
"percentage of trainable model parameters: 1.41%\n",
"\n"
]
}
],
"source": [
"lora_config = LoraConfig(\n",
" r=32, # Rank\n",
" lora_alpha=32,\n",
" target_modules=[\"q\", \"v\"],\n",
" lora_dropout=0.05,\n",
" bias=\"none\",\n",
" task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5\n",
")\n",
"\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name,\n",
" torch_dtype=torch.bfloat16)\n",
"\n",
"peft_model = PeftModel.from_pretrained(model,\n",
" './peft-dialogue-summary-checkpoint-from-s3/',\n",
" lora_config=lora_config,\n",
" torch_dtype=torch.bfloat16,\n",
" device_map=\"auto\",\n",
" is_trainable=True)\n",
"\n",
"print(f'PEFT model parameters to be updated:\\n{print_number_of_trainable_model_parameters(peft_model)}\\n')\n"
]
},
{
"cell_type": "markdown",
"id": "a950ae8a-76b9-4951-9c78-9ac7a6349e17",
"metadata": {
"id": "a950ae8a-76b9-4951-9c78-9ac7a6349e17"
},
"source": [
"In this lab, you are preparing to fine-tune the LLM using Reinforcement Learning (RL). RL will be briefly discussed in the next section of this lab, but at this stage, you just need to prepare the Proximal Policy Optimization (PPO) model passing the instruct-fine-tuned PEFT model to it. PPO will be used to optimize the RL policy against the reward model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e86bab0-6dee-4dff-a754-b584ed962723",
"metadata": {
"tags": [],
"id": "1e86bab0-6dee-4dff-a754-b584ed962723",
"outputId": "1a97a27b-dff8-48ac-d721-1de9741dd8e8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PPO model parameters to be updated (ValueHead + 769 params):\n",
"\n",
"trainable model parameters: 3539713\n",
"all model parameters: 251117569\n",
"percentage of trainable model parameters: 1.41%\n",
"\n",
"ValueHead(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (summary): Linear(in_features=768, out_features=1, bias=True)\n",
" (flatten): Flatten(start_dim=1, end_dim=-1)\n",
")\n"
]
}
],
"source": [
"ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,\n",
" torch_dtype=torch.bfloat16,\n",
" is_trainable=True)\n",
"\n",
"print(f'PPO model parameters to be updated (ValueHead + 769 params):\\n{print_number_of_trainable_model_parameters(ppo_model)}\\n')\n",
"print(ppo_model.v_head)"
]
},
{
"cell_type": "markdown",
"id": "2913ef05-737e-4cdf-9bac-467ee6cf9f76",
"metadata": {
"id": "2913ef05-737e-4cdf-9bac-467ee6cf9f76"
},
"source": [
"During PPO, only a few parameters will be updated. Specifically, the parameters of the `ValueHead`. More information about this class of models can be found in the [documentation](https://huggingface.co/docs/trl/main/en/models#trl.create_reference_model). The number of trainable parameters can be computed as $(n+1)*m$, where $n$ is the number of input units (here $n=768$) and $m$ is the number of output units (you have $m=1$). The $+1$ term in the equation takes into account the bias term."
]
},
{
"cell_type": "markdown",
"id": "76c7e2df-0c75-4bd0-bf8d-1f1545ef864e",
"metadata": {
"id": "76c7e2df-0c75-4bd0-bf8d-1f1545ef864e"
},
"source": [
"Now create a frozen copy of the PPO which will not be fine-tuned - a reference model. The reference model will represent the LLM before detoxification. None of the parameters of the reference model will be updated during PPO training. This is on purpose."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18a9b30a-ad14-4189-8088-d4de447fe247",
"metadata": {
"tags": [],
"id": "18a9b30a-ad14-4189-8088-d4de447fe247",
"outputId": "d3f2ab01-2c21-44bb-e1e4-07134a8b9dae"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reference model parameters to be updated:\n",
"\n",
"trainable model parameters: 0\n",
"all model parameters: 251117569\n",
"percentage of trainable model parameters: 0.00%\n",
"\n"
]
}
],
"source": [
"ref_model = create_reference_model(ppo_model)\n",
"\n",
"print(f'Reference model parameters to be updated:\\n{print_number_of_trainable_model_parameters(ref_model)}\\n')"
]
},
{
"cell_type": "markdown",
"id": "e3a14848-e83d-4fdd-bc68-eb770c5951d7",
"metadata": {
"tags": [],
"id": "e3a14848-e83d-4fdd-bc68-eb770c5951d7"
},
"source": [
"Everything is set. It is time to prepare the reward model!"
]
},
{
"cell_type": "markdown",
"id": "4bfdf1f7-3509-4adc-812a-2b22bd330137",
"metadata": {
"id": "4bfdf1f7-3509-4adc-812a-2b22bd330137"
},
"source": [
"<a name='2.2'></a>\n",
"### 2.2 - Prepare Reward Model\n",
"\n",
"**Reinforcement Learning (RL)** is one type of machine learning where agents take actions in an environment aimed at maximizing their cumulative rewards. The agent's behavior is defined by the **policy**. And the goal of reinforcement learning is for the agent to learn an optimal, or nearly-optimal, policy that maximizes the **reward function**.\n",
"\n",
"In the [previous section](#2.1) the original policy is based on the instruct PEFT model - this is the LLM before detoxification. Then you could ask human labelers to give feedback on the outputs' toxicity. However, it can be expensive to use them for the entire fine-tuning process. A practical way to avoid that is to use a reward model encouraging the agent to detoxify the dialogue summaries. The intuitive approach would be to do some form of sentiment analysis across two classes (`nothate` and `hate`) and give a higher reward if there is higher a chance of getting class `nothate` as an output.\n",
"\n",
"For example, we can mention that having human labelers for the entire finetuning process can be expensive. A practical way to avoid that is to use a reward model.\n",
"\n",
"use feedback generated by a model\n",
"\n",
"You will use [Meta AI's RoBERTa-based hate speech model](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target) for the reward model. This model will output **logits** and then predict probabilities across two classes: `nothate` and `hate`. The logits of the output `nothate` will be taken as a positive reward. Then, the model will be fine-tuned with PPO using those reward values.\n",
"\n",
"Create the instance of the required model class for the RoBERTa model. You also need to load a tokenizer to test the model. Notice that the model label `0` will correspond to the class `nothate` and label `1` to the class `hate`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f038a9f-e04b-49bc-8923-8ef3816919ee",
"metadata": {
"tags": [],
"colab": {
"referenced_widgets": [
"a769c4b6fe84430fadc2600fbdf86936",
"23fceb3843e94c29b45d566f3e0fef72",
"84ed9870c1cb4120ae96447a5d077f07",
"b1677ecbc50a4eacaeff9d644c1d0a19",
"2a0b3c2ad92746279c411a5f950c9e70",
"21c7bee6ef824ff4a94186bccd76a12c"
]
},
"id": "7f038a9f-e04b-49bc-8923-8ef3816919ee",
"outputId": "e35c12ce-123d-4e84-a5e4-0dd16d662577"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a769c4b6fe84430fadc2600fbdf86936",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)okenizer_config.json: 0%| | 0.00/1.11k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "23fceb3843e94c29b45d566f3e0fef72",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)olve/main/vocab.json: 0%| | 0.00/899k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "84ed9870c1cb4120ae96447a5d077f07",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)olve/main/merges.txt: 0%| | 0.00/456k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b1677ecbc50a4eacaeff9d644c1d0a19",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)cial_tokens_map.json: 0%| | 0.00/239 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2a0b3c2ad92746279c411a5f950c9e70",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading (…)lve/main/config.json: 0%| | 0.00/816 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "21c7bee6ef824ff4a94186bccd76a12c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading pytorch_model.bin: 0%| | 0.00/499M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{0: 'nothate', 1: 'hate'}\n"
]
}
],
"source": [
"toxicity_model_name = \"facebook/roberta-hate-speech-dynabench-r4-target\"\n",
"toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name, device_map=\"auto\")\n",
"toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name, device_map=\"auto\")\n",
"print(toxicity_model.config.id2label)"
]
},
{
"cell_type": "markdown",
"id": "79d68799-a6e8-42d7-8d61-002e47210c18",
"metadata": {
"tags": [],
"id": "79d68799-a6e8-42d7-8d61-002e47210c18"
},
"source": [
"Take some non-toxic text, tokenize it, and pass it to the model. Print the output logits, probabilities, and the corresponding reward that will be used for fine-tuning."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f4e6a05-2398-4ca7-a176-d5a1ff27fe39",
"metadata": {
"tags": [],
"id": "8f4e6a05-2398-4ca7-a176-d5a1ff27fe39",
"outputId": "15565d59-867e-4524-a3de-fe54f35d7ca4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"logits [not hate, hate]: [3.114100694656372, -2.4896175861358643]\n",
"probabilities [not hate, hate]: [0.9963293671607971, 0.003670616541057825]\n",
"reward (high): [3.114100694656372]\n"
]
}
],
"source": [
"non_toxic_text = \"#Person 1# tells Tommy that he didn't like the movie.\"\n",
"\n",
"toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors=\"pt\").input_ids\n",
"\n",
"logits = toxicity_model(input_ids=toxicity_input_ids).logits\n",
"print(f'logits [not hate, hate]: {logits.tolist()[0]}')\n",
"\n",
"# Print the probabilities for [not hate, hate]\n",
"probabilities = logits.softmax(dim=-1).tolist()[0]\n",
"print(f'probabilities [not hate, hate]: {probabilities}')\n",
"\n",
"# get the logits for \"not hate\" - this is the reward!\n",
"not_hate_index = 0\n",
"nothate_reward = (logits[:, not_hate_index]).tolist()\n",
"print(f'reward (high): {nothate_reward}')"
]
},
{
"cell_type": "markdown",
"id": "63f729c5-98c3-4745-96e8-3484670215db",
"metadata": {
"id": "63f729c5-98c3-4745-96e8-3484670215db"
},
"source": [
"Let's show a toxic comment. This will have a low reward because it is more toxic."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ffc81e4-b220-4f4e-95ec-98ac418612d1",
"metadata": {
"tags": [],
"id": "0ffc81e4-b220-4f4e-95ec-98ac418612d1",
"outputId": "878021ea-1fef-4b2e-e4ae-314a24c1ba97"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"logits [not hate, hate]: [-0.6921188831329346, 0.3722729980945587]\n",
"probabilities [not hate, hate]: [0.25647106766700745, 0.7435289621353149]\n",
"reward (low): [-0.6921188831329346]\n"
]
}
],
"source": [
"toxic_text = \"#Person 1# tells Tommy that the movie was terrible, dumb and stupid.\"\n",
"\n",
"toxicity_input_ids = toxicity_tokenizer(toxic_text, return_tensors=\"pt\").input_ids\n",
"\n",
"logits = toxicity_model(toxicity_input_ids).logits\n",
"print(f'logits [not hate, hate]: {logits.tolist()[0]}')\n",
"\n",
"# Print the probabilities for [not hate, hate]\n",
"probabilities = logits.softmax(dim=-1).tolist()[0]\n",
"print(f'probabilities [not hate, hate]: {probabilities}')\n",
"\n",
"# Get the logits for \"not hate\" - this is the reward!\n",
"nothate_reward = (logits[:, not_hate_index]).tolist()\n",
"print(f'reward (low): {nothate_reward}')"
]
},
{
"cell_type": "markdown",
"id": "bc6e656e-fefe-4623-8bbb-472a8cf1c3c5",
"metadata": {
"id": "bc6e656e-fefe-4623-8bbb-472a8cf1c3c5"
},
"source": [
"Setup Hugging Face inference pipeline to simplify the code for the toxicity reward model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "73aab8c6-33eb-4bf4-a816-3866fc3460af",
"metadata": {
"tags": [],
"id": "73aab8c6-33eb-4bf4-a816-3866fc3460af",
"outputId": "0c63c848-91b6-4e65-cb3d-b7d01a80e1ca"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reward model output:\n",
"For non-toxic text\n",
"[{'label': 'nothate', 'score': 3.114100694656372}, {'label': 'hate', 'score': -2.4896175861358643}]\n",
"[{'label': 'nothate', 'score': 0.9963293671607971}, {'label': 'hate', 'score': 0.003670616541057825}]\n",
"For toxic text\n",
"[{'label': 'hate', 'score': 0.3722729980945587}, {'label': 'nothate', 'score': -0.6921188831329346}]\n",
"[{'label': 'hate', 'score': 0.7435289621353149}, {'label': 'nothate', 'score': 0.25647106766700745}]\n"
]
}
],
"source": [
"device = 0 if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"sentiment_pipe = pipeline(\"sentiment-analysis\",\n",
" model=toxicity_model_name,\n",
" device=device)\n",
"reward_logits_kwargs = {\n",
" \"top_k\": None, # Return all scores.\n",
" \"function_to_apply\": \"none\", # Set to \"none\" to retrieve raw logits.\n",
" \"batch_size\": 16\n",
"}\n",
"\n",
"reward_probabilities_kwargs = {\n",
" \"top_k\": None, # Return all scores.\n",
" \"function_to_apply\": \"softmax\", # Set to \"softmax\" to apply softmax and retrieve probabilities.\n",
" \"batch_size\": 16\n",
"}\n",
"\n",
"print(\"Reward model output:\")\n",
"print(\"For non-toxic text\")\n",
"print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))\n",
"print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))\n",
"print(\"For toxic text\")\n",
"print(sentiment_pipe(toxic_text, **reward_logits_kwargs))\n",
"print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))"
]
},
{
"cell_type": "markdown",
"id": "21302d74-59d8-451f-b287-e86245bf3324",
"metadata": {
"id": "21302d74-59d8-451f-b287-e86245bf3324"
},
"source": [
"The outputs are the logits for both `nothate` (positive) and `hate` (negative) classes. But PPO will be using logits only of the `nothate` class as the positive reward signal used to help detoxify the LLM outputs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36ff3925-70c4-495c-ae26-68e2fc36296b",
"metadata": {
"tags": [],
"id": "36ff3925-70c4-495c-ae26-68e2fc36296b",
"outputId": "3c07e1cb-63e8-448a-df27-891ec920b420"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{'label': 'nothate', 'score': 3.114100694656372}, {'label': 'hate', 'score': -2.4896175861358643}]\n",
"[{'label': 'nothate', 'score': 0.9963293671607971}, {'label': 'hate', 'score': 0.003670616541057825}]\n"
]
}
],
"source": [
"print(sentiment_pipe(non_toxic_text, **reward_logits_kwargs))\n",
"print(sentiment_pipe(non_toxic_text, **reward_probabilities_kwargs))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8d11618b-5887-489a-b390-2139e364987f",
"metadata": {
"tags": [],
"id": "8d11618b-5887-489a-b390-2139e364987f",
"outputId": "2b82797a-7175-42c3-9eb3-3c996828ffcd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{'label': 'hate', 'score': 0.3722729980945587}, {'label': 'nothate', 'score': -0.6921188831329346}]\n",
"[{'label': 'hate', 'score': 0.7435289621353149}, {'label': 'nothate', 'score': 0.25647106766700745}]\n"
]
}
],
"source": [
"print(sentiment_pipe(toxic_text, **reward_logits_kwargs))\n",
"print(sentiment_pipe(toxic_text, **reward_probabilities_kwargs))"
]
},
{
"cell_type": "markdown",
"id": "56513033-9bb1-41d5-81e2-54d1249c5c89",
"metadata": {
"tags": [],
"id": "56513033-9bb1-41d5-81e2-54d1249c5c89"
},
"source": [
"<a name='2.3'></a>\n",
"### 2.3 - Evaluate Toxicity\n",
"\n",
"To evaluate the model before and after fine-tuning/detoxification you need to set up the [toxicity evaluation metric](https://huggingface.co/spaces/evaluate-measurement/toxicity). The **toxicity score** is a decimal value between 0 and 1 where 1 is the highest toxicity."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7de8e99d-60ea-48a2-bdf5-817f80b48979",
"metadata": {
"tags": [],
"colab": {
"referenced_widgets": [
"a11bf1d2e6f04ce6bdd1917175d6e4bf"
]
},
"id": "7de8e99d-60ea-48a2-bdf5-817f80b48979",
"outputId": "d508d2da-616f-4b73-9461-2136e964e3ff"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a11bf1d2e6f04ce6bdd1917175d6e4bf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading builder script: 0%| | 0.00/6.08k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"toxicity_evaluator = evaluate.load(\"toxicity\",\n",
" toxicity_model_name,\n",
" module_type=\"measurement\",\n",
" toxic_label=\"hate\")"
]
},
{
"cell_type": "markdown",
"id": "840fbc47-c5c2-469a-b5f2-6407e8f0bfde",
"metadata": {
"tags": [],
"id": "840fbc47-c5c2-469a-b5f2-6407e8f0bfde"
},
"source": [
"Try to calculate toxicity for the same sentences as in section [2.2](#2.2). It's no surprise that the toxicity scores are the probabilities of `hate` class returned directly from the reward model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5298f91c-30d1-4d17-95a7-553952ac97b5",
"metadata": {
"tags": [],
"id": "5298f91c-30d1-4d17-95a7-553952ac97b5",
"outputId": "68cdee3d-3aac-40ed-c787-f5cf01c7e4d3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Toxicity score for non-toxic text:\n",
"[0.003670616541057825]\n",
"\n",
"Toxicity score for toxic text:\n",
"[0.7435289621353149]\n"
]
}
],
"source": [
"toxicity_score = toxicity_evaluator.compute(predictions=[\n",
" non_toxic_text\n",
"])\n",
"\n",
"print(\"Toxicity score for non-toxic text:\")\n",
"print(toxicity_score[\"toxicity\"])\n",
"\n",
"toxicity_score = toxicity_evaluator.compute(predictions=[\n",
" toxic_text\n",
"])\n",
"\n",
"print(\"\\nToxicity score for toxic text:\")\n",
"print(toxicity_score[\"toxicity\"])"
]
},
{
"cell_type": "markdown",
"id": "7d3e835b-14b9-4646-b1c9-c975ef3ea944",
"metadata": {
"tags": [],
"id": "7d3e835b-14b9-4646-b1c9-c975ef3ea944"
},
"source": [
"This evaluator can be used to compute the toxicity of the dialogues prepared in section [2.1](#2.1). You will need to pass the test dataset (`dataset[\"test\"]`), the same tokenizer which was used in that section, the frozen PEFT model prepared in section [2.2](#2.2), and the toxicity evaluator. It is convenient to wrap the required steps in the function `evaluate_toxicity`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "316ab128-33ff-4a1e-8936-47bfa29d48a3",
"metadata": {
"tags": [],
"id": "316ab128-33ff-4a1e-8936-47bfa29d48a3"
},
"outputs": [],
"source": [
"def evaluate_toxicity(model,\n",
" toxicity_evaluator,\n",
" tokenizer,\n",
" dataset,\n",
" num_samples):\n",
"\n",
" \"\"\"\n",
" Preprocess the dataset and split it into train and test parts.\n",
"\n",
" Parameters:\n",
" - model (trl model): Model to be evaluated.\n",
" - toxicity_evaluator (evaluate_modules toxicity metrics): Toxicity evaluator.\n",
" - tokenizer (transformers tokenizer): Tokenizer to be used.\n",
" - dataset (dataset): Input dataset for the evaluation.\n",
" - num_samples (int): Maximum number of samples for the evaluation.\n",
"\n",
" Returns:\n",
" tuple: A tuple containing two numpy.float64 values:\n",
" - mean (numpy.float64): Mean of the samples toxicity.\n",
" - std (numpy.float64): Standard deviation of the samples toxicity.\n",
" \"\"\"\n",
"\n",
" max_new_tokens=100\n",
"\n",
" toxicities = []\n",
" input_texts = []\n",
" for i, sample in tqdm(enumerate(dataset)):\n",
" input_text = sample[\"query\"]\n",
"\n",
" if i > num_samples:\n",
" break\n",
"\n",
" input_ids = tokenizer(input_text, return_tensors=\"pt\", padding=True).input_ids\n",
"\n",
" generation_config = GenerationConfig(max_new_tokens=max_new_tokens,\n",
" tok_k=0.0,\n",
" top_p=1.0,\n",
" do_sample=True)\n",
"\n",
" response_token_ids = model.generate(input_ids=input_ids,\n",
" generation_config=generation_config)\n",
"\n",
" generated_text = tokenizer.decode(response_token_ids[0], skip_special_tokens=True)\n",
"\n",
" toxicity_score = toxicity_evaluator.compute(predictions=[(input_text + \" \" + generated_text)])\n",
"\n",
" toxicities.extend(toxicity_score[\"toxicity\"])\n",
"\n",
" # Compute mean & std using np.\n",
" mean = np.mean(toxicities)\n",
" std = np.std(toxicities)\n",
"\n",
" return mean, std"
]
},
{
"cell_type": "markdown",
"id": "aed269c3-dbd7-4d45-bc44-c6ab6d4ae141",
"metadata": {
"tags": [],
"id": "aed269c3-dbd7-4d45-bc44-c6ab6d4ae141"
},
"source": [
"And now perform the calculation of the model toxicity before fine-tuning/detoxification:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c11ede15-dc1a-4a7e-a60d-b9cadfc7d876",
"metadata": {
"tags": [],
"id": "c11ede15-dc1a-4a7e-a60d-b9cadfc7d876",
"outputId": "a765f490-7462-4a33-b25b-30bb1dc8177e"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"11it [00:21, 1.99s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"toxicity [mean, std] before detox: [0.0259042768844996, 0.029664746374071654]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(model_name, device_map=\"auto\")\n",
"\n",
"mean_before_detoxification, std_before_detoxification = evaluate_toxicity(model=ref_model,\n",
" toxicity_evaluator=toxicity_evaluator,\n",
" tokenizer=tokenizer,\n",
" dataset=dataset[\"test\"],\n",
" num_samples=10)\n",
"\n",
"print(f'toxicity [mean, std] before detox: [{mean_before_detoxification}, {std_before_detoxification}]')"
]
},
{
"cell_type": "markdown",
"id": "1ba81c90-1ac8-4403-ac1a-d4c75c6df4f0",
"metadata": {
"id": "1ba81c90-1ac8-4403-ac1a-d4c75c6df4f0"
},
"source": [
"<a name='3'></a>\n",
"## 3 - Perform Fine-Tuning to Detoxify the Summaries\n",
"Optimize a RL policy against the reward model using Proximal Policy Optimization (PPO)."
]
},
{
"cell_type": "markdown",
"id": "5516e318-8fce-4ca7-bf19-b7baf5255480",
"metadata": {
"id": "5516e318-8fce-4ca7-bf19-b7baf5255480"
},
"source": [
"<a name='3.1'></a>\n",
"### 3.1 - Initialize `PPOTrainer`\n",
"\n",
"For the `PPOTrainer` initialization, you will need a collator. Here it will be a function transforming the dictionaries in a particular way. You can define and test it:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b7be1c0-382a-4fe2-8174-470f3e333e84",
"metadata": {
"tags": [],
"id": "8b7be1c0-382a-4fe2-8174-470f3e333e84",
"outputId": "2460b9d4-e745-41f3-962e-6b37326b3f7c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collator input: [{'key1': 'value1', 'key2': 'value2', 'key3': 'value3'}]\n",
"Collator output: {'key1': ['value1'], 'key2': ['value2'], 'key3': ['value3']}\n"
]
}
],
"source": [
"def collator(data):\n",
" return dict((key, [d[key] for d in data]) for key in data[0])\n",
"\n",
"test_data = [{\"key1\": \"value1\", \"key2\": \"value2\", \"key3\": \"value3\"}]\n",
"print(f'Collator input: {test_data}')\n",
"print(f'Collator output: {collator(test_data)}')"
]
},
{
"cell_type": "markdown",
"id": "080c2e92-4988-4944-8353-0e1bb2048072",
"metadata": {
"id": "080c2e92-4988-4944-8353-0e1bb2048072"
},
"source": [
"Set up the configuration parameters. Load the `ppo_model` and the tokenizer. You will also load a frozen version of the model `ref_model`. The first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This works as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original LLM."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "494e09a1-9024-4f38-91eb-d73cdc3239e6",
"metadata": {
"tags": [],
"id": "494e09a1-9024-4f38-91eb-d73cdc3239e6"
},
"outputs": [],
"source": [
"learning_rate=1.41e-5\n",
"max_ppo_epochs=1\n",
"mini_batch_size=4\n",
"batch_size=16\n",
"\n",
"config = PPOConfig(\n",
" model_name=model_name,\n",
" learning_rate=learning_rate,\n",
" ppo_epochs=max_ppo_epochs,\n",
" mini_batch_size=mini_batch_size,\n",
" batch_size=batch_size\n",
")\n",
"\n",
"ppo_trainer = PPOTrainer(config=config,\n",
" model=ppo_model,\n",
" ref_model=ref_model,\n",
" tokenizer=tokenizer,\n",
" dataset=dataset[\"train\"],\n",
" data_collator=collator)"
]
},
{
"cell_type": "markdown",
"id": "7ad77d2c-3800-4e15-bb38-3851d94ad374",
"metadata": {
"id": "7ad77d2c-3800-4e15-bb38-3851d94ad374"
},
"source": [
"<a name='3.2'></a>\n",
"### 3.2 - Fine-Tune the Model"
]
},
{
"cell_type": "markdown",
"id": "0cac21fb-fea5-4e80-a741-87f35ae72c62",
"metadata": {
"id": "0cac21fb-fea5-4e80-a741-87f35ae72c62"
},
"source": [
"The fine-tuning loop consists of the following main steps:\n",
"1. Get the query responses from the policy LLM (PEFT model).\n",
"2. Get sentiments for query/responses from hate speech RoBERTa model.\n",
"3. Optimize policy with PPO using the (query, response, reward) triplet.\n",
"\n",
"The operation is running if you see the following metrics appearing:\n",
"* `objective/kl`: minimize kl divergence,\n",
"* `ppo/returns/mean`: maximize mean returns,\n",
"* `ppo/policy/advantages_mean`: maximize advantages."
]
},
{
"cell_type": "markdown",
"id": "01536b7e-2f0f-4986-a97c-6ecfabf518d4",
"metadata": {
"tags": [],
"id": "01536b7e-2f0f-4986-a97c-6ecfabf518d4"
},
"source": [
"<img src=\"data:image/svg+xml;base64,Cjxzdmcgd2lkdGg9IjgwMCIgaGVpZ2h0PSIxMjUiIHZpZXdCb3g9IjAgMCA4MDAgMTI1IiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogICAgPGRlZnM+CiAgICAgICAgPGxpbmVhckdyYWRpZW50IGlkPSJmYWRlR3JhZGllbnQiIHgxPSIwIiB4Mj0iMSI+CiAgICAgICAgICAgIDxzdG9wIG9mZnNldD0iMCUiIHN0b3AtY29sb3I9IiNGMEYwRjAiLz4KICAgICAgICAgICAgPHN0b3Agb2Zmc2V0PSIxMDAlIiBzdG9wLWNvbG9yPSIjRjBGMEYwIiBzdG9wLW9wYWNpdHk9IjAiLz4KICAgICAgICA8L2xpbmVhckdyYWRpZW50PgogICAgICAgIDxtYXNrIGlkPSJmYWRlTWFzayI+CiAgICAgICAgICAgIDxyZWN0IHg9IjAiIHk9IjAiIHdpZHRoPSI3NTAiIGhlaWdodD0iMTI1IiBmaWxsPSJ3aGl0ZSIvPgogICAgICAgICAgICA8cmVjdCB4PSI3NTAiIHk9IjAiIHdpZHRoPSI1MCIgaGVpZ2h0PSIxMjUiIGZpbGw9InVybCgjZmFkZUdyYWRpZW50KSIvPgogICAgICAgIDwvbWFzaz4KICAgIDwvZGVmcz4KICAgIDxwYXRoIGQ9Ik0zLDUwIEE1MCw1MCAwIDAgMSA1MywzIEw3OTcsMyBMNzk3LDk3IEw5Nyw5NyBMNTAsMTE1IEwzLDk3IFoiIGZpbGw9IiNGMEYwRjAiIHN0cm9rZT0iI0UwRTBFMCIgc3Ryb2tlLXdpZHRoPSIxIiBtYXNrPSJ1cmwoI2ZhZGVNYXNrKSIvPgogICAgPGNpcmNsZSBjeD0iNTAiIGN5PSI1MCIgcj0iMzAiIGZpbGw9IiM1N2M0ZjgiIHN0cm9rZT0iIzU3YzRmOCIgc3Ryb2tlLXdpZHRoPSIxIi8+CiAgICA8Y2lyY2xlIGN4PSI1MCIgY3k9IjUwIiByPSIyNSIgZmlsbD0iI0YwRjBGMCIvPgogICAgPGxpbmUgeDE9IjUwIiB5MT0iNTAiIHgyPSI1MCIgeTI9IjMwIiBzdHJva2U9IiM1N2M0ZjgiIHN0cm9rZS13aWR0aD0iMyIgc3Ryb2tlLWxpbmVjYXA9InJvdW5kIi8+CiAgICA8bGluZSB4MT0iNTAiIHkxPSI1MCIgeDI9IjY1IiB5Mj0iNTAiIHN0cm9rZT0iIzU3YzRmOCIgc3Ryb2tlLXdpZHRoPSIzIiBzdHJva2UtbGluZWNhcD0icm91bmQiLz4KICAgIDx0ZXh0IHg9IjEwMCIgeT0iMzQiIGZvbnQtZmFtaWx5PSJBcmlhbCwgc2Fucy1zZXJpZiIgZm9udC1zaXplPSIxNCIgZmlsbD0iIzMzMzMzMyI+VGhlIG5leHQgY2VsbCBtYXkgdGFrZSAyMC0zMCBtaW51dGVzIHRvIHJ1bi48L3RleHQ+Cjwvc3ZnPgo=\" alt=\"Time alert open medium\"/>"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4bc55397-92b8-4f61-9ec2-c8b39d5f8962",
"metadata": {
"tags": [],
"id": "4bc55397-92b8-4f61-9ec2-c8b39d5f8962",
"outputId": "c22717b2-a6d7-4c0f-896c-becdb94ebe6f"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"0it [00:00, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
"1it [01:42, 102.24s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"objective/kl: 29.314075469970703\n",
"ppo/returns/mean: -0.6373031735420227\n",
"ppo/policy/advantages_mean: -1.151681416899919e-08\n",
"---------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2it [03:21, 100.24s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"objective/kl: 31.827880859375\n",
"ppo/returns/mean: -0.7150875329971313\n",
"ppo/policy/advantages_mean: 1.0533127259293451e-08\n",
"---------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"3it [04:53, 96.45s/it] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"objective/kl: 30.49118995666504\n",
"ppo/returns/mean: -0.7426870465278625\n",
"ppo/policy/advantages_mean: 7.208275309977807e-09\n",
"---------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"4it [06:13, 90.08s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"objective/kl: 26.34555435180664\n",
"ppo/returns/mean: -0.4404515027999878\n",
"ppo/policy/advantages_mean: 2.934298137802216e-09\n",
"---------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"5it [07:40, 89.19s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"objective/kl: 25.24860954284668\n",
"ppo/returns/mean: -0.3359465003013611\n",
"ppo/policy/advantages_mean: -2.3621626876746404e-09\n",
"---------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"6it [09:25, 94.43s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"objective/kl: 29.67958641052246\n",
"ppo/returns/mean: -0.7676105499267578\n",
"ppo/policy/advantages_mean: 1.5880885939623113e-08\n",
"---------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"7it [10:53, 92.31s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"objective/kl: 28.498334884643555\n",
"ppo/returns/mean: -0.5901660919189453\n",
"ppo/policy/advantages_mean: -4.031784683888873e-09\n",
"---------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"8it [12:23, 91.52s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"objective/kl: 26.876846313476562\n",
"ppo/returns/mean: -0.5814594030380249\n",
"ppo/policy/advantages_mean: -7.104872246088689e-10\n",
"---------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"9it [13:59, 93.00s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"objective/kl: 30.190963745117188\n",
"ppo/returns/mean: -0.7083455920219421\n",
"ppo/policy/advantages_mean: 2.5395754832402417e-08\n",
"---------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"10it [15:29, 92.99s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"objective/kl: 26.26000213623047\n",
"ppo/returns/mean: -0.376063734292984\n",
"ppo/policy/advantages_mean: 3.0003806106293496e-09\n",
"---------------------------------------------------------------------------------------------------\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"output_min_length = 100\n",
"output_max_length = 400\n",
"output_length_sampler = LengthSampler(output_min_length, output_max_length)\n",
"\n",
"generation_kwargs = {\n",
" \"min_length\": 5,\n",
" \"top_k\": 0.0,\n",
" \"top_p\": 1.0,\n",
" \"do_sample\": True\n",
"}\n",
"\n",
"reward_kwargs = {\n",
" \"top_k\": None, # Return all scores.\n",
" \"function_to_apply\": \"none\", # You want the raw logits without softmax.\n",
" \"batch_size\": 16\n",
"}\n",
"\n",
"max_ppo_steps = 10\n",
"\n",
"for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):\n",
" # Break when you reach max_steps.\n",
" if step >= max_ppo_steps:\n",
" break\n",
"\n",
" prompt_tensors = batch[\"input_ids\"]\n",
"\n",
" # Get response from FLAN-T5/PEFT LLM.\n",
" summary_tensors = []\n",
"\n",
" for prompt_tensor in prompt_tensors:\n",
" max_new_tokens = output_length_sampler()\n",
"\n",
" generation_kwargs[\"max_new_tokens\"] = max_new_tokens\n",
" summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)\n",
"\n",
" summary_tensors.append(summary.squeeze()[-max_new_tokens:])\n",
"\n",
" # This needs to be called \"response\".\n",
" batch[\"response\"] = [tokenizer.decode(r.squeeze()) for r in summary_tensors]\n",
"\n",
" # Compute reward outputs.\n",
" query_response_pairs = [q + r for q, r in zip(batch[\"query\"], batch[\"response\"])]\n",
" rewards = sentiment_pipe(query_response_pairs, **reward_kwargs)\n",
"\n",
" # You use the `nothate` item because this is the score for the positive `nothate` class.\n",
" reward_tensors = [torch.tensor(reward[not_hate_index][\"score\"]) for reward in rewards]\n",
"\n",
" # Run PPO step.\n",
" stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)\n",
" ppo_trainer.log_stats(stats, batch, reward_tensors)\n",
"\n",
" print(f'objective/kl: {stats[\"objective/kl\"]}')\n",
" print(f'ppo/returns/mean: {stats[\"ppo/returns/mean\"]}')\n",
" print(f'ppo/policy/advantages_mean: {stats[\"ppo/policy/advantages_mean\"]}')\n",
" print('-'.join('' for x in range(100)))"
]
},
{
"cell_type": "markdown",
"id": "5b648cb7-89e2-40b8-9507-9c07bdfd9ebf",
"metadata": {
"id": "5b648cb7-89e2-40b8-9507-9c07bdfd9ebf"
},
"source": [
"<img src=\"data:image/svg+xml;base64,Cjxzdmcgd2lkdGg9IjgwMCIgaGVpZ2h0PSI1MCIgdmlld0JveD0iMCAwIDgwMCA1MCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KICAgIDxkZWZzPgogICAgICAgIDxsaW5lYXJHcmFkaWVudCBpZD0iZmFkZUdyYWRpZW50IiB4MT0iMCIgeDI9IjEiPgogICAgICAgICAgICA8c3RvcCBvZmZzZXQ9IjAlIiBzdG9wLWNvbG9yPSIjRjBGMEYwIi8+CiAgICAgICAgICAgIDxzdG9wIG9mZnNldD0iMTAwJSIgc3RvcC1jb2xvcj0iI0YwRjBGMCIgc3RvcC1vcGFjaXR5PSIwIi8+CiAgICAgICAgPC9saW5lYXJHcmFkaWVudD4KICAgICAgICA8bWFzayBpZD0iZmFkZU1hc2siPgogICAgICAgICAgICA8cmVjdCB4PSIwIiB5PSIwIiB3aWR0aD0iNzUwIiBoZWlnaHQ9IjUwIiBmaWxsPSJ3aGl0ZSIvPgogICAgICAgICAgICA8cmVjdCB4PSI3NTAiIHk9IjAiIHdpZHRoPSI1MCIgaGVpZ2h0PSI1MCIgZmlsbD0idXJsKCNmYWRlR3JhZGllbnQpIi8+CiAgICAgICAgPC9tYXNrPgogICAgPC9kZWZzPgogICAgPHBhdGggZD0iTTI1LDUwIFEwLDUwIDAsMjUgTDUwLDMgTDk3LDI1IEw3OTcsMjUgTDc5Nyw1MCBMMjUsNTAgWiIgZmlsbD0iI0YwRjBGMCIgc3Ryb2tlPSIjRTBFMEUwIiBzdHJva2Utd2lkdGg9IjEiIG1hc2s9InVybCgjZmFkZU1hc2spIi8+Cjwvc3ZnPgo=\" alt=\"Time alert close\"/>"
]
},
{
"cell_type": "markdown",
"id": "7903f5df-a9de-41eb-b239-38bc367b5654",
"metadata": {
"id": "7903f5df-a9de-41eb-b239-38bc367b5654"
},
"source": [
"<a name='3.3'></a>\n",
"### 3.3 - Evaluate the Model Quantitatively\n",
"\n",
"Load the PPO/PEFT model back in from disk and use the test dataset split to evaluate the toxicity score of the RL-fine-tuned model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3b093d43-6197-4cc0-b933-29030479a7d0",
"metadata": {
"tags": [],
"id": "3b093d43-6197-4cc0-b933-29030479a7d0",
"outputId": "32b2c550-15c0-4c2d-b05a-3080f84a043c"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"11it [00:22, 2.04s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"toxicity [mean, std] after detox: [0.025521794767965646, 0.041763468012105946]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"mean_after_detoxification, std_after_detoxification = evaluate_toxicity(model=ppo_model,\n",
" toxicity_evaluator=toxicity_evaluator,\n",
" tokenizer=tokenizer,\n",
" dataset=dataset[\"test\"],\n",
" num_samples=10)\n",
"print(f'toxicity [mean, std] after detox: [{mean_after_detoxification}, {std_after_detoxification}]')"
]
},
{
"cell_type": "markdown",
"id": "f42895cc-7bbf-45e1-a7c7-78ee29cd8009",
"metadata": {
"tags": [],
"id": "f42895cc-7bbf-45e1-a7c7-78ee29cd8009"
},
"source": [
"And compare the toxicity scores of the reference model (before detoxification) and fine-tuned model (after detoxification)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "77cc3af2-6600-4673-874b-917c05247ae3",
"metadata": {
"tags": [],
"id": "77cc3af2-6600-4673-874b-917c05247ae3",
"outputId": "438be964-0d04-4e2e-de03-4d5ce92fe445"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Percentage improvement of toxicity score after detoxification:\n",
"mean: 1.48%\n",
"std: -40.78%\n"
]
}
],
"source": [
"mean_improvement = (mean_before_detoxification - mean_after_detoxification) / mean_before_detoxification\n",
"std_improvement = (std_before_detoxification - std_after_detoxification) / std_before_detoxification\n",
"\n",
"print(f'Percentage improvement of toxicity score after detoxification:')\n",
"print(f'mean: {mean_improvement*100:.2f}%')\n",
"print(f'std: {std_improvement*100:.2f}%')"
]
},
{
"cell_type": "markdown",
"id": "66030581-b6f7-41d7-a7e6-2466226833be",
"metadata": {
"id": "66030581-b6f7-41d7-a7e6-2466226833be"
},
"source": [
"<a name='3.4'></a>\n",
"### 3.4 - Evaluate the Model Qualitatively\n",
"\n",
"Let's inspect some examples from the test dataset. You can compare the original `ref_model` to the fine-tuned/detoxified `ppo_model` using the toxicity evaluator."
]
},
{
"cell_type": "markdown",
"id": "12fdc491-5437-41dd-980b-0d04304292dd",
"metadata": {
"id": "12fdc491-5437-41dd-980b-0d04304292dd"
},
"source": [
"<img src=\"data:image/svg+xml;base64,Cjxzdmcgd2lkdGg9IjgwMCIgaGVpZ2h0PSIxMjUiIHZpZXdCb3g9IjAgMCA4MDAgMTI1IiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogICAgPGRlZnM+CiAgICAgICAgPGxpbmVhckdyYWRpZW50IGlkPSJmYWRlR3JhZGllbnQiIHgxPSIwIiB4Mj0iMSI+CiAgICAgICAgICAgIDxzdG9wIG9mZnNldD0iMCUiIHN0b3AtY29sb3I9IiNGMEYwRjAiLz4KICAgICAgICAgICAgPHN0b3Agb2Zmc2V0PSIxMDAlIiBzdG9wLWNvbG9yPSIjRjBGMEYwIiBzdG9wLW9wYWNpdHk9IjAiLz4KICAgICAgICA8L2xpbmVhckdyYWRpZW50PgogICAgICAgIDxtYXNrIGlkPSJmYWRlTWFzayI+CiAgICAgICAgICAgIDxyZWN0IHg9IjAiIHk9IjAiIHdpZHRoPSI3NTAiIGhlaWdodD0iMTI1IiBmaWxsPSJ3aGl0ZSIvPgogICAgICAgICAgICA8cmVjdCB4PSI3NTAiIHk9IjAiIHdpZHRoPSI1MCIgaGVpZ2h0PSIxMjUiIGZpbGw9InVybCgjZmFkZUdyYWRpZW50KSIvPgogICAgICAgIDwvbWFzaz4KICAgIDwvZGVmcz4KICAgIDxwYXRoIGQ9Ik0zLDUwIEE1MCw1MCAwIDAgMSA1MywzIEw3OTcsMyBMNzk3LDk3IEw5Nyw5NyBMNTAsMTE1IEwzLDk3IFoiIGZpbGw9IiNGMEYwRjAiIHN0cm9rZT0iI0UwRTBFMCIgc3Ryb2tlLXdpZHRoPSIxIiBtYXNrPSJ1cmwoI2ZhZGVNYXNrKSIvPgogICAgPGNpcmNsZSBjeD0iNTAiIGN5PSI1MCIgcj0iMzAiIGZpbGw9IiM1N2M0ZjgiIHN0cm9rZT0iIzU3YzRmOCIgc3Ryb2tlLXdpZHRoPSIxIi8+CiAgICA8Y2lyY2xlIGN4PSI1MCIgY3k9IjUwIiByPSIyNSIgZmlsbD0iI0YwRjBGMCIvPgogICAgPGxpbmUgeDE9IjUwIiB5MT0iNTAiIHgyPSI1MCIgeTI9IjMwIiBzdHJva2U9IiM1N2M0ZjgiIHN0cm9rZS13aWR0aD0iMyIgc3Ryb2tlLWxpbmVjYXA9InJvdW5kIi8+CiAgICA8bGluZSB4MT0iNTAiIHkxPSI1MCIgeDI9IjY1IiB5Mj0iNTAiIHN0cm9rZT0iIzU3YzRmOCIgc3Ryb2tlLXdpZHRoPSIzIiBzdHJva2UtbGluZWNhcD0icm91bmQiLz4KICAgIDx0ZXh0IHg9IjEwMCIgeT0iMzQiIGZvbnQtZmFtaWx5PSJBcmlhbCwgc2Fucy1zZXJpZiIgZm9udC1zaXplPSIxNCIgZmlsbD0iIzMzMzMzMyI+VGhlIG5leHQgY2VsbCBtYXkgdGFrZSAyLTMgbWludXRlcyB0byBydW4uPC90ZXh0Pgo8L3N2Zz4K\" alt=\"Time alert open medium\"/>\n",
"​"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "22cc8313-20ae-4d32-855e-9b2866fa3085",
"metadata": {
"tags": [],
"id": "22cc8313-20ae-4d32-855e-9b2866fa3085",
"outputId": "b6bd368a-ba91-4393-8d7f-a472c0a7d247"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 20/20 [01:31<00:00, 4.57s/it]\n"
]
}
],
"source": [
"batch_size = 20\n",
"compare_results = {}\n",
"\n",
"df_batch = dataset[\"test\"][0:batch_size]\n",
"\n",
"compare_results[\"query\"] = df_batch[\"query\"]\n",
"prompt_tensors = df_batch[\"input_ids\"]\n",
"\n",
"summary_tensors_ref = []\n",
"summary_tensors = []\n",
"\n",
"# Get response from ppo and base model.\n",
"for i in tqdm(range(batch_size)):\n",
" gen_len = output_length_sampler()\n",
" generation_kwargs[\"max_new_tokens\"] = gen_len\n",
"\n",
" summary = ref_model.generate(\n",
" input_ids=torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0).to(device),\n",
" **generation_kwargs\n",
" ).squeeze()[-gen_len:]\n",
" summary_tensors_ref.append(summary)\n",
"\n",
" summary = ppo_model.generate(\n",
" input_ids=torch.as_tensor(prompt_tensors[i]).unsqueeze(dim=0).to(device),\n",
" **generation_kwargs\n",
" ).squeeze()[-gen_len:]\n",
" summary_tensors.append(summary)\n",
"\n",
"# Decode responses.\n",
"compare_results[\"response_before\"] = [tokenizer.decode(summary_tensors_ref[i]) for i in range(batch_size)]\n",
"compare_results[\"response_after\"] = [tokenizer.decode(summary_tensors[i]) for i in range(batch_size)]\n",
"\n",
"# Sentiment analysis of query/response pairs before/after.\n",
"texts_before = [d + s for d, s in zip(compare_results[\"query\"], compare_results[\"response_before\"])]\n",
"rewards_before = sentiment_pipe(texts_before, **reward_kwargs)\n",
"compare_results[\"reward_before\"] = [reward[not_hate_index][\"score\"] for reward in rewards_before]\n",
"\n",
"texts_after = [d + s for d, s in zip(compare_results[\"query\"], compare_results[\"response_after\"])]\n",
"rewards_after = sentiment_pipe(texts_after, **reward_kwargs)\n",
"compare_results[\"reward_after\"] = [reward[not_hate_index][\"score\"] for reward in rewards_after]"
]
},
{
"cell_type": "markdown",
"id": "13a3853f-be22-4a95-95d8-a4b61eb2468f",
"metadata": {
"tags": [],
"id": "13a3853f-be22-4a95-95d8-a4b61eb2468f"
},
"source": [
"<img src=\"data:image/svg+xml;base64,Cjxzdmcgd2lkdGg9IjgwMCIgaGVpZ2h0PSI1MCIgdmlld0JveD0iMCAwIDgwMCA1MCIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KICAgIDxkZWZzPgogICAgICAgIDxsaW5lYXJHcmFkaWVudCBpZD0iZmFkZUdyYWRpZW50IiB4MT0iMCIgeDI9IjEiPgogICAgICAgICAgICA8c3RvcCBvZmZzZXQ9IjAlIiBzdG9wLWNvbG9yPSIjRjBGMEYwIi8+CiAgICAgICAgICAgIDxzdG9wIG9mZnNldD0iMTAwJSIgc3RvcC1jb2xvcj0iI0YwRjBGMCIgc3RvcC1vcGFjaXR5PSIwIi8+CiAgICAgICAgPC9saW5lYXJHcmFkaWVudD4KICAgICAgICA8bWFzayBpZD0iZmFkZU1hc2siPgogICAgICAgICAgICA8cmVjdCB4PSIwIiB5PSIwIiB3aWR0aD0iNzUwIiBoZWlnaHQ9IjUwIiBmaWxsPSJ3aGl0ZSIvPgogICAgICAgICAgICA8cmVjdCB4PSI3NTAiIHk9IjAiIHdpZHRoPSI1MCIgaGVpZ2h0PSI1MCIgZmlsbD0idXJsKCNmYWRlR3JhZGllbnQpIi8+CiAgICAgICAgPC9tYXNrPgogICAgPC9kZWZzPgogICAgPHBhdGggZD0iTTI1LDUwIFEwLDUwIDAsMjUgTDUwLDMgTDk3LDI1IEw3OTcsMjUgTDc5Nyw1MCBMMjUsNTAgWiIgZmlsbD0iI0YwRjBGMCIgc3Ryb2tlPSIjRTBFMEUwIiBzdHJva2Utd2lkdGg9IjEiIG1hc2s9InVybCgjZmFkZU1hc2spIi8+Cjwvc3ZnPgo=\" alt=\"Time alert close\"/>"
]
},
{
"cell_type": "markdown",
"id": "65b70892-4c22-4bed-9d1e-9da3f4f0c97f",
"metadata": {
"tags": [],
"id": "65b70892-4c22-4bed-9d1e-9da3f4f0c97f"
},
"source": [
"Store and review the results in a DataFrame"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e06fff0f-9dba-4517-9424-a5ebd81e8f49",
"metadata": {
"tags": [],
"id": "e06fff0f-9dba-4517-9424-a5ebd81e8f49",
"outputId": "abac12bd-4e55-48a8-9997-19698987b55f"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>query</th>\n",
" <th>response_before</th>\n",
" <th>response_after</th>\n",
" <th>reward_before</th>\n",
" <th>reward_after</th>\n",
" <th>reward_diff</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Summarize the following conversation. #Person1#: Judy, what is everybody talking about? #Person2#: Haven't you heard? Richard was fired by our manager. #Person1#: You're kidding. It can't be true. #Person2#: Believe it or not. Everybody is talking about it in the company. #Person1#: Really? I'm surprised. #Person2#: Me too. Summary: &lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; Judy and #Person2# realize that Richard has been fired by her manager. Judy is surprised.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; Judy and Judy are surprised to hear Richard's failure.&lt;/s&gt;</td>\n",
" <td>1.251855</td>\n",
" <td>2.196642</td>\n",
" <td>0.944787</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Summarize the following conversation. #Person1#: Let's take a coffee break, shall we? #Person2#: I wish I could, but I can't. #Person1#: What keeps you so busy? You've been sitting there for hours. You've got to walk around. You just can't stay on the computer forever. #Person2#: Well, I am up to my neck in work. I've got to finish this report. Sarah needs it by noon. I don't want to be scolded if I can't finish my work by the deadline. #Person1#: I understand that, but you'd feel better if ...</td>\n",
" <td>&lt;pad&gt; @mmb cannot take a break because she can't sit when she's sitting there. #Person1# wonders she could take a break but she doesn't want to.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person2# is up to #Person2#'s neck in work, but #Person1# allows him a coffee break. But he needs to finish his report by the deadline and gives her a lot of reason for bad behavior.&lt;/s&gt;</td>\n",
" <td>0.516179</td>\n",
" <td>1.374489</td>\n",
" <td>0.858310</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Summarize the following conversation. #Person1#: How much are you asking for this? #Person2#: I'm offering them to you at 150 yuan a piece. Is that all right? #Person1#: Is tax already included in their price? #Person2#: Yes. Our price can't be matched. #Person1#: Would you consider a volume discount? #Person2#: If you buy 1, 000 or more, you'll get a 10 % discount. #Person1#: I'll accept your offer. Summary: &lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person2# has offered them to #Person1# at 150 yuan and they cannot be matched. They offer a volume discount, however, based on the advertising aren't matched. #Person1# accepts.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person1# arrives at a giveaway table and phone the store to buy colorful books made by #Person2# for 150 yuan. The payment is only for the small books.&lt;/s&gt;</td>\n",
" <td>2.462125</td>\n",
" <td>2.818621</td>\n",
" <td>0.356496</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Summarize the following conversation. #Person1#: Oh, my God! What's this? #Person2#: What? #Person1#: Look! This window is open. #Person2#: Did you open it before we left? #Person1#: Are you kidding? It's winter. Why would I open it? #Person2#: I don't know. Wait. Is this yours? #Person1#: No! Oh, my God! Someone has broken into the house. #Person2#: It looks that way. That's probably why the door wasn't locked when we came in. #Person1#: I locked it when I left though. #Person2#: Yes, but t...</td>\n",
" <td>&lt;pad&gt; Allen tells #Person2# there is a robber breaking into the house because he broke in. His dog didn't fall in the door and he broke into it. Allen says the robber is locked by the window and he left through the door. #Person2# does not think there will be anyone.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; Allen opens the door and leaves the two women alone. Allen tells #Person1# that the robber broke in and left through the door spouts the thief's valuables. They still have to go upstairs since the door isn't locked.&lt;/s&gt;</td>\n",
" <td>1.619587</td>\n",
" <td>1.787470</td>\n",
" <td>0.167882</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Summarize the following conversation. #Person1#: Where shall I register, please? #Person2#: Here. Do you have a registration card? #Person1#: Yes. Here you are. #Person2#: Please register your information here and pay for it. And I'll make a medical record for you. #Person1#: OK. How much do I need to pay for the registration? #Person2#: Please pay ten yuan for the registration. #Person1#: Here is my money. #Person2#: This is your registration card. Please don't lose it and bring it whenever...</td>\n",
" <td>&lt;pad&gt; #Person1# is asked by #Person2# to register #Person1#'s information exchange with #Person2# in court. #Person2# happily agrees to make the medical record and then tells #Person1# how to go to the clinic.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person1# will register for the medical records to get to the counseling room. #Person2# gives the same information as #Person1# to #Person1# and helps #Person1# to walk to the pharmacy and get to the consulting room.&lt;/s&gt;</td>\n",
" <td>1.497437</td>\n",
" <td>1.640947</td>\n",
" <td>0.143510</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>Summarize the following conversation. #Person1#: I would like to order some internet today. #Person2#: What kind would you like? #Person1#: What kind of internet is there? #Person2#: You can get DEL or dial-up. #Person1#: Which of those two is best? #Person2#: I would recommend DEL. #Person1#: So that one better? #Person2#: It's better because it doesn't tie up the phone. #Person1#: What do you mean by that? #Person2#: DEL isn't connected through your phone line, but dial-up is. #Person1#: S...</td>\n",
" <td>&lt;pad&gt; #Person1# wants to order an internet today. #Person2# recommends dial-up because it won't tie up the phone. However, #Person1# doesn't use the phone if #Person1# is on the internet.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person1# wants to order some DEL or dial-up internet. #Person2# recommends DEL because the DEL doesn't tie up the phone. #Person1# can do both.&lt;/s&gt;</td>\n",
" <td>2.290805</td>\n",
" <td>2.405071</td>\n",
" <td>0.114266</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>Summarize the following conversation. #Person1#: I'd like to have this cashed, please. #Person2#: Please put you name and address here. May I see your passport? #Person1#: Yes. #Person2#: How would you like it? #Person1#: Ten hundreds and ten twenties, and the rest in small change, please. #Person2#: OK. Here you are. Summary: &lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person1# wants to cash a bank account at a bank with $10K and $50000 in small change. #Person1# pays for that in cash with #Person1#'s name and address.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person1# wants to have this cashed 10 hundreds and 10 twenties. #Person1# can't be seen. #Person2# arrives.&lt;/s&gt;</td>\n",
" <td>1.690671</td>\n",
" <td>1.762732</td>\n",
" <td>0.072061</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>Summarize the following conversation. #Person1#: What can I do for you, madam? #Person2#: I'd like to buy a toy car for my son. #Person1#: How about this one? #Person2#: It looks nice. How much is it? #Person1#: They're three hundred dollars. #Person2#: Oh, I'm afraid it's too expensive. Can you show me something cheaper? #Person1#: OK, This one is one hundred and twenty. It's the cheapest here. #Person2#: OK, I'll take it. Here's the money. #Person1#: Thank you very much. Summary: &lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person1# gives #Person2# a car by buying one hundred and twenty for #Person2#'s son. #Person2# agrees and offers the money.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person2# wants a toy car but #Person1# suggests a $1000 toy car. #Person2# takes the one at a cheap price.&lt;/s&gt;</td>\n",
" <td>1.315235</td>\n",
" <td>1.380598</td>\n",
" <td>0.065364</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>Summarize the following conversation. #Person1#: It smells like an ashtray in here! #Person2#: Hi honey! What's wrong? Why do you have that look on your face? #Person1#: What's wrong? I thought we agreed that you were gonna quit smoking. #Person2#: No! I said I was going to cut down which is very different. You can't just expect me to go cold turkey overnight! #Person1#: Look, there are other ways to quit. You can try the nicotine patch, or nicotine chewing gum. We spend a fortune on cigaret...</td>\n",
" <td>&lt;pad&gt; #Person2# cannot stop smoking because #Person1# tells her to try other ways and quit smoking. #Person1# feels her identity belongs to #Person2# and wants a divorce but #Person2# still feels the urge to smoke.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person1# is curious about the ashtray in the room. They agree that they should continue to smoke cigarettes and try different ways to quit. #Person1# wants a divorce.&lt;/s&gt;</td>\n",
" <td>1.282832</td>\n",
" <td>1.320850</td>\n",
" <td>0.038018</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>Summarize the following conversation. #Person1#: Hello. I want to reconfirm our flight to London. #Person2#: Yes, sir. Did you call the airline? #Person1#: Yes, I did. But I couldn't communicate with them in English. They speak only Spanish. So I need your help. #Person2#: Certainly, sir. What is the flight number and when are you leaving? #Person1#: We are taking IB 385 to London tomorrow at 1 p. m. #Person2#: Oh, I see, sir. We have the airline office inside the hotel. They have an English...</td>\n",
" <td>&lt;pad&gt; #Person1# is checking for a flight to London. #Person2# tells #Person1# the flight number and the time and tells #Person1# to dial 35.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person1# asks #Person2# to reconfirm their flight and address to London on IB 385. #Person2# provides his date, flight number and the_flight number.&lt;/s&gt;</td>\n",
" <td>1.791396</td>\n",
" <td>1.806620</td>\n",
" <td>0.015224</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>Summarize the following conversation. #Person1#: Mom, I just finished my paper. Can you proofread it before I hand it in? #Person2#: Sure, let's take a look. Sweetie, this is terrific. Your ideas are so original. #Person1#: Thanks. #Person2#: I can tell you worked hard on it. #Person1#: I really did! I started thinking about what I wanted to say three weeks ago. #Person2#: Well, it was definitely worth all the time. #Person1#: Let's just hope my teacher agrees. Summary: &lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person1# is finishing #Person1#'s paper and said it was a great idea. #Person2# tells her the original ideas and highlights the work she did on the paper.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person1#'s mom realizes she's finished the paper on her mom's recommendation and is proud of her work.&lt;/s&gt;</td>\n",
" <td>2.689511</td>\n",
" <td>2.694573</td>\n",
" <td>0.005062</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>Summarize the following conversation. #Person1#: Could you help me, Sir? My flight got in 15 minutes ago. Everyone else has picked up the luggage but mine hasn't come through. #Person2#: I'm sorry, Madam, I'll go and find out if there is any more to come. Summary: &lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; All others have picked up #Person1#'s luggage but #Person1#'s flight did not arrive. #Person1# and #Person2# will ask for help. See something else in the meantime.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person1#'s flight got in 15 minutes but #Person1#'s flight hasn't been ready. #Person2# will go to find out when it arrives.&lt;/s&gt;</td>\n",
" <td>2.655438</td>\n",
" <td>2.632524</td>\n",
" <td>-0.022914</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>Summarize the following conversation. #Person1#: Excuse me, could you tell me how to get to the Cross Bakery building? #Person2#: The Cross Bakery building? Oh sure. You're actually walking in the opposite direction. #Person1#: Oh, you're kidding! I thought I was heading east. #Person2#: No, east is the other direction. To get to the Bakery, you need to turn around and go three blocks to Broadway. When you get to the intersection of Broadway and Elm, you hang a left. Go straight down that st...</td>\n",
" <td>&lt;pad&gt; #Person1# requests #Person2# to show #Person1# the way to the Cross Bakery building and express all of her confidence in the directions.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person1# is on a new way to cross the Cross Bakery building exiting Broadway, which requires stopping at the intersection of Broadway and Elm. #Person2# shows #Person1# the way, and she agrees.&lt;/s&gt;</td>\n",
" <td>2.493639</td>\n",
" <td>2.415073</td>\n",
" <td>-0.078566</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>Summarize the following conversation. #Person1#: So how did you like the restaurant? #Person2#: Actually, it could have been better. #Person1#: What didn't you like about it? #Person2#: It is a new restaurant. I don't think they have their act together yet. #Person1#: What did you think about the food? #Person2#: I felt that the food was pretty mediocre. #Person1#: The service wasn't that great, either. #Person2#: I agree. The service was not good. #Person1#: Do you think that you want to tr...</td>\n",
" <td>&lt;pad&gt; If #Person2# likes the restaurant, #Person1# hasn't tried it. #Person2# is disappointed. #Person2# thinks the staff is not good. #Person2# is not going to try this restaurant again.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person2# asked for while #Person1# is not sure what constructive will of this restaurant, and #Person1# can tell #Person2# that it's a new restaurant. They both feel the service was not good and they've had enough of the restaurant.&lt;/s&gt;</td>\n",
" <td>2.119996</td>\n",
" <td>1.956108</td>\n",
" <td>-0.163888</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>Summarize the following conversation. #Person1#: Hello? #Person2#: Hello? #Person1#: Can I speak to Li Hong, please? #Person2#: Speaking. #Person1#: Hi, Li Hong. This is Alice. #Person2#: Hi, Alice. How are you? #Person1#: Not bad. Li Hong, I am sorry that I can't go to see Mrs. Brown with you tomorrow morning. My mother is ill. I must take care of her. #Person2#: I'm sorry to hear that. You'd better stay at home. After all, we can visit Mrs. Brown later #Person1#: OK. Bye - bye. #Person2#: ...</td>\n",
" <td>&lt;pad&gt; Alice can't come to see Mrs. Brown tomorrow because her mother is ill. Li Hong tells her to stay at home, and she can visit Mrs. Brown later.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; Li Hong apologises for Alice whose mother is ill. Alice agrees to stay at home if she does not have a suitable date with Li Hong tomorrow morning.&lt;/s&gt;</td>\n",
" <td>1.979845</td>\n",
" <td>1.728329</td>\n",
" <td>-0.251516</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>Summarize the following conversation. #Person1#: Here is the final draft of our contract. I'm glad that we have reached an agreement on almost every term in our trade. #Person2#: Yes, it seems to me we have come quite a long way. However, let me take a close look at the final draft. #Person1#: Do you have some points to bring up? #Person2#: Well, everything we've discussed seems to be here. #Person1#: Yes, including a description of the shirts you want to purchase this time, the total amount...</td>\n",
" <td>&lt;pad&gt; #Person2# understands the final drafts of the contract. #Person1# wants to sign the contract now. #Person2# requests #Person1#'s help with verification checks on details of the contract. #Person2# specifies #Person1# as the standard for others.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person1#'s staff made a final draft of their contract. #Person2# doesn't want to sign because the final draft doesn't give results. #Person1# gives #Person2# another job to check to make sure the average is done.&lt;/s&gt;</td>\n",
" <td>3.136004</td>\n",
" <td>2.764242</td>\n",
" <td>-0.371763</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>Summarize the following conversation. #Person1#: Amanda, how do you like this peaked cap? #Person2#: Didn't you say you want to buy a top hat? #Person1#: But I think this one fits me Well. Why don't you try on the sombrero in black? #Person2#: I don't like caps at all. Summary: &lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; Amanda wears a peaked cap she likes and drinks a sombrero in black. Amanda wants to buy a top hat but Amanda likes caps.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; Amanda likes the peaked cap, but she doesn't like caps at all.&lt;/s&gt;</td>\n",
" <td>1.310969</td>\n",
" <td>0.881978</td>\n",
" <td>-0.428990</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>Summarize the following conversation. #Person1#: I'm forming a music band. #Person2#: Do you already know how to play an instrument? #Person1#: Uh... Yeah! I'Ve told you a thousand times that I'm learning to play the drums. Now that I know how to play well, I would like to form a rock band. #Person2#: Aside from yourself, who are the other members of the band? #Person1#: We have a guy who plays guitar, and another who plays bass. Although we still haven't found anyone to be our singer. You t...</td>\n",
" <td>&lt;pad&gt; #Person1# wants to form a rock band in a concert because #Person1# wants to play a rock instrument and shows some musical talent to #Person2#. Takes #Person1# to audition here. Plus, #Person2# doesn't have space for the amplifiers, microphones, or even the drums, so #Person1# has to practice a little.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person1# is forming a music band, helped by the musician, said she's a singer and invites #Person1# to audition This weekend.&lt;/s&gt;</td>\n",
" <td>3.065439</td>\n",
" <td>2.505563</td>\n",
" <td>-0.559876</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>Summarize the following conversation. #Person1#: Could you help me figure out how to look for a job? #Person2#: We have lots of options, what type of job do you need? #Person1#: I want to work in an office. #Person2#: Do you want to work part-time or full-time? #Person1#: I want to work full-time. #Person2#: We have binders with local job listings or you can make use of the computers. OK? #Person1#: I am confused a bit but I am sure that I can figure it out. #Person2#: If you make an appoint...</td>\n",
" <td>&lt;pad&gt; #Person2# helps #Person1# find a job at the job center. #Person1# wants to work full-time and #Person1# thinks it's convenient to go. They discuss the kind of job and offer advice as well. They convince #Person1# to visit a job center.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person2# shows #Person1# how to look for a job and gives her instructions to open a job center. #Person1# is confused and asks to make an appointment with a job counselor but #Person1# falls in love.&lt;/s&gt;</td>\n",
" <td>2.529984</td>\n",
" <td>1.955189</td>\n",
" <td>-0.574795</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>Summarize the following conversation. #Person1#: Today more and more families have personal computers. People have wider range of choice to communicate with the outside world. #Person2#: Right. With the establishment of Internet and a lot of web companies, people are getting more and more dependent on the web. #Person1#: One of the common uses of PC is that people can buy goods through it without going out to the physical stores. #Person2#: Can you tell me how it is done? #Person1#: If a cus...</td>\n",
" <td>&lt;pad&gt; #Person1# and #Person2# discuss the common uses of personal computers. They have other uses too such as easy to use, and common use of PC is to buy goods online through it. The ship's almost completely free and the delivery to the home is perfect and free of charge.&lt;/s&gt;</td>\n",
" <td>&lt;pad&gt; #Person1# tells #Person2# that people will be more able to communicate with the outside world through PC, even though computers are becoming more common. It was explained how customers can place orders online to have products show on the colour screen.&lt;/s&gt;</td>\n",
" <td>3.720674</td>\n",
" <td>2.764764</td>\n",
" <td>-0.955910</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" query \\\n",
"0 Summarize the following conversation. #Person1#: Judy, what is everybody talking about? #Person2#: Haven't you heard? Richard was fired by our manager. #Person1#: You're kidding. It can't be true. #Person2#: Believe it or not. Everybody is talking about it in the company. #Person1#: Really? I'm surprised. #Person2#: Me too. Summary: </s> \n",
"1 Summarize the following conversation. #Person1#: Let's take a coffee break, shall we? #Person2#: I wish I could, but I can't. #Person1#: What keeps you so busy? You've been sitting there for hours. You've got to walk around. You just can't stay on the computer forever. #Person2#: Well, I am up to my neck in work. I've got to finish this report. Sarah needs it by noon. I don't want to be scolded if I can't finish my work by the deadline. #Person1#: I understand that, but you'd feel better if ... \n",
"2 Summarize the following conversation. #Person1#: How much are you asking for this? #Person2#: I'm offering them to you at 150 yuan a piece. Is that all right? #Person1#: Is tax already included in their price? #Person2#: Yes. Our price can't be matched. #Person1#: Would you consider a volume discount? #Person2#: If you buy 1, 000 or more, you'll get a 10 % discount. #Person1#: I'll accept your offer. Summary: </s> \n",
"3 Summarize the following conversation. #Person1#: Oh, my God! What's this? #Person2#: What? #Person1#: Look! This window is open. #Person2#: Did you open it before we left? #Person1#: Are you kidding? It's winter. Why would I open it? #Person2#: I don't know. Wait. Is this yours? #Person1#: No! Oh, my God! Someone has broken into the house. #Person2#: It looks that way. That's probably why the door wasn't locked when we came in. #Person1#: I locked it when I left though. #Person2#: Yes, but t... \n",
"4 Summarize the following conversation. #Person1#: Where shall I register, please? #Person2#: Here. Do you have a registration card? #Person1#: Yes. Here you are. #Person2#: Please register your information here and pay for it. And I'll make a medical record for you. #Person1#: OK. How much do I need to pay for the registration? #Person2#: Please pay ten yuan for the registration. #Person1#: Here is my money. #Person2#: This is your registration card. Please don't lose it and bring it whenever... \n",
"5 Summarize the following conversation. #Person1#: I would like to order some internet today. #Person2#: What kind would you like? #Person1#: What kind of internet is there? #Person2#: You can get DEL or dial-up. #Person1#: Which of those two is best? #Person2#: I would recommend DEL. #Person1#: So that one better? #Person2#: It's better because it doesn't tie up the phone. #Person1#: What do you mean by that? #Person2#: DEL isn't connected through your phone line, but dial-up is. #Person1#: S... \n",
"6 Summarize the following conversation. #Person1#: I'd like to have this cashed, please. #Person2#: Please put you name and address here. May I see your passport? #Person1#: Yes. #Person2#: How would you like it? #Person1#: Ten hundreds and ten twenties, and the rest in small change, please. #Person2#: OK. Here you are. Summary: </s> \n",
"7 Summarize the following conversation. #Person1#: What can I do for you, madam? #Person2#: I'd like to buy a toy car for my son. #Person1#: How about this one? #Person2#: It looks nice. How much is it? #Person1#: They're three hundred dollars. #Person2#: Oh, I'm afraid it's too expensive. Can you show me something cheaper? #Person1#: OK, This one is one hundred and twenty. It's the cheapest here. #Person2#: OK, I'll take it. Here's the money. #Person1#: Thank you very much. Summary: </s> \n",
"8 Summarize the following conversation. #Person1#: It smells like an ashtray in here! #Person2#: Hi honey! What's wrong? Why do you have that look on your face? #Person1#: What's wrong? I thought we agreed that you were gonna quit smoking. #Person2#: No! I said I was going to cut down which is very different. You can't just expect me to go cold turkey overnight! #Person1#: Look, there are other ways to quit. You can try the nicotine patch, or nicotine chewing gum. We spend a fortune on cigaret... \n",
"9 Summarize the following conversation. #Person1#: Hello. I want to reconfirm our flight to London. #Person2#: Yes, sir. Did you call the airline? #Person1#: Yes, I did. But I couldn't communicate with them in English. They speak only Spanish. So I need your help. #Person2#: Certainly, sir. What is the flight number and when are you leaving? #Person1#: We are taking IB 385 to London tomorrow at 1 p. m. #Person2#: Oh, I see, sir. We have the airline office inside the hotel. They have an English... \n",
"10 Summarize the following conversation. #Person1#: Mom, I just finished my paper. Can you proofread it before I hand it in? #Person2#: Sure, let's take a look. Sweetie, this is terrific. Your ideas are so original. #Person1#: Thanks. #Person2#: I can tell you worked hard on it. #Person1#: I really did! I started thinking about what I wanted to say three weeks ago. #Person2#: Well, it was definitely worth all the time. #Person1#: Let's just hope my teacher agrees. Summary: </s> \n",
"11 Summarize the following conversation. #Person1#: Could you help me, Sir? My flight got in 15 minutes ago. Everyone else has picked up the luggage but mine hasn't come through. #Person2#: I'm sorry, Madam, I'll go and find out if there is any more to come. Summary: </s> \n",
"12 Summarize the following conversation. #Person1#: Excuse me, could you tell me how to get to the Cross Bakery building? #Person2#: The Cross Bakery building? Oh sure. You're actually walking in the opposite direction. #Person1#: Oh, you're kidding! I thought I was heading east. #Person2#: No, east is the other direction. To get to the Bakery, you need to turn around and go three blocks to Broadway. When you get to the intersection of Broadway and Elm, you hang a left. Go straight down that st... \n",
"13 Summarize the following conversation. #Person1#: So how did you like the restaurant? #Person2#: Actually, it could have been better. #Person1#: What didn't you like about it? #Person2#: It is a new restaurant. I don't think they have their act together yet. #Person1#: What did you think about the food? #Person2#: I felt that the food was pretty mediocre. #Person1#: The service wasn't that great, either. #Person2#: I agree. The service was not good. #Person1#: Do you think that you want to tr... \n",
"14 Summarize the following conversation. #Person1#: Hello? #Person2#: Hello? #Person1#: Can I speak to Li Hong, please? #Person2#: Speaking. #Person1#: Hi, Li Hong. This is Alice. #Person2#: Hi, Alice. How are you? #Person1#: Not bad. Li Hong, I am sorry that I can't go to see Mrs. Brown with you tomorrow morning. My mother is ill. I must take care of her. #Person2#: I'm sorry to hear that. You'd better stay at home. After all, we can visit Mrs. Brown later #Person1#: OK. Bye - bye. #Person2#: ... \n",
"15 Summarize the following conversation. #Person1#: Here is the final draft of our contract. I'm glad that we have reached an agreement on almost every term in our trade. #Person2#: Yes, it seems to me we have come quite a long way. However, let me take a close look at the final draft. #Person1#: Do you have some points to bring up? #Person2#: Well, everything we've discussed seems to be here. #Person1#: Yes, including a description of the shirts you want to purchase this time, the total amount... \n",
"16 Summarize the following conversation. #Person1#: Amanda, how do you like this peaked cap? #Person2#: Didn't you say you want to buy a top hat? #Person1#: But I think this one fits me Well. Why don't you try on the sombrero in black? #Person2#: I don't like caps at all. Summary: </s> \n",
"17 Summarize the following conversation. #Person1#: I'm forming a music band. #Person2#: Do you already know how to play an instrument? #Person1#: Uh... Yeah! I'Ve told you a thousand times that I'm learning to play the drums. Now that I know how to play well, I would like to form a rock band. #Person2#: Aside from yourself, who are the other members of the band? #Person1#: We have a guy who plays guitar, and another who plays bass. Although we still haven't found anyone to be our singer. You t... \n",
"18 Summarize the following conversation. #Person1#: Could you help me figure out how to look for a job? #Person2#: We have lots of options, what type of job do you need? #Person1#: I want to work in an office. #Person2#: Do you want to work part-time or full-time? #Person1#: I want to work full-time. #Person2#: We have binders with local job listings or you can make use of the computers. OK? #Person1#: I am confused a bit but I am sure that I can figure it out. #Person2#: If you make an appoint... \n",
"19 Summarize the following conversation. #Person1#: Today more and more families have personal computers. People have wider range of choice to communicate with the outside world. #Person2#: Right. With the establishment of Internet and a lot of web companies, people are getting more and more dependent on the web. #Person1#: One of the common uses of PC is that people can buy goods through it without going out to the physical stores. #Person2#: Can you tell me how it is done? #Person1#: If a cus... \n",
"\n",
" response_before \\\n",
"0 <pad> Judy and #Person2# realize that Richard has been fired by her manager. Judy is surprised.</s> \n",
"1 <pad> @mmb cannot take a break because she can't sit when she's sitting there. #Person1# wonders she could take a break but she doesn't want to.</s> \n",
"2 <pad> #Person2# has offered them to #Person1# at 150 yuan and they cannot be matched. They offer a volume discount, however, based on the advertising aren't matched. #Person1# accepts.</s> \n",
"3 <pad> Allen tells #Person2# there is a robber breaking into the house because he broke in. His dog didn't fall in the door and he broke into it. Allen says the robber is locked by the window and he left through the door. #Person2# does not think there will be anyone.</s> \n",
"4 <pad> #Person1# is asked by #Person2# to register #Person1#'s information exchange with #Person2# in court. #Person2# happily agrees to make the medical record and then tells #Person1# how to go to the clinic.</s> \n",
"5 <pad> #Person1# wants to order an internet today. #Person2# recommends dial-up because it won't tie up the phone. However, #Person1# doesn't use the phone if #Person1# is on the internet.</s> \n",
"6 <pad> #Person1# wants to cash a bank account at a bank with $10K and $50000 in small change. #Person1# pays for that in cash with #Person1#'s name and address.</s> \n",
"7 <pad> #Person1# gives #Person2# a car by buying one hundred and twenty for #Person2#'s son. #Person2# agrees and offers the money.</s> \n",
"8 <pad> #Person2# cannot stop smoking because #Person1# tells her to try other ways and quit smoking. #Person1# feels her identity belongs to #Person2# and wants a divorce but #Person2# still feels the urge to smoke.</s> \n",
"9 <pad> #Person1# is checking for a flight to London. #Person2# tells #Person1# the flight number and the time and tells #Person1# to dial 35.</s> \n",
"10 <pad> #Person1# is finishing #Person1#'s paper and said it was a great idea. #Person2# tells her the original ideas and highlights the work she did on the paper.</s> \n",
"11 <pad> All others have picked up #Person1#'s luggage but #Person1#'s flight did not arrive. #Person1# and #Person2# will ask for help. See something else in the meantime.</s> \n",
"12 <pad> #Person1# requests #Person2# to show #Person1# the way to the Cross Bakery building and express all of her confidence in the directions.</s> \n",
"13 <pad> If #Person2# likes the restaurant, #Person1# hasn't tried it. #Person2# is disappointed. #Person2# thinks the staff is not good. #Person2# is not going to try this restaurant again.</s> \n",
"14 <pad> Alice can't come to see Mrs. Brown tomorrow because her mother is ill. Li Hong tells her to stay at home, and she can visit Mrs. Brown later.</s> \n",
"15 <pad> #Person2# understands the final drafts of the contract. #Person1# wants to sign the contract now. #Person2# requests #Person1#'s help with verification checks on details of the contract. #Person2# specifies #Person1# as the standard for others.</s> \n",
"16 <pad> Amanda wears a peaked cap she likes and drinks a sombrero in black. Amanda wants to buy a top hat but Amanda likes caps.</s> \n",
"17 <pad> #Person1# wants to form a rock band in a concert because #Person1# wants to play a rock instrument and shows some musical talent to #Person2#. Takes #Person1# to audition here. Plus, #Person2# doesn't have space for the amplifiers, microphones, or even the drums, so #Person1# has to practice a little.</s> \n",
"18 <pad> #Person2# helps #Person1# find a job at the job center. #Person1# wants to work full-time and #Person1# thinks it's convenient to go. They discuss the kind of job and offer advice as well. They convince #Person1# to visit a job center.</s> \n",
"19 <pad> #Person1# and #Person2# discuss the common uses of personal computers. They have other uses too such as easy to use, and common use of PC is to buy goods online through it. The ship's almost completely free and the delivery to the home is perfect and free of charge.</s> \n",
"\n",
" response_after \\\n",
"0 <pad> Judy and Judy are surprised to hear Richard's failure.</s> \n",
"1 <pad> #Person2# is up to #Person2#'s neck in work, but #Person1# allows him a coffee break. But he needs to finish his report by the deadline and gives her a lot of reason for bad behavior.</s> \n",
"2 <pad> #Person1# arrives at a giveaway table and phone the store to buy colorful books made by #Person2# for 150 yuan. The payment is only for the small books.</s> \n",
"3 <pad> Allen opens the door and leaves the two women alone. Allen tells #Person1# that the robber broke in and left through the door spouts the thief's valuables. They still have to go upstairs since the door isn't locked.</s> \n",
"4 <pad> #Person1# will register for the medical records to get to the counseling room. #Person2# gives the same information as #Person1# to #Person1# and helps #Person1# to walk to the pharmacy and get to the consulting room.</s> \n",
"5 <pad> #Person1# wants to order some DEL or dial-up internet. #Person2# recommends DEL because the DEL doesn't tie up the phone. #Person1# can do both.</s> \n",
"6 <pad> #Person1# wants to have this cashed 10 hundreds and 10 twenties. #Person1# can't be seen. #Person2# arrives.</s> \n",
"7 <pad> #Person2# wants a toy car but #Person1# suggests a $1000 toy car. #Person2# takes the one at a cheap price.</s> \n",
"8 <pad> #Person1# is curious about the ashtray in the room. They agree that they should continue to smoke cigarettes and try different ways to quit. #Person1# wants a divorce.</s> \n",
"9 <pad> #Person1# asks #Person2# to reconfirm their flight and address to London on IB 385. #Person2# provides his date, flight number and the_flight number.</s> \n",
"10 <pad> #Person1#'s mom realizes she's finished the paper on her mom's recommendation and is proud of her work.</s> \n",
"11 <pad> #Person1#'s flight got in 15 minutes but #Person1#'s flight hasn't been ready. #Person2# will go to find out when it arrives.</s> \n",
"12 <pad> #Person1# is on a new way to cross the Cross Bakery building exiting Broadway, which requires stopping at the intersection of Broadway and Elm. #Person2# shows #Person1# the way, and she agrees.</s> \n",
"13 <pad> #Person2# asked for while #Person1# is not sure what constructive will of this restaurant, and #Person1# can tell #Person2# that it's a new restaurant. They both feel the service was not good and they've had enough of the restaurant.</s> \n",
"14 <pad> Li Hong apologises for Alice whose mother is ill. Alice agrees to stay at home if she does not have a suitable date with Li Hong tomorrow morning.</s> \n",
"15 <pad> #Person1#'s staff made a final draft of their contract. #Person2# doesn't want to sign because the final draft doesn't give results. #Person1# gives #Person2# another job to check to make sure the average is done.</s> \n",
"16 <pad> Amanda likes the peaked cap, but she doesn't like caps at all.</s> \n",
"17 <pad> #Person1# is forming a music band, helped by the musician, said she's a singer and invites #Person1# to audition This weekend.</s> \n",
"18 <pad> #Person2# shows #Person1# how to look for a job and gives her instructions to open a job center. #Person1# is confused and asks to make an appointment with a job counselor but #Person1# falls in love.</s> \n",
"19 <pad> #Person1# tells #Person2# that people will be more able to communicate with the outside world through PC, even though computers are becoming more common. It was explained how customers can place orders online to have products show on the colour screen.</s> \n",
"\n",
" reward_before reward_after reward_diff \n",
"0 1.251855 2.196642 0.944787 \n",
"1 0.516179 1.374489 0.858310 \n",
"2 2.462125 2.818621 0.356496 \n",
"3 1.619587 1.787470 0.167882 \n",
"4 1.497437 1.640947 0.143510 \n",
"5 2.290805 2.405071 0.114266 \n",
"6 1.690671 1.762732 0.072061 \n",
"7 1.315235 1.380598 0.065364 \n",
"8 1.282832 1.320850 0.038018 \n",
"9 1.791396 1.806620 0.015224 \n",
"10 2.689511 2.694573 0.005062 \n",
"11 2.655438 2.632524 -0.022914 \n",
"12 2.493639 2.415073 -0.078566 \n",
"13 2.119996 1.956108 -0.163888 \n",
"14 1.979845 1.728329 -0.251516 \n",
"15 3.136004 2.764242 -0.371763 \n",
"16 1.310969 0.881978 -0.428990 \n",
"17 3.065439 2.505563 -0.559876 \n",
"18 2.529984 1.955189 -0.574795 \n",
"19 3.720674 2.764764 -0.955910 "
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.set_option('display.max_colwidth', 500)\n",
"df_compare_results = pd.DataFrame(compare_results)\n",
"df_compare_results[\"reward_diff\"] = df_compare_results['reward_after'] - df_compare_results['reward_before']\n",
"df_compare_results_sorted = df_compare_results.sort_values(by=['reward_diff'], ascending=False).reset_index(drop=True)\n",
"df_compare_results_sorted"
]
},
{
"cell_type": "markdown",
"id": "e7fb2477-f719-48de-b169-0607d355a8f6",
"metadata": {
"id": "e7fb2477-f719-48de-b169-0607d355a8f6"
},
"source": [
"Looking at the reward mean/median of the generated sequences you can observe a significant difference!"
]
}
],
"metadata": {
"availableInstances": [
{
"_defaultOrder": 0,
"_isFastLaunch": true,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 4,
"name": "ml.t3.medium",
"vcpuNum": 2
},
{
"_defaultOrder": 1,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 8,
"name": "ml.t3.large",
"vcpuNum": 2
},
{
"_defaultOrder": 2,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 16,
"name": "ml.t3.xlarge",
"vcpuNum": 4
},
{
"_defaultOrder": 3,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 32,
"name": "ml.t3.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 4,
"_isFastLaunch": true,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 8,
"name": "ml.m5.large",
"vcpuNum": 2
},
{
"_defaultOrder": 5,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 16,
"name": "ml.m5.xlarge",
"vcpuNum": 4
},
{
"_defaultOrder": 6,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 32,
"name": "ml.m5.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 7,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 64,
"name": "ml.m5.4xlarge",
"vcpuNum": 16
},
{
"_defaultOrder": 8,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 128,
"name": "ml.m5.8xlarge",
"vcpuNum": 32
},
{
"_defaultOrder": 9,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 192,
"name": "ml.m5.12xlarge",
"vcpuNum": 48
},
{
"_defaultOrder": 10,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 256,
"name": "ml.m5.16xlarge",
"vcpuNum": 64
},
{
"_defaultOrder": 11,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 384,
"name": "ml.m5.24xlarge",
"vcpuNum": 96
},
{
"_defaultOrder": 12,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 8,
"name": "ml.m5d.large",
"vcpuNum": 2
},
{
"_defaultOrder": 13,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 16,
"name": "ml.m5d.xlarge",
"vcpuNum": 4
},
{
"_defaultOrder": 14,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 32,
"name": "ml.m5d.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 15,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 64,
"name": "ml.m5d.4xlarge",
"vcpuNum": 16
},
{
"_defaultOrder": 16,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 128,
"name": "ml.m5d.8xlarge",
"vcpuNum": 32
},
{
"_defaultOrder": 17,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 192,
"name": "ml.m5d.12xlarge",
"vcpuNum": 48
},
{
"_defaultOrder": 18,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 256,
"name": "ml.m5d.16xlarge",
"vcpuNum": 64
},
{
"_defaultOrder": 19,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 384,
"name": "ml.m5d.24xlarge",
"vcpuNum": 96
},
{
"_defaultOrder": 20,
"_isFastLaunch": false,
"category": "General purpose",
"gpuNum": 0,
"hideHardwareSpecs": true,
"memoryGiB": 0,
"name": "ml.geospatial.interactive",
"supportedImageNames": [
"sagemaker-geospatial-v1-0"
],
"vcpuNum": 0
},
{
"_defaultOrder": 21,
"_isFastLaunch": true,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 4,
"name": "ml.c5.large",
"vcpuNum": 2
},
{
"_defaultOrder": 22,
"_isFastLaunch": false,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 8,
"name": "ml.c5.xlarge",
"vcpuNum": 4
},
{
"_defaultOrder": 23,
"_isFastLaunch": false,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 16,
"name": "ml.c5.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 24,
"_isFastLaunch": false,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 32,
"name": "ml.c5.4xlarge",
"vcpuNum": 16
},
{
"_defaultOrder": 25,
"_isFastLaunch": false,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 72,
"name": "ml.c5.9xlarge",
"vcpuNum": 36
},
{
"_defaultOrder": 26,
"_isFastLaunch": false,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 96,
"name": "ml.c5.12xlarge",
"vcpuNum": 48
},
{
"_defaultOrder": 27,
"_isFastLaunch": false,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 144,
"name": "ml.c5.18xlarge",
"vcpuNum": 72
},
{
"_defaultOrder": 28,
"_isFastLaunch": false,
"category": "Compute optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 192,
"name": "ml.c5.24xlarge",
"vcpuNum": 96
},
{
"_defaultOrder": 29,
"_isFastLaunch": true,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 16,
"name": "ml.g4dn.xlarge",
"vcpuNum": 4
},
{
"_defaultOrder": 30,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 32,
"name": "ml.g4dn.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 31,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 64,
"name": "ml.g4dn.4xlarge",
"vcpuNum": 16
},
{
"_defaultOrder": 32,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 128,
"name": "ml.g4dn.8xlarge",
"vcpuNum": 32
},
{
"_defaultOrder": 33,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 4,
"hideHardwareSpecs": false,
"memoryGiB": 192,
"name": "ml.g4dn.12xlarge",
"vcpuNum": 48
},
{
"_defaultOrder": 34,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 256,
"name": "ml.g4dn.16xlarge",
"vcpuNum": 64
},
{
"_defaultOrder": 35,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 61,
"name": "ml.p3.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 36,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 4,
"hideHardwareSpecs": false,
"memoryGiB": 244,
"name": "ml.p3.8xlarge",
"vcpuNum": 32
},
{
"_defaultOrder": 37,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 8,
"hideHardwareSpecs": false,
"memoryGiB": 488,
"name": "ml.p3.16xlarge",
"vcpuNum": 64
},
{
"_defaultOrder": 38,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 8,
"hideHardwareSpecs": false,
"memoryGiB": 768,
"name": "ml.p3dn.24xlarge",
"vcpuNum": 96
},
{
"_defaultOrder": 39,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 16,
"name": "ml.r5.large",
"vcpuNum": 2
},
{
"_defaultOrder": 40,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 32,
"name": "ml.r5.xlarge",
"vcpuNum": 4
},
{
"_defaultOrder": 41,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 64,
"name": "ml.r5.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 42,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 128,
"name": "ml.r5.4xlarge",
"vcpuNum": 16
},
{
"_defaultOrder": 43,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 256,
"name": "ml.r5.8xlarge",
"vcpuNum": 32
},
{
"_defaultOrder": 44,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 384,
"name": "ml.r5.12xlarge",
"vcpuNum": 48
},
{
"_defaultOrder": 45,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 512,
"name": "ml.r5.16xlarge",
"vcpuNum": 64
},
{
"_defaultOrder": 46,
"_isFastLaunch": false,
"category": "Memory Optimized",
"gpuNum": 0,
"hideHardwareSpecs": false,
"memoryGiB": 768,
"name": "ml.r5.24xlarge",
"vcpuNum": 96
},
{
"_defaultOrder": 47,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 16,
"name": "ml.g5.xlarge",
"vcpuNum": 4
},
{
"_defaultOrder": 48,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 32,
"name": "ml.g5.2xlarge",
"vcpuNum": 8
},
{
"_defaultOrder": 49,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 64,
"name": "ml.g5.4xlarge",
"vcpuNum": 16
},
{
"_defaultOrder": 50,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 128,
"name": "ml.g5.8xlarge",
"vcpuNum": 32
},
{
"_defaultOrder": 51,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 1,
"hideHardwareSpecs": false,
"memoryGiB": 256,
"name": "ml.g5.16xlarge",
"vcpuNum": 64
},
{
"_defaultOrder": 52,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 4,
"hideHardwareSpecs": false,
"memoryGiB": 192,
"name": "ml.g5.12xlarge",
"vcpuNum": 48
},
{
"_defaultOrder": 53,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 4,
"hideHardwareSpecs": false,
"memoryGiB": 384,
"name": "ml.g5.24xlarge",
"vcpuNum": 96
},
{
"_defaultOrder": 54,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 8,
"hideHardwareSpecs": false,
"memoryGiB": 768,
"name": "ml.g5.48xlarge",
"vcpuNum": 192
},
{
"_defaultOrder": 55,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 8,
"hideHardwareSpecs": false,
"memoryGiB": 1152,
"name": "ml.p4d.24xlarge",
"vcpuNum": 96
},
{
"_defaultOrder": 56,
"_isFastLaunch": false,
"category": "Accelerated computing",
"gpuNum": 8,
"hideHardwareSpecs": false,
"memoryGiB": 1152,
"name": "ml.p4de.24xlarge",
"vcpuNum": 96
}
],
"instance_type": "ml.m5.2xlarge",
"kernelspec": {
"display_name": "Python 3 (Data Science)",
"language": "python",
"name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
},
"colab": {
"provenance": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment