Skip to content

Instantly share code, notes, and snippets.

@tomohxx
Last active January 11, 2021 09:04
Show Gist options
  • Save tomohxx/286d73b383892ece9c9e152056b33de6 to your computer and use it in GitHub Desktop.
Save tomohxx/286d73b383892ece9c9e152056b33de6 to your computer and use it in GitHub Desktop.
牌の危険度推定(多変量RNN)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 牌の危険度推定(多変量RNN)\n",
"\n",
"## 概要\n",
"- 立直者の捨て牌から牌の危険度を推定する\n",
"- マルチラベルクラス分類"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import math\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"model = keras.Sequential()\n",
"model.add(layers.Masking(input_shape=(None, 35), mask_value=-1.0))\n",
"model.add(layers.SimpleRNN(35, return_sequences=False))\n",
"model.add(layers.Dense(34, activation=\"sigmoid\"))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"masking (Masking) (None, None, 35) 0 \n",
"_________________________________________________________________\n",
"simple_rnn (SimpleRNN) (None, 35) 2485 \n",
"_________________________________________________________________\n",
"dense (Dense) (None, 34) 1224 \n",
"=================================================================\n",
"Total params: 3,709\n",
"Trainable params: 3,709\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"model.compile(\n",
" optimizer=keras.optimizers.RMSprop(),\n",
" loss=keras.losses.binary_crossentropy,\n",
" metrics=[keras.metrics.binary_accuracy],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('riichi_river.csv', header=None)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(129932, 20, 35)\n",
"(129932, 34)\n"
]
}
],
"source": [
"x, y = [], []\n",
"\n",
"for row in df.itertuples(name=None, index=None):\n",
" # 待ち牌のリスト\n",
" waits = [int(el) for el in row[1].split()]\n",
" # 捨て牌のリスト\n",
" discards = [int(el) for el in row[2].split()[0:][::2]]\n",
" # ツモ切りフラグのリスト(ツモ切り: 1, 手出し: 0)\n",
" tsumos = [int(el) for el in row[2].split()][1:][::2]\n",
" \n",
" tmp_x = [[0]*35]*20\n",
" \n",
" for i in range(len(discards)):\n",
" tmp_x[i][discards[i]] = 1\n",
" tmp_x[i][-1] = tsumos[i]\n",
" \n",
" for i in range(len(discards), 20):\n",
" tmp_x[i] = [-1 for _ in range(35)]\n",
" \n",
" tmp_y = [1 if i in waits else 0 for i in range(34)]\n",
"\n",
" x.append(tmp_x)\n",
" y.append(tmp_y)\n",
" \n",
"x = np.array(x)\n",
"y = np.array(y)\n",
"\n",
"print(x.shape)\n",
"print(y.shape)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=0)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.2344 - binary_accuracy: 0.9313\n",
"Epoch 2/20\n",
"1625/1625 [==============================] - 16s 10ms/step - loss: 0.1947 - binary_accuracy: 0.9470\n",
"Epoch 3/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.1920 - binary_accuracy: 0.9470\n",
"Epoch 4/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.1903 - binary_accuracy: 0.9471\n",
"Epoch 5/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.1895 - binary_accuracy: 0.9470\n",
"Epoch 6/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.1887 - binary_accuracy: 0.9471\n",
"Epoch 7/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.1888 - binary_accuracy: 0.9470\n",
"Epoch 8/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.1882 - binary_accuracy: 0.9471\n",
"Epoch 9/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.1881 - binary_accuracy: 0.9471\n",
"Epoch 10/20\n",
"1625/1625 [==============================] - 16s 10ms/step - loss: 0.1878 - binary_accuracy: 0.9471\n",
"Epoch 11/20\n",
"1625/1625 [==============================] - 18s 11ms/step - loss: 0.1874 - binary_accuracy: 0.9471\n",
"Epoch 12/20\n",
"1625/1625 [==============================] - 18s 11ms/step - loss: 0.1872 - binary_accuracy: 0.9472\n",
"Epoch 13/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.1870 - binary_accuracy: 0.9471\n",
"Epoch 14/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.1873 - binary_accuracy: 0.9470\n",
"Epoch 15/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.1873 - binary_accuracy: 0.9470\n",
"Epoch 16/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.1870 - binary_accuracy: 0.9471\n",
"Epoch 17/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.1867 - binary_accuracy: 0.9471\n",
"Epoch 18/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.1868 - binary_accuracy: 0.9470\n",
"Epoch 19/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.1867 - binary_accuracy: 0.9471\n",
"Epoch 20/20\n",
"1625/1625 [==============================] - 15s 9ms/step - loss: 0.1866 - binary_accuracy: 0.9470\n"
]
}
],
"source": [
"history = model.fit(x_train, y_train, batch_size=64, epochs=20)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate on test data\n",
"407/407 [==============================] - 2s 4ms/step - loss: 0.1873 - binary_accuracy: 0.9471\n",
"test loss, test acc: [0.18734900653362274, 0.9471310973167419]\n"
]
}
],
"source": [
"# Evaluate the model on the test data using `evaluate`\n",
"print(\"Evaluate on test data\")\n",
"results = model.evaluate(x_test, y_test, batch_size=64)\n",
"print(\"test loss, test acc:\", results)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"predictions = model.predict(x[1].reshape(1, 20, 35))[0]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# 危険度をプロット\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# 当たりの牌にマーク\n",
"mark_point = [i for i in range(34) if y[1][i]]\n",
"\n",
"plt.plot(range(34), predictions, marker=\"o\", markevery= mark_point)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment