Skip to content

Instantly share code, notes, and snippets.

@nogawanogawa
Created November 22, 2023 04:24
Show Gist options
  • Save nogawanogawa/101b18d44531657dd7e7518b68c2f5d5 to your computer and use it in GitHub Desktop.
Save nogawanogawa/101b18d44531657dd7e7518b68c2f5d5 to your computer and use it in GitHub Desktop.
OpenP5_generate_dataset_notebook.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"mount_file_id": "1zoFByp2H4shOe2TCGQ2F6uR7W51qEteP",
"authorship_tag": "ABX9TyOyLTAqfgsOFLR53WABvpD2",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/nogawanogawa/101b18d44531657dd7e7518b68c2f5d5/openp5_generate_dataset_notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"## generate dataset\n",
"\n",
"generate_dataset.shの実行に時間がかかるのでその部分だけ切り出したもの。\n",
"\n",
"### 事前準備\n",
"\n",
"[Google Drive link](https://drive.google.com/drive/folders/1W5i5ryetj_gkcOpG1aZfL5Y8Yk6RxwYE)からデータをダウンロードし、下記のディレクトリ階層のように配置する。\n",
"\n",
"<details>\n",
"\n",
"```\n",
"/content/drive/MyDrive/OpenP5# tree -L 1\n",
".\n",
"├── data\n",
"│ ├── Beauty\n",
"│ │ ├── item_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── item_random_indexing.txt\n",
"│ │ ├── item_sequential_indexing_original.txt\n",
"│ │ ├── user_indexing.txt\n",
"│ │ ├── user_sequence_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── user_sequence_random_indexing.txt\n",
"│ │ ├── user_sequence_sequential_indexing_original.txt\n",
"│ │ └── user_sequence.txt\n",
"│ ├── CDs\n",
"│ │ ├── item_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── item_random_indexing.txt\n",
"│ │ ├── item_sequential_indexing_original.txt\n",
"│ │ ├── user_indexing.txt\n",
"│ │ ├── user_sequence_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── user_sequence_random_indexing.txt\n",
"│ │ ├── user_sequence_sequential_indexing_original.txt\n",
"│ │ └── user_sequence.txt\n",
"│ ├── Clothing\n",
"│ │ ├── item_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── item_random_indexing.txt\n",
"│ │ ├── item_sequential_indexing_original.txt\n",
"│ │ ├── user_indexing.txt\n",
"│ │ ├── user_sequence_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── user_sequence_random_indexing.txt\n",
"│ │ ├── user_sequence_sequential_indexing_original.txt\n",
"│ │ └── user_sequence.txt\n",
"│ ├── Electronics\n",
"│ │ ├── item_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── item_random_indexing.txt\n",
"│ │ ├── item_sequential_indexing_original.txt\n",
"│ │ ├── user_indexing.txt\n",
"│ │ ├── user_sequence_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── user_sequence_random_indexing.txt\n",
"│ │ ├── user_sequence_sequential_indexing_original.txt\n",
"│ │ └── user_sequence.txt\n",
"│ ├── LastFM\n",
"│ │ ├── item_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── item_random_indexing.txt\n",
"│ │ ├── item_sequential_indexing_original.txt\n",
"│ │ ├── user_indexing.txt\n",
"│ │ ├── user_sequence_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── user_sequence_random_indexing.txt\n",
"│ │ ├── user_sequence_sequential_indexing_original.txt\n",
"│ │ └── user_sequence.txt\n",
"│ ├── ML100K\n",
"│ │ ├── item_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── item_random_indexing.txt\n",
"│ │ ├── item_sequential_indexing_original.txt\n",
"│ │ ├── user_indexing.txt\n",
"│ │ ├── user_sequence_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── user_sequence_random_indexing.txt\n",
"│ │ ├── user_sequence_sequential_indexing_original.txt\n",
"│ │ └── user_sequence.txt\n",
"│ ├── ML1M\n",
"│ │ ├── item_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── item_random_indexing.txt\n",
"│ │ ├── item_sequential_indexing_original.txt\n",
"│ │ ├── user_indexing.txt\n",
"│ │ ├── user_sequence_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── user_sequence_random_indexing.txt\n",
"│ │ ├── user_sequence_sequential_indexing_original.txt\n",
"│ │ └── user_sequence.txt\n",
"│ ├── Movies\n",
"│ │ ├── item_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── item_random_indexing.txt\n",
"│ │ ├── item_sequential_indexing_original.txt\n",
"│ │ ├── user_indexing.txt\n",
"│ │ ├── user_sequence_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── user_sequence_random_indexing.txt\n",
"│ │ ├── user_sequence_sequential_indexing_original.txt\n",
"│ │ └── user_sequence.txt\n",
"│ ├── Taobao\n",
"│ │ ├── item_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── item_random_indexing.txt\n",
"│ │ ├── item_sequential_indexing_original.txt\n",
"│ │ ├── user_indexing.txt\n",
"│ │ ├── user_sequence_collaborative_indexing_500_20_sequential.txt\n",
"│ │ ├── user_sequence_random_indexing.txt\n",
"│ │ ├── user_sequence_sequential_indexing_original.txt\n",
"│ │ └── user_sequence.txt\n",
"│ └── Yelp\n",
"│ ├── item_collaborative_indexing_500_20_sequential.txt\n",
"│ ├── item_random_indexing.txt\n",
"│ ├── item_sequential_indexing_original.txt\n",
"│ ├── user_indexing.txt\n",
"│ ├── user_sequence_collaborative_indexing_500_20_sequential.txt\n",
"│ ├── user_sequence_random_indexing.txt\n",
"│ ├── user_sequence_sequential_indexing_original.txt\n",
"│ └── user_sequence.txt\n",
"└── preprocessed_data\n",
"```\n",
"\n",
"</details>\n",
"\n",
"このdriveをこのnotebookでマウントする\n",
"\n",
"### 結果\n",
"\n",
"`data`ディレクトリの対象データセットディレクトリに前処理済みのデータセットが保存される"
],
"metadata": {
"id": "l8Gz_bXdQobK"
}
},
{
"cell_type": "code",
"source": [
"!git clone https://github.com/agiresearch/OpenP5.git"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2SIvJqfHcgQZ",
"outputId": "2f464e94-2dbc-444c-8792-8c786cf9077b"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cloning into 'OpenP5'...\n",
"remote: Enumerating objects: 478, done.\u001b[K\n",
"remote: Counting objects: 100% (152/152), done.\u001b[K\n",
"remote: Compressing objects: 100% (123/123), done.\u001b[K\n",
"remote: Total 478 (delta 79), reused 16 (delta 13), pack-reused 326\u001b[K\n",
"Receiving objects: 100% (478/478), 11.88 MiB | 11.05 MiB/s, done.\n",
"Resolving deltas: 100% (257/257), done.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import argparse\n",
"import random\n",
"import argparse\n",
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.distributed as dist\n",
"from tqdm import tqdm\n",
"from collections import defaultdict\n",
"import logging\n",
"import re\n",
"import json\n",
"import pdb\n",
"\n",
"import numpy as np\n",
"from itertools import combinations\n",
"from sklearn.cluster import SpectralClustering\n",
"from scipy.sparse import csr_matrix\n",
"\n",
"import pickle\n",
"import inspect\n",
"import sys"
],
"metadata": {
"id": "907VCMRgUgUB"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# from https://github.com/agiresearch/OpenP5/blob/main/src/utils/utils.py\n",
"def ReadLineFromFile(path):\n",
" if not os.path.exists(path):\n",
" raise FileNotFoundError\n",
" lines = []\n",
" with open(path,'r') as fd:\n",
" for line in fd:\n",
" lines.append(line.rstrip('\\n'))\n",
" return lines\n",
"\n",
"def WriteDictToFile(path, write_dict):\n",
" with open(path, 'w') as out:\n",
" for user, items in write_dict.items():\n",
" if type(items) == list:\n",
" out.write(user + ' ' + ' '.join(items) + '\\n')\n",
" else:\n",
" out.write(user + ' ' + str(items) + '\\n')\n",
"\n",
"\n",
"def load_prompt_template(path, task_list):\n",
" \"\"\"\n",
" Load prompt template from the file. Keep training tasks only.\n",
" Input:\n",
" - path: The path for prompt template txt file.\n",
" - task_list: A list of required tasks.\n",
" Return:\n",
" - prompt_templates: a dictionary of prompt templates. e.g., {task: {'seen': {'0': {'Input': template_input, 'Output': template_output}}}}\n",
"\n",
" \"\"\"\n",
"\n",
" if not os.path.exists(path):\n",
" raise FileNotFoundError\n",
" prompt_info = ReadLineFromFile(path)\n",
" prompt_templates = dict()\n",
" for prompt in prompt_info:\n",
" t = [sens.strip() for sens in prompt.split(';')]\n",
" if t[0] not in task_list:\n",
" continue\n",
" if t[0] not in prompt_templates:\n",
" prompt_templates[t[0]] = dict()\n",
" if t[1] not in prompt_templates[t[0]]:\n",
" prompt_templates[t[0]][t[1]] = dict()\n",
" num = len(prompt_templates[t[0]][t[1]])\n",
" prompt_templates[t[0]][t[1]][str(num)] = dict()\n",
" prompt_templates[t[0]][t[1]][str(num)]['Input'] = t[2]\n",
" prompt_templates[t[0]][t[1]][str(num)]['Output'] = t[3]\n",
" return prompt_templates\n",
"\n",
"def get_info_from_prompt(prompt_templates):\n",
" \"\"\"\n",
" Extract the require information from the prompt templates.\n",
" Input:\n",
" - prompt_templates: a dictionary of prompt templates.\n",
" Output:\n",
" - info: a list of required information.\n",
" \"\"\"\n",
"\n",
" info = []\n",
" for task in prompt_templates:\n",
" for see in prompt_templates[task]:\n",
" for i in prompt_templates[task][see]:\n",
" info += re.findall(r'\\{.*?\\}', prompt_templates[task][see][i]['Input'])\n",
" info += re.findall(r'\\{.*?\\}', prompt_templates[task][see][i]['Output'])\n",
" info = [i[1:-1] for i in set(info)]\n",
" return info\n",
"\n",
"def check_task_prompt(prompt_templates, task_list):\n",
" \"\"\"\n",
" Check if all tasks have prompt templates. Raise Error if training tasks have no prompt.\n",
" Input:\n",
" - prompt_templates: A dictionary of prompt templates.\n",
" - task_list: A list of training tasks.\n",
" \"\"\"\n",
" for task in task_list:\n",
" assert task in prompt_templates, f\"No prompt for {task} task\""
],
"metadata": {
"id": "g50S71Y0VGPF"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# from https://github.com/agiresearch/OpenP5/blob/main/src/utils/indexing.py\n",
"\n",
"def sequential_indexing(data_path, dataset, user_sequence_dict, order):\n",
" \"\"\"\n",
" Use sequential indexing method to index the given user seuqnece dict.\n",
" \"\"\"\n",
" user_index_file = os.path.join(data_path, dataset, 'user_indexing.txt')\n",
" item_index_file = os.path.join(data_path, dataset, f'item_sequential_indexing_{order}.txt')\n",
" reindex_sequence_file = os.path.join(data_path, dataset, f'user_sequence_sequential_indexing_{order}.txt')\n",
"\n",
" if os.path.exists(reindex_sequence_file):\n",
" user_sequence = ReadLineFromFile(reindex_sequence_file)\n",
"\n",
" item_info = ReadLineFromFile(item_index_file)\n",
" item_map = get_dict_from_lines(item_info)\n",
"\n",
" return construct_user_sequence_dict(user_sequence), item_map\n",
"\n",
" # For user index, load from txt file if already exists, otherwise generate from user sequence and save.\n",
" if os.path.exists(user_index_file):\n",
" user_info = ReadLineFromFile(user_index_file)\n",
" user_map = get_dict_from_lines(user_info)\n",
" else:\n",
" user_map = generate_user_map(user_sequence_dict)\n",
" WriteDictToFile(user_index_file, user_map)\n",
"\n",
"\n",
" # For item index, load from txt file if already exists, otherwise generate from user sequence and save.\n",
" if os.path.exists(item_index_file):\n",
" item_info = ReadLineFromFile(item_index_file)\n",
" item_map = get_dict_from_lines(item_info)\n",
" else:\n",
" item_map = dict()\n",
" if order == 'original':\n",
" user_list = user_sequence_dict.keys()\n",
" elif order == 'short2long':\n",
" user_list = sorted(user_sequence_dict, key=lambda x: len(user_sequence_dict[x]), reverse=False)\n",
" elif order == 'long2short':\n",
" user_list = sorted(user_sequence_dict, key=lambda x: len(user_sequence_dict[x]), reverse=True)\n",
"\n",
" for user in user_list:\n",
" items = user_sequence_dict[user][:-2]\n",
" for item in items:\n",
" if item not in item_map:\n",
" item_map[item] = str(len(item_map) + 1001)\n",
" for user in user_list:\n",
" items = user_sequence_dict[user][-2:]\n",
" for item in items:\n",
" if item not in item_map:\n",
" item_map[item] = str(len(item_map) + 1001)\n",
" WriteDictToFile(item_index_file, item_map)\n",
"\n",
" reindex_user_sequence_dict = reindex(user_sequence_dict, user_map, item_map)\n",
" WriteDictToFile(reindex_sequence_file, reindex_user_sequence_dict)\n",
" return reindex_user_sequence_dict, item_map\n",
"\n",
"\n",
"\n",
"def random_indexing(data_path, dataset, user_sequence_dict):\n",
" \"\"\"\n",
" Use random indexing method to index the given user seuqnece dict.\n",
" \"\"\"\n",
" user_index_file = os.path.join(data_path, dataset, 'user_indexing.txt')\n",
" item_index_file = os.path.join(data_path, dataset, 'item_random_indexing.txt')\n",
" reindex_sequence_file = os.path.join(data_path, dataset, f'user_sequence_random_indexing.txt')\n",
"\n",
" if os.path.exists(reindex_sequence_file):\n",
" user_sequence = ReadLineFromFile(reindex_sequence_file)\n",
"\n",
" item_info = ReadLineFromFile(item_index_file)\n",
" item_map = get_dict_from_lines(item_info)\n",
"\n",
" return construct_user_sequence_dict(user_sequence), item_map\n",
"\n",
" # For user index, load from txt file if already exists, otherwise generate from user sequence and save.\n",
" if os.path.exists(user_index_file):\n",
" user_info = ReadLineFromFile(user_index_file)\n",
" user_map = get_dict_from_lines(user_info)\n",
" else:\n",
" user_map = generate_user_map(user_sequence_dict)\n",
" WriteDictToFile(user_index_file, user_map)\n",
"\n",
"\n",
" # For item index, load from txt file if already exists, otherwise generate from user sequence and save.\n",
" if os.path.exists(item_index_file):\n",
" item_info = ReadLineFromFile(item_index_file)\n",
" item_map = get_dict_from_lines(item_info)\n",
" else:\n",
" item_map = dict()\n",
" items = set()\n",
" for user in user_sequence_dict:\n",
" items.update(user_sequence_dict[user])\n",
" items = list(items)\n",
" random.shuffle(items)\n",
" for item in items:\n",
" if item not in item_map:\n",
" item_map[item] = str(len(item_map) + 1001)\n",
" WriteDictToFile(item_index_file, item_map)\n",
"\n",
" reindex_user_sequence_dict = reindex(user_sequence_dict, user_map, item_map)\n",
" WriteDictToFile(reindex_sequence_file, reindex_user_sequence_dict)\n",
" return reindex_user_sequence_dict, item_map\n",
"\n",
"def collaborative_indexing(data_path, dataset, user_sequence_dict, token_size, cluster_num, last_token, float32):\n",
" \"\"\"\n",
" Use collaborative indexing method to index the given user seuqnece dict.\n",
" \"\"\"\n",
" user_index_file = os.path.join(data_path, dataset, 'user_indexing.txt')\n",
" item_index_file = os.path.join(data_path, dataset, f'item_collaborative_indexing_{token_size}_{cluster_num}_{last_token}.txt')\n",
" reindex_sequence_file = os.path.join(data_path, dataset, f'user_sequence_collaborative_indexing_{token_size}_{cluster_num}_{last_token}.txt')\n",
"\n",
" if os.path.exists(reindex_sequence_file):\n",
" user_sequence = ReadLineFromFile(reindex_sequence_file)\n",
"\n",
" item_info = ReadLineFromFile(item_index_file)\n",
" item_map = get_dict_from_lines(item_info)\n",
"\n",
" return construct_user_sequence_dict(user_sequence), item_map\n",
"\n",
" # For user index, load from txt file if already exists, otherwise generate from user sequence and save.\n",
" if os.path.exists(user_index_file):\n",
" user_info = ReadLineFromFile(user_index_file)\n",
" user_map = get_dict_from_lines(user_info)\n",
" else:\n",
" user_map = generate_user_map(user_sequence_dict)\n",
" WriteDictToFile(user_index_file, user_map)\n",
"\n",
"\n",
" # For item index, load from txt file if already exists, otherwise generate from user sequence and save.\n",
" if os.path.exists(item_index_file):\n",
" item_info = ReadLineFromFile(item_index_file)\n",
" item_map = get_dict_from_lines(item_info)\n",
" else:\n",
" item_map = generate_collaborative_id(user_sequence_dict, token_size, cluster_num, last_token, float32)\n",
" WriteDictToFile(item_index_file, item_map)\n",
"\n",
" reindex_user_sequence_dict = reindex(user_sequence_dict, user_map, item_map)\n",
" WriteDictToFile(reindex_sequence_file, reindex_user_sequence_dict)\n",
" return reindex_user_sequence_dict, item_map\n",
"\n",
"def generate_collaborative_id(user_sequence_dict, token_size, cluster_num, last_token, float32):\n",
" \"\"\"\n",
" Generate collaborative index for items.\n",
" \"\"\"\n",
" # get the items in training data and all data.\n",
" all_items = set()\n",
" train_items = set()\n",
" for user in user_sequence_dict:\n",
" all_items.update(set(user_sequence_dict[user]))\n",
" train_items.update(set(user_sequence_dict[user][:-2]))\n",
"\n",
" # reindex all training items for calculating the adjacency matrix\n",
" item2id = dict()\n",
" id2item = dict()\n",
" for item in train_items:\n",
" item2id[item] = len(item2id)\n",
" id2item[len(id2item)] = item\n",
"\n",
"\n",
" # calculate the co-occurrence of items in the training data as an adjacency matrix\n",
" if float32 > 0:\n",
" adj_matrix = np.zeros((len(item2id), len(item2id)), dtype=np.float32)\n",
" else:\n",
" adj_matrix = np.zeros((len(item2id), len(item2id)))\n",
" for user in user_sequence_dict:\n",
" interactions = user_sequence_dict[user][:-2]\n",
" for pairs in combinations(interactions, 2):\n",
" adj_matrix[item2id[pairs[0]]][item2id[pairs[1]]] += 1\n",
" adj_matrix[item2id[pairs[1]]][item2id[pairs[0]]] += 1\n",
"\n",
"\n",
" # get the clustering results for the first layer\n",
" clustering = SpectralClustering(\n",
" n_clusters=cluster_num,\n",
" assign_labels=\"cluster_qr\",\n",
" random_state=0,\n",
" affinity=\"precomputed\",\n",
" ).fit(adj_matrix)\n",
" labels = clustering.labels_.tolist()\n",
"\n",
" # count the clustering results\n",
" grouping = defaultdict(list)\n",
" for i in range(len(labels)):\n",
" grouping[labels[i]].append((id2item[i],i))\n",
"\n",
" item_map = dict()\n",
" index_now = 0\n",
"\n",
" # add current clustering information into the item indexing results.\n",
" item_map, index_now = add_token_to_indexing(item_map, grouping, index_now, token_size)\n",
"\n",
" # add current clustering info into a queue for BFS\n",
" queue = []\n",
" for group in grouping:\n",
" queue.append(grouping[group])\n",
"\n",
" # apply BFS to further use spectral clustering for large groups (> token_size)\n",
" while queue:\n",
" group_items = queue.pop(0)\n",
"\n",
" # if current group is small enough, add the last token to item indexing\n",
" if len(group_items) <= token_size:\n",
" item_list = [items[0] for items in group_items]\n",
" if last_token == 'sequential':\n",
" item_map = add_last_token_to_indexing_sequential(item_map, item_list, token_size)\n",
" elif last_token == 'random':\n",
" item_map = add_last_token_to_indexing_random(item_map, item_list, token_size)\n",
" else:\n",
" # calculate the adjacency matrix for current group\n",
" if float32 > 0:\n",
" sub_adj_matrix = np.zeros((len(group_items), len(group_items)), dtype=np.float32)\n",
" else:\n",
" sub_adj_matrix = np.zeros((len(group_items), len(group_items)))\n",
" for i in range(len(group_items)):\n",
" for j in range(i+1, len(group_items)):\n",
" sub_adj_matrix[i][j] = adj_matrix[group_items[i][1]][group_items[j][1]]\n",
" sub_adj_matrix[j][i] = adj_matrix[group_items[j][1]][group_items[i][1]]\n",
"\n",
" # get the clustering results for current group\n",
" clustering = SpectralClustering(\n",
" n_clusters=cluster_num,\n",
" assign_labels=\"cluster_qr\",\n",
" random_state=0,\n",
" affinity=\"precomputed\",\n",
" ).fit(sub_adj_matrix)\n",
" labels = clustering.labels_.tolist()\n",
"\n",
" # count current clustering results\n",
" grouping = defaultdict(list)\n",
" for i in range(len(labels)):\n",
" grouping[labels[i]].append(group_items[i])\n",
"\n",
" # add current clustering information into the item indexing results.\n",
" item_map, index_now = add_token_to_indexing(item_map, grouping, index_now, token_size)\n",
"\n",
" # push current clustering info into the queue\n",
" for group in grouping:\n",
" queue.append(grouping[group])\n",
"\n",
" # if some items are not in the training data, assign an index for them\n",
" remaining_items = list(all_items - train_items)\n",
" if len(remaining_items) > 0:\n",
" if last_token == 'sequential':\n",
" item_map = add_last_token_to_indexing_sequential(item_map, remaining_items, token_size)\n",
" elif last_token == 'random':\n",
" item_map = add_last_token_to_indexing_random(item_map, remaining_items, token_size)\n",
"\n",
" return item_map\n",
"\n",
"\n",
"\n",
"def add_token_to_indexing(item_map, grouping, index_now, token_size):\n",
" for group in grouping:\n",
" index_now = index_now % token_size\n",
" for (item, idx) in grouping[group]:\n",
" if item not in item_map:\n",
" item_map[item] = ''\n",
" item_map[item] += f'<CI{index_now}>'\n",
" index_now += 1\n",
" return item_map, index_now\n",
"\n",
"def add_last_token_to_indexing_random(item_map, item_list, token_size):\n",
" last_tokens = random.sample([i for i in range(token_size)], len(item_list))\n",
" for i in range(len(item_list)):\n",
" item = item_list[i]\n",
" if item not in item_map:\n",
" item_map[item] = ''\n",
" item_map[item] += f'<CI{last_tokens[i]}>'\n",
" return item_map\n",
"\n",
"def add_last_token_to_indexing_sequential(item_map, item_list, token_size):\n",
" for i in range(len(item_list)):\n",
" item = item_list[i]\n",
" if item not in item_map:\n",
" item_map[item] = ''\n",
" item_map[item] += f'<CI{i}>'\n",
" return item_map\n",
"\n",
"\n",
"def get_dict_from_lines(lines):\n",
" \"\"\"\n",
" Used to get user or item map from lines loaded from txt file.\n",
" \"\"\"\n",
" index_map = dict()\n",
" for line in lines:\n",
" info = line.split(\" \")\n",
" index_map[info[0]] = info[1]\n",
" return index_map\n",
"\n",
"\n",
"\n",
"\n",
"def generate_user_map(user_sequence_dict):\n",
" \"\"\"\n",
" generate user map based on user sequence dict.\n",
" \"\"\"\n",
" user_map = dict()\n",
" for user in user_sequence_dict.keys():\n",
" user_map[user] = str(len(user_map) + 1)\n",
" return user_map\n",
"\n",
"\n",
"def reindex(user_sequence_dict, user_map, item_map):\n",
" \"\"\"\n",
" reindex the given user sequence dict by given user map and item map\n",
" \"\"\"\n",
" reindex_user_sequence_dict = dict()\n",
" for user in user_sequence_dict:\n",
" uid = user_map[user]\n",
" items = user_sequence_dict[user]\n",
" reindex_user_sequence_dict[uid] = [item_map[i] for i in items]\n",
"\n",
" return reindex_user_sequence_dict\n",
"\n",
"\n",
"def construct_user_sequence_dict(user_sequence):\n",
" \"\"\"\n",
" Convert a list of string to a user sequence dict. user as key, item list as value.\n",
" \"\"\"\n",
"\n",
" user_seq_dict = dict()\n",
" for line in user_sequence:\n",
" user_seq = line.split(\" \")\n",
" user_seq_dict[user_seq[0]] = user_seq[1:]\n",
" return user_seq_dict"
],
"metadata": {
"id": "y9Jwhy1rVmEc"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# from https://github.com/agiresearch/OpenP5/blob/main/src/generate_dataset.py\n",
"def generate_dataset(\n",
" data_path: str = '../data',\n",
" item_indexing: str = \"sequential\",\n",
" tasks_: str = 'sequential,straightforward',\n",
" dataset: str = \"Beauty\",\n",
" prompt_file: str = \"../prompt.txt\",\n",
" sequential_order: str = \"original\",\n",
" collaborative_token_size: int = 200,\n",
" collaborative_cluster: int = 20,\n",
" collaborative_last_token: str = \"sequential\",\n",
" collaborative_float32: int = 0,\n",
" max_his: int = 10,\n",
" his_prefix: int = 1,\n",
" his_sep:str = \" , \",\n",
" skip_empty_his: str = 1,\n",
"):\n",
" tasks = tasks_.split(',')\n",
"\n",
" file_data = dict()\n",
" file_data['data'] = []\n",
"\n",
" user_sequence = ReadLineFromFile(os.path.join(data_path, dataset, 'user_sequence.txt'))\n",
" user_sequence_dict = construct_user_sequence_dict(user_sequence)\n",
"\n",
" if item_indexing == 'sequential':\n",
" print(\"Reindex data with sequential indexing method\")\n",
" reindex_user_seq_dict, item_map = sequential_indexing(data_path, dataset, user_sequence_dict, sequential_order)\n",
" elif item_indexing == 'random':\n",
" print(\"Reindex data with random indexing method\")\n",
" reindex_user_seq_dict, item_map = random_indexing(data_path, dataset, user_sequence_dict)\n",
" elif item_indexing == 'collaborative':\n",
" print(f\"Reindex data with collaborative indexing method with token_size {collaborative_token_size} and {collaborative_cluster} cluster\")\n",
" reindex_user_seq_dict, item_map = collaborative_indexing(data_path, dataset, user_sequence_dict, \\\n",
" collaborative_token_size, collaborative_cluster, \\\n",
" collaborative_last_token, collaborative_float32)\n",
" else:\n",
" raise NotImplementedError\n",
"\n",
"\n",
" # get prompt\n",
" prompt = load_prompt_template(prompt_file, tasks)\n",
" info = get_info_from_prompt(prompt)\n",
" check_task_prompt(prompt, tasks)\n",
" print(f\"get prompt from {prompt_file}\")\n",
"\n",
"\n",
" # Load training data samples\n",
" training_data_samples = []\n",
" for user in reindex_user_seq_dict:\n",
" items = reindex_user_seq_dict[user][:-2]\n",
" for i in range(len(items)):\n",
" if i == 0:\n",
" if skip_empty_his > 0:\n",
" continue\n",
" one_sample = dict()\n",
" one_sample['dataset'] = dataset\n",
" one_sample['user_id'] = user\n",
" if his_prefix > 0:\n",
" one_sample['target'] = 'item_' + items[i]\n",
" else:\n",
" one_sample['target'] = items[i]\n",
" if 'history' in info:\n",
" history = items[:i]\n",
" if max_his > 0:\n",
" history = history[-max_his:]\n",
" if his_prefix > 0:\n",
" one_sample['history'] = his_sep.join([\"item_\" + item_idx for item_idx in history])\n",
" else:\n",
" one_sample['history'] = his_sep.join(history)\n",
" training_data_samples.append(one_sample)\n",
" print(\"load training data\")\n",
" print(f'there are {len(training_data_samples)} samples in training data.')\n",
"\n",
" # construct sentences\n",
" for i in range(len(training_data_samples)):\n",
" one_sample = training_data_samples[i]\n",
" for task in tasks:\n",
" datapoint = {}\n",
" datapoint['task'] = dataset + task\n",
" datapoint['data_id'] = i\n",
" for pid in prompt[task]['seen']:\n",
" datapoint['instruction'] = prompt[task]['seen'][pid]['Input']\n",
" datapoint['input'] = prompt[task]['seen'][pid]['Input'].format(**one_sample)\n",
" datapoint['output'] = prompt[task]['seen'][pid]['Output'].format(**one_sample)\n",
" file_data['data'].append(datapoint.copy())\n",
"\n",
" print(\"data constructed\")\n",
" print(f\"there are {len(file_data['data'])} prompts in training data.\")\n",
"\n",
"\n",
" # save the data to json file\n",
" output_path = f'{dataset}_{tasks_}_{item_indexing}_train.json'\n",
"\n",
" with open(os.path.join(data_path, dataset, output_path), 'w') as openfile:\n",
" json.dump(file_data, openfile)\n"
],
"metadata": {
"id": "eZrN7hcgUlHA"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# from https://github.com/agiresearch/OpenP5/blob/main/src/generate_dataset_eval.py\n",
"\n",
"def generate_dataset_eval(\n",
" data_path: str = '../data',\n",
" item_indexing: str = \"sequential\",\n",
" tasks_: str = 'sequential,straightforward',\n",
" dataset: str = \"Beauty\",\n",
" prompt_file: str = \"../prompt.txt\",\n",
" sequential_order: str = \"original\",\n",
" collaborative_token_size: int = 200,\n",
" collaborative_cluster: int = 20,\n",
" collaborative_last_token: str = \"sequential\",\n",
" collaborative_float32: int = 0,\n",
" max_his: int = 10,\n",
" his_prefix: int = 1,\n",
" his_sep:str = \" , \",\n",
" skip_empty_his: str = 1,\n",
" mode: str = 'validation',\n",
" prompt_str: str = 'seen:0',\n",
"):\n",
" tasks = tasks_.split(',')\n",
"\n",
" file_data = dict()\n",
" file_data['data'] = []\n",
"\n",
" user_sequence = ReadLineFromFile(os.path.join(data_path, dataset, 'user_sequence.txt'))\n",
" user_sequence_dict = construct_user_sequence_dict(user_sequence)\n",
"\n",
" if item_indexing == 'sequential':\n",
" print(\"Reindex data with sequential indexing method\")\n",
" reindex_user_seq_dict, item_map = sequential_indexing(data_path, dataset, user_sequence_dict, sequential_order)\n",
" elif item_indexing == 'random':\n",
" print(\"Reindex data with random indexing method\")\n",
" reindex_user_seq_dict, item_map = random_indexing(data_path, dataset, user_sequence_dict)\n",
" elif item_indexing == 'collaborative':\n",
" print(f\"Reindex data with collaborative indexing method with token_size {collaborative_token_size} and {collaborative_cluster} cluster\")\n",
" reindex_user_seq_dict, item_map = collaborative_indexing(data_path, dataset, user_sequence_dict, \\\n",
" collaborative_token_size, collaborative_cluster, \\\n",
" collaborative_last_token, collaborative_float32)\n",
" else:\n",
" raise NotImplementedError\n",
"\n",
"\n",
" # get prompt\n",
" prompt = load_prompt_template(prompt_file, tasks)\n",
" info = get_info_from_prompt(prompt)\n",
" check_task_prompt(prompt, tasks)\n",
" print(f\"get prompt from {prompt_file}\")\n",
"\n",
" # Load data samples\n",
" if mode == 'validation':\n",
" data_samples = load_validation(\n",
" dataset=dataset,\n",
" max_his=max_his,\n",
" his_prefix=his_prefix,\n",
" his_sep=his_sep,\n",
" reindex_user_seq_dict=reindex_user_seq_dict,\n",
" info=info\n",
" )\n",
" prompt_info = prompt_str.split(':')\n",
" output_path = f'{dataset}_{tasks_}_{item_indexing}_validation_{prompt_str}.json'\n",
" elif mode == 'test':\n",
" data_samples = load_test(\n",
" dataset,\n",
" max_his,\n",
" his_prefix,\n",
" his_sep,\n",
" reindex_user_seq_dict,\n",
" info\n",
" )\n",
" prompt_info = prompt_str.split(':')\n",
" output_path = f'{dataset}_{tasks_}_{item_indexing}_test_{prompt_str}.json'\n",
" else:\n",
" raise NotImplementedError\n",
" print(f'there are {len(data_samples)} samples in {mode} data.')\n",
" print(prompt_info)\n",
"\n",
" # construct sentences\n",
" for i in range(len(data_samples)):\n",
" one_sample = data_samples[i]\n",
" for task in tasks:\n",
" datapoint = {}\n",
" datapoint['task'] = dataset + task\n",
" datapoint['instruction'] = prompt[task][prompt_info[0]][prompt_info[1]]['Input']\n",
" datapoint['input'] = prompt[task][prompt_info[0]][prompt_info[1]]['Input'].format(**one_sample)\n",
" datapoint['output'] = prompt[task][prompt_info[0]][prompt_info[1]]['Output'].format(**one_sample)\n",
" file_data['data'].append(datapoint.copy())\n",
"\n",
" print(\"data constructed\")\n",
" print(f\"there are {len(file_data['data'])} prompts in {mode} data.\")\n",
"\n",
"\n",
" # save the data to json file\n",
"\n",
" with open(os.path.join(data_path, dataset, output_path), 'w') as openfile:\n",
" json.dump(file_data, openfile)\n",
"\n",
"\n",
"def load_test(\n",
" dataset,\n",
" max_his,\n",
" his_prefix,\n",
" his_sep,\n",
" reindex_user_seq_dict,\n",
" info\n",
" ):\n",
"\n",
" data_samples = []\n",
" for user in reindex_user_seq_dict:\n",
" items = reindex_user_seq_dict[user]\n",
" one_sample = dict()\n",
" one_sample['dataset'] = dataset\n",
" one_sample['user_id'] = user\n",
" if his_prefix > 0:\n",
" one_sample['target'] = 'item_' + items[-1]\n",
" else:\n",
" one_sample['target'] = items[-1]\n",
" if 'history' in info:\n",
" history = items[:-1]\n",
" if max_his > 0:\n",
" history = history[-max_his:]\n",
" if his_prefix > 0:\n",
" one_sample['history'] = his_sep.join([\"item_\" + item_idx for item_idx in history])\n",
" else:\n",
" one_sample['history'] = his_sep.join(history)\n",
" data_samples.append(one_sample)\n",
" return data_samples\n",
"\n",
"def load_validation(\n",
" dataset,\n",
" max_his,\n",
" his_prefix,\n",
" his_sep,\n",
" reindex_user_seq_dict,\n",
" info,\n",
" ):\n",
" data_samples = []\n",
" for user in reindex_user_seq_dict:\n",
" items = reindex_user_seq_dict[user]\n",
" one_sample = dict()\n",
" one_sample['dataset'] = dataset\n",
" one_sample['user_id'] = user\n",
" if his_prefix > 0:\n",
" one_sample['target'] = 'item_' + items[-2]\n",
" else:\n",
" one_sample['target'] = items[-2]\n",
" if 'history' in info:\n",
" history = items[:-2]\n",
" if max_his > 0:\n",
" history = history[-max_his:]\n",
" if his_prefix > 0:\n",
" one_sample['history'] = his_sep.join([\"item_\" + item_idx for item_idx in history])\n",
" else:\n",
" one_sample['history'] = his_sep.join(history)\n",
" data_samples.append(one_sample)\n",
" return data_samples\n"
],
"metadata": {
"id": "PkcepAgwUz3K"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"target_datasets = [\"Beauty\", \"ML100K\"] # 対象のデータセットを選択する [Beauty ML100K ML1M Yelp Electronics Movies CDs Clothing Taobao LastFM]\n",
"\n",
"for dataset in target_datasets:\n",
" for indexing in [\"random\", \"sequential\", \"collaborative\"]:\n",
" generate_dataset(\n",
" dataset = dataset,\n",
" data_path = '/content/drive/MyDrive/OpenP5/data/',\n",
" item_indexing = indexing,\n",
" tasks_ = 'sequential,straightforward',\n",
" prompt_file = 'OpenP5/prompt.txt',\n",
" )\n",
" generate_dataset_eval(\n",
" dataset = dataset,\n",
" data_path = '/content/drive/MyDrive/OpenP5/data/',\n",
" item_indexing = indexing,\n",
" tasks_ = 'sequential,straightforward',\n",
" prompt_file = 'OpenP5/prompt.txt',\n",
" mode=\"validation\",\n",
" prompt_str=str(\"seen:0\")\n",
" )\n",
" generate_dataset_eval(\n",
" dataset = dataset,\n",
" data_path = '/content/drive/MyDrive/OpenP5/data/',\n",
" item_indexing = indexing,\n",
" tasks_ = 'sequential,straightforward',\n",
" prompt_file = 'OpenP5/prompt.txt',\n",
" mode=\"test\",\n",
" prompt_str=str(\"seen:0\")\n",
" )\n",
" generate_dataset_eval(\n",
" dataset = dataset,\n",
" data_path = '/content/drive/MyDrive/OpenP5/data/',\n",
" item_indexing = indexing,\n",
" tasks_ = 'sequential,straightforward',\n",
" prompt_file = 'OpenP5/prompt.txt',\n",
" mode=\"test\",\n",
" prompt_str=str(\"unseen:0\")\n",
" )\n"
],
"metadata": {
"id": "cRsdy65NXmNc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "Z04m0wXqbemx"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment