-
-
Save nogawanogawa/b751146ff57ced6a70deba45f578bcc6 to your computer and use it in GitHub Desktop.
OpenP5_train_notebook.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/nogawanogawa/b751146ff57ced6a70deba45f578bcc6/openp5_train_notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "dt6kfEGBkU7u", | |
"outputId": "a98382d4-123f-44ff-fa14-1923b7351154" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.35.2)\n", | |
"Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.15.0)\n", | |
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n", | |
"Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.3)\n", | |
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n", | |
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n", | |
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", | |
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n", | |
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n", | |
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.0)\n", | |
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.0)\n", | |
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n", | |
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n", | |
"Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.5)\n", | |
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.7)\n", | |
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n", | |
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", | |
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.15)\n", | |
"Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n", | |
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.6)\n", | |
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n", | |
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.3.2)\n", | |
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n", | |
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", | |
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n", | |
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n", | |
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", | |
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.5.0)\n", | |
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n", | |
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n", | |
"Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", | |
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n", | |
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n", | |
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.66.1)\n", | |
"Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.6.2)\n", | |
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from peft) (1.23.5)\n", | |
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from peft) (23.2)\n", | |
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n", | |
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from peft) (6.0.1)\n", | |
"Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.1.0+cu118)\n", | |
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from peft) (4.35.2)\n", | |
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from peft) (4.66.1)\n", | |
"Requirement already satisfied: accelerate>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from peft) (0.24.1)\n", | |
"Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft) (0.4.0)\n", | |
"Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.21.0->peft) (0.19.3)\n", | |
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.13.1)\n", | |
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (4.5.0)\n", | |
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (1.12)\n", | |
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.2.1)\n", | |
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.2)\n", | |
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2023.6.0)\n", | |
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.1.0)\n", | |
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2023.6.3)\n", | |
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2.31.0)\n", | |
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (0.15.0)\n", | |
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.3)\n", | |
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (3.3.2)\n", | |
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (3.4)\n", | |
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (2.0.7)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (2023.7.22)\n", | |
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13.0->peft) (1.3.0)\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install transformers datasets\n", | |
"!pip install tqdm\n", | |
"!pip install peft" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "u4M5mcq3kHK6" | |
}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import transformers\n", | |
"import argparse\n", | |
"import torch\n", | |
"import logging\n", | |
"from torch.utils.data import ConcatDataset, DataLoader\n", | |
"from datasets import load_dataset, concatenate_datasets\n", | |
"import re\n", | |
"import sys\n", | |
"import random\n", | |
"\n", | |
"import numpy as np\n", | |
"import pickle\n", | |
"import inspect\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"\n", | |
"from transformers import (\n", | |
" T5Config,\n", | |
" T5ForConditionalGeneration,\n", | |
" AutoTokenizer,\n", | |
" LlamaForCausalLM\n", | |
")\n", | |
"\n", | |
"from peft import ( # noqa: E402\n", | |
" LoraConfig,\n", | |
" get_peft_model,\n", | |
" get_peft_model_state_dict,\n", | |
" prepare_model_for_int8_training,\n", | |
" set_peft_model_state_dict,\n", | |
")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "76btAfm1kT8v" | |
}, | |
"outputs": [], | |
"source": [ | |
"# from https://github.com/agiresearch/OpenP5/blob/main/src/TaskAlternateTrainer.py\n", | |
"\n", | |
"from torch.utils.data.distributed import DistributedSampler\n", | |
"from torch.utils.data.sampler import RandomSampler\n", | |
"from torch.utils.data import DataLoader\n", | |
"import inspect\n", | |
"import transformers\n", | |
"from transformers.utils import find_labels\n", | |
"\n", | |
"class MultitaskDataloader:\n", | |
" \"\"\"\n", | |
" Data loader that combines and samples from multiple single-task\n", | |
" data loaders.\n", | |
" \"\"\"\n", | |
" def __init__(self, model, dataset_dict, batch_size, collate_fn, accelerator=None):\n", | |
" self.model = model\n", | |
" self.batch_size = batch_size\n", | |
" self.collate_fn = collate_fn\n", | |
" self._signature_columns = None\n", | |
"\n", | |
" for task in dataset_dict:\n", | |
" dataset_dict[task] = self._remove_unused_columns(dataset_dict[task])\n", | |
"\n", | |
" self.dataloader_dict = dict()\n", | |
" for task in dataset_dict:\n", | |
" self.dataloader_dict[task] = DataLoader(dataset_dict[task], batch_size = self.batch_size, sampler = RandomSampler(dataset_dict[task]), collate_fn = self.collate_fn)\n", | |
" if accelerator is not None:\n", | |
" for task in self.dataloader_dict:\n", | |
" self.dataloader_dict[task] = accelerator.prepare(self.dataloader_dict[task])\n", | |
"\n", | |
" self.num_batches_dict = {task: len(dataloader) for task, dataloader in self.dataloader_dict.items()}\n", | |
" self.tasks_list = list(self.dataloader_dict.keys())\n", | |
" self.num_tasks = len(self.tasks_list)\n", | |
"\n", | |
" self.task_max_num_batch = max(list(self.num_batches_dict.values()))\n", | |
"\n", | |
" self._init()\n", | |
"\n", | |
"\n", | |
" def _init(self):\n", | |
" self.dataloader_iters = [iter(dataloader) for task, dataloader in self.dataloader_dict.items()]\n", | |
" self.batch_idx = 0\n", | |
"\n", | |
"\n", | |
" def _set_signature_columns_if_needed(self):\n", | |
" if self._signature_columns is None:\n", | |
" # Inspect model forward signature to keep only the arguments it accepts.\n", | |
" signature = inspect.signature(self.model.forward)\n", | |
" self._signature_columns = list(signature.parameters.keys())\n", | |
" # Labels may be named label or label_ids, the default data collator handles that.\n", | |
" self._signature_columns += list(set([\"label\", \"label_ids\"] + find_labels(self.model.__class__)))\n", | |
"\n", | |
" def _remove_unused_columns(self, dataset):\n", | |
" self._set_signature_columns_if_needed()\n", | |
" signature_columns = self._signature_columns\n", | |
"\n", | |
" ignored_columns = list(set(dataset.column_names) - set(signature_columns))\n", | |
" return dataset.remove_columns(ignored_columns)\n", | |
"\n", | |
" def __len__(self):\n", | |
" return self.num_tasks * self.task_max_num_batch\n", | |
"\n", | |
" def __iter__(self):\n", | |
" return self\n", | |
"\n", | |
" def __next__(self):\n", | |
" task_id = self.batch_idx % self.num_tasks\n", | |
" self.batch_idx += 1\n", | |
" # if self.batch_idx < self.num_tasks * self.task_max_num_batch:\n", | |
" try:\n", | |
" return next(self.dataloader_iters[task_id])\n", | |
" except StopIteration:\n", | |
" self.dataloader_iters[task_id] = iter(self.dataloader_dict[self.tasks_list[task_id]])\n", | |
" return next(self.dataloader_iters[task_id])\n", | |
" # self._init()\n", | |
"\n", | |
"\n", | |
"class TaskAlternateTrainer(transformers.Trainer):\n", | |
"\n", | |
" def get_train_dataloader(self):\n", | |
" train_dataset = self.train_dataset\n", | |
" data_collator = self.data_collator\n", | |
"\n", | |
" return MultitaskDataloader(self.model, train_dataset, self._train_batch_size, data_collator, self.accelerator)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "FqSRHxoLlGvE" | |
}, | |
"outputs": [], | |
"source": [ | |
"# from https://github.com/agiresearch/OpenP5/blob/main/src/utils/utils.py\n", | |
"\n", | |
"def set_seed(seed):\n", | |
" random.seed(seed)\n", | |
" np.random.seed(seed)\n", | |
" torch.manual_seed(seed)\n", | |
" torch.cuda.manual_seed_all(seed)\n", | |
" torch.backends.cudnn.benchmark = False\n", | |
" torch.backends.cudnn.deterministic = True\n", | |
" torch.backends.cudnn.enabled = False\n", | |
"\n", | |
"\n", | |
"\n", | |
"def random_initialization(model, tokenizer, backbone):\n", | |
" ids = []\n", | |
" for x in range(30000):\n", | |
" tokenized_ids = tokenizer.encode(str(x))\n", | |
" if 3 in tokenized_ids:\n", | |
" tokenized_ids.remove(3)\n", | |
" if 1 in tokenized_ids:\n", | |
" tokenized_ids.remove(1)\n", | |
" ids += tokenized_ids\n", | |
" ids = list(set(ids))\n", | |
"\n", | |
" # reinitialize the embedding in the backbone models\n", | |
" for index in ids:\n", | |
" if 't5' in backbone:\n", | |
" model.shared.weight.data[index] = nn.init.normal_(\n", | |
" model.shared.weight.data[index], 0, 1.0\n", | |
" )\n", | |
" elif 'llama' in backbone.lower():\n", | |
" model.model.embed_tokens.weight.data[index] = nn.init.normal_(\n", | |
" model.model.embed_tokens.weight.data[index], 0, 1.0\n", | |
" )\n", | |
"\n", | |
" return model\n", | |
"\n", | |
"def setup_logging(\n", | |
" log_name: str,\n", | |
" datasets: str,\n", | |
" log_dir: str,\n", | |
" model_dir: str,\n", | |
" checkpoint_dir: str,\n", | |
" logging_level: int,\n", | |
" sample_prompt: str,\n", | |
" his_prefix: str,\n", | |
" skip_empty_his: int,\n", | |
" max_his: int,\n", | |
" master_port: int,\n", | |
" tasks: str,\n", | |
" backbone: str,\n", | |
" item_indexing: str,\n", | |
" lr: float,\n", | |
" epochs: int,\n", | |
" batch_size: int,\n", | |
" sample_num: int,\n", | |
" prompt_file: str\n", | |
"):\n", | |
" log_name = log_name(\n", | |
" datasets,\n", | |
" sample_prompt,\n", | |
" his_prefix,\n", | |
" skip_empty_his,\n", | |
" max_his,\n", | |
" master_port,\n", | |
" tasks,\n", | |
" backbone,\n", | |
" item_indexing,\n", | |
" lr,\n", | |
" epochs,\n", | |
" batch_size,\n", | |
" sample_num,\n", | |
" prompt_file\n", | |
" )\n", | |
" if len(datasets.split(',')) > 1:\n", | |
" folder_name = 'SP5'\n", | |
" else:\n", | |
" folder_name = datasets\n", | |
" folder = os.path.join(log_dir, folder_name)\n", | |
" if not os.path.exists(folder):\n", | |
" os.makedirs(folder)\n", | |
" log_file = os.path.join(log_dir, folder_name, log_name + '.log')\n", | |
"\n", | |
" for handler in logging.root.handlers[:]:\n", | |
" logging.root.removeHandler(handler)\n", | |
" logging.basicConfig(filename=log_file, level=logging_level, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n", | |
" logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))\n", | |
"\n", | |
" return\n", | |
"\n", | |
"def log_name(\n", | |
" datasets,\n", | |
" sample_prompt,\n", | |
" his_prefix,\n", | |
" skip_empty_his,\n", | |
" max_his,\n", | |
" master_port,\n", | |
" tasks,\n", | |
" backbone,\n", | |
" item_indexing,\n", | |
" lr,\n", | |
" epochs,\n", | |
" batch_size,\n", | |
" sample_num,\n", | |
" prompt_file,\n", | |
"):\n", | |
" if len(datasets.split(',')) > 1:\n", | |
" folder_name = 'SP5'\n", | |
" else:\n", | |
" folder_name = datasets\n", | |
" params = [str(sample_prompt), str(his_prefix), str(skip_empty_his), str(max_his), str(master_port), folder_name, tasks, backbone, item_indexing, str(lr), str(epochs), str(batch_size), sample_num, prompt_file[3:-4]]\n", | |
" return '_'.join(params)\n", | |
"\n", | |
"def ReadLineFromFile(path):\n", | |
" if not os.path.exists(path):\n", | |
" raise FileNotFoundError\n", | |
" lines = []\n", | |
" with open(path,'r') as fd:\n", | |
" for line in fd:\n", | |
" lines.append(line.rstrip('\\n'))\n", | |
" return lines\n", | |
"\n", | |
"def WriteDictToFile(path, write_dict):\n", | |
" with open(path, 'w') as out:\n", | |
" for user, items in write_dict.items():\n", | |
" if type(items) == list:\n", | |
" out.write(user + ' ' + ' '.join(items) + '\\n')\n", | |
" else:\n", | |
" out.write(user + ' ' + str(items) + '\\n')\n", | |
"\n", | |
"\n", | |
"def load_prompt_template(path, task_list):\n", | |
" \"\"\"\n", | |
" Load prompt template from the file. Keep training tasks only.\n", | |
" Input:\n", | |
" - path: The path for prompt template txt file.\n", | |
" - task_list: A list of required tasks.\n", | |
" Return:\n", | |
" - prompt_templates: a dictionary of prompt templates. e.g., {task: {'seen': {'0': {'Input': template_input, 'Output': template_output}}}}\n", | |
"\n", | |
" \"\"\"\n", | |
"\n", | |
" if not os.path.exists(path):\n", | |
" raise FileNotFoundError\n", | |
" prompt_info = ReadLineFromFile(path)\n", | |
" prompt_templates = dict()\n", | |
" for prompt in prompt_info:\n", | |
" t = [sens.strip() for sens in prompt.split(';')]\n", | |
" if t[0] not in task_list:\n", | |
" continue\n", | |
" if t[0] not in prompt_templates:\n", | |
" prompt_templates[t[0]] = dict()\n", | |
" if t[1] not in prompt_templates[t[0]]:\n", | |
" prompt_templates[t[0]][t[1]] = dict()\n", | |
" num = len(prompt_templates[t[0]][t[1]])\n", | |
" prompt_templates[t[0]][t[1]][str(num)] = dict()\n", | |
" prompt_templates[t[0]][t[1]][str(num)]['Input'] = t[2]\n", | |
" prompt_templates[t[0]][t[1]][str(num)]['Output'] = t[3]\n", | |
" return prompt_templates\n", | |
"\n", | |
"def get_info_from_prompt(prompt_templates):\n", | |
" \"\"\"\n", | |
" Extract the require information from the prompt templates.\n", | |
" Input:\n", | |
" - prompt_templates: a dictionary of prompt templates.\n", | |
" Output:\n", | |
" - info: a list of required information.\n", | |
" \"\"\"\n", | |
"\n", | |
" info = []\n", | |
" for task in prompt_templates:\n", | |
" for see in prompt_templates[task]:\n", | |
" for i in prompt_templates[task][see]:\n", | |
" info += re.findall(r'\\{.*?\\}', prompt_templates[task][see][i]['Input'])\n", | |
" info += re.findall(r'\\{.*?\\}', prompt_templates[task][see][i]['Output'])\n", | |
" info = [i[1:-1] for i in set(info)]\n", | |
" return info\n", | |
"\n", | |
"def check_task_prompt(prompt_templates, task_list):\n", | |
" \"\"\"\n", | |
" Check if all tasks have prompt templates. Raise Error if training tasks have no prompt.\n", | |
" Input:\n", | |
" - prompt_templates: A dictionary of prompt templates.\n", | |
" - task_list: A list of training tasks.\n", | |
" \"\"\"\n", | |
" for task in task_list:\n", | |
" assert task in prompt_templates, f\"No prompt for {task} task\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "Y8XKqvL1lL0p" | |
}, | |
"outputs": [], | |
"source": [ | |
"# from https://github.com/agiresearch/OpenP5/blob/main/src/utils/indexing.py\n", | |
"\n", | |
"def sequential_indexing(data_path, dataset, user_sequence_dict, order):\n", | |
" \"\"\"\n", | |
" Use sequential indexing method to index the given user seuqnece dict.\n", | |
" \"\"\"\n", | |
" user_index_file = os.path.join(data_path, dataset, 'user_indexing.txt')\n", | |
" item_index_file = os.path.join(data_path, dataset, f'item_sequential_indexing_{order}.txt')\n", | |
" reindex_sequence_file = os.path.join(data_path, dataset, f'user_sequence_sequential_indexing_{order}.txt')\n", | |
"\n", | |
" if os.path.exists(reindex_sequence_file):\n", | |
" user_sequence = ReadLineFromFile(reindex_sequence_file)\n", | |
"\n", | |
" item_info = ReadLineFromFile(item_index_file)\n", | |
" item_map = get_dict_from_lines(item_info)\n", | |
"\n", | |
" return construct_user_sequence_dict(user_sequence), item_map\n", | |
"\n", | |
" # For user index, load from txt file if already exists, otherwise generate from user sequence and save.\n", | |
" if os.path.exists(user_index_file):\n", | |
" user_info = ReadLineFromFile(user_index_file)\n", | |
" user_map = get_dict_from_lines(user_info)\n", | |
" else:\n", | |
" user_map = generate_user_map(user_sequence_dict)\n", | |
" WriteDictToFile(user_index_file, user_map)\n", | |
"\n", | |
"\n", | |
" # For item index, load from txt file if already exists, otherwise generate from user sequence and save.\n", | |
" if os.path.exists(item_index_file):\n", | |
" item_info = ReadLineFromFile(item_index_file)\n", | |
" item_map = get_dict_from_lines(item_info)\n", | |
" else:\n", | |
" item_map = dict()\n", | |
" if order == 'original':\n", | |
" user_list = user_sequence_dict.keys()\n", | |
" elif order == 'short2long':\n", | |
" user_list = sorted(user_sequence_dict, key=lambda x: len(user_sequence_dict[x]), reverse=False)\n", | |
" elif order == 'long2short':\n", | |
" user_list = sorted(user_sequence_dict, key=lambda x: len(user_sequence_dict[x]), reverse=True)\n", | |
"\n", | |
" for user in user_list:\n", | |
" items = user_sequence_dict[user][:-2]\n", | |
" for item in items:\n", | |
" if item not in item_map:\n", | |
" item_map[item] = str(len(item_map) + 1001)\n", | |
" for user in user_list:\n", | |
" items = user_sequence_dict[user][-2:]\n", | |
" for item in items:\n", | |
" if item not in item_map:\n", | |
" item_map[item] = str(len(item_map) + 1001)\n", | |
" WriteDictToFile(item_index_file, item_map)\n", | |
"\n", | |
" reindex_user_sequence_dict = reindex(user_sequence_dict, user_map, item_map)\n", | |
" WriteDictToFile(reindex_sequence_file, reindex_user_sequence_dict)\n", | |
" return reindex_user_sequence_dict, item_map\n", | |
"\n", | |
"\n", | |
"\n", | |
"def random_indexing(data_path, dataset, user_sequence_dict):\n", | |
" \"\"\"\n", | |
" Use random indexing method to index the given user seuqnece dict.\n", | |
" \"\"\"\n", | |
" user_index_file = os.path.join(data_path, dataset, 'user_indexing.txt')\n", | |
" item_index_file = os.path.join(data_path, dataset, 'item_random_indexing.txt')\n", | |
" reindex_sequence_file = os.path.join(data_path, dataset, f'user_sequence_random_indexing.txt')\n", | |
"\n", | |
" if os.path.exists(reindex_sequence_file):\n", | |
" user_sequence = ReadLineFromFile(reindex_sequence_file)\n", | |
"\n", | |
" item_info = ReadLineFromFile(item_index_file)\n", | |
" item_map = get_dict_from_lines(item_info)\n", | |
"\n", | |
" return construct_user_sequence_dict(user_sequence), item_map\n", | |
"\n", | |
" # For user index, load from txt file if already exists, otherwise generate from user sequence and save.\n", | |
" if os.path.exists(user_index_file):\n", | |
" user_info = ReadLineFromFile(user_index_file)\n", | |
" user_map = get_dict_from_lines(user_info)\n", | |
" else:\n", | |
" user_map = generate_user_map(user_sequence_dict)\n", | |
" WriteDictToFile(user_index_file, user_map)\n", | |
"\n", | |
"\n", | |
" # For item index, load from txt file if already exists, otherwise generate from user sequence and save.\n", | |
" if os.path.exists(item_index_file):\n", | |
" item_info = ReadLineFromFile(item_index_file)\n", | |
" item_map = get_dict_from_lines(item_info)\n", | |
" else:\n", | |
" item_map = dict()\n", | |
" items = set()\n", | |
" for user in user_sequence_dict:\n", | |
" items.update(user_sequence_dict[user])\n", | |
" items = list(items)\n", | |
" random.shuffle(items)\n", | |
" for item in items:\n", | |
" if item not in item_map:\n", | |
" item_map[item] = str(len(item_map) + 1001)\n", | |
" WriteDictToFile(item_index_file, item_map)\n", | |
"\n", | |
" reindex_user_sequence_dict = reindex(user_sequence_dict, user_map, item_map)\n", | |
" WriteDictToFile(reindex_sequence_file, reindex_user_sequence_dict)\n", | |
" return reindex_user_sequence_dict, item_map\n", | |
"\n", | |
"def collaborative_indexing(data_path, dataset, user_sequence_dict, token_size, cluster_num, last_token, float32):\n", | |
" \"\"\"\n", | |
" Use collaborative indexing method to index the given user seuqnece dict.\n", | |
" \"\"\"\n", | |
" user_index_file = os.path.join(data_path, dataset, 'user_indexing.txt')\n", | |
" item_index_file = os.path.join(data_path, dataset, f'item_collaborative_indexing_{token_size}_{cluster_num}_{last_token}.txt')\n", | |
" reindex_sequence_file = os.path.join(data_path, dataset, f'user_sequence_collaborative_indexing_{token_size}_{cluster_num}_{last_token}.txt')\n", | |
"\n", | |
" if os.path.exists(reindex_sequence_file):\n", | |
" user_sequence = ReadLineFromFile(reindex_sequence_file)\n", | |
"\n", | |
" item_info = ReadLineFromFile(item_index_file)\n", | |
" item_map = get_dict_from_lines(item_info)\n", | |
"\n", | |
" return construct_user_sequence_dict(user_sequence), item_map\n", | |
"\n", | |
" # For user index, load from txt file if already exists, otherwise generate from user sequence and save.\n", | |
" if os.path.exists(user_index_file):\n", | |
" user_info = ReadLineFromFile(user_index_file)\n", | |
" user_map = get_dict_from_lines(user_info)\n", | |
" else:\n", | |
" user_map = generate_user_map(user_sequence_dict)\n", | |
" WriteDictToFile(user_index_file, user_map)\n", | |
"\n", | |
"\n", | |
" # For item index, load from txt file if already exists, otherwise generate from user sequence and save.\n", | |
" if os.path.exists(item_index_file):\n", | |
" item_info = ReadLineFromFile(item_index_file)\n", | |
" item_map = get_dict_from_lines(item_info)\n", | |
" else:\n", | |
" item_map = generate_collaborative_id(user_sequence_dict, token_size, cluster_num, last_token, float32)\n", | |
" WriteDictToFile(item_index_file, item_map)\n", | |
"\n", | |
" reindex_user_sequence_dict = reindex(user_sequence_dict, user_map, item_map)\n", | |
" WriteDictToFile(reindex_sequence_file, reindex_user_sequence_dict)\n", | |
" return reindex_user_sequence_dict, item_map\n", | |
"\n", | |
"def generate_collaborative_id(user_sequence_dict, token_size, cluster_num, last_token, float32):\n", | |
" \"\"\"\n", | |
" Generate collaborative index for items.\n", | |
" \"\"\"\n", | |
" # get the items in training data and all data.\n", | |
" all_items = set()\n", | |
" train_items = set()\n", | |
" for user in user_sequence_dict:\n", | |
" all_items.update(set(user_sequence_dict[user]))\n", | |
" train_items.update(set(user_sequence_dict[user][:-2]))\n", | |
"\n", | |
" # reindex all training items for calculating the adjacency matrix\n", | |
" item2id = dict()\n", | |
" id2item = dict()\n", | |
" for item in train_items:\n", | |
" item2id[item] = len(item2id)\n", | |
" id2item[len(id2item)] = item\n", | |
"\n", | |
"\n", | |
" # calculate the co-occurrence of items in the training data as an adjacency matrix\n", | |
" if float32 > 0:\n", | |
" adj_matrix = np.zeros((len(item2id), len(item2id)), dtype=np.float32)\n", | |
" else:\n", | |
" adj_matrix = np.zeros((len(item2id), len(item2id)))\n", | |
" for user in user_sequence_dict:\n", | |
" interactions = user_sequence_dict[user][:-2]\n", | |
" for pairs in combinations(interactions, 2):\n", | |
" adj_matrix[item2id[pairs[0]]][item2id[pairs[1]]] += 1\n", | |
" adj_matrix[item2id[pairs[1]]][item2id[pairs[0]]] += 1\n", | |
"\n", | |
"\n", | |
" # get the clustering results for the first layer\n", | |
" clustering = SpectralClustering(\n", | |
" n_clusters=cluster_num,\n", | |
" assign_labels=\"cluster_qr\",\n", | |
" random_state=0,\n", | |
" affinity=\"precomputed\",\n", | |
" ).fit(adj_matrix)\n", | |
" labels = clustering.labels_.tolist()\n", | |
"\n", | |
" # count the clustering results\n", | |
" grouping = defaultdict(list)\n", | |
" for i in range(len(labels)):\n", | |
" grouping[labels[i]].append((id2item[i],i))\n", | |
"\n", | |
" item_map = dict()\n", | |
" index_now = 0\n", | |
"\n", | |
" # add current clustering information into the item indexing results.\n", | |
" item_map, index_now = add_token_to_indexing(item_map, grouping, index_now, token_size)\n", | |
"\n", | |
" # add current clustering info into a queue for BFS\n", | |
" queue = []\n", | |
" for group in grouping:\n", | |
" queue.append(grouping[group])\n", | |
"\n", | |
" # apply BFS to further use spectral clustering for large groups (> token_size)\n", | |
" while queue:\n", | |
" group_items = queue.pop(0)\n", | |
"\n", | |
" # if current group is small enough, add the last token to item indexing\n", | |
" if len(group_items) <= token_size:\n", | |
" item_list = [items[0] for items in group_items]\n", | |
" if last_token == 'sequential':\n", | |
" item_map = add_last_token_to_indexing_sequential(item_map, item_list, token_size)\n", | |
" elif last_token == 'random':\n", | |
" item_map = add_last_token_to_indexing_random(item_map, item_list, token_size)\n", | |
" else:\n", | |
" # calculate the adjacency matrix for current group\n", | |
" if float32 > 0:\n", | |
" sub_adj_matrix = np.zeros((len(group_items), len(group_items)), dtype=np.float32)\n", | |
" else:\n", | |
" sub_adj_matrix = np.zeros((len(group_items), len(group_items)))\n", | |
" for i in range(len(group_items)):\n", | |
" for j in range(i+1, len(group_items)):\n", | |
" sub_adj_matrix[i][j] = adj_matrix[group_items[i][1]][group_items[j][1]]\n", | |
" sub_adj_matrix[j][i] = adj_matrix[group_items[j][1]][group_items[i][1]]\n", | |
"\n", | |
" # get the clustering results for current group\n", | |
" clustering = SpectralClustering(\n", | |
" n_clusters=cluster_num,\n", | |
" assign_labels=\"cluster_qr\",\n", | |
" random_state=0,\n", | |
" affinity=\"precomputed\",\n", | |
" ).fit(sub_adj_matrix)\n", | |
" labels = clustering.labels_.tolist()\n", | |
"\n", | |
" # count current clustering results\n", | |
" grouping = defaultdict(list)\n", | |
" for i in range(len(labels)):\n", | |
" grouping[labels[i]].append(group_items[i])\n", | |
"\n", | |
" # add current clustering information into the item indexing results.\n", | |
" item_map, index_now = add_token_to_indexing(item_map, grouping, index_now, token_size)\n", | |
"\n", | |
" # push current clustering info into the queue\n", | |
" for group in grouping:\n", | |
" queue.append(grouping[group])\n", | |
"\n", | |
" # if some items are not in the training data, assign an index for them\n", | |
" remaining_items = list(all_items - train_items)\n", | |
" if len(remaining_items) > 0:\n", | |
" if last_token == 'sequential':\n", | |
" item_map = add_last_token_to_indexing_sequential(item_map, remaining_items, token_size)\n", | |
" elif last_token == 'random':\n", | |
" item_map = add_last_token_to_indexing_random(item_map, remaining_items, token_size)\n", | |
"\n", | |
" return item_map\n", | |
"\n", | |
"\n", | |
"\n", | |
"def add_token_to_indexing(item_map, grouping, index_now, token_size):\n", | |
" for group in grouping:\n", | |
" index_now = index_now % token_size\n", | |
" for (item, idx) in grouping[group]:\n", | |
" if item not in item_map:\n", | |
" item_map[item] = ''\n", | |
" item_map[item] += f'<CI{index_now}>'\n", | |
" index_now += 1\n", | |
" return item_map, index_now\n", | |
"\n", | |
"def add_last_token_to_indexing_random(item_map, item_list, token_size):\n", | |
" last_tokens = random.sample([i for i in range(token_size)], len(item_list))\n", | |
" for i in range(len(item_list)):\n", | |
" item = item_list[i]\n", | |
" if item not in item_map:\n", | |
" item_map[item] = ''\n", | |
" item_map[item] += f'<CI{last_tokens[i]}>'\n", | |
" return item_map\n", | |
"\n", | |
"def add_last_token_to_indexing_sequential(item_map, item_list, token_size):\n", | |
" for i in range(len(item_list)):\n", | |
" item = item_list[i]\n", | |
" if item not in item_map:\n", | |
" item_map[item] = ''\n", | |
" item_map[item] += f'<CI{i}>'\n", | |
" return item_map\n", | |
"\n", | |
"\n", | |
"def get_dict_from_lines(lines):\n", | |
" \"\"\"\n", | |
" Used to get user or item map from lines loaded from txt file.\n", | |
" \"\"\"\n", | |
" index_map = dict()\n", | |
" for line in lines:\n", | |
" info = line.split(\" \")\n", | |
" index_map[info[0]] = info[1]\n", | |
" return index_map\n", | |
"\n", | |
"\n", | |
"\n", | |
"\n", | |
"def generate_user_map(user_sequence_dict):\n", | |
" \"\"\"\n", | |
" generate user map based on user sequence dict.\n", | |
" \"\"\"\n", | |
" user_map = dict()\n", | |
" for user in user_sequence_dict.keys():\n", | |
" user_map[user] = str(len(user_map) + 1)\n", | |
" return user_map\n", | |
"\n", | |
"\n", | |
"def reindex(user_sequence_dict, user_map, item_map):\n", | |
" \"\"\"\n", | |
" reindex the given user sequence dict by given user map and item map\n", | |
" \"\"\"\n", | |
" reindex_user_sequence_dict = dict()\n", | |
" for user in user_sequence_dict:\n", | |
" uid = user_map[user]\n", | |
" items = user_sequence_dict[user]\n", | |
" reindex_user_sequence_dict[uid] = [item_map[i] for i in items]\n", | |
"\n", | |
" return reindex_user_sequence_dict\n", | |
"\n", | |
"\n", | |
"def construct_user_sequence_dict(user_sequence):\n", | |
" \"\"\"\n", | |
" Convert a list of string to a user sequence dict. user as key, item list as value.\n", | |
" \"\"\"\n", | |
"\n", | |
" user_seq_dict = dict()\n", | |
" for line in user_sequence:\n", | |
" user_seq = line.split(\" \")\n", | |
" user_seq_dict[user_seq[0]] = user_seq[1:]\n", | |
" return user_seq_dict" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "WrtxYVChmVuW" | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "uNRMkMSakNbx" | |
}, | |
"outputs": [], | |
"source": [ | |
"# from https://github.com/agiresearch/OpenP5/blob/main/src/train.py\n", | |
"\n", | |
"def train(\n", | |
" seed: int = 2023,\n", | |
" model_dir: str = \"../model\",\n", | |
" checkpoint_dir: str = \"../checkpoint\",\n", | |
" model_name: str = \"model.pt\",\n", | |
" log_dir: str = \"../log\",\n", | |
" master_addr: str = \"localhost\",\n", | |
" master_port: str = \"12345\",\n", | |
" logging_level: int = logging.INFO,\n", | |
" data_path: str = \"../data\",\n", | |
" item_indexing: str = \"sequential\",\n", | |
" tasks: str = 'sequential,straightforward',\n", | |
" datasets_: str = \"Beauty\",\n", | |
" prompt_file: str = '../prompt.txt',\n", | |
" sequential_orderl: str = \"original\",\n", | |
" collaborative_token_size: int = 200,\n", | |
" collaborative_cluster: int = 20,\n", | |
" collaborative_last_token: str = \"sequential\",\n", | |
" collaborative_float32: int = 0,\n", | |
" max_his: int = 10,\n", | |
" his_prefix: int = 1,\n", | |
" his_sep:str = \" , \",\n", | |
" skip_empty_his: str = 1,\n", | |
" valid_prompt: str = 'seen:0',\n", | |
" valid_prompt_sample: int = 1,\n", | |
" valid_sample_num: str = '3,3',\n", | |
" test_prompt: str = 'seen:0',\n", | |
" sample_prompt: int = 0,\n", | |
" sample_num: str = '2,2',\n", | |
" cutoff: int = 1024,\n", | |
" batch_size: int = 32,\n", | |
" eval_batch_size: int = 32,\n", | |
" group_task_in_batch: int = 1,\n", | |
" task_alternating_optim: int = 0,\n", | |
" optim: str = \"adamw_torch\",\n", | |
" epochs: int = 10,\n", | |
" lr: float = 1e-3,\n", | |
" clip: float = 1,\n", | |
" logging_steps: int= 10,\n", | |
" warmup_steps: int = 100,\n", | |
" gradient_accumulation_steps: int = 16,\n", | |
" weight_decay: float = 0.01,\n", | |
" adam_eps: float = 1e-6,\n", | |
" dropout: float = 0.1,\n", | |
" alpha: float = 2,\n", | |
" backbone: str = 't5-small',\n", | |
" random_initialize: int = 1,\n", | |
" test_epoch: int = 1,\n", | |
" valid_select: int = 0,\n", | |
" lora: int = 0,\n", | |
" lora_r: int = 8,\n", | |
" lora_alpha: int = 16,\n", | |
" lora_dropout: float = 0.05,\n", | |
" lora_target_modules: str = 'q_proj,v_proj,embed_tokens',\n", | |
" metrics: str = 'hit@5,hit@10,ndcg@5,ndcg@10'\n", | |
"):\n", | |
"\n", | |
" # setup\n", | |
" setup_logging(\n", | |
" log_name,\n", | |
" datasets_,\n", | |
" log_dir,\n", | |
" model_dir,\n", | |
" checkpoint_dir,\n", | |
" logging_level,\n", | |
" sample_prompt,\n", | |
" his_prefix,\n", | |
" skip_empty_his,\n", | |
" max_his,\n", | |
" master_port,\n", | |
" tasks,\n", | |
" backbone,\n", | |
" item_indexing,\n", | |
" lr,\n", | |
" epochs,\n", | |
" batch_size,\n", | |
" sample_num,\n", | |
" prompt_file\n", | |
" )\n", | |
"\n", | |
" set_seed(seed)\n", | |
"\n", | |
" # determine whether distributed\n", | |
" device_map = \"auto\"\n", | |
" world_size = int(os.environ.get(\"WORLD_SIZE\", 1))\n", | |
" ddp = world_size != 1\n", | |
" if ddp:\n", | |
" device_map = {\"\": int(os.environ.get(\"LOCAL_RANK\") or 0)}\n", | |
" gradient_accumulation_steps = gradient_accumulation_steps // world_size\n", | |
"\n", | |
" # use wandb\n", | |
" wandb_project = \"\"\n", | |
" wandb_run_name = \"\"\n", | |
" wandb_watch = \"\" # options: false | gradients | all\n", | |
" wandb_log_model = \"\"\n", | |
" use_wandb = len(wandb_project) > 0 or (\n", | |
" \"WANDB_PROJECT\" in os.environ and len(os.environ[\"WANDB_PROJECT\"]) > 0\n", | |
" )\n", | |
" # Only overwrite environ if wandb param passed\n", | |
" if len(wandb_project) > 0:\n", | |
" os.environ[\"WANDB_PROJECT\"] = wandb_project\n", | |
" if len(wandb_watch) > 0:\n", | |
" os.environ[\"WANDB_WATCH\"] = wandb_watch\n", | |
" if len(wandb_log_model) > 0:\n", | |
" os.environ[\"WANDB_LOG_MODEL\"] = wandb_log_model\n", | |
"\n", | |
" # load model, tokenizer\n", | |
" if 't5' in backbone.lower():\n", | |
" config = T5Config.from_pretrained(backbone)\n", | |
" model = T5ForConditionalGeneration.from_pretrained(backbone, config=config)\n", | |
" tokenizer = AutoTokenizer.from_pretrained(backbone)\n", | |
" elif 'llama' in backbone.lower():\n", | |
" model = LlamaForCausalLM.from_pretrained(\n", | |
" 'meta-llama/' + backbone,\n", | |
" load_in_8bit=True,\n", | |
" torch_dtype=torch.float16\n", | |
" )\n", | |
" tokenizer = AutoTokenizer.from_pretrained('meta-llama/' + backbone)\n", | |
" else:\n", | |
" raise NotImplementError\n", | |
"\n", | |
"\n", | |
"\n", | |
"\n", | |
" datasets = datasets_.split(',')\n", | |
" if len(datasets) == 1:\n", | |
" dataset = datasets[0]\n", | |
" train_data_file = os.path.join(data_path, dataset, f'{dataset}_{tasks}_{item_indexing}_train.json')\n", | |
" valid_data_file = os.path.join(data_path, dataset, f'{dataset}_{tasks}_{item_indexing}_validation_{valid_prompt}.json')\n", | |
" train_data = load_dataset(\"json\", data_files=train_data_file, field='data')\n", | |
" valid_data = load_dataset(\"json\", data_files=valid_data_file, field='data')\n", | |
" else:\n", | |
" train_data_list, valid_data_list = [], []\n", | |
" for dataset in datasets:\n", | |
" train_data_file = os.path.join(data_path, dataset, f'{dataset}_{tasks}_{item_indexing}_train.json')\n", | |
" valid_data_file = os.path.join(data_path, dataset, f'{dataset}_{tasks}_{item_indexing}_validation_{valid_prompt}.json')\n", | |
" t_data = load_dataset(\"json\", data_files=train_data_file, field='data')\n", | |
" v_data = load_dataset(\"json\", data_files=valid_data_file, field='data')\n", | |
" train_data_list.append(t_data)\n", | |
" valid_data_list.append(v_data)\n", | |
" train_data = concatenate_datasets(train_data_list)\n", | |
" valid_data = concatenate_datasets(valid_data_list)\n", | |
"\n", | |
" def tokenize(prompt, add_eos_token=True):\n", | |
" # there's probably a way to do this with the tokenizer settings\n", | |
" # but again, gotta move fast\n", | |
" result = tokenizer(\n", | |
" prompt, truncation=True, max_length=cutoff, padding=False, return_tensors=None,\n", | |
" )\n", | |
" if (isinstance(result[\"input_ids\"][-1], int) and result[\"input_ids\"][-1] != tokenizer.eos_token_id\n", | |
" and len(result[\"input_ids\"]) < cutoff\n", | |
" and add_eos_token\n", | |
" ):\n", | |
" result[\"input_ids\"].append(tokenizer.eos_token_id)\n", | |
" result[\"attention_mask\"].append(1)\n", | |
" elif isinstance(result[\"input_ids\"][-1], list) and add_eos_token:\n", | |
" for i in range(len(result['input_ids'])):\n", | |
" if result[\"input_ids\"][i][-1] != tokenizer.eos_token_id and len(result[\"input_ids\"][i]) < cutoff:\n", | |
" result[\"input_ids\"][i].append(tokenizer.eos_token_id)\n", | |
" result[\"attention_mask\"][i].append(1)\n", | |
"\n", | |
" result[\"labels\"] = result[\"input_ids\"].copy()\n", | |
"\n", | |
" return result\n", | |
"\n", | |
" def generate_prompt(data_point):\n", | |
" return f'{data_point[\"input\"]} {data_point[\"output\"]}'\n", | |
"\n", | |
" def process_func(datapoint):\n", | |
" if 't5' in backbone.lower():\n", | |
" encoding = tokenize(datapoint['input'], add_eos_token=True)\n", | |
" labels = tokenize(datapoint['output'], add_eos_token=True)\n", | |
" encoding['labels'] = labels['input_ids'].copy()\n", | |
" elif 'llama' in backbone.lower():\n", | |
" user_prompt = generate_prompt({**datapoint, \"output\": \"\"})\n", | |
" encoding_input = tokenize(user_prompt, add_eos_token=False)\n", | |
" input_len = len(encoding_input[\"input_ids\"])\n", | |
" full_prompt = generate_prompt(datapoint)\n", | |
" encoding = tokenize(full_prompt)\n", | |
"\n", | |
" encoding[\"labels\"] = (\n", | |
" [-100] * input_len\n", | |
" + encoding[\"labels\"][input_len:]\n", | |
" )\n", | |
"\n", | |
" # return encoding\n", | |
" return {**datapoint, **encoding}\n", | |
"\n", | |
" # add token and resize embedding for collaborative indexing\n", | |
" if item_indexing == 'collaborative':\n", | |
" new_tokens = []\n", | |
" for dataset in datasets:\n", | |
" item_index_file = os.path.join(data_path, dataset, f'item_collaborative_indexing_{collaborative_token_size}_{collaborative_cluster}_{collaborative_last_token}.txt')\n", | |
" item_info = ReadLineFromFile(item_index_file)\n", | |
" item_map = get_dict_from_lines(item_info)\n", | |
" for idx in list(item_map.values()):\n", | |
" new_token += re.findall(r'\\<.*?\\>', idx)\n", | |
" tokenizer.add_tokens(ds.new_token)\n", | |
" model.resize_token_embeddings(len(tokenizer))\n", | |
"\n", | |
" # no task alternating optimization if only one task in the data\n", | |
" if len(set(train_data['train']['task'])) == 1:\n", | |
" task_alternating_optim = 0\n", | |
"\n", | |
" if task_alternating_optim == 1:\n", | |
" TrainSet = dict()\n", | |
" for task in set(train_data['train']['task']):\n", | |
" TrainSet[task] = train_data['train'].filter(lambda example: example[\"task\"]==task)\n", | |
" for task in TrainSet:\n", | |
" TrainSet[task] = TrainSet[task].shuffle().map(process_func, batched=True)\n", | |
"\n", | |
" else:\n", | |
" TrainSet = train_data['train'].shuffle().map(process_func, batched=True)\n", | |
"\n", | |
" ValidSet = valid_data['train'].shuffle().map(process_func, batched=True)\n", | |
"\n", | |
"\n", | |
"\n", | |
" # randomly initialize number related tokens\n", | |
" if random_initialize == 1:\n", | |
" # logging.info(\"Random initialize number related tokens\")\n", | |
" random_initialization(model, tokenizer, backbone)\n", | |
"\n", | |
" # apply lora\n", | |
" if lora > 0:\n", | |
" model = prepare_model_for_int8_training(model)\n", | |
"\n", | |
" config = LoraConfig(\n", | |
" r=lora_r,\n", | |
" lora_alpha=lora_alpha,\n", | |
" target_modules=lora_target_modules.split(','),\n", | |
" lora_dropout=lora_dropout,\n", | |
" bias=\"none\",\n", | |
" task_type=\"CAUSAL_LM\",\n", | |
" )\n", | |
" model = get_peft_model(model, config)\n", | |
" model.print_trainable_parameters()\n", | |
"\n", | |
" tokenizer.pad_token_id = (\n", | |
" 0 # unk. we want this to be different from the eos token\n", | |
" )\n", | |
" tokenizer.padding_side = \"left\"\n", | |
"\n", | |
" # decide output dir\n", | |
" if len(datasets_.split(',')) > 1:\n", | |
" folder_name = 'SP5'\n", | |
" else:\n", | |
" folder_name = datasets_\n", | |
" output_dir = os.path.join(model_dir, folder_name, item_indexing, backbone)\n", | |
"\n", | |
" if task_alternating_optim == 1:\n", | |
" trainer = TaskAlternateTrainer(model=model,\n", | |
" train_dataset=TrainSet,\n", | |
" eval_dataset=ValidSet if valid_select > 0 else None,\n", | |
" args= transformers.TrainingArguments(\n", | |
" per_device_train_batch_size=batch_size,\n", | |
" gradient_accumulation_steps=gradient_accumulation_steps,\n", | |
" warmup_steps=warmup_steps,\n", | |
" num_train_epochs=epochs,\n", | |
" learning_rate=lr,\n", | |
" fp16=True,\n", | |
" logging_dir=log_dir,\n", | |
" logging_steps=logging_steps,\n", | |
" optim=optim,\n", | |
" evaluation_strategy=\"steps\" if valid_select > 0 else \"no\",\n", | |
" save_strategy=\"steps\",\n", | |
" eval_steps=200 if valid_select > 0 else None,\n", | |
" save_steps=200,\n", | |
" output_dir=output_dir,\n", | |
" save_total_limit=3,\n", | |
" load_best_model_at_end=True if valid_select > 0 else False,\n", | |
" ddp_find_unused_parameters=False if ddp else None,\n", | |
" group_by_length=False,\n", | |
" report_to=\"wandb\" if use_wandb else None,\n", | |
" run_name=wandb_run_name if use_wandb else None,\n", | |
" ),\n", | |
" data_collator=transformers.DataCollatorForSeq2Seq(\n", | |
" tokenizer, pad_to_multiple_of=8, return_tensors=\"pt\", padding=True\n", | |
" ),\n", | |
" )\n", | |
" else:\n", | |
" trainer = transformers.Trainer(\n", | |
" model=model,\n", | |
" train_dataset=TrainSet,\n", | |
" eval_dataset=ValidSet if valid_select > 0 else None,\n", | |
" args= transformers.TrainingArguments(\n", | |
" per_device_train_batch_size=batch_size,\n", | |
" gradient_accumulation_steps=gradient_accumulation_steps,\n", | |
" warmup_steps=warmup_steps,\n", | |
" num_train_epochs=epochs,\n", | |
" learning_rate=lr,\n", | |
" fp16=True,\n", | |
" logging_steps=logging_steps,\n", | |
" optim=optim,\n", | |
" evaluation_strategy=\"steps\" if valid_select > 0 else \"no\",\n", | |
" save_strategy=\"steps\",\n", | |
" eval_steps=200 if valid_select > 0 else None,\n", | |
" save_steps=200,\n", | |
" output_dir=output_dir,\n", | |
" save_total_limit=3,\n", | |
" load_best_model_at_end=True if valid_select > 0 else False,\n", | |
" ddp_find_unused_parameters=False if ddp else None,\n", | |
" group_by_length=False,\n", | |
" report_to=\"wandb\" if use_wandb else None,\n", | |
" run_name=wandb_run_name if use_wandb else None,\n", | |
" ),\n", | |
" data_collator=transformers.DataCollatorForSeq2Seq(\n", | |
" tokenizer, pad_to_multiple_of=8, return_tensors=\"pt\", padding=True\n", | |
" ),\n", | |
" )\n", | |
"\n", | |
" trainer.train()\n", | |
"\n", | |
" model.save_pretrained(output_dir)\n", | |
" tokenizer.save_pretrained(output_dir)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "0o3bmaaOsicZ" | |
}, | |
"outputs": [], | |
"source": [ | |
"train(\n", | |
" model_dir=\"/content/drive/MyDrive/OpenP5/models/\",\n", | |
" log_dir=\"/content/drive/MyDrive/OpenP5/logs/\",\n", | |
" master_port=1234,\n", | |
" item_indexing=\"sequential\",\n", | |
" tasks=\"sequential,straightforward\",\n", | |
" datasets_=\"Beauty\",\n", | |
" data_path = '/content/drive/MyDrive/OpenP5/data/',\n", | |
" epochs=10,\n", | |
" batch_size=512,\n", | |
" backbone=\"t5-small\",\n", | |
" cutoff=1024\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "rVPlAgoItT0K" | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "CvD6VXbcv-me" | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"machine_shape": "hm", | |
"provenance": [], | |
"gpuType": "A100", | |
"mount_file_id": "1xZL99KZEHlaop9JBPzG5ms2lZJFmfn_b", | |
"authorship_tag": "ABX9TyP+SKCMl6NJs+nraank8qWC", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment