Skip to content

Instantly share code, notes, and snippets.

@nogawanogawa
Created November 22, 2023 04:31
Show Gist options
  • Save nogawanogawa/b751146ff57ced6a70deba45f578bcc6 to your computer and use it in GitHub Desktop.
Save nogawanogawa/b751146ff57ced6a70deba45f578bcc6 to your computer and use it in GitHub Desktop.
OpenP5_train_notebook.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/nogawanogawa/b751146ff57ced6a70deba45f578bcc6/openp5_train_notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dt6kfEGBkU7u",
"outputId": "a98382d4-123f-44ff-fa14-1923b7351154"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.35.2)\n",
"Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.15.0)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.3)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.0)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.0)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n",
"Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.5)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.7)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n",
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.15)\n",
"Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.6)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.3.2)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.5.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n",
"Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.66.1)\n",
"Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.6.2)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from peft) (1.23.5)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from peft) (23.2)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from peft) (6.0.1)\n",
"Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.1.0+cu118)\n",
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from peft) (4.35.2)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from peft) (4.66.1)\n",
"Requirement already satisfied: accelerate>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from peft) (0.24.1)\n",
"Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft) (0.4.0)\n",
"Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.21.0->peft) (0.19.3)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.13.1)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (4.5.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.2)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2023.6.0)\n",
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.1.0)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2023.6.3)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2.31.0)\n",
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (0.15.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.3)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (2023.7.22)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13.0->peft) (1.3.0)\n"
]
}
],
"source": [
"!pip install transformers datasets\n",
"!pip install tqdm\n",
"!pip install peft"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u4M5mcq3kHK6"
},
"outputs": [],
"source": [
"import os\n",
"import transformers\n",
"import argparse\n",
"import torch\n",
"import logging\n",
"from torch.utils.data import ConcatDataset, DataLoader\n",
"from datasets import load_dataset, concatenate_datasets\n",
"import re\n",
"import sys\n",
"import random\n",
"\n",
"import numpy as np\n",
"import pickle\n",
"import inspect\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"from transformers import (\n",
" T5Config,\n",
" T5ForConditionalGeneration,\n",
" AutoTokenizer,\n",
" LlamaForCausalLM\n",
")\n",
"\n",
"from peft import ( # noqa: E402\n",
" LoraConfig,\n",
" get_peft_model,\n",
" get_peft_model_state_dict,\n",
" prepare_model_for_int8_training,\n",
" set_peft_model_state_dict,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "76btAfm1kT8v"
},
"outputs": [],
"source": [
"# from https://github.com/agiresearch/OpenP5/blob/main/src/TaskAlternateTrainer.py\n",
"\n",
"from torch.utils.data.distributed import DistributedSampler\n",
"from torch.utils.data.sampler import RandomSampler\n",
"from torch.utils.data import DataLoader\n",
"import inspect\n",
"import transformers\n",
"from transformers.utils import find_labels\n",
"\n",
"class MultitaskDataloader:\n",
" \"\"\"\n",
" Data loader that combines and samples from multiple single-task\n",
" data loaders.\n",
" \"\"\"\n",
" def __init__(self, model, dataset_dict, batch_size, collate_fn, accelerator=None):\n",
" self.model = model\n",
" self.batch_size = batch_size\n",
" self.collate_fn = collate_fn\n",
" self._signature_columns = None\n",
"\n",
" for task in dataset_dict:\n",
" dataset_dict[task] = self._remove_unused_columns(dataset_dict[task])\n",
"\n",
" self.dataloader_dict = dict()\n",
" for task in dataset_dict:\n",
" self.dataloader_dict[task] = DataLoader(dataset_dict[task], batch_size = self.batch_size, sampler = RandomSampler(dataset_dict[task]), collate_fn = self.collate_fn)\n",
" if accelerator is not None:\n",
" for task in self.dataloader_dict:\n",
" self.dataloader_dict[task] = accelerator.prepare(self.dataloader_dict[task])\n",
"\n",
" self.num_batches_dict = {task: len(dataloader) for task, dataloader in self.dataloader_dict.items()}\n",
" self.tasks_list = list(self.dataloader_dict.keys())\n",
" self.num_tasks = len(self.tasks_list)\n",
"\n",
" self.task_max_num_batch = max(list(self.num_batches_dict.values()))\n",
"\n",
" self._init()\n",
"\n",
"\n",
" def _init(self):\n",
" self.dataloader_iters = [iter(dataloader) for task, dataloader in self.dataloader_dict.items()]\n",
" self.batch_idx = 0\n",
"\n",
"\n",
" def _set_signature_columns_if_needed(self):\n",
" if self._signature_columns is None:\n",
" # Inspect model forward signature to keep only the arguments it accepts.\n",
" signature = inspect.signature(self.model.forward)\n",
" self._signature_columns = list(signature.parameters.keys())\n",
" # Labels may be named label or label_ids, the default data collator handles that.\n",
" self._signature_columns += list(set([\"label\", \"label_ids\"] + find_labels(self.model.__class__)))\n",
"\n",
" def _remove_unused_columns(self, dataset):\n",
" self._set_signature_columns_if_needed()\n",
" signature_columns = self._signature_columns\n",
"\n",
" ignored_columns = list(set(dataset.column_names) - set(signature_columns))\n",
" return dataset.remove_columns(ignored_columns)\n",
"\n",
" def __len__(self):\n",
" return self.num_tasks * self.task_max_num_batch\n",
"\n",
" def __iter__(self):\n",
" return self\n",
"\n",
" def __next__(self):\n",
" task_id = self.batch_idx % self.num_tasks\n",
" self.batch_idx += 1\n",
" # if self.batch_idx < self.num_tasks * self.task_max_num_batch:\n",
" try:\n",
" return next(self.dataloader_iters[task_id])\n",
" except StopIteration:\n",
" self.dataloader_iters[task_id] = iter(self.dataloader_dict[self.tasks_list[task_id]])\n",
" return next(self.dataloader_iters[task_id])\n",
" # self._init()\n",
"\n",
"\n",
"class TaskAlternateTrainer(transformers.Trainer):\n",
"\n",
" def get_train_dataloader(self):\n",
" train_dataset = self.train_dataset\n",
" data_collator = self.data_collator\n",
"\n",
" return MultitaskDataloader(self.model, train_dataset, self._train_batch_size, data_collator, self.accelerator)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FqSRHxoLlGvE"
},
"outputs": [],
"source": [
"# from https://github.com/agiresearch/OpenP5/blob/main/src/utils/utils.py\n",
"\n",
"def set_seed(seed):\n",
" random.seed(seed)\n",
" np.random.seed(seed)\n",
" torch.manual_seed(seed)\n",
" torch.cuda.manual_seed_all(seed)\n",
" torch.backends.cudnn.benchmark = False\n",
" torch.backends.cudnn.deterministic = True\n",
" torch.backends.cudnn.enabled = False\n",
"\n",
"\n",
"\n",
"def random_initialization(model, tokenizer, backbone):\n",
" ids = []\n",
" for x in range(30000):\n",
" tokenized_ids = tokenizer.encode(str(x))\n",
" if 3 in tokenized_ids:\n",
" tokenized_ids.remove(3)\n",
" if 1 in tokenized_ids:\n",
" tokenized_ids.remove(1)\n",
" ids += tokenized_ids\n",
" ids = list(set(ids))\n",
"\n",
" # reinitialize the embedding in the backbone models\n",
" for index in ids:\n",
" if 't5' in backbone:\n",
" model.shared.weight.data[index] = nn.init.normal_(\n",
" model.shared.weight.data[index], 0, 1.0\n",
" )\n",
" elif 'llama' in backbone.lower():\n",
" model.model.embed_tokens.weight.data[index] = nn.init.normal_(\n",
" model.model.embed_tokens.weight.data[index], 0, 1.0\n",
" )\n",
"\n",
" return model\n",
"\n",
"def setup_logging(\n",
" log_name: str,\n",
" datasets: str,\n",
" log_dir: str,\n",
" model_dir: str,\n",
" checkpoint_dir: str,\n",
" logging_level: int,\n",
" sample_prompt: str,\n",
" his_prefix: str,\n",
" skip_empty_his: int,\n",
" max_his: int,\n",
" master_port: int,\n",
" tasks: str,\n",
" backbone: str,\n",
" item_indexing: str,\n",
" lr: float,\n",
" epochs: int,\n",
" batch_size: int,\n",
" sample_num: int,\n",
" prompt_file: str\n",
"):\n",
" log_name = log_name(\n",
" datasets,\n",
" sample_prompt,\n",
" his_prefix,\n",
" skip_empty_his,\n",
" max_his,\n",
" master_port,\n",
" tasks,\n",
" backbone,\n",
" item_indexing,\n",
" lr,\n",
" epochs,\n",
" batch_size,\n",
" sample_num,\n",
" prompt_file\n",
" )\n",
" if len(datasets.split(',')) > 1:\n",
" folder_name = 'SP5'\n",
" else:\n",
" folder_name = datasets\n",
" folder = os.path.join(log_dir, folder_name)\n",
" if not os.path.exists(folder):\n",
" os.makedirs(folder)\n",
" log_file = os.path.join(log_dir, folder_name, log_name + '.log')\n",
"\n",
" for handler in logging.root.handlers[:]:\n",
" logging.root.removeHandler(handler)\n",
" logging.basicConfig(filename=log_file, level=logging_level, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n",
" logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))\n",
"\n",
" return\n",
"\n",
"def log_name(\n",
" datasets,\n",
" sample_prompt,\n",
" his_prefix,\n",
" skip_empty_his,\n",
" max_his,\n",
" master_port,\n",
" tasks,\n",
" backbone,\n",
" item_indexing,\n",
" lr,\n",
" epochs,\n",
" batch_size,\n",
" sample_num,\n",
" prompt_file,\n",
"):\n",
" if len(datasets.split(',')) > 1:\n",
" folder_name = 'SP5'\n",
" else:\n",
" folder_name = datasets\n",
" params = [str(sample_prompt), str(his_prefix), str(skip_empty_his), str(max_his), str(master_port), folder_name, tasks, backbone, item_indexing, str(lr), str(epochs), str(batch_size), sample_num, prompt_file[3:-4]]\n",
" return '_'.join(params)\n",
"\n",
"def ReadLineFromFile(path):\n",
" if not os.path.exists(path):\n",
" raise FileNotFoundError\n",
" lines = []\n",
" with open(path,'r') as fd:\n",
" for line in fd:\n",
" lines.append(line.rstrip('\\n'))\n",
" return lines\n",
"\n",
"def WriteDictToFile(path, write_dict):\n",
" with open(path, 'w') as out:\n",
" for user, items in write_dict.items():\n",
" if type(items) == list:\n",
" out.write(user + ' ' + ' '.join(items) + '\\n')\n",
" else:\n",
" out.write(user + ' ' + str(items) + '\\n')\n",
"\n",
"\n",
"def load_prompt_template(path, task_list):\n",
" \"\"\"\n",
" Load prompt template from the file. Keep training tasks only.\n",
" Input:\n",
" - path: The path for prompt template txt file.\n",
" - task_list: A list of required tasks.\n",
" Return:\n",
" - prompt_templates: a dictionary of prompt templates. e.g., {task: {'seen': {'0': {'Input': template_input, 'Output': template_output}}}}\n",
"\n",
" \"\"\"\n",
"\n",
" if not os.path.exists(path):\n",
" raise FileNotFoundError\n",
" prompt_info = ReadLineFromFile(path)\n",
" prompt_templates = dict()\n",
" for prompt in prompt_info:\n",
" t = [sens.strip() for sens in prompt.split(';')]\n",
" if t[0] not in task_list:\n",
" continue\n",
" if t[0] not in prompt_templates:\n",
" prompt_templates[t[0]] = dict()\n",
" if t[1] not in prompt_templates[t[0]]:\n",
" prompt_templates[t[0]][t[1]] = dict()\n",
" num = len(prompt_templates[t[0]][t[1]])\n",
" prompt_templates[t[0]][t[1]][str(num)] = dict()\n",
" prompt_templates[t[0]][t[1]][str(num)]['Input'] = t[2]\n",
" prompt_templates[t[0]][t[1]][str(num)]['Output'] = t[3]\n",
" return prompt_templates\n",
"\n",
"def get_info_from_prompt(prompt_templates):\n",
" \"\"\"\n",
" Extract the require information from the prompt templates.\n",
" Input:\n",
" - prompt_templates: a dictionary of prompt templates.\n",
" Output:\n",
" - info: a list of required information.\n",
" \"\"\"\n",
"\n",
" info = []\n",
" for task in prompt_templates:\n",
" for see in prompt_templates[task]:\n",
" for i in prompt_templates[task][see]:\n",
" info += re.findall(r'\\{.*?\\}', prompt_templates[task][see][i]['Input'])\n",
" info += re.findall(r'\\{.*?\\}', prompt_templates[task][see][i]['Output'])\n",
" info = [i[1:-1] for i in set(info)]\n",
" return info\n",
"\n",
"def check_task_prompt(prompt_templates, task_list):\n",
" \"\"\"\n",
" Check if all tasks have prompt templates. Raise Error if training tasks have no prompt.\n",
" Input:\n",
" - prompt_templates: A dictionary of prompt templates.\n",
" - task_list: A list of training tasks.\n",
" \"\"\"\n",
" for task in task_list:\n",
" assert task in prompt_templates, f\"No prompt for {task} task\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y8XKqvL1lL0p"
},
"outputs": [],
"source": [
"# from https://github.com/agiresearch/OpenP5/blob/main/src/utils/indexing.py\n",
"\n",
"def sequential_indexing(data_path, dataset, user_sequence_dict, order):\n",
" \"\"\"\n",
" Use sequential indexing method to index the given user seuqnece dict.\n",
" \"\"\"\n",
" user_index_file = os.path.join(data_path, dataset, 'user_indexing.txt')\n",
" item_index_file = os.path.join(data_path, dataset, f'item_sequential_indexing_{order}.txt')\n",
" reindex_sequence_file = os.path.join(data_path, dataset, f'user_sequence_sequential_indexing_{order}.txt')\n",
"\n",
" if os.path.exists(reindex_sequence_file):\n",
" user_sequence = ReadLineFromFile(reindex_sequence_file)\n",
"\n",
" item_info = ReadLineFromFile(item_index_file)\n",
" item_map = get_dict_from_lines(item_info)\n",
"\n",
" return construct_user_sequence_dict(user_sequence), item_map\n",
"\n",
" # For user index, load from txt file if already exists, otherwise generate from user sequence and save.\n",
" if os.path.exists(user_index_file):\n",
" user_info = ReadLineFromFile(user_index_file)\n",
" user_map = get_dict_from_lines(user_info)\n",
" else:\n",
" user_map = generate_user_map(user_sequence_dict)\n",
" WriteDictToFile(user_index_file, user_map)\n",
"\n",
"\n",
" # For item index, load from txt file if already exists, otherwise generate from user sequence and save.\n",
" if os.path.exists(item_index_file):\n",
" item_info = ReadLineFromFile(item_index_file)\n",
" item_map = get_dict_from_lines(item_info)\n",
" else:\n",
" item_map = dict()\n",
" if order == 'original':\n",
" user_list = user_sequence_dict.keys()\n",
" elif order == 'short2long':\n",
" user_list = sorted(user_sequence_dict, key=lambda x: len(user_sequence_dict[x]), reverse=False)\n",
" elif order == 'long2short':\n",
" user_list = sorted(user_sequence_dict, key=lambda x: len(user_sequence_dict[x]), reverse=True)\n",
"\n",
" for user in user_list:\n",
" items = user_sequence_dict[user][:-2]\n",
" for item in items:\n",
" if item not in item_map:\n",
" item_map[item] = str(len(item_map) + 1001)\n",
" for user in user_list:\n",
" items = user_sequence_dict[user][-2:]\n",
" for item in items:\n",
" if item not in item_map:\n",
" item_map[item] = str(len(item_map) + 1001)\n",
" WriteDictToFile(item_index_file, item_map)\n",
"\n",
" reindex_user_sequence_dict = reindex(user_sequence_dict, user_map, item_map)\n",
" WriteDictToFile(reindex_sequence_file, reindex_user_sequence_dict)\n",
" return reindex_user_sequence_dict, item_map\n",
"\n",
"\n",
"\n",
"def random_indexing(data_path, dataset, user_sequence_dict):\n",
" \"\"\"\n",
" Use random indexing method to index the given user seuqnece dict.\n",
" \"\"\"\n",
" user_index_file = os.path.join(data_path, dataset, 'user_indexing.txt')\n",
" item_index_file = os.path.join(data_path, dataset, 'item_random_indexing.txt')\n",
" reindex_sequence_file = os.path.join(data_path, dataset, f'user_sequence_random_indexing.txt')\n",
"\n",
" if os.path.exists(reindex_sequence_file):\n",
" user_sequence = ReadLineFromFile(reindex_sequence_file)\n",
"\n",
" item_info = ReadLineFromFile(item_index_file)\n",
" item_map = get_dict_from_lines(item_info)\n",
"\n",
" return construct_user_sequence_dict(user_sequence), item_map\n",
"\n",
" # For user index, load from txt file if already exists, otherwise generate from user sequence and save.\n",
" if os.path.exists(user_index_file):\n",
" user_info = ReadLineFromFile(user_index_file)\n",
" user_map = get_dict_from_lines(user_info)\n",
" else:\n",
" user_map = generate_user_map(user_sequence_dict)\n",
" WriteDictToFile(user_index_file, user_map)\n",
"\n",
"\n",
" # For item index, load from txt file if already exists, otherwise generate from user sequence and save.\n",
" if os.path.exists(item_index_file):\n",
" item_info = ReadLineFromFile(item_index_file)\n",
" item_map = get_dict_from_lines(item_info)\n",
" else:\n",
" item_map = dict()\n",
" items = set()\n",
" for user in user_sequence_dict:\n",
" items.update(user_sequence_dict[user])\n",
" items = list(items)\n",
" random.shuffle(items)\n",
" for item in items:\n",
" if item not in item_map:\n",
" item_map[item] = str(len(item_map) + 1001)\n",
" WriteDictToFile(item_index_file, item_map)\n",
"\n",
" reindex_user_sequence_dict = reindex(user_sequence_dict, user_map, item_map)\n",
" WriteDictToFile(reindex_sequence_file, reindex_user_sequence_dict)\n",
" return reindex_user_sequence_dict, item_map\n",
"\n",
"def collaborative_indexing(data_path, dataset, user_sequence_dict, token_size, cluster_num, last_token, float32):\n",
" \"\"\"\n",
" Use collaborative indexing method to index the given user seuqnece dict.\n",
" \"\"\"\n",
" user_index_file = os.path.join(data_path, dataset, 'user_indexing.txt')\n",
" item_index_file = os.path.join(data_path, dataset, f'item_collaborative_indexing_{token_size}_{cluster_num}_{last_token}.txt')\n",
" reindex_sequence_file = os.path.join(data_path, dataset, f'user_sequence_collaborative_indexing_{token_size}_{cluster_num}_{last_token}.txt')\n",
"\n",
" if os.path.exists(reindex_sequence_file):\n",
" user_sequence = ReadLineFromFile(reindex_sequence_file)\n",
"\n",
" item_info = ReadLineFromFile(item_index_file)\n",
" item_map = get_dict_from_lines(item_info)\n",
"\n",
" return construct_user_sequence_dict(user_sequence), item_map\n",
"\n",
" # For user index, load from txt file if already exists, otherwise generate from user sequence and save.\n",
" if os.path.exists(user_index_file):\n",
" user_info = ReadLineFromFile(user_index_file)\n",
" user_map = get_dict_from_lines(user_info)\n",
" else:\n",
" user_map = generate_user_map(user_sequence_dict)\n",
" WriteDictToFile(user_index_file, user_map)\n",
"\n",
"\n",
" # For item index, load from txt file if already exists, otherwise generate from user sequence and save.\n",
" if os.path.exists(item_index_file):\n",
" item_info = ReadLineFromFile(item_index_file)\n",
" item_map = get_dict_from_lines(item_info)\n",
" else:\n",
" item_map = generate_collaborative_id(user_sequence_dict, token_size, cluster_num, last_token, float32)\n",
" WriteDictToFile(item_index_file, item_map)\n",
"\n",
" reindex_user_sequence_dict = reindex(user_sequence_dict, user_map, item_map)\n",
" WriteDictToFile(reindex_sequence_file, reindex_user_sequence_dict)\n",
" return reindex_user_sequence_dict, item_map\n",
"\n",
"def generate_collaborative_id(user_sequence_dict, token_size, cluster_num, last_token, float32):\n",
" \"\"\"\n",
" Generate collaborative index for items.\n",
" \"\"\"\n",
" # get the items in training data and all data.\n",
" all_items = set()\n",
" train_items = set()\n",
" for user in user_sequence_dict:\n",
" all_items.update(set(user_sequence_dict[user]))\n",
" train_items.update(set(user_sequence_dict[user][:-2]))\n",
"\n",
" # reindex all training items for calculating the adjacency matrix\n",
" item2id = dict()\n",
" id2item = dict()\n",
" for item in train_items:\n",
" item2id[item] = len(item2id)\n",
" id2item[len(id2item)] = item\n",
"\n",
"\n",
" # calculate the co-occurrence of items in the training data as an adjacency matrix\n",
" if float32 > 0:\n",
" adj_matrix = np.zeros((len(item2id), len(item2id)), dtype=np.float32)\n",
" else:\n",
" adj_matrix = np.zeros((len(item2id), len(item2id)))\n",
" for user in user_sequence_dict:\n",
" interactions = user_sequence_dict[user][:-2]\n",
" for pairs in combinations(interactions, 2):\n",
" adj_matrix[item2id[pairs[0]]][item2id[pairs[1]]] += 1\n",
" adj_matrix[item2id[pairs[1]]][item2id[pairs[0]]] += 1\n",
"\n",
"\n",
" # get the clustering results for the first layer\n",
" clustering = SpectralClustering(\n",
" n_clusters=cluster_num,\n",
" assign_labels=\"cluster_qr\",\n",
" random_state=0,\n",
" affinity=\"precomputed\",\n",
" ).fit(adj_matrix)\n",
" labels = clustering.labels_.tolist()\n",
"\n",
" # count the clustering results\n",
" grouping = defaultdict(list)\n",
" for i in range(len(labels)):\n",
" grouping[labels[i]].append((id2item[i],i))\n",
"\n",
" item_map = dict()\n",
" index_now = 0\n",
"\n",
" # add current clustering information into the item indexing results.\n",
" item_map, index_now = add_token_to_indexing(item_map, grouping, index_now, token_size)\n",
"\n",
" # add current clustering info into a queue for BFS\n",
" queue = []\n",
" for group in grouping:\n",
" queue.append(grouping[group])\n",
"\n",
" # apply BFS to further use spectral clustering for large groups (> token_size)\n",
" while queue:\n",
" group_items = queue.pop(0)\n",
"\n",
" # if current group is small enough, add the last token to item indexing\n",
" if len(group_items) <= token_size:\n",
" item_list = [items[0] for items in group_items]\n",
" if last_token == 'sequential':\n",
" item_map = add_last_token_to_indexing_sequential(item_map, item_list, token_size)\n",
" elif last_token == 'random':\n",
" item_map = add_last_token_to_indexing_random(item_map, item_list, token_size)\n",
" else:\n",
" # calculate the adjacency matrix for current group\n",
" if float32 > 0:\n",
" sub_adj_matrix = np.zeros((len(group_items), len(group_items)), dtype=np.float32)\n",
" else:\n",
" sub_adj_matrix = np.zeros((len(group_items), len(group_items)))\n",
" for i in range(len(group_items)):\n",
" for j in range(i+1, len(group_items)):\n",
" sub_adj_matrix[i][j] = adj_matrix[group_items[i][1]][group_items[j][1]]\n",
" sub_adj_matrix[j][i] = adj_matrix[group_items[j][1]][group_items[i][1]]\n",
"\n",
" # get the clustering results for current group\n",
" clustering = SpectralClustering(\n",
" n_clusters=cluster_num,\n",
" assign_labels=\"cluster_qr\",\n",
" random_state=0,\n",
" affinity=\"precomputed\",\n",
" ).fit(sub_adj_matrix)\n",
" labels = clustering.labels_.tolist()\n",
"\n",
" # count current clustering results\n",
" grouping = defaultdict(list)\n",
" for i in range(len(labels)):\n",
" grouping[labels[i]].append(group_items[i])\n",
"\n",
" # add current clustering information into the item indexing results.\n",
" item_map, index_now = add_token_to_indexing(item_map, grouping, index_now, token_size)\n",
"\n",
" # push current clustering info into the queue\n",
" for group in grouping:\n",
" queue.append(grouping[group])\n",
"\n",
" # if some items are not in the training data, assign an index for them\n",
" remaining_items = list(all_items - train_items)\n",
" if len(remaining_items) > 0:\n",
" if last_token == 'sequential':\n",
" item_map = add_last_token_to_indexing_sequential(item_map, remaining_items, token_size)\n",
" elif last_token == 'random':\n",
" item_map = add_last_token_to_indexing_random(item_map, remaining_items, token_size)\n",
"\n",
" return item_map\n",
"\n",
"\n",
"\n",
"def add_token_to_indexing(item_map, grouping, index_now, token_size):\n",
" for group in grouping:\n",
" index_now = index_now % token_size\n",
" for (item, idx) in grouping[group]:\n",
" if item not in item_map:\n",
" item_map[item] = ''\n",
" item_map[item] += f'<CI{index_now}>'\n",
" index_now += 1\n",
" return item_map, index_now\n",
"\n",
"def add_last_token_to_indexing_random(item_map, item_list, token_size):\n",
" last_tokens = random.sample([i for i in range(token_size)], len(item_list))\n",
" for i in range(len(item_list)):\n",
" item = item_list[i]\n",
" if item not in item_map:\n",
" item_map[item] = ''\n",
" item_map[item] += f'<CI{last_tokens[i]}>'\n",
" return item_map\n",
"\n",
"def add_last_token_to_indexing_sequential(item_map, item_list, token_size):\n",
" for i in range(len(item_list)):\n",
" item = item_list[i]\n",
" if item not in item_map:\n",
" item_map[item] = ''\n",
" item_map[item] += f'<CI{i}>'\n",
" return item_map\n",
"\n",
"\n",
"def get_dict_from_lines(lines):\n",
" \"\"\"\n",
" Used to get user or item map from lines loaded from txt file.\n",
" \"\"\"\n",
" index_map = dict()\n",
" for line in lines:\n",
" info = line.split(\" \")\n",
" index_map[info[0]] = info[1]\n",
" return index_map\n",
"\n",
"\n",
"\n",
"\n",
"def generate_user_map(user_sequence_dict):\n",
" \"\"\"\n",
" generate user map based on user sequence dict.\n",
" \"\"\"\n",
" user_map = dict()\n",
" for user in user_sequence_dict.keys():\n",
" user_map[user] = str(len(user_map) + 1)\n",
" return user_map\n",
"\n",
"\n",
"def reindex(user_sequence_dict, user_map, item_map):\n",
" \"\"\"\n",
" reindex the given user sequence dict by given user map and item map\n",
" \"\"\"\n",
" reindex_user_sequence_dict = dict()\n",
" for user in user_sequence_dict:\n",
" uid = user_map[user]\n",
" items = user_sequence_dict[user]\n",
" reindex_user_sequence_dict[uid] = [item_map[i] for i in items]\n",
"\n",
" return reindex_user_sequence_dict\n",
"\n",
"\n",
"def construct_user_sequence_dict(user_sequence):\n",
" \"\"\"\n",
" Convert a list of string to a user sequence dict. user as key, item list as value.\n",
" \"\"\"\n",
"\n",
" user_seq_dict = dict()\n",
" for line in user_sequence:\n",
" user_seq = line.split(\" \")\n",
" user_seq_dict[user_seq[0]] = user_seq[1:]\n",
" return user_seq_dict"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WrtxYVChmVuW"
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uNRMkMSakNbx"
},
"outputs": [],
"source": [
"# from https://github.com/agiresearch/OpenP5/blob/main/src/train.py\n",
"\n",
"def train(\n",
" seed: int = 2023,\n",
" model_dir: str = \"../model\",\n",
" checkpoint_dir: str = \"../checkpoint\",\n",
" model_name: str = \"model.pt\",\n",
" log_dir: str = \"../log\",\n",
" master_addr: str = \"localhost\",\n",
" master_port: str = \"12345\",\n",
" logging_level: int = logging.INFO,\n",
" data_path: str = \"../data\",\n",
" item_indexing: str = \"sequential\",\n",
" tasks: str = 'sequential,straightforward',\n",
" datasets_: str = \"Beauty\",\n",
" prompt_file: str = '../prompt.txt',\n",
" sequential_orderl: str = \"original\",\n",
" collaborative_token_size: int = 200,\n",
" collaborative_cluster: int = 20,\n",
" collaborative_last_token: str = \"sequential\",\n",
" collaborative_float32: int = 0,\n",
" max_his: int = 10,\n",
" his_prefix: int = 1,\n",
" his_sep:str = \" , \",\n",
" skip_empty_his: str = 1,\n",
" valid_prompt: str = 'seen:0',\n",
" valid_prompt_sample: int = 1,\n",
" valid_sample_num: str = '3,3',\n",
" test_prompt: str = 'seen:0',\n",
" sample_prompt: int = 0,\n",
" sample_num: str = '2,2',\n",
" cutoff: int = 1024,\n",
" batch_size: int = 32,\n",
" eval_batch_size: int = 32,\n",
" group_task_in_batch: int = 1,\n",
" task_alternating_optim: int = 0,\n",
" optim: str = \"adamw_torch\",\n",
" epochs: int = 10,\n",
" lr: float = 1e-3,\n",
" clip: float = 1,\n",
" logging_steps: int= 10,\n",
" warmup_steps: int = 100,\n",
" gradient_accumulation_steps: int = 16,\n",
" weight_decay: float = 0.01,\n",
" adam_eps: float = 1e-6,\n",
" dropout: float = 0.1,\n",
" alpha: float = 2,\n",
" backbone: str = 't5-small',\n",
" random_initialize: int = 1,\n",
" test_epoch: int = 1,\n",
" valid_select: int = 0,\n",
" lora: int = 0,\n",
" lora_r: int = 8,\n",
" lora_alpha: int = 16,\n",
" lora_dropout: float = 0.05,\n",
" lora_target_modules: str = 'q_proj,v_proj,embed_tokens',\n",
" metrics: str = 'hit@5,hit@10,ndcg@5,ndcg@10'\n",
"):\n",
"\n",
" # setup\n",
" setup_logging(\n",
" log_name,\n",
" datasets_,\n",
" log_dir,\n",
" model_dir,\n",
" checkpoint_dir,\n",
" logging_level,\n",
" sample_prompt,\n",
" his_prefix,\n",
" skip_empty_his,\n",
" max_his,\n",
" master_port,\n",
" tasks,\n",
" backbone,\n",
" item_indexing,\n",
" lr,\n",
" epochs,\n",
" batch_size,\n",
" sample_num,\n",
" prompt_file\n",
" )\n",
"\n",
" set_seed(seed)\n",
"\n",
" # determine whether distributed\n",
" device_map = \"auto\"\n",
" world_size = int(os.environ.get(\"WORLD_SIZE\", 1))\n",
" ddp = world_size != 1\n",
" if ddp:\n",
" device_map = {\"\": int(os.environ.get(\"LOCAL_RANK\") or 0)}\n",
" gradient_accumulation_steps = gradient_accumulation_steps // world_size\n",
"\n",
" # use wandb\n",
" wandb_project = \"\"\n",
" wandb_run_name = \"\"\n",
" wandb_watch = \"\" # options: false | gradients | all\n",
" wandb_log_model = \"\"\n",
" use_wandb = len(wandb_project) > 0 or (\n",
" \"WANDB_PROJECT\" in os.environ and len(os.environ[\"WANDB_PROJECT\"]) > 0\n",
" )\n",
" # Only overwrite environ if wandb param passed\n",
" if len(wandb_project) > 0:\n",
" os.environ[\"WANDB_PROJECT\"] = wandb_project\n",
" if len(wandb_watch) > 0:\n",
" os.environ[\"WANDB_WATCH\"] = wandb_watch\n",
" if len(wandb_log_model) > 0:\n",
" os.environ[\"WANDB_LOG_MODEL\"] = wandb_log_model\n",
"\n",
" # load model, tokenizer\n",
" if 't5' in backbone.lower():\n",
" config = T5Config.from_pretrained(backbone)\n",
" model = T5ForConditionalGeneration.from_pretrained(backbone, config=config)\n",
" tokenizer = AutoTokenizer.from_pretrained(backbone)\n",
" elif 'llama' in backbone.lower():\n",
" model = LlamaForCausalLM.from_pretrained(\n",
" 'meta-llama/' + backbone,\n",
" load_in_8bit=True,\n",
" torch_dtype=torch.float16\n",
" )\n",
" tokenizer = AutoTokenizer.from_pretrained('meta-llama/' + backbone)\n",
" else:\n",
" raise NotImplementError\n",
"\n",
"\n",
"\n",
"\n",
" datasets = datasets_.split(',')\n",
" if len(datasets) == 1:\n",
" dataset = datasets[0]\n",
" train_data_file = os.path.join(data_path, dataset, f'{dataset}_{tasks}_{item_indexing}_train.json')\n",
" valid_data_file = os.path.join(data_path, dataset, f'{dataset}_{tasks}_{item_indexing}_validation_{valid_prompt}.json')\n",
" train_data = load_dataset(\"json\", data_files=train_data_file, field='data')\n",
" valid_data = load_dataset(\"json\", data_files=valid_data_file, field='data')\n",
" else:\n",
" train_data_list, valid_data_list = [], []\n",
" for dataset in datasets:\n",
" train_data_file = os.path.join(data_path, dataset, f'{dataset}_{tasks}_{item_indexing}_train.json')\n",
" valid_data_file = os.path.join(data_path, dataset, f'{dataset}_{tasks}_{item_indexing}_validation_{valid_prompt}.json')\n",
" t_data = load_dataset(\"json\", data_files=train_data_file, field='data')\n",
" v_data = load_dataset(\"json\", data_files=valid_data_file, field='data')\n",
" train_data_list.append(t_data)\n",
" valid_data_list.append(v_data)\n",
" train_data = concatenate_datasets(train_data_list)\n",
" valid_data = concatenate_datasets(valid_data_list)\n",
"\n",
" def tokenize(prompt, add_eos_token=True):\n",
" # there's probably a way to do this with the tokenizer settings\n",
" # but again, gotta move fast\n",
" result = tokenizer(\n",
" prompt, truncation=True, max_length=cutoff, padding=False, return_tensors=None,\n",
" )\n",
" if (isinstance(result[\"input_ids\"][-1], int) and result[\"input_ids\"][-1] != tokenizer.eos_token_id\n",
" and len(result[\"input_ids\"]) < cutoff\n",
" and add_eos_token\n",
" ):\n",
" result[\"input_ids\"].append(tokenizer.eos_token_id)\n",
" result[\"attention_mask\"].append(1)\n",
" elif isinstance(result[\"input_ids\"][-1], list) and add_eos_token:\n",
" for i in range(len(result['input_ids'])):\n",
" if result[\"input_ids\"][i][-1] != tokenizer.eos_token_id and len(result[\"input_ids\"][i]) < cutoff:\n",
" result[\"input_ids\"][i].append(tokenizer.eos_token_id)\n",
" result[\"attention_mask\"][i].append(1)\n",
"\n",
" result[\"labels\"] = result[\"input_ids\"].copy()\n",
"\n",
" return result\n",
"\n",
" def generate_prompt(data_point):\n",
" return f'{data_point[\"input\"]} {data_point[\"output\"]}'\n",
"\n",
" def process_func(datapoint):\n",
" if 't5' in backbone.lower():\n",
" encoding = tokenize(datapoint['input'], add_eos_token=True)\n",
" labels = tokenize(datapoint['output'], add_eos_token=True)\n",
" encoding['labels'] = labels['input_ids'].copy()\n",
" elif 'llama' in backbone.lower():\n",
" user_prompt = generate_prompt({**datapoint, \"output\": \"\"})\n",
" encoding_input = tokenize(user_prompt, add_eos_token=False)\n",
" input_len = len(encoding_input[\"input_ids\"])\n",
" full_prompt = generate_prompt(datapoint)\n",
" encoding = tokenize(full_prompt)\n",
"\n",
" encoding[\"labels\"] = (\n",
" [-100] * input_len\n",
" + encoding[\"labels\"][input_len:]\n",
" )\n",
"\n",
" # return encoding\n",
" return {**datapoint, **encoding}\n",
"\n",
" # add token and resize embedding for collaborative indexing\n",
" if item_indexing == 'collaborative':\n",
" new_tokens = []\n",
" for dataset in datasets:\n",
" item_index_file = os.path.join(data_path, dataset, f'item_collaborative_indexing_{collaborative_token_size}_{collaborative_cluster}_{collaborative_last_token}.txt')\n",
" item_info = ReadLineFromFile(item_index_file)\n",
" item_map = get_dict_from_lines(item_info)\n",
" for idx in list(item_map.values()):\n",
" new_token += re.findall(r'\\<.*?\\>', idx)\n",
" tokenizer.add_tokens(ds.new_token)\n",
" model.resize_token_embeddings(len(tokenizer))\n",
"\n",
" # no task alternating optimization if only one task in the data\n",
" if len(set(train_data['train']['task'])) == 1:\n",
" task_alternating_optim = 0\n",
"\n",
" if task_alternating_optim == 1:\n",
" TrainSet = dict()\n",
" for task in set(train_data['train']['task']):\n",
" TrainSet[task] = train_data['train'].filter(lambda example: example[\"task\"]==task)\n",
" for task in TrainSet:\n",
" TrainSet[task] = TrainSet[task].shuffle().map(process_func, batched=True)\n",
"\n",
" else:\n",
" TrainSet = train_data['train'].shuffle().map(process_func, batched=True)\n",
"\n",
" ValidSet = valid_data['train'].shuffle().map(process_func, batched=True)\n",
"\n",
"\n",
"\n",
" # randomly initialize number related tokens\n",
" if random_initialize == 1:\n",
" # logging.info(\"Random initialize number related tokens\")\n",
" random_initialization(model, tokenizer, backbone)\n",
"\n",
" # apply lora\n",
" if lora > 0:\n",
" model = prepare_model_for_int8_training(model)\n",
"\n",
" config = LoraConfig(\n",
" r=lora_r,\n",
" lora_alpha=lora_alpha,\n",
" target_modules=lora_target_modules.split(','),\n",
" lora_dropout=lora_dropout,\n",
" bias=\"none\",\n",
" task_type=\"CAUSAL_LM\",\n",
" )\n",
" model = get_peft_model(model, config)\n",
" model.print_trainable_parameters()\n",
"\n",
" tokenizer.pad_token_id = (\n",
" 0 # unk. we want this to be different from the eos token\n",
" )\n",
" tokenizer.padding_side = \"left\"\n",
"\n",
" # decide output dir\n",
" if len(datasets_.split(',')) > 1:\n",
" folder_name = 'SP5'\n",
" else:\n",
" folder_name = datasets_\n",
" output_dir = os.path.join(model_dir, folder_name, item_indexing, backbone)\n",
"\n",
" if task_alternating_optim == 1:\n",
" trainer = TaskAlternateTrainer(model=model,\n",
" train_dataset=TrainSet,\n",
" eval_dataset=ValidSet if valid_select > 0 else None,\n",
" args= transformers.TrainingArguments(\n",
" per_device_train_batch_size=batch_size,\n",
" gradient_accumulation_steps=gradient_accumulation_steps,\n",
" warmup_steps=warmup_steps,\n",
" num_train_epochs=epochs,\n",
" learning_rate=lr,\n",
" fp16=True,\n",
" logging_dir=log_dir,\n",
" logging_steps=logging_steps,\n",
" optim=optim,\n",
" evaluation_strategy=\"steps\" if valid_select > 0 else \"no\",\n",
" save_strategy=\"steps\",\n",
" eval_steps=200 if valid_select > 0 else None,\n",
" save_steps=200,\n",
" output_dir=output_dir,\n",
" save_total_limit=3,\n",
" load_best_model_at_end=True if valid_select > 0 else False,\n",
" ddp_find_unused_parameters=False if ddp else None,\n",
" group_by_length=False,\n",
" report_to=\"wandb\" if use_wandb else None,\n",
" run_name=wandb_run_name if use_wandb else None,\n",
" ),\n",
" data_collator=transformers.DataCollatorForSeq2Seq(\n",
" tokenizer, pad_to_multiple_of=8, return_tensors=\"pt\", padding=True\n",
" ),\n",
" )\n",
" else:\n",
" trainer = transformers.Trainer(\n",
" model=model,\n",
" train_dataset=TrainSet,\n",
" eval_dataset=ValidSet if valid_select > 0 else None,\n",
" args= transformers.TrainingArguments(\n",
" per_device_train_batch_size=batch_size,\n",
" gradient_accumulation_steps=gradient_accumulation_steps,\n",
" warmup_steps=warmup_steps,\n",
" num_train_epochs=epochs,\n",
" learning_rate=lr,\n",
" fp16=True,\n",
" logging_steps=logging_steps,\n",
" optim=optim,\n",
" evaluation_strategy=\"steps\" if valid_select > 0 else \"no\",\n",
" save_strategy=\"steps\",\n",
" eval_steps=200 if valid_select > 0 else None,\n",
" save_steps=200,\n",
" output_dir=output_dir,\n",
" save_total_limit=3,\n",
" load_best_model_at_end=True if valid_select > 0 else False,\n",
" ddp_find_unused_parameters=False if ddp else None,\n",
" group_by_length=False,\n",
" report_to=\"wandb\" if use_wandb else None,\n",
" run_name=wandb_run_name if use_wandb else None,\n",
" ),\n",
" data_collator=transformers.DataCollatorForSeq2Seq(\n",
" tokenizer, pad_to_multiple_of=8, return_tensors=\"pt\", padding=True\n",
" ),\n",
" )\n",
"\n",
" trainer.train()\n",
"\n",
" model.save_pretrained(output_dir)\n",
" tokenizer.save_pretrained(output_dir)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0o3bmaaOsicZ"
},
"outputs": [],
"source": [
"train(\n",
" model_dir=\"/content/drive/MyDrive/OpenP5/models/\",\n",
" log_dir=\"/content/drive/MyDrive/OpenP5/logs/\",\n",
" master_port=1234,\n",
" item_indexing=\"sequential\",\n",
" tasks=\"sequential,straightforward\",\n",
" datasets_=\"Beauty\",\n",
" data_path = '/content/drive/MyDrive/OpenP5/data/',\n",
" epochs=10,\n",
" batch_size=512,\n",
" backbone=\"t5-small\",\n",
" cutoff=1024\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rVPlAgoItT0K"
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CvD6VXbcv-me"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"machine_shape": "hm",
"provenance": [],
"gpuType": "A100",
"mount_file_id": "1xZL99KZEHlaop9JBPzG5ms2lZJFmfn_b",
"authorship_tag": "ABX9TyP+SKCMl6NJs+nraank8qWC",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment