Skip to content

Instantly share code, notes, and snippets.

@messefor
Created May 28, 2023 09:38
Show Gist options
  • Save messefor/4b233d95137177d7a92960864137cef0 to your computer and use it in GitHub Desktop.
Save messefor/4b233d95137177d7a92960864137cef0 to your computer and use it in GitHub Desktop.
bert_sim_trial.ipynb
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"private_outputs": true,
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5NCwotIcqw7f"
},
"outputs": [],
"source": [
"# 4-1\n",
"# !pip install transformers==4.5.0 # うまくいかない(tokenizerがはいらない)\n",
"! pip install transformers # OK\n",
"! pip install fugashi==1.1.0 ipadic==1.0.0 # OK"
]
},
{
"cell_type": "code",
"source": [
"# 4-2\n",
"from pprint import pprint\n",
"import torch\n",
"from transformers import BertJapaneseTokenizer, BertModel"
],
"metadata": {
"id": "0ekbW1fxq7QK"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"\n",
"# 4-3\n",
"model_name = 'cl-tohoku/bert-base-japanese-whole-word-masking'\n",
"tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)"
],
"metadata": {
"id": "zv34ZUQWrLxT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(tokenizer.tokenize('私は今日は機械学習モデルであるBERTの勉強をしている。'))\n",
"print(tokenizer.tokenize('パン屋でパンを買って食べよう。'))"
],
"metadata": {
"id": "0IuwZsmRse2I"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"input_ids = tokenizer.encode('私は今日は機械学習モデルであるBERTの勉強をしている。')\n",
"print(input_ids)"
],
"metadata": {
"id": "0xtrwfxxslWb"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(tokenizer.convert_ids_to_tokens(input_ids))"
],
"metadata": {
"id": "jSZf6ruTta8x"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"text = '明日は自然言語処理の勉強をしよう。'\n",
"encoding = tokenizer(\n",
" text, max_length=12, padding='max_length', truncation=True\n",
")\n",
"pprint(encoding)\n",
"tokens = tokenizer.convert_ids_to_tokens(encoding['input_ids'])\n",
"pprint(tokens)"
],
"metadata": {
"id": "OOSADh3gtSbh"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"encoding = tokenizer(\n",
" text, max_length=6, padding='max_length', truncation=True\n",
")\n",
"pprint(encoding)\n",
"tokens = tokenizer.convert_ids_to_tokens(encoding['input_ids'])\n",
"pprint(tokens)"
],
"metadata": {
"id": "ffXZywZXt5Ge"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"encoding = tokenizer(\n",
" text, max_length=20, padding='max_length', truncation=True\n",
")\n",
"pprint(encoding)\n",
"tokens = tokenizer.convert_ids_to_tokens(encoding['input_ids'])\n",
"pprint(tokens)"
],
"metadata": {
"id": "EKaEMNmEvLa4"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"text_list = ['私は今日も楽しく集中して勉強している', '明日は自然言語処理の勉強をしよう']\n",
"encoding = tokenizer(text_list, padding='longest', return_tensors='pt')\n",
"encoding"
],
"metadata": {
"id": "ZLr3xBO7vODx"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"encoding['input_ids']"
],
"metadata": {
"id": "2FEIt0Gfv8kx"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"MAX_LENGTH = 32\n",
"encoding = tokenizer(text_list, max_length=MAX_LENGTH, padding='max_length', truncation=True, return_tensors='pt')\n",
"encoding['input_ids']"
],
"metadata": {
"id": "4MwcS18-wLb5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 4-14\n",
"model_name = 'cl-tohoku/bert-base-japanese-whole-word-masking'\n",
"bert = BertModel.from_pretrained(model_name)\n",
"\n",
"bert = bert.cuda()"
],
"metadata": {
"id": "44aToexFwjlS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "uFMIoDS10LrF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(bert.config)"
],
"metadata": {
"id": "9zs1hkf3wzbd"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"MAX_LENGTH = 32\n",
"text_list = ['私は今日も楽しく集中して勉強している', \n",
" '明日は自然言語処理の勉強をしよう', \n",
" 'アイスキャンディー食べたい'] # bach_size = 3\n",
"tokenizer = BertJapaneseTokenizer.from_pretrained(bert.name_or_path)\n",
"encodings = tokenizer(text_list, max_length=MAX_LENGTH, padding='max_length', truncation=True, return_tensors='pt')"
],
"metadata": {
"id": "fQSO8gqPzREQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"encodings = {k: v.cuda() for k, v in encodings.items()}"
],
"metadata": {
"id": "y_9AdyAm0BrA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# pytorch ではtensorをcuda() メソッドでdeviceに乗せることができるらしい\n",
"# torch.Tensor([1, 2, 3]).cuda()"
],
"metadata": {
"id": "xJXT69i70kdp"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"output = bert(**encodings)"
],
"metadata": {
"id": "EQ4ZCh9q015F"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"output.last_hidden_state.shape"
],
"metadata": {
"id": "s8yk_cgG1cqE"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# https://pytorch.org/docs/stable/generated/torch.no_grad.html\n",
"with torch.no_grad():\n",
" output = bert(**encodings)\n",
" last_hidden_state = output.last_hidden_state"
],
"metadata": {
"id": "zyunCjMX1dtP"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(last_hidden_state.shape)\n",
"last_hidden_state"
],
"metadata": {
"id": "Hgoklk-t2brI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"last_hidden_state = last_hidden_state.cpu()\n",
"last_hidden_state.numpy()"
],
"metadata": {
"id": "NA9B37Dw2mJN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 10-2 文章ベクトル"
],
"metadata": {
"id": "u0nyDfaG3fOL"
}
},
{
"cell_type": "code",
"source": [
"# ! wget https://www.rondhuit.com/download/ldcc-20140209.tar.gz\n",
"# ! tar -zxf ldcc-20140209.tar.gz"
],
"metadata": {
"id": "5rtYpwex4ZtR"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"! ls -la text/dokujo-tsushin/dokujo-tsushin-6763585.txt"
],
"metadata": {
"id": "WhUdfOur5NGq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import random\n",
"import glob\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"from sklearn.manifold import TSNE\n",
"from sklearn.decomposition import PCA\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"from transformers import BertJapaneseTokenizer, BertModel\n",
"\n",
"MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'"
],
"metadata": {
"id": "0syE5I8e3e7f"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"bert = BertModel.from_pretrained(MODEL_NAME).cuda()"
],
"metadata": {
"id": "WzHbVkoY6Xpe"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"bert.config"
],
"metadata": {
"id": "SL5zErhU7rN8"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"tokenizer = BertJapaneseTokenizer.from_pretrained(bert.name_or_path)"
],
"metadata": {
"id": "-6Shsy1h7ZpI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"category_list = [\n",
" 'dokujo-tsushin'\n",
" , 'it-life-hack'\n",
" , 'kaden-channel'\n",
" , 'livedoor-homme'\n",
" , 'movie-enter'\n",
" , 'peachy'\n",
" , 'smax'\n",
" , 'sports-watch'\n",
" , 'topic-news'\n",
"]"
],
"metadata": {
"id": "BEEeisFz2pcz"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"tokenizer_args = dict(max_length=256,\n",
" padding='max_length',\n",
" truncation=True,\n",
" return_tensors='pt')"
],
"metadata": {
"id": "6xK4o-lZ67vh"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sentence_vectors_mean = []\n",
"sentence_vectors_cls = []\n",
"labels = []\n",
"\n",
"for label, category in enumerate(tqdm(category_list)):\n",
"\n",
" for txt_path in glob.glob(f'text/{category}/{category}*.txt'):\n",
" # print(txt_path)\n",
" with open(txt_path) as f:\n",
" lines = f.read().splitlines()\n",
" text = '\\n'.join(lines[3:])\n",
" encoding = tokenizer(text, **tokenizer_args)\n",
" encoding = {k: v.cuda() for k, v in encoding.items()}\n",
" attention_mask = encoding['attention_mask']\n",
" with torch.no_grad():\n",
" output = bert(**encoding)\n",
" last_hidden_state = output.last_hidden_state\n",
" avg_hidden_state = \\\n",
" (last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(1, keepdim=True)\n",
"\n",
" # print(last_hidden_state.shape)\n",
" # print(avg_hidden_state.shape)\n",
" # print(avg_hidden_state)\n",
"\n",
" sentence_vectors_mean.append(avg_hidden_state[0].cpu().numpy()) # mean\n",
" sentence_vectors_cls.append(last_hidden_state[0, 0, :].cpu().numpy()) # cls\n",
" labels.append(label)"
],
"metadata": {
"id": "2JtZ6aO05lkw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sentence_vectors_mean = np.vstack(sentence_vectors_mean)\n",
"sentence_vectors_cls = np.vstack(sentence_vectors_cls)\n",
"labels = np.array(labels)"
],
"metadata": {
"id": "MzNM2lcj-0wH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(sentence_vectors_mean.shape)\n",
"print(sentence_vectors_cls.shape)\n"
],
"metadata": {
"id": "NWgRnbQS6EtO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sentence_vectors_tsne = TSNE(n_components=2).fit_transform(sentence_vectors_mean)\n",
"sentence_vectors_tsne.shape"
],
"metadata": {
"id": "hHdFZCCCJe-z"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"fig = plt.figure(figsize=(10, 10))\n",
"for label in range(9):\n",
" ax = fig.add_subplot(3, 3, label + 1)\n",
" ax.plot(\n",
" sentence_vectors_tsne[:, 0]\n",
" , sentence_vectors_tsne[:, 1]\n",
" , 'o'\n",
" , markersize=1\n",
" , color=[0.7, 0.7, 0.7]\n",
" )\n",
" index = labels == label\n",
" ax.plot(\n",
" sentence_vectors_tsne[index, 0]\n",
" , sentence_vectors_tsne[index, 1]\n",
" , 'o'\n",
" , markersize=1\n",
" , color='k'\n",
" )\n",
" ax.set_title(category_list[label])\n",
"fig.tight_layout()"
],
"metadata": {
"id": "n7XsI-K6J0Ac"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sentence_vectors_tsne = TSNE(n_components=2).fit_transform(sentence_vectors_cls)\n",
"sentence_vectors_tsne.shape"
],
"metadata": {
"id": "OstHyn-CNf_B"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"fig = plt.figure(figsize=(10, 10))\n",
"for label in range(9):\n",
" ax = fig.add_subplot(3, 3, label + 1)\n",
" ax.plot(\n",
" sentence_vectors_tsne[:, 0]\n",
" , sentence_vectors_tsne[:, 1]\n",
" , 'o'\n",
" , markersize=1\n",
" , color=[0.7, 0.7, 0.7]\n",
" )\n",
" index = labels == label\n",
" ax.plot(\n",
" sentence_vectors_tsne[index, 0]\n",
" , sentence_vectors_tsne[index, 1]\n",
" , 'o'\n",
" , markersize=1\n",
" , color='k'\n",
" )\n",
" ax.set_title(category_list[label])\n",
"fig.tight_layout()"
],
"metadata": {
"id": "P29iH6-aRs6x"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 類似度検索"
],
"metadata": {
"id": "NgemHh9KNilU"
}
},
{
"cell_type": "code",
"source": [
"sentence_vectors = sentence_vectors_mean"
],
"metadata": {
"id": "0HHb55adRw8c"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"norm = np.linalg.norm(sentence_vectors, axis=1, keepdims=True)\n",
"sentence_vectors_normalized = sentence_vectors / norm"
],
"metadata": {
"id": "3Hu2LtyiKs-K"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sentence_vectors_normalized.shape"
],
"metadata": {
"id": "Lpg8eMv8VQmo"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"query_nums = np.random.randint(sentence_vectors_normalized.shape[0], size=10)\n",
"query_vecs = sentence_vectors_normalized[query_nums, :]\n",
"\n",
"index = np.ones(sentence_vectors_normalized.shape[0], dtype=bool)\n",
"index[query_nums] = False\n",
"docs_vecs = sentence_vectors_normalized[index, :]"
],
"metadata": {
"id": "YoslOQdTR0ma"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"query_vecs.shape\n",
"docs_vecs.shape"
],
"metadata": {
"id": "qiLjaW3LVJku"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"\n",
"[category_list[i] for i in labels[query_nums]]"
],
"metadata": {
"id": "SdJeRYWaWK5U"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 類似度行列(クエリ x ドキュメント)"
],
"metadata": {
"id": "I34GrnTfVeCo"
}
},
{
"cell_type": "code",
"source": [
"sim_matrix_qa = docs_vecs.dot(query_vecs.T)"
],
"metadata": {
"id": "j9W10R3xVtcy"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"result_top10 = np.argsort(-1 * sim_matrix_qa, axis=0)[:10, :]\n",
"result_top10[: , 0]"
],
"metadata": {
"id": "gOiIQUozV555"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"for i_query in range(10):\n",
" print([category_list[labels[i]] for i in result_top10[: , i_query]])"
],
"metadata": {
"id": "U2e3KmBXWmD7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"\n",
"## 類似度行列(全体)"
],
"metadata": {
"id": "RPUIb43OVXFm"
}
},
{
"cell_type": "code",
"source": [
"sim_matrix = sentence_vectors_normalized.dot(sentence_vectors_normalized.T)"
],
"metadata": {
"id": "lUSvUbqdNdJI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"np.fill_diagonal(sim_matrix, -1)\n",
"similar_news = sim_matrix.argmax(axis=1)"
],
"metadata": {
"id": "Kn5rlQ5bOWgc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(sim_matrix.shape)\n",
"np.argsort(-1 * sim_matrix, axis=0)"
],
"metadata": {
"id": "nCd1PBMqOoNV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"similar_news"
],
"metadata": {
"id": "CnuA4ieGOgQO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"fig, ax = plt.subplots()\n",
"heatmap = ax.pcolor(sim_matrix, cmap=plt.cm.Blues)"
],
"metadata": {
"id": "6HHgHPRCOhAN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "EjXIwINbQJPK"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment