Skip to content

Instantly share code, notes, and snippets.

@secsilm
Last active June 4, 2019 08:28
Show Gist options
  • Save secsilm/cb6eb2f599d4b56478438e2e88e83917 to your computer and use it in GitHub Desktop.
Save secsilm/cb6eb2f599d4b56478438e2e88e83917 to your computer and use it in GitHub Desktop.
使用 TensorFlow Estimators 和 TensorFlow Hub 对酒店评论进行情绪分类
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "4cx-1AuMzYCX"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\secsi\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n"
]
}
],
"source": [
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"import tensorflow as tf\n",
"import tensorflow_hub as hub\n",
"from sklearn.model_selection import train_test_split\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "H7USxnbI7xLu"
},
"outputs": [],
"source": [
"sns.set(rc={'figure.figsize': (10, 8)})"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "WTDdU2f_2u3e"
},
"outputs": [],
"source": [
"pos = Path('pos.txt').read_text(encoding='utf8').split('\\n')\n",
"neg = Path('neg.txt').read_text(encoding='utf8').split('\\n')\n",
"pos = list(filter(None, pos))\n",
"neg = list(filter(None, neg))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"colab_type": "code",
"id": "maCqkZQl3FTe",
"outputId": "3e8df2f0-57b0-4773-967c-921d1f411e9a"
},
"outputs": [
{
"data": {
"text/plain": [
"['1 距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较为简单.',\n",
" '1 商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!']"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pos[:2]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 72
},
"colab_type": "code",
"id": "yIOUD4tA3e8O",
"outputId": "756abd5d-c6d4-4685-a0bb-7b1e31d22e8c"
},
"outputs": [
{
"data": {
"text/plain": [
"['-1 标准间太差房间还不如3星的而且设施非常陈旧.建议酒店把老的标准间从新改善.',\n",
" '-1 服务态度极其差,前台接待好象没有受过培训,连基本的礼貌都不懂,竟然同时接待几个客人;大堂副理更差,跟客人辩解个没完,要总经理的电话投诉竟然都不敢给。要是没有作什么亏心事情,跟本不用这么怕。']"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"neg[:2]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"colab_type": "code",
"id": "TDeVsJbg3geG",
"outputId": "ffcae6cc-9bc2-4391-c359-cff2cb9c1133"
},
"outputs": [
{
"data": {
"text/plain": [
"'积极样本数:7000,消极样本数:3000'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f'积极样本数:{len(pos)},消极样本数:{len(neg)}'"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 219
},
"colab_type": "code",
"id": "qUjjOpYR33K2",
"outputId": "5fd773c6-bab6-400b-dad8-30b099a5879e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(10000, 2)\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>comment</th>\n",
" <th>label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>早餐太差,无论去多少人,那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>宾馆在小街道上,不大好找,但还好北京热心同胞很多~宾馆设施跟介绍的差不多,房间很小,确实挺小...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" comment label\n",
"0 距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较... 1\n",
"1 商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错! 1\n",
"2 早餐太差,无论去多少人,那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。 1\n",
"3 宾馆在小街道上,不大好找,但还好北京热心同胞很多~宾馆设施跟介绍的差不多,房间很小,确实挺小... 1\n",
"4 CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风 1"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pos_text = list(map(lambda x: x[1:].strip(), pos))\n",
"pos_label = [1] * len(pos_text)\n",
"neg_text = list(map(lambda x: x[1:].strip(), neg))\n",
"neg_label = [0] * len(neg_text)\n",
"\n",
"dataset = pd.DataFrame({'comment': pos_text + neg_text, 'label': pos_label + neg_label}, columns=['comment', 'label'])\n",
"print(dataset.shape)\n",
"dataset.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 496
},
"colab_type": "code",
"id": "K1EUC3DV3_HQ",
"outputId": "7affdd86-4367-40c8-ca61-41e972b9f7ae"
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.axes._subplots.AxesSubplot at 0x1e2b9dcfc88>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAHNCAYAAAAt526PAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAF9VJREFUeJzt3F9s1Qf9//HXod2coyVAxAsyR0CHER1bakPUFIyLExNjNIoCM3qB0WTZnCSbwnDAyBgFza8XDp1Ts5spOvHf19+lg03GWGAhoq6ZM+rEOXRBwYx2czDO+d38xndI18LmO8eWx+PufM67Oe/PxTl59vM5baPVarUCAMB/1KR2LwAAMBGJLACAAiILAKCAyAIAKCCyAAAKdLZ7gX93+PCxdq/AODJt2sU5evTZdq8BTDA+WzhbM2Z0v+xzrmQxrnV2drR7BWAC8tnCf4LIAgAoILIAAAqILACAAiILAKCAyAIAKCCyAAAKiCwAgAIiCwCggMgCACggsgAACogsAIACIgsAoIDIAgAoILIAAAqILACAAiILAKBA51gDP/7xj/OTn/wkSfL888/nscceyz333JPbb789HR0d6evry/XXX59ms5lbb701jz/+eC688MJs3Lgxs2bNyoEDB86YBQCY6MaMrI985CP5yEc+kiTZsGFDPvrRj2b9+vW544478oY3vCGf/exnMzg4mKeeeirHjx/PvffemwMHDmTz5s258847R5x961vfWn5iAADtdNa3C3/zm9/k97//fT7wgQ/k+PHjufTSS9NoNNLX15eHH344+/fvz8KFC5MkV155ZR599NEMDQ2NOAsAMNGNeSXrRXfddVeuu+66DA0Npaur69TxyZMn58knnzzjeEdHx8vOjmbatIvT2dlxLudwXvjgjf/T7hUYR/7v//lQu1eAcW/GjO52r8A4d1aR9cwzz+SPf/xj3vGOd2RoaCjDw8OnnhseHs6UKVPyr3/967TjzWYzXV1dI86O5ujRZ8/1HIB/c/jwsXavAOPajBnd3kecldFi/KxuFz7yyCN517velSTp6urKBRdckD//+c9ptVrZvXt3ent709PTk127diVJDhw4kLlz577sLADARHdWV7KeeOKJXHLJJaceb9iwITfddFNOnjyZvr6+XHHFFbn88svz0EMPZdmyZWm1Wtm0adPLzgIATHSNVqvVavcSL+Xy7MhWbN7Z7hUYR+5efVW7V4Bxze1Cztarvl0IAMC5EVkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFCg82yG7rrrruzcuTMnTpzI8uXLs2DBgqxevTqNRiOXXXZZ1q9fn0mTJmXr1q154IEH0tnZmTVr1mT+/Pk5ePDgiLMAABPZmLWzd+/e/PKXv8z3vve93HPPPfnb3/6W/v7+rFy5Mtu2bUur1cqOHTsyODiYffv2Zfv27RkYGMiGDRuSZMRZAICJbswrWbt3787cuXNz3XXXZWhoKF/84hfzgx/8IAsWLEiSLFq0KA899FBmz56dvr6+NBqNzJw5MydPnsyRI0cyODh4xuzVV1/9sq83bdrF6ezs+A+dHpyfZszobvcKMO55H/FqjRlZR48ezaFDh/KNb3wjf/nLX3Lttdem1Wql0WgkSSZPnpxjx45laGgoU6dOPfVzLx4faXb013v21ZwPkOTw4dHfZ8DoZszo9j7irIwW42NG1tSpUzNnzpxceOGFmTNnTl7zmtfkb3/726nnh4eHM2XKlHR1dWV4ePi0493d3ad9/+rFWQCAiW7M72S9/e1vz4MPPphWq5Wnn346zz33XN75zndm7969SZJdu3alt7c3PT092b17d5rNZg4dOpRms5np06dn3rx5Z8wCAEx0Y17Jes973pNHHnkkS5YsSavVyrp163LJJZdk7dq1GRgYyJw5c7J48eJ0dHSkt7c3S5cuTbPZzLp165Ikq1atOmMWAGCia7RarVa7l3gp98BHtmLzznavwDhy9+qr2r0CjGu+k8XZGu07Wf5hFQBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFOg8m6EPf/jD6e7uTpJccsklWbp0aW6//fZ0dHSkr68v119/fZrNZm699dY8/vjjufDCC7Nx48bMmjUrBw4cOGMWAGCiGzOynn/++STJPffcc+rYhz70odxxxx15wxvekM9+9rMZHBzMU089lePHj+fee+/NgQMHsnnz5tx5551Zv379GbNvfetb684IAOC/wJiR9dvf/jbPPfdcVqxYkRdeeCGf+9zncvz48Vx66aVJkr6+vjz88MM5fPhwFi5cmCS58sor8+ijj2ZoaGjE2dEia9q0i9PZ2fGfODc4b82Y0d3uFWDc8z7i1Rozsi666KJ8+tOfzsc+9rH86U9/ymc+85lMmTLl1POTJ0/Ok08+maGhoXR1dZ063tHRccaxF2dHc/Tos6/kPICXOHz4WLtXgHFtxoxu7yPOymgxPmZkzZ49O7NmzUqj0cjs2bPT3d2df/7zn6eeHx4ezpQpU/Kvf/0rw8PDp443m810dXWdduzFWQCAiW7Mvy784Q9/mM2bNydJnn766Tz33HO5+OKL8+c//zmtViu7d+9Ob29venp6smvXriTJgQMHMnfu3HR1deWCCy44YxYAYKIb80rWkiVLcvPNN2f58uVpNBrZtGlTJk2alJtuuiknT55MX19frrjiilx++eV56KGHsmzZsrRarWzatClJsmHDhjNmAQAmukar1Wq1e4mXcg98ZCs272z3Cowjd6++qt0rwLjmO1mcrdG+k+WfkQIAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQIGziqx//OMfefe7350//OEPOXjwYJYvX55rrrkm69evT7PZTJJs3bo1S5YsybJly/LrX/86SV52FgBgohszsk6cOJF169bloosuSpL09/dn5cqV2bZtW1qtVnbs2JHBwcHs27cv27dvz8DAQDZs2PCyswAA54MxI2vLli1ZtmxZXv/61ydJBgcHs2DBgiTJokWLsmfPnuzfvz99fX1pNBqZOXNmTp48mSNHjow4CwBwPugc7ckf//jHmT59ehYuXJhvfvObSZJWq5VGo5EkmTx5co4dO5ahoaFMnTr11M+9eHyk2bFMm3ZxOjs7XvEJAcmMGd3tXgHGPe8jXq1RI+tHP/pRGo1GHn744Tz22GNZtWpVjhw5cur54eHhTJkyJV1dXRkeHj7teHd3dyZNmnTG7FiOHn32lZwH8BKHD4/9Cw3w8mbM6PY+4qyMFuOj3i787ne/m+985zu555578pa3vCVbtmzJokWLsnfv3iTJrl270tvbm56enuzevTvNZjOHDh1Ks9nM9OnTM2/evDNmAQDOB6NeyRrJqlWrsnbt2gwMDGTOnDlZvHhxOjo60tvbm6VLl6bZbGbdunUvOwsAcD5otFqtVruXeCmXZ0e2YvPOdq/AOHL36qvavQKMa24XcrZe8e1CAABeGZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFOgca+DkyZO55ZZb8sQTT6SjoyP9/f1ptVpZvXp1Go1GLrvssqxfvz6TJk3K1q1b88ADD6SzszNr1qzJ/Pnzc/DgwRFnAQAmsjFr5/7770+SfP/7388NN9yQ/v7+9Pf3Z+XKldm2bVtarVZ27NiRwcHB7Nu3L9u3b8/AwEA2bNiQJCPOAgBMdGNG1nvf+97cdtttSZJDhw7lda97XQYHB7NgwYIkyaJFi7Jnz57s378/fX19aTQamTlzZk6ePJkjR46MOAsAMNGNebswSTo7O7Nq1ar8/Oc/z1e/+tXcf//9aTQaSZLJkyfn2LFjGRoaytSpU0/9zIvHW63WGbOjmTbt4nR2drzS8wGSzJjR3e4VGCc+fu+17V6BceQHS+9s9wrjyllFVpJs2bIlN910Uz7+8Y/n+eefP3V8eHg4U6ZMSVdXV4aHh0873t3dfdr3r16cHc3Ro8+ey/7ACA4fHv2XGYBXwmfLmUb7pXbM24U//elPc9dddyVJXvva16bRaORtb3tb9u7dmyTZtWtXent709PTk927d6fZbObQoUNpNpuZPn165s2bd8YsAMBEN+aVrPe97325+eab84lPfCIvvPBC1qxZkze+8Y1Zu3ZtBgYGMmfOnCxevDgdHR3p7e3N0qVL02w2s27duiTJqlWrzpgFAJjoGq1Wq9XuJV7KpciRrdi8s90rMI7cvfqqdq/AOHHdzi+2ewXGka9d9eV2r/Bf51XdLgQA4NyJLACAAiILAKCAyAIAKCCyAAAKiCwAgAIiCwCggMgCACggsgAACogsAIACIgsAoIDIAgAoILIAAAqILACAAiILAKCAyAIAKCCyAAAKiCwAgAIiCwCggMgCACggsgAACogsAIACIgsAoIDIAgAoILIAAAqILACAAiILAKCAyAIAKCCyAAAKiCwAgAIiCwCggMgCACggsgAACogsAIACIgsAoIDIAgAoILIAAAqILACAAiILAKCAyAIAKCCyAAAKiCwAgAIiCwCggMgCACggsgAACogsAIACIgsAoIDIAgAoILIAAAqILACAAiILAKBA52hPnjhxImvWrMlTTz2V48eP59prr82b3vSmrF69Oo1GI5dddlnWr1+fSZMmZevWrXnggQfS2dmZNWvWZP78+Tl48OCIswAAE92oxfOzn/0sU6dOzbZt2/Ktb30rt912W/r7+7Ny5cps27YtrVYrO3bsyODgYPbt25ft27dnYGAgGzZsSJIRZwEAzgejRtb73//+fP7znz/1uKOjI4ODg1mwYEGSZNGiRdmzZ0/279+fvr6+NBqNzJw5MydPnsyRI0dGnAUAOB+Mertw8uTJSZKhoaHccMMNWblyZbZs2ZJGo3Hq+WPHjmVoaChTp0497eeOHTuWVqt1xuxYpk27OJ2dHa/4hIBkxozudq8ATEA+W87NqJGVJH/9619z3XXX5ZprrskHP/jBfOUrXzn13PDwcKZMmZKurq4MDw+fdry7u/u071+9ODuWo0efPddzAP7N4cNj/0IDcK58tpxptPAc9Xbh3//+96xYsSJf+MIXsmTJkiTJvHnzsnfv3iTJrl270tvbm56enuzevTvNZjOHDh1Ks9nM9OnTR5wFADgfjHol6xvf+EaeeeaZfP3rX8/Xv/71JMmXvvSlbNy4MQMDA5kzZ04WL16cjo6O9Pb2ZunSpWk2m1m3bl2SZNWqVVm7du1pswAA54NGq9VqtXuJl3IpcmQrNu9s9wqMI3evvqrdKzBOXLfzi+1egXHka1d9ud0r/Nd5xbcLAQB4ZUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUOCsIutXv/pVPvnJTyZJDh48mOXLl+eaa67J+vXr02w2kyRbt27NkiVLsmzZsvz6178edRYAYKIbM7K+9a1v5ZZbbsnzzz+fJOnv78/KlSuzbdu2tFqt7NixI4ODg9m3b1+2b9+egYGBbNiw4WVnAQDOB2NG1qWXXpo77rjj1OPBwcEsWLAgSbJo0aLs2bMn+/fvT19fXxqNRmbOnJmTJ0/myJEjI84CAJwPOscaWLx4cf7yl7+cetxqtdJoNJIkkydPzrFjxzI0NJSpU6eemnnx+EizY5k27eJ0dnac84kA/2vGjO52rwBMQD5bzs2YkfXvJk3634tfw8PDmTJlSrq6ujI8PHza8e7u7hFnx3L06LPnuhLwbw4fHvsXGoBz5bPlTKOF5zn/deG8efOyd+/eJMmuXbvS29ubnp6e7N69O81mM4cOHUqz2cz06dNHnAUAOB+c85WsVatWZe3atRkYGMicOXOyePHidHR0pLe3N0uXLk2z2cy6detedhYA4HzQaLVarXYv8VIuRY5sxead7V6BceTu1Ve1ewXGiet2frHdKzCOfO2qL7d7hf86/9HbhQAAjE1kAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAAZEFAFBAZAEAFBBZAAAFRBYAQAGRBQBQQGQBABQQWQAABUQWAEABkQUAUEBkAQAUEFkAAAVEFgBAgc7qF2g2m7n11lvz+OOP58ILL8zGjRsza9as6pcFAGir8itZ9913X44fP5577703N954YzZv3lz9kgAAbddotVqtyhfo7+/P/Pnz84EPfCBJsnDhwjz44IOVLwkA0HblV7KGhobS1dV16nFHR0deeOGF6pcFAGir8sjq6urK8PDwqcfNZjOdneVfBQMAaKvyyOrp6cmuXbuSJAcOHMjcuXOrXxIAoO3Kv5P14l8X/u53v0ur1cqmTZvyxje+sfIlAQDarjyyAADOR/4ZKQBAAZEFAFBAZAEAFBBZAPD/NZvNdq/ABOIfVgFwXnvyySfT39+fRx99NJ2dnWk2m5k7d25uvvnmzJ49u93rMY7560IAzmuf+tSncuONN+aKK644dezAgQPZvHlzvv/977dxM8Y7V7IYNz75yU/mxIkTpx1rtVppNBo+CIFX7Pjx46cFVpJceeWVbdqGiURkMW7cdNNNueWWW/K1r30tHR0d7V4HmCDe/OY35+abb87ChQvT3d2d4eHh/OIXv8ib3/zmdq/GOOd2IePKt7/97cyaNStXX311u1cBJohWq5X77rsv+/fvz9DQULq6utLT05Orr746jUaj3esxjoksAIAC/oUDAEABkQUAUEBkAQAUEFkAAAX+H89/TXfeU2sEAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 720x576 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dataset.label.value_counts().plot(kind='bar')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "4bK_bTB6773U"
},
"outputs": [],
"source": [
"train_df, test_df = train_test_split(dataset, test_size=0.3, random_state=73)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 640
},
"colab_type": "code",
"id": "pT0EBLjV4AEO",
"outputId": "a13e7843-52d1-4aa8-fe21-ae272db5d688"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABIcAAAJICAYAAAD7H4TAAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzs3W2QnXV9//HPyS4Iye5OkhKsKZAmglPD7cQtUl3SEcTFVqpiIDcIlSBWisG0BQMxN8QEEqquo4kBjfVJMBUitjKjU+tEMBOTJk6YQFmBWgtBCDrBRJLdQG72nP+DDlvShAT7l5yz+b1ejzjX/g75XjPnylzzzvU7W6nVarUAAAAAUKRB9R4AAAAAgPoRhwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDN9R4AOLosWLAgP/nJT5IkP//5z/MHf/AHOe6445Ik99xzT/9/H8qqVauybt26zJo163WdFQDgaPK7uA97Wa1Wy9VXX50vfelLaWtre13mBRpHpVar1eo9BHB0uuCCC/LFL34xZ555Zr1HAQAoyv/vfdi+ffty+umn5yc/+Yk4BAXw5BBwxJxxxhm58MIL8/jjj+dzn/tcnnjiidxzzz3Zu3dvXnjhhVx77bWZMmVKvv3tb+f73/9+vvKVr+TKK6/MOeeck4ceeijPPfdc/uRP/iTz58/PoEF2xQIAvFY/+9nPctttt2XHjh3p6+vLRz7ykXzwgx9MT09Pbrnlljz99NMZNGhQzjzzzMybNy+33HJLkuSKK67I1772tbzxjW+s8xkArydxCDhi9u7dm3e961354he/mN7e3ixYsCBf/epXM2zYsGzatClXX311pkyZcsD7nn766Sxfvjy7du3Ke9/73mzYsCHnnXdeHc4AAGDg2bt3bz75yU+mq6srf/RHf5QdO3bk8ssvz6mnnpr/+I//yJ49e/Kd73wn+/bty5w5c/LMM89k4cKFuf/++/ONb3zDk0NQAHEIOKLa29uTJEOGDMldd92VH/3oR3nqqafy+OOPZ9euXQd9z7ve9a4MGjQoLS0tGTVqVF544YUjOTIAwID285//PL/4xS8yY8aM/mN79uzJY489lvPOOy9f/OIXc9VVV+Ud73hHrrnmmpx88snZt29fHScGjjRxCDiiBg8enCT55S9/mYkTJ+byyy/P2972tlx88cV54IEHDvqeV355YqVSia9KAwB47arVaoYOHZrvfOc7/ce2bt2atra2vOENb8gPfvCDrF+/Pv/2b/+Wv/zLv8xtt92Wd77znXWcGDjSfGkHUBePPvpohg8fnr/+679OR0dHfxjq6+ur82QAAEeXU089NYMGDcp3v/vdJMmzzz6b973vfXn88cezfPnyzJ49O+eff34+9alP5bzzzstPf/rTNDU1pVKpeIIICiEOAXXxzne+M2984xtz8cUX573vfW+ee+65DB8+PJs3b673aAAAR5Vjjz02d955Z/7xH/8xl1xyST760Y/m7/7u73L22Wfngx/8YF566aX8+Z//eS699NLs3r07V1xxRSqVSt7znvdk8uTJ+fnPf17vUwBeZ36VPQAAAEDBPDkEAAAAUDBxCAAAAKBg4hAAAABAwcQhAAAAgIKJQwAAAAAFa673AP/b1q076z0CdTJs2OBs376r3mMAR5DrvlwjRrTWewT+F/dg5fJ3MZTHdV+mQ91/eXKIhtHc3FTvEYAjzHUPUH/+LobyuO7538QhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwZpfy6IPfOADaW1tTZKcdNJJmThxYm677bY0NTWlo6Mjn/jEJ1KtVnPrrbfmiSeeyLHHHpsFCxZk1KhR2bRp0wFrAQAAAGgMh41Du3fvTpIsX768/9j73//+LF68OCeffHI+9rGPpbu7O88++2z27NmTe+65J5s2bcqiRYty5513Zu7cuQesPf3001+/MwIAAADgNTtsHHr88cfz4osvZurUqdm3b1+mTZuWPXv25JRTTkmSdHR0ZN26ddm6dWvOP//8JMk555yTRx99ND09PQddKw4BAAAANIbDxqHjjjsu11xzTS677LI89dRTufbaa9PW1tb/8yFDhuQXv/hFenp60tLS0n+8qanpgGMvrz2UYcMGp7m56f9yLhwFRoxorfcIwBHmugcAgPo6bBwaPXp0Ro0alUqlktGjR6e1tTW/+c1v+n/e29ubtra2vPTSS+nt7e0/Xq1W09LSst+xl9ceyvbtu/4v58FRYMSI1mzdurPeYwBHkOu+XKIgAEDjOOxvK/vWt76VRYsWJUl+9atf5cUXX8zgwYPz9NNPp1arZc2aNWlvb8+4ceOyevXqJMmmTZvylre8JS0tLTnmmGMOWAsAAABAYzjsk0MTJkzILbfcksmTJ6dSqeT222/PoEGDcuONN6avry8dHR05++yzc+aZZ+bHP/5xJk2alFqtlttvvz1JMm/evAPWAgAAANAYKrVarVbvIV7J9oJy2V4C5XHdl8u2ssbjWiyXv4uhPK77Mh3q/uuw28oAAAAAOHqJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBgzfUegP1NXfTDeo9AnXz95gvqPQIAFMn9V7ncfwH8N08OAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwcQhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwcQhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwcQhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABSsud4DAACwv71792bmzJl59tlns2fPnlx33XU59dRTc/PNN6dSqeS0007L3LlzM2jQoCxZsiQPPvhgmpubM3PmzJx11lnZvHnzQdcCAByMuwQAgAZz//33Z+jQoVmxYkWWLVuW+fPnZ+HChZk+fXpWrFiRWq2WVatWpbu7Oxs2bMjKlSvT1dWVefPmJclB1wIAvBpxCACgwVx88cX55Cc/2f+6qakp3d3dOffcc5Mk48ePz9q1a7Nx48Z0dHSkUqlk5MiR6evry7Zt2w66FgDg1dhWBgDQYIYMGZIk6enpyQ033JDp06fnjjvuSKVS6f/5zp0709PTk6FDh+73vp07d6ZWqx2w9nCGDRuc5uam1+FsoHGNGNFa7xGgbnz+eSVxCACgAT333HO5/vrrM2XKlFxyySX57Gc/2/+z3t7etLW1paWlJb29vfsdb21t3e/7hV5eezjbt+/63Z4ADABbtx4+nMLRaMSIVp//Ah0qCNpWBgDQYJ5//vlMnTo1N910UyZMmJAkGTt2bNavX58kWb16ddrb2zNu3LisWbMm1Wo1W7ZsSbVazfDhww+6FgDg1XhyCACgwdx1113ZsWNHli5dmqVLlyZJPv3pT2fBggXp6urKmDFj0tnZmaamprS3t2fixImpVquZM2dOkmTGjBmZPXv2fmsBAF5NpVar1eo9xCuV/mjb1EU/rPcI1MnXb76g3iPAEeeR5nL5noPGU/K16P6rXO6/KJV7sDLZVgYAAADAQYlDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAAChYc70HAADg4B5++OF87nOfy/Lly/M3f/M3ef7555Mkzz77bM4+++x84QtfyMc//vH85je/yTHHHJM3vOEN+drXvpbNmzfn5ptvTqVSyWmnnZa5c+dm0CD/JggAHJw4BADQgJYtW5b7778/xx9/fJLkC1/4QpLkhRdeyFVXXZVbbrklSfL000/nu9/9biqVSv97Fy5cmOnTp+ftb3975syZk1WrVuWiiy468icBAAwI/gkJAKABnXLKKVm8ePEBxxcvXpwPf/jDOfHEE/P8889nx44d+fjHP57JkyfngQceSJJ0d3fn3HPPTZKMHz8+a9euPaKzAwADiyeHAAAaUGdnZ5555pn9jv3617/OunXr+p8a2rt3b6ZOnZqrrroqL7zwQiZPnpyzzjortVqt/0miIUOGZOfOnYf984YNG5zm5qbf/YlAAxsxorXeI0Dd+PzzSq8pDv3617/OpZdemq9//etpbm4+6B72JUuW5MEHH0xzc3NmzpyZs846y353AIDfoX/5l3/J+973vjQ1/XfEOeGEEzJp0qQ0Nzfn937v9/LWt741Tz755H73W729vWlrazvs/3v79l2v29zQqLZuPXw4haPRiBGtPv8FOlQQPGyp2bt3b+bMmZPjjjsuyf/sYV+xYkVqtVpWrVqV7u7ubNiwIStXrkxXV1fmzZv3qmsBAPi/WbduXcaPH9//eu3atZk+fXqS/45AP/vZzzJmzJiMHTs269evT5KsXr067e3tdZkXABgYDhuH7rjjjkyaNCknnnhikoPvYd+4cWM6OjpSqVQycuTI9PX1Zdu2bfa7AwD8Dj355JM5+eST+1//6Z/+aUaNGpXLL78811xzTf72b/82w4cPz4wZM7J48eJMnDgxe/fuTWdnZx2nBgAa3SG3lX3729/O8OHDc/755+erX/1qkhx0D3tPT0+GDh3a/76Xj9vvDq+dPb+UymcfXt1JJ52Ue++9t//1d7/73QPWfPrTnz7g2OjRo3P33Xe/rrMBAEePQ8ah++67L5VKJevWrctjjz2WGTNmZNu2bf0/f3kPe0tLS3p7e/c73traar87/Bbs+aVE9ruXSxQEAGgch9xW9o1vfCN33313li9fnre+9a254447Mn78+AP2sI8bNy5r1qxJtVrNli1bUq1WM3z4cPvdAQAAABrcb/2r7GfMmJHZs2enq6srY8aMSWdnZ5qamtLe3p6JEyemWq1mzpw5r7oWAAAAgMZRqdVqtXoP8Uqlby+YuuiH9R6BOvn6zRfUewQ44mwrK5dtZY2n5GvR/Ve53H9RKvdgZfr/+lX2AAAAABy9xCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAgAb18MMP58orr0ySdHd35/zzz8+VV16ZK6+8Mt/73veSJEuWLMmECRMyadKkPPLII0mSzZs3Z/LkyZkyZUrmzp2barVat3MAABpfc70HAADgQMuWLcv999+f448/Pkny05/+NFdffXWmTp3av6a7uzsbNmzIypUr89xzz2XatGm57777snDhwkyfPj1vf/vbM2fOnKxatSoXXXRRvU4FAGhwnhwCAGhAp5xyShYvXtz/+tFHH82DDz6YK664IjNnzkxPT082btyYjo6OVCqVjBw5Mn19fdm2bVu6u7tz7rnnJknGjx+ftWvX1us0AIABwJNDAAANqLOzM88880z/67POOiuXXXZZzjjjjNx555358pe/nNbW1gwdOrR/zZAhQ7Jz587UarVUKpX9jh3OsGGD09zc9Ls/EWhgI0a01nsEqBuff15JHAIAGAAuuuiitLW19f/3/Pnzc+GFF6a3t7d/TW9vb1pbWzNo0KD9jr38vkPZvn3X735oaHBbtx4+nMLRaMSIVp//Ah0qCNpWBgAwAFxzzTX9Xzi9bt26nH766Rk3blzWrFmTarWaLVu2pFqtZvjw4Rk7dmzWr1+fJFm9enXa29vrOToA0OA8OQQAMADceuutmT9/fo455piccMIJmT9/flpaWtLe3p6JEyemWq1mzpw5SZIZM2Zk9uzZ6erqypgxY9LZ2Vnn6QGARlap1Wq1eg/xSqU/2jZ10Q/rPQJ18vWbL6j3CHDEeaS5XL7noPGUfC26/yqX+y9K5R6sTLaVAQAAAHBQ4hAAAABAwcQhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwcQhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwcQhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwcQhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwcQhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBgzYdb0NfXl1mzZuXJJ59MU1NTFi5cmFqtlptvvjmVSiWnnXZa5s6dm0GDBmXJkiV58MEH09zcnJkzZ+ass87K5s2bD7oWAAAAgPo7bKV54IEHkiTf/OY3c8MNN2ThwoVZuHBhpk+fnhUrVqRWq2XVqlXp7u7Ohg0bsnLlynR1dWXevHlJctC1AAAAADSGw8ahd7/73Zk/f36SZMuWLTnhhBPS3d2dc889N0kyfvz4rF27Nhs3bkxHR0cqlUpGjhyZvr6+bNu27aBrAQAAAGgMh91WliTNzc2ZMWNGfvCDH+RLX/pSHnjggVQqlSTJkCFDsnPnzvT09GTo0KH973n5eK1WO2DtoQwbNjjNzU3/1/OBAWvEiNZ6jwB14bMPAAD19ZriUJLccccdufHGG3P55Zdn9+7d/cd7e3vT1taWlpaW9Pb27ne8tbV1v+8XenntoWzfvuu3mR+OGlu3HjqcwtFoxIhWn/1CiYIAAI3jsNvK/vmf/zlf+cpXkiTHH398KpVKzjjjjKxfvz5Jsnr16rS3t2fcuHFZs2ZNqtVqtmzZkmq1muHDh2fs2LEHrAUAAACgMRz2yaH3vOc9ueWWW3LFFVdk3759mTlzZt785jdn9uzZ6erqypgxY9LZ2Zmmpqa0t7dn4sSJqVarmTNnTpJkxowZB6wFAAAAoDFUarVard5DvFLp2wumLvphvUegTr5+8wX1HgGOONvKymVbWeMp+Vp0/1Uu91+Uyj1YmQ51/3XYbWUAAAAAHL3EIQAAAICCiUMAAAAABROHAAAAAAomDgEAAAAU7LC/yh4AgPp4+OGH87nPfS7Lly/PY489lvnz56epqSnHHnts7rjjjpxwwglZsGBBHnrooQwZMiRJsnTp0uzduzc33nhjXnrppZx44olZuHBhjj/++DqfDQDQqDw5BADQgJYtW5ZZs2Zl9+7dSZLbbrsts2fPzvLly3PRRRdl2bJlSZLu7u587Wtfy/Lly7N8+fK0trZm6dKled/73pcVK1Zk7Nixueeee+p5KgBAg/PkEABAAzrllFOyePHifOpTn0qSdHV15cQTT0yS9PX15Q1veEOq1Wo2b96cOXPm5Pnnn8+ECRMyYcKEbNy4MX/1V3+VJBk/fny6urrykY985JB/3rBhg9Pc3PS6nhM0mhEjWus9AtSNzz+vJA4BADSgzs7OPPPMM/2vXw5DDz30UO6+++584xvfyK5du/LhD384V199dfr6+nLVVVfljDPOSE9PT1pb//umf8iQIdm5c+dh/7zt23e9PicCDWzr1sNfG3A0GjGi1ee/QIcKguIQAMAA8b3vfS933nlnvvrVr2b48OH9Qejl7xM677zz8vjjj6elpSW9vb057rjj0tvbm7a2tjpPDgA0Mt85BAAwAHznO9/J3XffneXLl+fkk09Okjz11FOZMmVK+vr6snfv3jz00EM5/fTTM27cuPzoRz9KkqxevTpve9vb6jk6ANDgPDkEANDg+vr6ctttt+VNb3pTpk2bliT54z/+49xwww255JJLcvnll+eYY47J+9///px22mm57rrrMmPGjNx7770ZNmxYPv/5z9f5DACARiYOAQA0qJNOOin33ntvkmTDhg0HXXPttdfm2muv3e/YCSeckH/4h3943ecDAI4OtpUBAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgolDAAAAAAUThwAAAAAKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAgAb18MMP58orr0ySbN68OZMnT86UKVMyd+7cVKvVJMmSJUsyYcKETJo0KY888sgh1wIAHIw4BADQgJYtW5ZZs2Zl9+7dSZKFCxdm+vTpWbFiRWq1WlatWpXu7u5s2LAhK1euTFdXV+bNm/eqawEAXo04BADQgE455ZQsXry4/3V3d3fOPffcJMn48eOzdu3abNy4MR0dHalUKhk5cmT6+vqybdu2g64FAHg1zfUeAACAA3V2duaZZ57pf12r1VKpVJIkQ4YMyc6dO9PT05OhQ4f2r3n5+MHWHs6wYYPT3Nz0Oz4LaGwjRrTWewSoG59/XkkcAgAYAAYN+p8Hvnt7e9PW1paWlpb09vbud7y1tfWgaw9n+/Zdv9uBYQDYuvXw4RSORiNGtPr8F+hQQdC2MgCAAWDs2LFZv359kmT16tVpb2/PuHHjsmbNmlSr1WzZsiXVajXDhw8/6FoAgFfjySEAgAFgxowZmT17drq6ujJmzJh0dnamqakp7e3tmThxYqrVaubMmfOqawEAXk2lVqvV6j3EK5X+aNvURT+s9wjUyddvvqDeI8AR55Hmcvmeg8ZT8rXo/qtc7r8olXuwMtlWBgAAAMBBiUMAAAAABROHAAAAAAomDgEAAAAUTBwCAAAAKJg4BAAAAFAwcQgAAACgYOIQAAAAQMHEIQAAAICCiUMAAAAABROHAAAAAAomDgEAAAAUTBwCAAAAKJg4BAAAAFAwcQgAAACgYM31HgCgZNf/8FP1HoE6+fIFf1/vEQAAIIknhwAAAACKJg4BAAAAFEwcAgAAACiYOAQAAABQMHEIAAAAoGDiEAAAAEDBxCEAAACAgjXXewAAAAA4kq7/4afqPQJ18uUL/r7eIzSkQ8ahvXv3ZubMmXn22WezZ8+eXHfddTn11FNz8803p1Kp5LTTTsvcuXMzaNCgLFmyJA8++GCam5szc+bMnHXWWdm8efNB1wIAAADQGA5Zau6///4MHTo0K1asyLJlyzJ//vwsXLgw06dPz4oVK1Kr1bJq1ap0d3dnw4YNWblyZbq6ujJv3rwkOehaAAAAABrHIePQxRdfnE9+8pP9r5uamtLd3Z1zzz03STJ+/PisXbs2GzduTEdHRyqVSkaOHJm+vr5s27btoGsBAAAAaByH3FY2ZMiQJElPT09uuOGGTJ8+PXfccUcqlUr/z3fu3Jmenp4MHTp0v/ft3LkztVrtgLWHM2zY4DQ3N/2fTwgGqhEjWus9AnAEueYBAGgUh/1C6ueeey7XX399pkyZkksuuSSf/exn+3/W29ubtra2tLS0pLe3d7/jra2t+32/0MtrD2f79l2/7TnAUWHr1sPHU+DoUfo1L44BADSOQ24re/755zN16tTcdNNNmTBhQpJk7NixWb9+fZJk9erVaW9vz7hx47JmzZpUq9Vs2bIl1Wo1w4cPP+haAAAAABrHIZ8cuuuuu7Jjx44sXbo0S5cuTZJ8+tOfzoIFC9LV1ZUxY8aks7MzTU1NaW9vz8SJE1OtVjNnzpwkyYwZMzJ79uz91gIAAADQOA4Zh2bNmpVZs2YdcPzuu+8+4Ni0adMybdq0/Y6NHj36oGsBAAAAaAyH3FYGAAAAwNFNHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwcQhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwcQhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwZrrPQAAAK/Nt7/97fzTP/1TkmRrB5hKAAAOiUlEQVT37t157LHH8vnPfz5///d/nze96U1JkmnTpqW9vT233nprnnjiiRx77LFZsGBBRo0aVc/RAYAGJg4BAAwQl156aS699NIkybx58/KhD30o3d3duemmm9LZ2dm/7l//9V+zZ8+e3HPPPdm0aVMWLVqUO++8s15jAwANzrYyAIAB5t///d/zn//5n5k4cWK6u7tz3333ZcqUKVm0aFH27duXjRs35vzzz0+SnHPOOXn00UfrPDEA0Mg8OQQAMMB85StfyfXXX58keec735l3v/vdOemkkzJ37tx885vfTE9PT1paWvrXNzU1Zd++fWlufvVbv2HDBqe5uel1nx0ayYgRrfUeATjCXPcHJw4BAAwgO3bsyH/913/lvPPOS5J86EMfSltbW5LkwgsvzPe///20tramt7e3/z3VavWQYShJtm/f9foNDQ1q69ad9R4BOMJKvu4PFcZsKwMAGEB+8pOf5B3veEeSpFar5S/+4i/yy1/+Mkmybt26nH766Rk3blxWr16dJNm0aVPe8pa31G1eAKDxeXIIAGAAefLJJ3PSSSclSSqVShYsWJBPfOITOe644/LmN785l19+eZqamvLjH/84kyZNSq1Wy+23317nqQGARiYOAQAMIB/96Ef3e93R0ZGOjo4D1n3mM585UiMBAAOcbWUAAAAABROHAAAAAAomDgEAAAAUTBwCAAAAKJg4BAAAAFAwcQgAAACgYOIQAAAAQMHEIQAAAICCiUMAAAAABROHAAAAAAomDgEAAAAUTBwCAAAAKJg4BAAAAFAwcQgAAACgYOIQAAAAQMHEIQAAAICCiUMAAAAABROHAAAAAAomDgEAAAAUTBwCAAAAKJg4BAAAAFAwcQgAAACgYOIQAAAAQMHEIQAAAICCiUMAAAAABROHAAAAAAomDgEAAAAUTBwCAAAAKJg4BAAAAFAwcQgAAACgYOIQAAAAQMHEIQAAAICCiUMAAAAABROHAAAAAAomDgEAAAAUTBwCAAAAKJg4BAAAAFAwcQgAAACgYOIQAAAAQMHEIQAAAICCiUMAAAAABROHAAAAAAomDgEAAAAUTBwCAAAAKJg4BAAAAFAwcQgAAACgYOIQAAAAQMHEIQAAAICCiUMAAAAABROHAAAAAAomDgEAAAAUTBwCAAAAKJg4BAAAAFCw5noPAADAa/eBD3wgra2tSZKTTjopEydOzG233ZampqZ0dHTkE5/4RKrVam699dY88cQTOfbYY7NgwYKMGjWqzpMDAI1KHAIAGCB2796dJFm+fHn/sfe///1ZvHhxTj755HzsYx9Ld3d3nn322ezZsyf33HNPNm3alEWLFuXOO++s19gAQIMThwAABojHH388L774YqZOnZp9+/Zl2rRp2bNnT0455ZQkSUdHR9atW5etW7fm/PPPT5Kcc845efTRR+s5NgDQ4MQhAIAB4rjjjss111yTyy67LE899VSuvfbatLW19f98yJAh+cUvfpGenp60tLT0H29qasq+ffvS3Pzqt37Dhg1Oc3PT6zo/NJoRI1rrPQJwhLnuD04cAgAYIEaPHp1Ro0alUqlk9OjRaW1tzW9+85v+n/f29qatrS0vvfRSent7+49Xq9VDhqEk2b591+s2NzSqrVt31nsE4Agr+bo/VBjz28oAAAaIb33rW1m0aFGS5Fe/+lVefPHFDB48OE8//XRqtVrWrFmT9vb2jBs3LqtXr06SbNq0KW95y1vqOTYA0OA8OQQAMEBMmDAht9xySyZPnpxKpZLbb789gwYNyo033pi+vr50dHTk7LPPzplnnpkf//jHmTRpUmq1Wm6//fZ6jw4ANDBxCABggDj22GPz+c9//oDj9957736vBw0alM985jNHaiwAYICzrQwAAACgYOIQAAAAQMHEIQAAAICCiUMAAAAABROHAAAAAAomDgEAAAAUTBwCAAAAKJg4BAAAAFCw1xSHHn744Vx55ZVJks2bN2fy5MmZMmVK5s6dm2q1miRZsmRJJkyYkEmTJuWRRx455FoAAAAAGsNh49CyZcsya9as7N69O0mycOHCTJ8+PStWrEitVsuqVavS3d2dDRs2ZOXKlenq6sq8efNedS0AAAAAjeOwceiUU07J4sWL+193d3fn3HPPTZKMHz8+a9euzcaNG9PR0ZFKpZKRI0emr68v27ZtO+haAAAAABpH8+EWdHZ25plnnul/XavVUqlUkiRDhgzJzp0709PTk6FDh/avefn4wdYezrBhg9Pc3PRbnwgMdCNGtNZ7BOAIcs0DANAoDhuH/rdBg/7nYaPe3t60tbWlpaUlvb29+x1vbW096NrD2b591287EhwVtm49fDwFjh6lX/PiGABA4/itf1vZ2LFjs379+iTJ6tWr097ennHjxmXNmjWpVqvZsmVLqtVqhg8fftC1AAAAADSO3/rJoRkzZmT27Nnp6urKmDFj0tnZmaamprS3t2fixImpVquZM2fOq64FAAAAoHG8pjh00kkn5d57702SjB49OnffffcBa6ZNm5Zp06btd+zV1gIAAADQGH7rbWUAAAAAHD3EIQAAAICCiUMAAAAABROHAAAAAAomDgEAAAAUTBwCAAAAKJg4BAAAAFAwcQgAAACgYOIQAAAAQMHEIQAAAICCiUMAAAAABROHAAAAAAomDgEAAAAUTBwCAAAAKJg4BAAAAFAwcQgAAACgYOIQAAAAQMHEIQAAAICCiUMAAAAABROHAAAAAAomDgEAAAAUTBwCAAAAKJg4BAAAAFAwcQgAAACgYOIQAAAAQMHEIQAAAICCiUMAAAAABROHAAAAAAomDgEAAAAUTBwCAAAAKJg4BAAAAFAwcQgAAACgYOIQAAAAQMHEIQAAAICCiUMAAAAABROHAAAAAAomDgEAAAAUrLneAwAA8Nrs3bs3M2fOzLPPPps9e/bkuuuuy+///u/n4x//eP7wD/8wSTJ58uT82Z/9WZYsWZIHH3wwzc3NmTlzZs4666z6Dg8ANCxxCABggLj//vszdOjQfPazn8327dvzwQ9+MNdff32uvvrqTJ06tX9dd3d3NmzYkJUrV+a5557LtGnTct9999VxcgCgkYlDAAADxMUXX5zOzs7+101NTXn00Ufz5JNPZtWqVRk1alRmzpyZjRs3pqOjI5VKJSNHjkxfX1+2bduW4cOH13F6AKBRiUMAAAPEkCFDkiQ9PT254YYbMn369OzZsyeXXXZZzjjjjNx555358pe/nNbW1gwdOnS/9+3cufOQcWjYsMFpbm563c8BGsmIEa31HgE4wlz3BycOAQAMIM8991yuv/76TJkyJZdcckl27NiRtra2JMlFF12U+fPn58ILL0xvb2//e3p7e9Paeuib4e3bd72uc0Mj2rp1Z71HAI6wkq/7Q4Uxv60MAGCAeP755zN16tTcdNNNmTBhQpLkmmuuySOPPJIkWbduXU4//fSMGzcua9asSbVazZYtW1KtVm0pAwBelSeHAAAGiLvuuis7duzI0qVLs3Tp0iTJzTffnNtvvz3HHHNMTjjhhMyfPz8tLS1pb2/PxIkTU61WM2fOnDpPDgA0MnEIAGCAmDVrVmbNmnXA8W9+85sHHJs2bVqmTZt2JMYCAAY428oAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwcQhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwcQhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwcQhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwcQhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAAAomDgEAAAAUDBxCAAAAKBg4hAAAABAwcQhAAAAgIKJQwAAAAAFE4cAAAAACiYOAQAAABRMHAIAAID/194d6jQSRWEAPk0b1CDRDSFpJU19ZfXKVfASxTRBIAjUrCRreABYu7YGfJOKGrC8QmsKmVm1JCS7mxF7GcL9PjdjzlGTP3/uzUDGlEMAAAAAGVMOAQAAAGRMOQQAAACQMeUQAAAAQMaUQwAAAAAZUw4BAAAAZEw5BAAAAJAx5RAAAABAxpRDAAAAABlTDgEAAABkrJN6QFmWcXZ2Fg8PD7GzsxPn5+fR7XZTjwUAyJoMBgDUlfzk0Hw+j+12G7e3tzGZTGI2m6UeCQCQPRkMAKgreTm0WCxiNBpFRMRgMIjVapV6JABA9mQwAKCu5NfK1ut1FEXx+txut+Pl5SU6nT+P3tvbTb3Sh/bz25emVwDe0Y+v35teAfikZLD65C/IjwwGbyU/OVQURWw2m9fnsiz/GkoAAPg/ZDAAoK7k5dBwOIz7+/uIiFgul9Hr9VKPBADIngwGANTVqqqqSjng958yHh8fo6qquLi4iIODg5QjAQCyJ4MBAHUlL4cAAAAA+LiSXysDAAAA4ONSDgEAAABkTDkEAAAAkDHlEADvrizLplcAAMiK/MW/dJpeAIA8PD09xeXlZaxWq+h0OlGWZfR6vZhOp7G/v9/0egAAn478RV3+VgbAuzg+Po7JZBKHh4ev75bLZcxms7i5uWlwMwCAz0n+oi4nh2jM0dFRPD8/v3lXVVW0Wi0fKviEttvtm2ASETEYDBraBiBP8hfkRf6iLuUQjTk5OYnT09O4urqKdrvd9DpAYv1+P6bTaYxGo9jd3Y3NZhN3d3fR7/ebXg0gG/IX5EX+oi7XymjU9fV1dLvdGI/HTa8CJFZVVczn81gsFrFer6MoihgOhzEej6PVajW9HkA25C/Ih/xFXcohAAAAgIz5lT0AAABAxpRDAAAAABlTDgEAAABkTDkEAAAAkDHlEAAAAEDGfgFflNXtQxPLUAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 1440x720 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(20, 10))\n",
"plt.subplot(121)\n",
"train_df.label.value_counts().plot(kind='bar')\n",
"plt.title('Train')\n",
"plt.subplot(122)\n",
"test_df.label.value_counts().plot(kind='bar')\n",
"plt.title('Test');"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Hx9D6FjD8YWl"
},
"outputs": [],
"source": [
"# 定义训练输入函数\n",
"train_input_fn = tf.estimator.inputs.pandas_input_fn(\n",
" train_df, train_df[\"label\"], num_epochs=50, shuffle=True)\n",
"\n",
"# 用于在整个训练集上进行测试的输入函数\n",
"predict_train_input_fn = tf.estimator.inputs.pandas_input_fn(\n",
" train_df, train_df[\"label\"], shuffle=False)\n",
"\n",
"# 用于在整个测试集上进行测试的输入函数\n",
"predict_test_input_fn = tf.estimator.inputs.pandas_input_fn(\n",
" test_df, test_df[\"label\"], shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "wwbRQwmk921z"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using F:\\tfhubcache to cache modules.\n"
]
}
],
"source": [
"# 定义特征列,这里使用 TensorFlow Hub 来构建\n",
"embedded_text_feature_column = hub.text_embedding_column(\n",
" key=\"comment\", \n",
" module_spec=\"https://tfhub.dev/google/nnlm-zh-dim128/1\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 72
},
"colab_type": "code",
"id": "DPBWsepd972D",
"outputId": "0738a027-5f50-48af-bb7d-87fcb33e0871"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using default config.\n",
"INFO:tensorflow:Using config: {'_model_dir': 'models', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x000001E2BA66E438>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n"
]
}
],
"source": [
"# 初始化内置的 DNNClassifier\n",
"estimator = tf.estimator.DNNClassifier(\n",
" hidden_units=[100, 100],\n",
" feature_columns=[embedded_text_feature_column],\n",
" n_classes=2,\n",
" optimizer=tf.train.AdagradOptimizer(learning_rate=0.003),\n",
" model_dir='models')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1179
},
"colab_type": "code",
"id": "3RsI8Kkb-SQL",
"outputId": "bc31cf37-e92c-4fcc-e6bb-4f8d3c04711e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n",
"INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Create CheckpointSaverHook.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Restoring parameters from models\\model.ckpt-2735\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"INFO:tensorflow:Saving checkpoints for 2735 into models\\model.ckpt.\n",
"INFO:tensorflow:loss = 0.06739878, step = 2736\n",
"INFO:tensorflow:global_step/sec: 337.655\n",
"INFO:tensorflow:loss = 0.034682423, step = 2836 (0.311 sec)\n",
"INFO:tensorflow:global_step/sec: 485.717\n",
"INFO:tensorflow:loss = 0.04515608, step = 2936 (0.185 sec)\n",
"INFO:tensorflow:global_step/sec: 543.789\n",
"INFO:tensorflow:loss = 0.028866358, step = 3036 (0.184 sec)\n",
"INFO:tensorflow:global_step/sec: 546.763\n",
"INFO:tensorflow:loss = 0.04221748, step = 3136 (0.183 sec)\n",
"INFO:tensorflow:global_step/sec: 546.761\n",
"INFO:tensorflow:loss = 0.10310146, step = 3236 (0.183 sec)\n",
"INFO:tensorflow:global_step/sec: 543.79\n",
"INFO:tensorflow:loss = 0.046141706, step = 3336 (0.184 sec)\n",
"INFO:tensorflow:global_step/sec: 510.497\n",
"INFO:tensorflow:loss = 0.056765363, step = 3436 (0.277 sec)\n",
"INFO:tensorflow:global_step/sec: 363.79\n",
"INFO:tensorflow:loss = 0.038011882, step = 3536 (0.194 sec)\n",
"INFO:tensorflow:global_step/sec: 507.906\n",
"INFO:tensorflow:loss = 0.03202858, step = 3636 (0.197 sec)\n",
"INFO:tensorflow:global_step/sec: 495.333\n",
"INFO:tensorflow:loss = 0.039412722, step = 3736 (0.203 sec)\n",
"INFO:tensorflow:global_step/sec: 543.792\n",
"INFO:tensorflow:loss = 0.045118544, step = 3836 (0.183 sec)\n",
"INFO:tensorflow:global_step/sec: 488.084\n",
"INFO:tensorflow:loss = 0.03442596, step = 3936 (0.205 sec)\n",
"INFO:tensorflow:global_step/sec: 505.341\n",
"INFO:tensorflow:loss = 0.025132883, step = 4036 (0.198 sec)\n",
"INFO:tensorflow:global_step/sec: 549.765\n",
"INFO:tensorflow:loss = 0.023409069, step = 4136 (0.182 sec)\n",
"INFO:tensorflow:global_step/sec: 495.335\n",
"INFO:tensorflow:loss = 0.043221936, step = 4236 (0.203 sec)\n",
"INFO:tensorflow:global_step/sec: 471.968\n",
"INFO:tensorflow:loss = 0.03103593, step = 4336 (0.211 sec)\n",
"INFO:tensorflow:global_step/sec: 469.753\n",
"INFO:tensorflow:loss = 0.022563316, step = 4436 (0.213 sec)\n",
"INFO:tensorflow:global_step/sec: 546.761\n",
"INFO:tensorflow:loss = 0.021799732, step = 4536 (0.183 sec)\n",
"INFO:tensorflow:global_step/sec: 523.862\n",
"INFO:tensorflow:loss = 0.02162357, step = 4636 (0.191 sec)\n",
"INFO:tensorflow:global_step/sec: 523.861\n",
"INFO:tensorflow:loss = 0.018305503, step = 4736 (0.191 sec)\n",
"INFO:tensorflow:global_step/sec: 415.176\n",
"INFO:tensorflow:loss = 0.021797787, step = 4836 (0.241 sec)\n",
"INFO:tensorflow:global_step/sec: 483.368\n",
"INFO:tensorflow:loss = 0.039814856, step = 4936 (0.207 sec)\n",
"INFO:tensorflow:global_step/sec: 452.748\n",
"INFO:tensorflow:loss = 0.019658014, step = 5036 (0.221 sec)\n",
"INFO:tensorflow:global_step/sec: 505.341\n",
"INFO:tensorflow:loss = 0.019889731, step = 5136 (0.198 sec)\n",
"INFO:tensorflow:global_step/sec: 513.114\n",
"INFO:tensorflow:loss = 0.02168344, step = 5236 (0.196 sec)\n",
"INFO:tensorflow:global_step/sec: 490.478\n",
"INFO:tensorflow:loss = 0.022522777, step = 5336 (0.203 sec)\n",
"INFO:tensorflow:global_step/sec: 523.86\n",
"INFO:tensorflow:loss = 0.03371704, step = 5436 (0.191 sec)\n",
"INFO:tensorflow:Saving checkpoints for 5470 into models\\model.ckpt.\n",
"INFO:tensorflow:Loss for final step: 0.004834718.\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.estimator.canned.dnn.DNNClassifier at 0x1e2b9db8550>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 开始训练\n",
"estimator.train(input_fn=train_input_fn)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 523
},
"colab_type": "code",
"id": "gLgdyN4n-dUU",
"outputId": "30fa242d-7ff4-4343-9590-558a258c42fc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n",
"INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n",
"WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to \"careful_interpolation\" instead.\n",
"WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to \"careful_interpolation\" instead.\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Starting evaluation at 2018-08-23-05:13:11\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Restoring parameters from models\\model.ckpt-5470\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"INFO:tensorflow:Finished evaluation at 2018-08-23-05:13:16\n",
"INFO:tensorflow:Saving dict for global step 5470: accuracy = 1.0, accuracy_baseline = 0.69757146, auc = 1.0, auc_precision_recall = 1.0, average_loss = 0.0001732409, global_step = 5470, label/mean = 0.69757146, loss = 0.022048842, precision = 1.0, prediction/mean = 0.6975434, recall = 1.0\n",
"INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5470: models\\model.ckpt-5470\n",
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n",
"INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n",
"WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to \"careful_interpolation\" instead.\n",
"WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to \"careful_interpolation\" instead.\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Starting evaluation at 2018-08-23-05:13:18\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Restoring parameters from models\\model.ckpt-5470\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"INFO:tensorflow:Finished evaluation at 2018-08-23-05:13:20\n",
"INFO:tensorflow:Saving dict for global step 5470: accuracy = 1.0, accuracy_baseline = 0.70566666, auc = 1.0, auc_precision_recall = 1.0, average_loss = 0.00036637738, global_step = 5470, label/mean = 0.70566666, loss = 0.045797173, precision = 1.0, prediction/mean = 0.7054589, recall = 1.0\n",
"INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5470: models\\model.ckpt-5470\n"
]
}
],
"source": [
"# 在训练集和测试集上进行测试\n",
"train_eval_result = estimator.evaluate(input_fn=predict_train_input_fn)\n",
"test_eval_result = estimator.evaluate(input_fn=predict_test_input_fn)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 208
},
"colab_type": "code",
"id": "eUS3ATyY_KR8",
"outputId": "4ba428be-3beb-4fec-ffa1-589e9c2d535b"
},
"outputs": [
{
"data": {
"text/plain": [
"{'accuracy': 1.0,\n",
" 'accuracy_baseline': 0.69757146,\n",
" 'auc': 1.0,\n",
" 'auc_precision_recall': 1.0,\n",
" 'average_loss': 0.0001732409,\n",
" 'label/mean': 0.69757146,\n",
" 'loss': 0.022048842,\n",
" 'precision': 1.0,\n",
" 'prediction/mean': 0.6975434,\n",
" 'recall': 1.0,\n",
" 'global_step': 5470}"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_eval_result"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 208
},
"colab_type": "code",
"id": "-ijy0go0_OKU",
"outputId": "320f52a0-afee-444d-8cff-9bfda12cd2c7"
},
"outputs": [
{
"data": {
"text/plain": [
"{'accuracy': 1.0,\n",
" 'accuracy_baseline': 0.70566666,\n",
" 'auc': 1.0,\n",
" 'auc_precision_recall': 1.0,\n",
" 'average_loss': 0.00036637738,\n",
" 'label/mean': 0.70566666,\n",
" 'loss': 0.045797173,\n",
" 'precision': 1.0,\n",
" 'prediction/mean': 0.7054589,\n",
" 'recall': 1.0,\n",
" 'global_step': 5470}"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_eval_result"
]
}
],
"metadata": {
"colab": {
"name": "tensorflowhub-share-ppt.ipynb",
"provenance": [],
"version": "0.3.2"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment