Created
May 28, 2023 09:38
-
-
Save messefor/4b233d95137177d7a92960864137cef0 to your computer and use it in GitHub Desktop.
bert_sim_trial.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"private_outputs": true, | |
"provenance": [], | |
"gpuType": "T4" | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "5NCwotIcqw7f" | |
}, | |
"outputs": [], | |
"source": [ | |
"# 4-1\n", | |
"# !pip install transformers==4.5.0 # うまくいかない(tokenizerがはいらない)\n", | |
"! pip install transformers # OK\n", | |
"! pip install fugashi==1.1.0 ipadic==1.0.0 # OK" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 4-2\n", | |
"from pprint import pprint\n", | |
"import torch\n", | |
"from transformers import BertJapaneseTokenizer, BertModel" | |
], | |
"metadata": { | |
"id": "0ekbW1fxq7QK" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"\n", | |
"# 4-3\n", | |
"model_name = 'cl-tohoku/bert-base-japanese-whole-word-masking'\n", | |
"tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)" | |
], | |
"metadata": { | |
"id": "zv34ZUQWrLxT" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(tokenizer.tokenize('私は今日は機械学習モデルであるBERTの勉強をしている。'))\n", | |
"print(tokenizer.tokenize('パン屋でパンを買って食べよう。'))" | |
], | |
"metadata": { | |
"id": "0IuwZsmRse2I" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"input_ids = tokenizer.encode('私は今日は機械学習モデルであるBERTの勉強をしている。')\n", | |
"print(input_ids)" | |
], | |
"metadata": { | |
"id": "0xtrwfxxslWb" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(tokenizer.convert_ids_to_tokens(input_ids))" | |
], | |
"metadata": { | |
"id": "jSZf6ruTta8x" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"text = '明日は自然言語処理の勉強をしよう。'\n", | |
"encoding = tokenizer(\n", | |
" text, max_length=12, padding='max_length', truncation=True\n", | |
")\n", | |
"pprint(encoding)\n", | |
"tokens = tokenizer.convert_ids_to_tokens(encoding['input_ids'])\n", | |
"pprint(tokens)" | |
], | |
"metadata": { | |
"id": "OOSADh3gtSbh" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"encoding = tokenizer(\n", | |
" text, max_length=6, padding='max_length', truncation=True\n", | |
")\n", | |
"pprint(encoding)\n", | |
"tokens = tokenizer.convert_ids_to_tokens(encoding['input_ids'])\n", | |
"pprint(tokens)" | |
], | |
"metadata": { | |
"id": "ffXZywZXt5Ge" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"encoding = tokenizer(\n", | |
" text, max_length=20, padding='max_length', truncation=True\n", | |
")\n", | |
"pprint(encoding)\n", | |
"tokens = tokenizer.convert_ids_to_tokens(encoding['input_ids'])\n", | |
"pprint(tokens)" | |
], | |
"metadata": { | |
"id": "EKaEMNmEvLa4" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"text_list = ['私は今日も楽しく集中して勉強している', '明日は自然言語処理の勉強をしよう']\n", | |
"encoding = tokenizer(text_list, padding='longest', return_tensors='pt')\n", | |
"encoding" | |
], | |
"metadata": { | |
"id": "ZLr3xBO7vODx" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"encoding['input_ids']" | |
], | |
"metadata": { | |
"id": "2FEIt0Gfv8kx" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"MAX_LENGTH = 32\n", | |
"encoding = tokenizer(text_list, max_length=MAX_LENGTH, padding='max_length', truncation=True, return_tensors='pt')\n", | |
"encoding['input_ids']" | |
], | |
"metadata": { | |
"id": "4MwcS18-wLb5" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 4-14\n", | |
"model_name = 'cl-tohoku/bert-base-japanese-whole-word-masking'\n", | |
"bert = BertModel.from_pretrained(model_name)\n", | |
"\n", | |
"bert = bert.cuda()" | |
], | |
"metadata": { | |
"id": "44aToexFwjlS" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "uFMIoDS10LrF" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(bert.config)" | |
], | |
"metadata": { | |
"id": "9zs1hkf3wzbd" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"MAX_LENGTH = 32\n", | |
"text_list = ['私は今日も楽しく集中して勉強している', \n", | |
" '明日は自然言語処理の勉強をしよう', \n", | |
" 'アイスキャンディー食べたい'] # bach_size = 3\n", | |
"tokenizer = BertJapaneseTokenizer.from_pretrained(bert.name_or_path)\n", | |
"encodings = tokenizer(text_list, max_length=MAX_LENGTH, padding='max_length', truncation=True, return_tensors='pt')" | |
], | |
"metadata": { | |
"id": "fQSO8gqPzREQ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"encodings = {k: v.cuda() for k, v in encodings.items()}" | |
], | |
"metadata": { | |
"id": "y_9AdyAm0BrA" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# pytorch ではtensorをcuda() メソッドでdeviceに乗せることができるらしい\n", | |
"# torch.Tensor([1, 2, 3]).cuda()" | |
], | |
"metadata": { | |
"id": "xJXT69i70kdp" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"output = bert(**encodings)" | |
], | |
"metadata": { | |
"id": "EQ4ZCh9q015F" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"output.last_hidden_state.shape" | |
], | |
"metadata": { | |
"id": "s8yk_cgG1cqE" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# https://pytorch.org/docs/stable/generated/torch.no_grad.html\n", | |
"with torch.no_grad():\n", | |
" output = bert(**encodings)\n", | |
" last_hidden_state = output.last_hidden_state" | |
], | |
"metadata": { | |
"id": "zyunCjMX1dtP" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(last_hidden_state.shape)\n", | |
"last_hidden_state" | |
], | |
"metadata": { | |
"id": "Hgoklk-t2brI" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"last_hidden_state = last_hidden_state.cpu()\n", | |
"last_hidden_state.numpy()" | |
], | |
"metadata": { | |
"id": "NA9B37Dw2mJN" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 10-2 文章ベクトル" | |
], | |
"metadata": { | |
"id": "u0nyDfaG3fOL" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# ! wget https://www.rondhuit.com/download/ldcc-20140209.tar.gz\n", | |
"# ! tar -zxf ldcc-20140209.tar.gz" | |
], | |
"metadata": { | |
"id": "5rtYpwex4ZtR" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"! ls -la text/dokujo-tsushin/dokujo-tsushin-6763585.txt" | |
], | |
"metadata": { | |
"id": "WhUdfOur5NGq" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import random\n", | |
"import glob\n", | |
"from tqdm import tqdm\n", | |
"import numpy as np\n", | |
"from sklearn.manifold import TSNE\n", | |
"from sklearn.decomposition import PCA\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"import torch\n", | |
"from torch.utils.data import DataLoader\n", | |
"from transformers import BertJapaneseTokenizer, BertModel\n", | |
"\n", | |
"MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'" | |
], | |
"metadata": { | |
"id": "0syE5I8e3e7f" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"bert = BertModel.from_pretrained(MODEL_NAME).cuda()" | |
], | |
"metadata": { | |
"id": "WzHbVkoY6Xpe" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"bert.config" | |
], | |
"metadata": { | |
"id": "SL5zErhU7rN8" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"tokenizer = BertJapaneseTokenizer.from_pretrained(bert.name_or_path)" | |
], | |
"metadata": { | |
"id": "-6Shsy1h7ZpI" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"category_list = [\n", | |
" 'dokujo-tsushin'\n", | |
" , 'it-life-hack'\n", | |
" , 'kaden-channel'\n", | |
" , 'livedoor-homme'\n", | |
" , 'movie-enter'\n", | |
" , 'peachy'\n", | |
" , 'smax'\n", | |
" , 'sports-watch'\n", | |
" , 'topic-news'\n", | |
"]" | |
], | |
"metadata": { | |
"id": "BEEeisFz2pcz" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"tokenizer_args = dict(max_length=256,\n", | |
" padding='max_length',\n", | |
" truncation=True,\n", | |
" return_tensors='pt')" | |
], | |
"metadata": { | |
"id": "6xK4o-lZ67vh" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sentence_vectors_mean = []\n", | |
"sentence_vectors_cls = []\n", | |
"labels = []\n", | |
"\n", | |
"for label, category in enumerate(tqdm(category_list)):\n", | |
"\n", | |
" for txt_path in glob.glob(f'text/{category}/{category}*.txt'):\n", | |
" # print(txt_path)\n", | |
" with open(txt_path) as f:\n", | |
" lines = f.read().splitlines()\n", | |
" text = '\\n'.join(lines[3:])\n", | |
" encoding = tokenizer(text, **tokenizer_args)\n", | |
" encoding = {k: v.cuda() for k, v in encoding.items()}\n", | |
" attention_mask = encoding['attention_mask']\n", | |
" with torch.no_grad():\n", | |
" output = bert(**encoding)\n", | |
" last_hidden_state = output.last_hidden_state\n", | |
" avg_hidden_state = \\\n", | |
" (last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(1, keepdim=True)\n", | |
"\n", | |
" # print(last_hidden_state.shape)\n", | |
" # print(avg_hidden_state.shape)\n", | |
" # print(avg_hidden_state)\n", | |
"\n", | |
" sentence_vectors_mean.append(avg_hidden_state[0].cpu().numpy()) # mean\n", | |
" sentence_vectors_cls.append(last_hidden_state[0, 0, :].cpu().numpy()) # cls\n", | |
" labels.append(label)" | |
], | |
"metadata": { | |
"id": "2JtZ6aO05lkw" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sentence_vectors_mean = np.vstack(sentence_vectors_mean)\n", | |
"sentence_vectors_cls = np.vstack(sentence_vectors_cls)\n", | |
"labels = np.array(labels)" | |
], | |
"metadata": { | |
"id": "MzNM2lcj-0wH" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(sentence_vectors_mean.shape)\n", | |
"print(sentence_vectors_cls.shape)\n" | |
], | |
"metadata": { | |
"id": "NWgRnbQS6EtO" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sentence_vectors_tsne = TSNE(n_components=2).fit_transform(sentence_vectors_mean)\n", | |
"sentence_vectors_tsne.shape" | |
], | |
"metadata": { | |
"id": "hHdFZCCCJe-z" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"fig = plt.figure(figsize=(10, 10))\n", | |
"for label in range(9):\n", | |
" ax = fig.add_subplot(3, 3, label + 1)\n", | |
" ax.plot(\n", | |
" sentence_vectors_tsne[:, 0]\n", | |
" , sentence_vectors_tsne[:, 1]\n", | |
" , 'o'\n", | |
" , markersize=1\n", | |
" , color=[0.7, 0.7, 0.7]\n", | |
" )\n", | |
" index = labels == label\n", | |
" ax.plot(\n", | |
" sentence_vectors_tsne[index, 0]\n", | |
" , sentence_vectors_tsne[index, 1]\n", | |
" , 'o'\n", | |
" , markersize=1\n", | |
" , color='k'\n", | |
" )\n", | |
" ax.set_title(category_list[label])\n", | |
"fig.tight_layout()" | |
], | |
"metadata": { | |
"id": "n7XsI-K6J0Ac" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sentence_vectors_tsne = TSNE(n_components=2).fit_transform(sentence_vectors_cls)\n", | |
"sentence_vectors_tsne.shape" | |
], | |
"metadata": { | |
"id": "OstHyn-CNf_B" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"fig = plt.figure(figsize=(10, 10))\n", | |
"for label in range(9):\n", | |
" ax = fig.add_subplot(3, 3, label + 1)\n", | |
" ax.plot(\n", | |
" sentence_vectors_tsne[:, 0]\n", | |
" , sentence_vectors_tsne[:, 1]\n", | |
" , 'o'\n", | |
" , markersize=1\n", | |
" , color=[0.7, 0.7, 0.7]\n", | |
" )\n", | |
" index = labels == label\n", | |
" ax.plot(\n", | |
" sentence_vectors_tsne[index, 0]\n", | |
" , sentence_vectors_tsne[index, 1]\n", | |
" , 'o'\n", | |
" , markersize=1\n", | |
" , color='k'\n", | |
" )\n", | |
" ax.set_title(category_list[label])\n", | |
"fig.tight_layout()" | |
], | |
"metadata": { | |
"id": "P29iH6-aRs6x" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 類似度検索" | |
], | |
"metadata": { | |
"id": "NgemHh9KNilU" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sentence_vectors = sentence_vectors_mean" | |
], | |
"metadata": { | |
"id": "0HHb55adRw8c" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"norm = np.linalg.norm(sentence_vectors, axis=1, keepdims=True)\n", | |
"sentence_vectors_normalized = sentence_vectors / norm" | |
], | |
"metadata": { | |
"id": "3Hu2LtyiKs-K" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sentence_vectors_normalized.shape" | |
], | |
"metadata": { | |
"id": "Lpg8eMv8VQmo" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"query_nums = np.random.randint(sentence_vectors_normalized.shape[0], size=10)\n", | |
"query_vecs = sentence_vectors_normalized[query_nums, :]\n", | |
"\n", | |
"index = np.ones(sentence_vectors_normalized.shape[0], dtype=bool)\n", | |
"index[query_nums] = False\n", | |
"docs_vecs = sentence_vectors_normalized[index, :]" | |
], | |
"metadata": { | |
"id": "YoslOQdTR0ma" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"query_vecs.shape\n", | |
"docs_vecs.shape" | |
], | |
"metadata": { | |
"id": "qiLjaW3LVJku" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"\n", | |
"[category_list[i] for i in labels[query_nums]]" | |
], | |
"metadata": { | |
"id": "SdJeRYWaWK5U" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 類似度行列(クエリ x ドキュメント)" | |
], | |
"metadata": { | |
"id": "I34GrnTfVeCo" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sim_matrix_qa = docs_vecs.dot(query_vecs.T)" | |
], | |
"metadata": { | |
"id": "j9W10R3xVtcy" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"result_top10 = np.argsort(-1 * sim_matrix_qa, axis=0)[:10, :]\n", | |
"result_top10[: , 0]" | |
], | |
"metadata": { | |
"id": "gOiIQUozV555" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for i_query in range(10):\n", | |
" print([category_list[labels[i]] for i in result_top10[: , i_query]])" | |
], | |
"metadata": { | |
"id": "U2e3KmBXWmD7" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"\n", | |
"## 類似度行列(全体)" | |
], | |
"metadata": { | |
"id": "RPUIb43OVXFm" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sim_matrix = sentence_vectors_normalized.dot(sentence_vectors_normalized.T)" | |
], | |
"metadata": { | |
"id": "lUSvUbqdNdJI" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"np.fill_diagonal(sim_matrix, -1)\n", | |
"similar_news = sim_matrix.argmax(axis=1)" | |
], | |
"metadata": { | |
"id": "Kn5rlQ5bOWgc" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(sim_matrix.shape)\n", | |
"np.argsort(-1 * sim_matrix, axis=0)" | |
], | |
"metadata": { | |
"id": "nCd1PBMqOoNV" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"similar_news" | |
], | |
"metadata": { | |
"id": "CnuA4ieGOgQO" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"fig, ax = plt.subplots()\n", | |
"heatmap = ax.pcolor(sim_matrix, cmap=plt.cm.Blues)" | |
], | |
"metadata": { | |
"id": "6HHgHPRCOhAN" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "EjXIwINbQJPK" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment