Created
January 27, 2020 16:27
-
-
Save trycycle/96a784f29b2e77b41f6d97ce4a2b4d0f to your computer and use it in GitHub Desktop.
evaluate.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.6" | |
}, | |
"colab": { | |
"name": "evaluate.ipynb", | |
"provenance": [], | |
"include_colab_link": true | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/trycycle/96a784f29b2e77b41f6d97ce4a2b4d0f/evaluate.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "kJ5G5uCF00FP", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import pickle\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"from sklearn.svm import SVC\n", | |
"from sklearn.model_selection import StratifiedKFold\n", | |
"from sklearn.model_selection import cross_validate" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "2KvNmQnr00FV", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## 評価関数の準備\n", | |
"交差検定(クロスバリデーション)でPrecision(適合率)等を評価する" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ypn3tezt00FV", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def evaluate_with_cross_valiation(model, X, y, n_splits=5):\n", | |
" \"\"\" 交差検定によってPrecision, Recall, Accuracy, F値を評価\n", | |
" \"\"\"\n", | |
" # 評価指標\n", | |
" score_funcs = {\n", | |
" 'Precision': 'precision_macro',\n", | |
" 'Recall': 'recall_macro',\n", | |
" 'Accuracy': 'accuracy',\n", | |
" 'F1': 'f1_macro'\n", | |
" }\n", | |
" \n", | |
" # k-Fold で汎化性能を評価する\n", | |
" kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=30)\n", | |
" \n", | |
" scores = cross_validate(model, X, y, cv=kf, scoring=score_funcs)\n", | |
" result = {metric: np.mean(values) for metric, values in scores.items()}\n", | |
" return result" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "em0nTopB00FY", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## ベクトルの準備\n", | |
"X_は説明変数.\n", | |
"* X_tfidfはシーン内のコメント情報をTF-IDF値で重み付けした特徴ベクトル\n", | |
"* X_subはシーン内のコメント数,文末がwで終わっているコメントの割合,文末が笑で終わっているコメントの割合,文末が草で終わっているコメントの割合,文末が「w,笑,草」のいずれかで終わっているコメントの割合を次元とする特徴ベクトル\n", | |
"* XはX_tfidfとX_subの結合ベクトル" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "GaXbdTZ200FZ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"with open(\"data/tfidf-vector.dat\", \"rb\") as f:\n", | |
" X_tfidf = np.array(pickle.load(f))\n", | |
" \n", | |
"with open(\"data/sub-vector.dat\", \"rb\") as f:\n", | |
" X_sub = np.array(pickle.load(f)) \n", | |
"\n", | |
"X = np.hstack([X_tfidf, X_sub])" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "x0N1dotQ00Fb", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"yは目的変数.具体的には,クラウドソーシングで半分以上のワーカーが「面白い」と答えたシーンには1が,そうでないシーンには0が割り当てられている." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "QbsnO0FJ00Fc", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"testset_df = pd.read_csv('data/complete_testset.tsv', sep='\\t')\n", | |
"y = testset_df['is_interesting'].to_list()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "EDPRvuTi00Ff", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## 学習 & 評価\n", | |
"91動画に関する455シーン(うち66シーンが「面白い」と判定)をデータセットして用いて,面白シーン判定器を訓練・評価.\n", | |
"* 学習モデルはSVM with RBFカーネル.\n", | |
"* 適合率ベースで見ると「TF-IDF + 文末笑い表現」が一番強い" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "uCH0Bdhi00Fg", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"### 1. シーン内のコメント情報をTF-IDF値で重み付けした特徴ベクトル" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "EKyC76Cd00Fg", | |
"colab_type": "code", | |
"outputId": "df2b4e75-da12-42ce-aa74-f7a3525bb4d7", | |
"colab": {} | |
}, | |
"source": [ | |
"model = SVC(kernel='rbf', C=1e3, gamma=0.1)\n", | |
"\n", | |
"# 2-fold cross validation\n", | |
"# Xはベクトルのリスト,yは正解値のリスト\n", | |
"evaluate_with_cross_valiation(model, X_tfidf, y, n_splits=2)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{'fit_time': 0.17990326881408691,\n", | |
" 'score_time': 0.6486226320266724,\n", | |
" 'test_Precision': 0.5359217999126256,\n", | |
" 'test_Recall': 0.5010357340254248,\n", | |
" 'test_Accuracy': 0.8351978514568359,\n", | |
" 'test_F1': 0.48036346569892197}" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 96 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "9SrDGTLP00Fl", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"### 2. コメント文末の笑い記号に着目した特徴ベクトル" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "iEH4yE6V00Fl", | |
"colab_type": "code", | |
"outputId": "907fb7c6-f8fa-49e7-9091-180fd47b9133", | |
"colab": {} | |
}, | |
"source": [ | |
"# 2-fold cross validation\n", | |
"# Xはベクトルのリスト,yは正解値のリスト\n", | |
"evaluate_with_cross_valiation(model, X_sub, y, n_splits=2)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{'fit_time': 0.010805368423461914,\n", | |
" 'score_time': 0.0034672021865844727,\n", | |
" 'test_Precision': 0.5314755224894894,\n", | |
" 'test_Recall': 0.5158045562684739,\n", | |
" 'test_Accuracy': 0.8066504366643481,\n", | |
" 'test_F1': 0.5108522814833494}" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 97 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "TodrXNv-00Fo", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"### 3. 特徴1と特徴2の両方を考慮した特徴ベクトル" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "CCYJNAyT00Fo", | |
"colab_type": "code", | |
"outputId": "d8870561-2e97-4b9c-9f3b-509a6e161c36", | |
"colab": {} | |
}, | |
"source": [ | |
"# 2-fold cross validation\n", | |
"# Xはベクトルのリスト,yは正解値のリスト\n", | |
"evaluate_with_cross_valiation(model, X, y, n_splits=2)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{'fit_time': 0.17981934547424316,\n", | |
" 'score_time': 0.5898785591125488,\n", | |
" 'test_Precision': 0.5843308080808081,\n", | |
" 'test_Recall': 0.5048884963833418,\n", | |
" 'test_Accuracy': 0.8417864595409228,\n", | |
" 'test_F1': 0.4832971228631523}" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 99 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "-LoBgxs700Fr", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"### 4. 何が出てもすべて「面白いと答える」アルゴリズム" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Z1GvWvYZ00Fs", | |
"colab_type": "code", | |
"outputId": "c50492bd-085f-4323-a626-79f09796f915", | |
"colab": {} | |
}, | |
"source": [ | |
"# 正答率\n", | |
"sum(y) / len(y)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.14505494505494507" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 100 | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment