Skip to content

Instantly share code, notes, and snippets.

@yamanetoshi
Created October 21, 2018 04:29
Show Gist options
  • Save yamanetoshi/007442f101d13c0c0b68db2495eb9adb to your computer and use it in GitHub Desktop.
Save yamanetoshi/007442f101d13c0c0b68db2495eb9adb to your computer and use it in GitHub Desktop.
Copy of movie-descriptions-tfhub.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Copy of movie-descriptions-tfhub.ipynb",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"[View in Colaboratory](https://colab.research.google.com/gist/yamanetoshi/007442f101d13c0c0b68db2495eb9adb/copy-of-movie-descriptions-tfhub.ipynb)"
]
},
{
"metadata": {
"id": "feFbqZpP1soY",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# Building a text classification model with TF Hub\n",
"\n",
"In this notebook, we'll walk you through building a model to predict the genres of a movie given its description. The emphasis here is not on accuracy, but instead how to use TF Hub layers in a text classification model.\n",
"\n",
"To start, import the necessary dependencies for this project."
]
},
{
"metadata": {
"id": "rOEllRxGQ_me",
"colab_type": "code",
"outputId": "8e806655-ba92-4d09-c2ed-28d9089af96f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"cell_type": "code",
"source": [
"import os\n",
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"import tensorflow_hub as hub\n",
"import json\n",
"import pickle\n",
"import urllib\n",
"\n",
"from sklearn.preprocessing import MultiLabelBinarizer\n",
"\n",
"print(tf.__version__)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"1.11.0\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "xvwN2Jkx2CdU",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## The dataset\n",
"\n",
"We need a lot of text inputs to train our model. For this model we'll use [this awesome movies dataset](https://www.kaggle.com/rounakbanik/the-movies-dataset) from Kaggle. To simplify things I've made the `movies_metadata.csv` file available in a public Cloud Storage bucket so we can download it with `wget`. I've preprocessed the dataset already to limit the number of genres we'll use for our model, but first let's take a look at the original data so we can see what we're working with.\n",
"\n",
"私たちのモデルを訓練するには、多くのテキスト入力が必要です。 このモデルでは、Kaggleのこの素晴らしいムービーデータセットを使用します。 物事を単純化するために、公開されているCloud Storageバケットでmovies_metadata.csvファイルを利用できるようにしたので、wgetでダウンロードすることができます。 モデルに使用するジャンルの数を制限するためにデータセットを事前に事前処理しましたが、最初に元のデータを見てみましょう。"
]
},
{
"metadata": {
"id": "YObfZBenyfMT",
"colab_type": "code",
"outputId": "35a002af-e5c9-40d1-cfd0-2b15b1912d95",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 235
}
},
"cell_type": "code",
"source": [
"# Download the data from GCS\n",
"!wget 'https://storage.googleapis.com/movies_data/movies_metadata.csv'"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"--2018-10-19 02:43:46-- https://storage.googleapis.com/movies_data/movies_metadata.csv\n",
"Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.13.128, 2607:f8b0:400c:c03::80\n",
"Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.13.128|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 34445126 (33M) [application/octet-stream]\n",
"Saving to: ‘movies_metadata.csv’\n",
"\n",
"movies_metadata.csv 100%[===================>] 32.85M 93.6MB/s in 0.4s \n",
"\n",
"2018-10-19 02:43:51 (93.6 MB/s) - ‘movies_metadata.csv’ saved [34445126/34445126]\n",
"\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "WFKB0Bw62xW-",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Next we'll convert the dataset to a Pandas dataframe and print the first 5 rows. For this model we're only using 2 of these columns: `genres` and `overview`.\n",
"\n",
"次に、データセットをPandasデータフレームに変換し、最初の5行を印刷します。 このモデルでは、ジャンルと概要の2つの列のみを使用しています。"
]
},
{
"metadata": {
"id": "NaZiKWtGyoQE",
"colab_type": "code",
"outputId": "fef3a276-ee7d-4dfa-b098-26d0658a5f0a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 729
}
},
"cell_type": "code",
"source": [
"data = pd.read_csv('movies_metadata.csv')\n",
"data.head()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/IPython/core/interactiveshell.py:2718: DtypeWarning: Columns (10) have mixed types. Specify dtype option on import or set low_memory=False.\n",
" interactivity=interactivity, compiler=compiler, result=result)\n"
],
"name": "stderr"
},
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>adult</th>\n",
" <th>belongs_to_collection</th>\n",
" <th>budget</th>\n",
" <th>genres</th>\n",
" <th>homepage</th>\n",
" <th>id</th>\n",
" <th>imdb_id</th>\n",
" <th>original_language</th>\n",
" <th>original_title</th>\n",
" <th>overview</th>\n",
" <th>...</th>\n",
" <th>release_date</th>\n",
" <th>revenue</th>\n",
" <th>runtime</th>\n",
" <th>spoken_languages</th>\n",
" <th>status</th>\n",
" <th>tagline</th>\n",
" <th>title</th>\n",
" <th>video</th>\n",
" <th>vote_average</th>\n",
" <th>vote_count</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>False</td>\n",
" <td>{'id': 10194, 'name': 'Toy Story Collection', ...</td>\n",
" <td>30000000</td>\n",
" <td>[{'id': 16, 'name': 'Animation'}, {'id': 35, '...</td>\n",
" <td>http://toystory.disney.com/toy-story</td>\n",
" <td>862</td>\n",
" <td>tt0114709</td>\n",
" <td>en</td>\n",
" <td>Toy Story</td>\n",
" <td>Led by Woody, Andy's toys live happily in his ...</td>\n",
" <td>...</td>\n",
" <td>1995-10-30</td>\n",
" <td>373554033.0</td>\n",
" <td>81.0</td>\n",
" <td>[{'iso_639_1': 'en', 'name': 'English'}]</td>\n",
" <td>Released</td>\n",
" <td>NaN</td>\n",
" <td>Toy Story</td>\n",
" <td>False</td>\n",
" <td>7.7</td>\n",
" <td>5415.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>False</td>\n",
" <td>NaN</td>\n",
" <td>65000000</td>\n",
" <td>[{'id': 12, 'name': 'Adventure'}, {'id': 14, '...</td>\n",
" <td>NaN</td>\n",
" <td>8844</td>\n",
" <td>tt0113497</td>\n",
" <td>en</td>\n",
" <td>Jumanji</td>\n",
" <td>When siblings Judy and Peter discover an encha...</td>\n",
" <td>...</td>\n",
" <td>1995-12-15</td>\n",
" <td>262797249.0</td>\n",
" <td>104.0</td>\n",
" <td>[{'iso_639_1': 'en', 'name': 'English'}, {'iso...</td>\n",
" <td>Released</td>\n",
" <td>Roll the dice and unleash the excitement!</td>\n",
" <td>Jumanji</td>\n",
" <td>False</td>\n",
" <td>6.9</td>\n",
" <td>2413.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>False</td>\n",
" <td>{'id': 119050, 'name': 'Grumpy Old Men Collect...</td>\n",
" <td>0</td>\n",
" <td>[{'id': 10749, 'name': 'Romance'}, {'id': 35, ...</td>\n",
" <td>NaN</td>\n",
" <td>15602</td>\n",
" <td>tt0113228</td>\n",
" <td>en</td>\n",
" <td>Grumpier Old Men</td>\n",
" <td>A family wedding reignites the ancient feud be...</td>\n",
" <td>...</td>\n",
" <td>1995-12-22</td>\n",
" <td>0.0</td>\n",
" <td>101.0</td>\n",
" <td>[{'iso_639_1': 'en', 'name': 'English'}]</td>\n",
" <td>Released</td>\n",
" <td>Still Yelling. Still Fighting. Still Ready for...</td>\n",
" <td>Grumpier Old Men</td>\n",
" <td>False</td>\n",
" <td>6.5</td>\n",
" <td>92.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>False</td>\n",
" <td>NaN</td>\n",
" <td>16000000</td>\n",
" <td>[{'id': 35, 'name': 'Comedy'}, {'id': 18, 'nam...</td>\n",
" <td>NaN</td>\n",
" <td>31357</td>\n",
" <td>tt0114885</td>\n",
" <td>en</td>\n",
" <td>Waiting to Exhale</td>\n",
" <td>Cheated on, mistreated and stepped on, the wom...</td>\n",
" <td>...</td>\n",
" <td>1995-12-22</td>\n",
" <td>81452156.0</td>\n",
" <td>127.0</td>\n",
" <td>[{'iso_639_1': 'en', 'name': 'English'}]</td>\n",
" <td>Released</td>\n",
" <td>Friends are the people who let you be yourself...</td>\n",
" <td>Waiting to Exhale</td>\n",
" <td>False</td>\n",
" <td>6.1</td>\n",
" <td>34.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>False</td>\n",
" <td>{'id': 96871, 'name': 'Father of the Bride Col...</td>\n",
" <td>0</td>\n",
" <td>[{'id': 35, 'name': 'Comedy'}]</td>\n",
" <td>NaN</td>\n",
" <td>11862</td>\n",
" <td>tt0113041</td>\n",
" <td>en</td>\n",
" <td>Father of the Bride Part II</td>\n",
" <td>Just when George Banks has recovered from his ...</td>\n",
" <td>...</td>\n",
" <td>1995-02-10</td>\n",
" <td>76578911.0</td>\n",
" <td>106.0</td>\n",
" <td>[{'iso_639_1': 'en', 'name': 'English'}]</td>\n",
" <td>Released</td>\n",
" <td>Just When His World Is Back To Normal... He's ...</td>\n",
" <td>Father of the Bride Part II</td>\n",
" <td>False</td>\n",
" <td>5.7</td>\n",
" <td>173.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 24 columns</p>\n",
"</div>"
],
"text/plain": [
" adult belongs_to_collection budget \\\n",
"0 False {'id': 10194, 'name': 'Toy Story Collection', ... 30000000 \n",
"1 False NaN 65000000 \n",
"2 False {'id': 119050, 'name': 'Grumpy Old Men Collect... 0 \n",
"3 False NaN 16000000 \n",
"4 False {'id': 96871, 'name': 'Father of the Bride Col... 0 \n",
"\n",
" genres \\\n",
"0 [{'id': 16, 'name': 'Animation'}, {'id': 35, '... \n",
"1 [{'id': 12, 'name': 'Adventure'}, {'id': 14, '... \n",
"2 [{'id': 10749, 'name': 'Romance'}, {'id': 35, ... \n",
"3 [{'id': 35, 'name': 'Comedy'}, {'id': 18, 'nam... \n",
"4 [{'id': 35, 'name': 'Comedy'}] \n",
"\n",
" homepage id imdb_id original_language \\\n",
"0 http://toystory.disney.com/toy-story 862 tt0114709 en \n",
"1 NaN 8844 tt0113497 en \n",
"2 NaN 15602 tt0113228 en \n",
"3 NaN 31357 tt0114885 en \n",
"4 NaN 11862 tt0113041 en \n",
"\n",
" original_title \\\n",
"0 Toy Story \n",
"1 Jumanji \n",
"2 Grumpier Old Men \n",
"3 Waiting to Exhale \n",
"4 Father of the Bride Part II \n",
"\n",
" overview ... release_date \\\n",
"0 Led by Woody, Andy's toys live happily in his ... ... 1995-10-30 \n",
"1 When siblings Judy and Peter discover an encha... ... 1995-12-15 \n",
"2 A family wedding reignites the ancient feud be... ... 1995-12-22 \n",
"3 Cheated on, mistreated and stepped on, the wom... ... 1995-12-22 \n",
"4 Just when George Banks has recovered from his ... ... 1995-02-10 \n",
"\n",
" revenue runtime spoken_languages \\\n",
"0 373554033.0 81.0 [{'iso_639_1': 'en', 'name': 'English'}] \n",
"1 262797249.0 104.0 [{'iso_639_1': 'en', 'name': 'English'}, {'iso... \n",
"2 0.0 101.0 [{'iso_639_1': 'en', 'name': 'English'}] \n",
"3 81452156.0 127.0 [{'iso_639_1': 'en', 'name': 'English'}] \n",
"4 76578911.0 106.0 [{'iso_639_1': 'en', 'name': 'English'}] \n",
"\n",
" status tagline \\\n",
"0 Released NaN \n",
"1 Released Roll the dice and unleash the excitement! \n",
"2 Released Still Yelling. Still Fighting. Still Ready for... \n",
"3 Released Friends are the people who let you be yourself... \n",
"4 Released Just When His World Is Back To Normal... He's ... \n",
"\n",
" title video vote_average vote_count \n",
"0 Toy Story False 7.7 5415.0 \n",
"1 Jumanji False 6.9 2413.0 \n",
"2 Grumpier Old Men False 6.5 92.0 \n",
"3 Waiting to Exhale False 6.1 34.0 \n",
"4 Father of the Bride Part II False 5.7 173.0 \n",
"\n",
"[5 rows x 24 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 3
}
]
},
{
"metadata": {
"id": "MBLcNSE_7Icv",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Preparing the data for our model\n",
"\n",
"I've done some preprocessing to limit the dataset to the top 9 genres, and I've saved the Pandas dataframes as public [Pickle](https://docs.python.org/3/library/pickle.html) files in GCS. Here we download those files. The resulting `descriptions` and `genres` variables are Pandas Series containing all descriptions and genres from our dataset respectively.\n",
"\n",
"私は、データセットをトップ9のジャンルに限定するためにいくつかの前処理を行いました.PandasデータフレームをGCSのパブリックピクルファイルとして保存しました。 ここでそれらのファイルをダウンロードします。 その結果の説明とジャンルの変数は、私たちのデータセットのすべての説明とジャンルを含むPandasシリーズです。"
]
},
{
"metadata": {
"id": "rzjJuKhir-PH",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"urllib.request.urlretrieve('https://storage.googleapis.com/bq-imports/descriptions.p', 'descriptions.p')\n",
"urllib.request.urlretrieve('https://storage.googleapis.com/bq-imports/genres.p', 'genres.p')\n",
"\n",
"descriptions = pickle.load(open('descriptions.p', 'rb'))\n",
"genres = pickle.load(open('genres.p', 'rb'))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "ZUypuN818T_D",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"### Splitting our data\n",
"When we train our model, we'll use 80% of the data for training and set aside 20% of the data to evaluate how our model performed.\n",
"\n",
"モデルをトレーニングするときは、データの80%をトレーニングに使用し、データの20%はモデルのパフォーマンスを評価するために使用します。"
]
},
{
"metadata": {
"id": "_nticMcj1alW",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"train_size = int(len(descriptions) * .8)\n",
"\n",
"train_descriptions = descriptions[:train_size].astype('str')\n",
"train_genres = genres[:train_size]\n",
"\n",
"test_descriptions = descriptions[train_size:].astype('str')\n",
"test_genres = genres[train_size:]"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "FmZ9iqK88nSD",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"### Formatting our labels\n",
"When we train our model we'll provide the labels (in this case genres) associated with each movie. We can't pass the genres in as strings directly, we'll transform them into multi-hot vectors. Since we have 9 genres, we'll have a 9 element vector for each movie with 0s and 1s indicating which genres are present in each description.\n",
"\n",
"モデルを訓練するときは、各映画に関連付けられたラベル(この場合はジャンル)を提供します。 ジャンルを文字列として直接渡すことはできません。それらをマルチホットベクトルに変換します。 私たちは9つのジャンルを持っているので、それぞれの映画に9要素のベクトルがあり、それぞれの説明にどのジャンルが存在するかを示す0と1があります。"
]
},
{
"metadata": {
"id": "bouv0R-D7J45",
"colab_type": "code",
"outputId": "97c76a58-2262-404c-e6e8-cc1e314acbf5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"cell_type": "code",
"source": [
"encoder = MultiLabelBinarizer()\n",
"encoder.fit_transform(train_genres)\n",
"train_encoded = encoder.transform(train_genres)\n",
"test_encoded = encoder.transform(test_genres)\n",
"num_classes = len(encoder.classes_)\n",
"\n",
"# Print all possible genres and the labels for the first movie in our training dataset\n",
"print(encoder.classes_)\n",
"print(train_encoded[0])"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"['Action' 'Adventure' 'Comedy' 'Crime' 'Documentary' 'Horror' 'Romance'\n",
" 'Science Fiction' 'Thriller']\n",
"[0 0 1 0 0 0 1 0 0]\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "Ir8ez0K_9sYA",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"### Create our TF Hub embedding layer\n",
"[TF Hub]() provides a library of existing pre-trained model checkpoints for various kinds of models (images, text, and more) In this model we'll use the TF Hub `universal-sentence-encoder` module for our pre-trained word embeddings. We only need one line of code to instantiate module. When we train our model, it'll convert our array of movie description strings to embeddings. When we train our model, we'll use this as a feature column.\n",
"\n",
"TF Hubは、さまざまな種類のモデル(画像、テキストなど)のための既存の事前訓練モデルチェックポイントのライブラリを提供します。このモデルでは、TF Hubのuniversal-sentence-encoderモジュールを使用します 訓練された単語埋め込み。 私たちはモジュールをインスタンス化するために1行のコードしか必要としません。 私たちのモデルを訓練すると、私たちのムービー記述ストリングの配列が埋め込みに変換されます。 モデルをトレーニングするときには、これをフィーチャー列として使用します。"
]
},
{
"metadata": {
"id": "PWuNUXq7a-7p",
"colab_type": "code",
"outputId": "ec4b4f91-11a4-4e5e-e829-e7e55b06e25c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"cell_type": "code",
"source": [
"description_embeddings = hub.text_embedding_column(\"descriptions\", module_spec=\"https://tfhub.dev/google/universal-sentence-encoder/2\", trainable=False)\n"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Using /tmp/tfhub_modules to cache modules.\n",
"INFO:tensorflow:Downloading TF-Hub Module 'https://tfhub.dev/google/universal-sentence-encoder/2'.\n",
"INFO:tensorflow:Downloaded TF-Hub Module 'https://tfhub.dev/google/universal-sentence-encoder/2'.\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "9vscf4Fo-iI-",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Instantiating our DNNEstimator Model\n",
"The first parameter we pass to our DNNEstimator is called a head, and defines the type of labels our model should expect. Since we want our model to output multiple labels, we’ll use multi_label_head here. Then we'll convert our features and labels to numpy arrays and instantiate our Estimator. `batch_size` and `num_epochs` are hyperparameters - you should experiment with different values to see what works best on your dataset.\n",
"\n",
"DNNEstimatorに渡す最初のパラメータは頭部と呼ばれ、モデルが期待するラベルの種類を定義します。 モデルに複数のラベルを出力したいので、ここではmulti_label_headを使用します。 次に、フィーチャとラベルをnumpy配列に変換し、Estimatorをインスタンス化します。 `batch_size`と` num_epochs`はハイパーパラメータです。異なる値を試して、あなたのデータセットで何が最もうまくいくかを調べる必要があります。"
]
},
{
"metadata": {
"id": "c0Vsmu9O21je",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"multi_label_head = tf.contrib.estimator.multi_label_head(\n",
" num_classes,\n",
" loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE\n",
")"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "8mTpWD_Q8GKe",
"colab_type": "code",
"outputId": "f6c75c1b-5331-4f38-d667-815eb7db2b0e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 210
}
},
"cell_type": "code",
"source": [
"features = {\n",
" \"descriptions\": np.array(train_descriptions).astype(np.str)\n",
"}\n",
"labels = np.array(train_encoded).astype(np.int32)\n",
"train_input_fn = tf.estimator.inputs.numpy_input_fn(features, labels, shuffle=True, batch_size=32, num_epochs=25)\n",
"estimator = tf.contrib.estimator.DNNEstimator(\n",
" head=multi_label_head,\n",
" hidden_units=[64,10],\n",
" feature_columns=[description_embeddings])"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Using default config.\n",
"WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpceq0l35q\n",
"INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpceq0l35q', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
"graph_options {\n",
" rewrite_options {\n",
" meta_optimizer_iterations: ONE\n",
" }\n",
"}\n",
", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fb26ddcc6d8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "5ak1cZPZ_ZYM",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Training and serving our model \n",
"To train our model, we simply call `train()` passing it the input function we defined above. Once our model is trained, we'll define an evaluation input function similar to the one above and call `evaluate()`. When this completes we'll get a few metrics we can use to evaluate our model's accuracy.\n",
"\n",
"我々のモデルを訓練するために、我々は単に `train()`を呼んで、上で定義した入力関数を渡します。 モデルが訓練されると、上記のような評価入力関数を定義し、 `evaluate()`を呼び出します。 これが完了すると、モデルの精度を評価するために使用できるいくつかの指標が得られます。"
]
},
{
"metadata": {
"id": "jmtvJ5o3Olcg",
"colab_type": "code",
"outputId": "f62aa09d-f1a7-49a7-eafb-e13564cad02f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 3573
}
},
"cell_type": "code",
"source": [
"estimator.train(input_fn=train_input_fn)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/estimator/inputs/queues/feeding_queue_runner.py:62: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"To construct input pipelines, use the `tf.data` module.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/estimator/inputs/queues/feeding_functions.py:500: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"To construct input pipelines, use the `tf.data` module.\n",
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n",
"INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Create CheckpointSaverHook.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/monitored_session.py:804: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"To construct input pipelines, use the `tf.data` module.\n",
"INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpceq0l35q/model.ckpt.\n",
"INFO:tensorflow:loss = 0.6906395, step = 1\n",
"INFO:tensorflow:global_step/sec: 47.7169\n",
"INFO:tensorflow:loss = 0.5307451, step = 101 (2.097 sec)\n",
"INFO:tensorflow:global_step/sec: 55.9306\n",
"INFO:tensorflow:loss = 0.41936988, step = 201 (1.788 sec)\n",
"INFO:tensorflow:global_step/sec: 55.4976\n",
"INFO:tensorflow:loss = 0.4674911, step = 301 (1.802 sec)\n",
"INFO:tensorflow:global_step/sec: 56.4919\n",
"INFO:tensorflow:loss = 0.433739, step = 401 (1.770 sec)\n",
"INFO:tensorflow:global_step/sec: 56.133\n",
"INFO:tensorflow:loss = 0.44665545, step = 501 (1.781 sec)\n",
"INFO:tensorflow:global_step/sec: 56.8832\n",
"INFO:tensorflow:loss = 0.32581294, step = 601 (1.758 sec)\n",
"INFO:tensorflow:global_step/sec: 55.6628\n",
"INFO:tensorflow:loss = 0.47503218, step = 701 (1.797 sec)\n",
"INFO:tensorflow:global_step/sec: 56.9997\n",
"INFO:tensorflow:loss = 0.39237317, step = 801 (1.755 sec)\n",
"INFO:tensorflow:global_step/sec: 55.6784\n",
"INFO:tensorflow:loss = 0.4085172, step = 901 (1.797 sec)\n",
"INFO:tensorflow:global_step/sec: 56.0513\n",
"INFO:tensorflow:loss = 0.3785057, step = 1001 (1.783 sec)\n",
"INFO:tensorflow:global_step/sec: 55.7738\n",
"INFO:tensorflow:loss = 0.45325413, step = 1101 (1.793 sec)\n",
"INFO:tensorflow:global_step/sec: 56.5832\n",
"INFO:tensorflow:loss = 0.40296894, step = 1201 (1.767 sec)\n",
"INFO:tensorflow:global_step/sec: 56.0652\n",
"INFO:tensorflow:loss = 0.36876005, step = 1301 (1.784 sec)\n",
"INFO:tensorflow:global_step/sec: 55.6143\n",
"INFO:tensorflow:loss = 0.39673683, step = 1401 (1.798 sec)\n",
"INFO:tensorflow:global_step/sec: 56.2325\n",
"INFO:tensorflow:loss = 0.41657734, step = 1501 (1.778 sec)\n",
"INFO:tensorflow:global_step/sec: 56.8023\n",
"INFO:tensorflow:loss = 0.3529602, step = 1601 (1.761 sec)\n",
"INFO:tensorflow:global_step/sec: 55.8456\n",
"INFO:tensorflow:loss = 0.33297592, step = 1701 (1.796 sec)\n",
"INFO:tensorflow:global_step/sec: 55.4169\n",
"INFO:tensorflow:loss = 0.37074035, step = 1801 (1.802 sec)\n",
"INFO:tensorflow:global_step/sec: 56.7319\n",
"INFO:tensorflow:loss = 0.39014798, step = 1901 (1.760 sec)\n",
"INFO:tensorflow:global_step/sec: 55.687\n",
"INFO:tensorflow:loss = 0.3445291, step = 2001 (1.796 sec)\n",
"INFO:tensorflow:global_step/sec: 55.8369\n",
"INFO:tensorflow:loss = 0.3068616, step = 2101 (1.791 sec)\n",
"INFO:tensorflow:global_step/sec: 55.3735\n",
"INFO:tensorflow:loss = 0.37884736, step = 2201 (1.806 sec)\n",
"INFO:tensorflow:global_step/sec: 56.8518\n",
"INFO:tensorflow:loss = 0.27714652, step = 2301 (1.762 sec)\n",
"INFO:tensorflow:global_step/sec: 56.0077\n",
"INFO:tensorflow:loss = 0.3206575, step = 2401 (1.783 sec)\n",
"INFO:tensorflow:global_step/sec: 55.2645\n",
"INFO:tensorflow:loss = 0.27519703, step = 2501 (1.810 sec)\n",
"INFO:tensorflow:global_step/sec: 56.4367\n",
"INFO:tensorflow:loss = 0.3924746, step = 2601 (1.774 sec)\n",
"INFO:tensorflow:global_step/sec: 57.5702\n",
"INFO:tensorflow:loss = 0.34570682, step = 2701 (1.738 sec)\n",
"INFO:tensorflow:global_step/sec: 56.0012\n",
"INFO:tensorflow:loss = 0.3191098, step = 2801 (1.782 sec)\n",
"INFO:tensorflow:global_step/sec: 56.7857\n",
"INFO:tensorflow:loss = 0.31155926, step = 2901 (1.761 sec)\n",
"INFO:tensorflow:global_step/sec: 57.0461\n",
"INFO:tensorflow:loss = 0.31713024, step = 3001 (1.753 sec)\n",
"INFO:tensorflow:global_step/sec: 56.7644\n",
"INFO:tensorflow:loss = 0.28780913, step = 3101 (1.762 sec)\n",
"INFO:tensorflow:global_step/sec: 56.2077\n",
"INFO:tensorflow:loss = 0.22899172, step = 3201 (1.779 sec)\n",
"INFO:tensorflow:global_step/sec: 54.9932\n",
"INFO:tensorflow:loss = 0.3066067, step = 3301 (1.822 sec)\n",
"INFO:tensorflow:global_step/sec: 57.5482\n",
"INFO:tensorflow:loss = 0.25926715, step = 3401 (1.735 sec)\n",
"INFO:tensorflow:global_step/sec: 55.8937\n",
"INFO:tensorflow:loss = 0.28046352, step = 3501 (1.789 sec)\n",
"INFO:tensorflow:global_step/sec: 55.8569\n",
"INFO:tensorflow:loss = 0.2631961, step = 3601 (1.790 sec)\n",
"INFO:tensorflow:global_step/sec: 54.9162\n",
"INFO:tensorflow:loss = 0.41314426, step = 3701 (1.821 sec)\n",
"INFO:tensorflow:global_step/sec: 56.8567\n",
"INFO:tensorflow:loss = 0.32123646, step = 3801 (1.764 sec)\n",
"INFO:tensorflow:global_step/sec: 55.7648\n",
"INFO:tensorflow:loss = 0.33113036, step = 3901 (1.788 sec)\n",
"INFO:tensorflow:global_step/sec: 56.258\n",
"INFO:tensorflow:loss = 0.24928045, step = 4001 (1.782 sec)\n",
"INFO:tensorflow:global_step/sec: 56.1698\n",
"INFO:tensorflow:loss = 0.29253238, step = 4101 (1.780 sec)\n",
"INFO:tensorflow:global_step/sec: 56.3224\n",
"INFO:tensorflow:loss = 0.33170295, step = 4201 (1.771 sec)\n",
"INFO:tensorflow:global_step/sec: 56.0725\n",
"INFO:tensorflow:loss = 0.2857255, step = 4301 (1.784 sec)\n",
"INFO:tensorflow:global_step/sec: 55.2234\n",
"INFO:tensorflow:loss = 0.32770783, step = 4401 (1.811 sec)\n",
"INFO:tensorflow:global_step/sec: 56.9397\n",
"INFO:tensorflow:loss = 0.34567958, step = 4501 (1.756 sec)\n",
"INFO:tensorflow:global_step/sec: 56.588\n",
"INFO:tensorflow:loss = 0.22739732, step = 4601 (1.770 sec)\n",
"INFO:tensorflow:global_step/sec: 55.8973\n",
"INFO:tensorflow:loss = 0.23397678, step = 4701 (1.786 sec)\n",
"INFO:tensorflow:global_step/sec: 54.8537\n",
"INFO:tensorflow:loss = 0.28782395, step = 4801 (1.826 sec)\n",
"INFO:tensorflow:global_step/sec: 57.5266\n",
"INFO:tensorflow:loss = 0.34130412, step = 4901 (1.735 sec)\n",
"INFO:tensorflow:global_step/sec: 56.4719\n",
"INFO:tensorflow:loss = 0.2601814, step = 5001 (1.775 sec)\n",
"INFO:tensorflow:global_step/sec: 55.9204\n",
"INFO:tensorflow:loss = 0.23378342, step = 5101 (1.784 sec)\n",
"INFO:tensorflow:global_step/sec: 55.8136\n",
"INFO:tensorflow:loss = 0.25749898, step = 5201 (1.791 sec)\n",
"INFO:tensorflow:global_step/sec: 56.4817\n",
"INFO:tensorflow:loss = 0.2671554, step = 5301 (1.771 sec)\n",
"INFO:tensorflow:global_step/sec: 55.5648\n",
"INFO:tensorflow:loss = 0.34511262, step = 5401 (1.800 sec)\n",
"INFO:tensorflow:global_step/sec: 56.0035\n",
"INFO:tensorflow:loss = 0.28205428, step = 5501 (1.790 sec)\n",
"INFO:tensorflow:global_step/sec: 55.4771\n",
"INFO:tensorflow:loss = 0.32101482, step = 5601 (1.798 sec)\n",
"INFO:tensorflow:global_step/sec: 56.557\n",
"INFO:tensorflow:loss = 0.20183584, step = 5701 (1.768 sec)\n",
"INFO:tensorflow:global_step/sec: 56.2328\n",
"INFO:tensorflow:loss = 0.2667691, step = 5801 (1.779 sec)\n",
"INFO:tensorflow:global_step/sec: 55.831\n",
"INFO:tensorflow:loss = 0.29511333, step = 5901 (1.791 sec)\n",
"INFO:tensorflow:global_step/sec: 56.5377\n",
"INFO:tensorflow:loss = 0.2963462, step = 6001 (1.769 sec)\n",
"INFO:tensorflow:global_step/sec: 56.6478\n",
"INFO:tensorflow:loss = 0.2812236, step = 6101 (1.765 sec)\n",
"INFO:tensorflow:global_step/sec: 56.2698\n",
"INFO:tensorflow:loss = 0.1931965, step = 6201 (1.782 sec)\n",
"INFO:tensorflow:global_step/sec: 55.2199\n",
"INFO:tensorflow:loss = 0.3035059, step = 6301 (1.811 sec)\n",
"INFO:tensorflow:global_step/sec: 57.0946\n",
"INFO:tensorflow:loss = 0.22357331, step = 6401 (1.751 sec)\n",
"INFO:tensorflow:global_step/sec: 55.862\n",
"INFO:tensorflow:loss = 0.26666236, step = 6501 (1.786 sec)\n",
"INFO:tensorflow:global_step/sec: 56.1681\n",
"INFO:tensorflow:loss = 0.3552798, step = 6601 (1.785 sec)\n",
"INFO:tensorflow:global_step/sec: 55.0176\n",
"INFO:tensorflow:loss = 0.3320067, step = 6701 (1.812 sec)\n",
"INFO:tensorflow:global_step/sec: 56.4125\n",
"INFO:tensorflow:loss = 0.2941891, step = 6801 (1.773 sec)\n",
"INFO:tensorflow:global_step/sec: 55.7634\n",
"INFO:tensorflow:loss = 0.21167153, step = 6901 (1.797 sec)\n",
"INFO:tensorflow:global_step/sec: 55.9856\n",
"INFO:tensorflow:loss = 0.24919438, step = 7001 (1.783 sec)\n",
"INFO:tensorflow:global_step/sec: 55.8329\n",
"INFO:tensorflow:loss = 0.32960105, step = 7101 (1.794 sec)\n",
"INFO:tensorflow:global_step/sec: 56.8646\n",
"INFO:tensorflow:loss = 0.28020442, step = 7201 (1.756 sec)\n",
"INFO:tensorflow:global_step/sec: 55.5984\n",
"INFO:tensorflow:loss = 0.2793401, step = 7301 (1.799 sec)\n",
"INFO:tensorflow:global_step/sec: 53.6319\n",
"INFO:tensorflow:loss = 0.24506754, step = 7401 (1.865 sec)\n",
"INFO:tensorflow:global_step/sec: 51.0607\n",
"INFO:tensorflow:loss = 0.27951616, step = 7501 (1.958 sec)\n",
"INFO:tensorflow:global_step/sec: 50.6917\n",
"INFO:tensorflow:loss = 0.26933527, step = 7601 (1.973 sec)\n",
"INFO:tensorflow:global_step/sec: 50.3787\n",
"INFO:tensorflow:loss = 0.26389205, step = 7701 (1.985 sec)\n",
"INFO:tensorflow:global_step/sec: 49.2026\n",
"INFO:tensorflow:loss = 0.23955397, step = 7801 (2.032 sec)\n",
"INFO:tensorflow:global_step/sec: 51.8819\n",
"INFO:tensorflow:loss = 0.22410524, step = 7901 (1.935 sec)\n",
"INFO:tensorflow:global_step/sec: 50.6407\n",
"INFO:tensorflow:loss = 0.20797625, step = 8001 (1.967 sec)\n",
"INFO:tensorflow:global_step/sec: 50.0789\n",
"INFO:tensorflow:loss = 0.24727762, step = 8101 (1.997 sec)\n",
"INFO:tensorflow:global_step/sec: 50.1024\n",
"INFO:tensorflow:loss = 0.3366214, step = 8201 (1.995 sec)\n",
"INFO:tensorflow:global_step/sec: 56.7003\n",
"INFO:tensorflow:loss = 0.17994358, step = 8301 (1.764 sec)\n",
"INFO:tensorflow:global_step/sec: 55.6441\n",
"INFO:tensorflow:loss = 0.30137646, step = 8401 (1.798 sec)\n",
"INFO:tensorflow:global_step/sec: 56.0976\n",
"INFO:tensorflow:loss = 0.27949485, step = 8501 (1.782 sec)\n",
"INFO:tensorflow:global_step/sec: 55.5684\n",
"INFO:tensorflow:loss = 0.28376475, step = 8601 (1.802 sec)\n",
"INFO:tensorflow:global_step/sec: 56.2581\n",
"INFO:tensorflow:loss = 0.22538447, step = 8701 (1.779 sec)\n",
"INFO:tensorflow:global_step/sec: 55.945\n",
"INFO:tensorflow:loss = 0.25649977, step = 8801 (1.783 sec)\n",
"INFO:tensorflow:global_step/sec: 55.9041\n",
"INFO:tensorflow:loss = 0.29682606, step = 8901 (1.790 sec)\n",
"INFO:tensorflow:global_step/sec: 56.5568\n",
"INFO:tensorflow:loss = 0.3323989, step = 9001 (1.767 sec)\n",
"INFO:tensorflow:global_step/sec: 55.7399\n",
"INFO:tensorflow:loss = 0.221881, step = 9101 (1.794 sec)\n",
"INFO:tensorflow:global_step/sec: 56.3936\n",
"INFO:tensorflow:loss = 0.30686402, step = 9201 (1.773 sec)\n",
"INFO:tensorflow:global_step/sec: 55.4242\n",
"INFO:tensorflow:loss = 0.28778303, step = 9301 (1.804 sec)\n",
"INFO:tensorflow:Saving checkpoints for 9371 into /tmp/tmpceq0l35q/model.ckpt.\n",
"INFO:tensorflow:Loss for final step: 0.26044637.\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tensorflow.contrib.estimator.python.estimator.dnn.DNNEstimator at 0x7fb29a12e4a8>"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"metadata": {
"id": "dMgti0YmJO7F",
"colab_type": "code",
"outputId": "4c553356-cd06-49a8-e65f-97b64750bd7f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 360
}
},
"cell_type": "code",
"source": [
"# Define our eval input_fn and run eval\n",
"eval_input_fn = tf.estimator.inputs.numpy_input_fn({\"descriptions\": np.array(test_descriptions).astype(np.str)}, test_encoded.astype(np.int32), shuffle=False)\n",
"estimator.evaluate(input_fn=eval_input_fn)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n",
"INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n",
"WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to \"careful_interpolation\" instead.\n",
"WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to \"careful_interpolation\" instead.\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Starting evaluation at 2018-10-19-04:02:46\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Restoring parameters from /tmp/tmpceq0l35q/model.ckpt-9371\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"INFO:tensorflow:Finished evaluation at 2018-10-19-04:02:59\n",
"INFO:tensorflow:Saving dict for global step 9371: auc = 0.9134479, auc_precision_recall = 0.7385525, average_loss = 0.2481405, global_step = 9371, loss = 0.24884044\n",
"INFO:tensorflow:Saving 'checkpoint_path' summary for global step 9371: /tmp/tmpceq0l35q/model.ckpt-9371\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'auc': 0.9134479,\n",
" 'auc_precision_recall': 0.7385525,\n",
" 'average_loss': 0.2481405,\n",
" 'global_step': 9371,\n",
" 'loss': 0.24884044}"
]
},
"metadata": {
"tags": []
},
"execution_count": 11
}
]
},
{
"metadata": {
"id": "mcPyCfmWABVO",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Generating predictions on new data\n",
"Now for the most fun part! Let's generate predictions on movie descriptions our model hasn't seen before. We'll define an array of 3 new description strings (the comments indicate the correct genres) and create a `predict_input_fn`. Then we'll display the top 2 genres along with their confidence percentages for each of the 3 movies."
]
},
{
"metadata": {
"id": "ixlCKF6NEkTx",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"# Test our model on some raw description data\n",
"raw_test = [\n",
" \"An examination of our dietary choices and the food we put in our bodies. Based on Jonathan Safran Foer's memoir.\", # Documentary\n",
" \"After escaping an attack by what he claims was a 70-foot shark, Jonas Taylor must confront his fears to save those trapped in a sunken submersible.\", # Action, Adventure\n",
" \"A teenager tries to survive the last week of her disastrous eighth-grade year before leaving to start high school.\", # Comedy\n",
"]\n",
"\n"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "XHpMIWFsE4OB",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"# Generate predictions\n",
"predict_input_fn = tf.estimator.inputs.numpy_input_fn({\"descriptions\": np.array(raw_test).astype(np.str)}, shuffle=False)\n",
"results = estimator.predict(predict_input_fn)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "iMVzrHpPDvoy",
"colab_type": "code",
"outputId": "7241ee96-5c5d-404d-f50e-9062e54eede2",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 306
}
},
"cell_type": "code",
"source": [
"# Display predictions\n",
"for movie_genres in results:\n",
" top_2 = movie_genres['probabilities'].argsort()[-2:][::-1]\n",
" for genre in top_2:\n",
" text_genre = encoder.classes_[genre]\n",
" print(text_genre + ': ' + str(round(movie_genres['probabilities'][genre] * 100, 2)) + '%')\n",
" print('')"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n",
"INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Restoring parameters from /tmp/tmpceq0l35q/model.ckpt-9371\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"Documentary: 97.67%\n",
"Comedy: 20.15%\n",
"\n",
"Horror: 78.77%\n",
"Science Fiction: 52.26%\n",
"\n",
"Comedy: 72.04%\n",
"Romance: 15.19%\n",
"\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "CfZTfK-e7MJr",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment