Skip to content

Instantly share code, notes, and snippets.

@djhoese
Last active December 12, 2019 23:44
Show Gist options
  • Save djhoese/5b6f5dcf4c11c9ef2d862179bd18dafd to your computer and use it in GitHub Desktop.
Save djhoese/5b6f5dcf4c11c9ef2d862179bd18dafd to your computer and use it in GitHub Desktop.
Work for the KaggleTMDB Box Office Prediction
*.png
test.csv
train.csv
sample_submission.csv
tmdb-box-office-prediction.zip
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5"
},
"outputs": [],
"source": [
"import os\n",
"import numpy as np # linear algebra\n",
"import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
"from time import sleep\n",
"import urllib\n",
"from sklearn.linear_model import LinearRegression\n",
"from skimage.color import rgb2hsv\n",
"from skimage import data, segmentation, color\n",
"from skimage.future import graph\n",
"from skimage.io import imread\n",
"import dask.array as da\n",
"import dask.array.image\n",
"\n",
"# Any results you write to the current directory are saved as output.\n",
"base_img_url = \"http://image.tmdb.org/t/p/w185\"\n",
"# base_input = os.path.join('/', 'kaggle', 'input', 'tmdb-box-office-prediction')\n",
"base_input = '.'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
"_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>belongs_to_collection</th>\n",
" <th>budget</th>\n",
" <th>genres</th>\n",
" <th>homepage</th>\n",
" <th>imdb_id</th>\n",
" <th>original_language</th>\n",
" <th>original_title</th>\n",
" <th>overview</th>\n",
" <th>popularity</th>\n",
" <th>...</th>\n",
" <th>release_date</th>\n",
" <th>runtime</th>\n",
" <th>spoken_languages</th>\n",
" <th>status</th>\n",
" <th>tagline</th>\n",
" <th>title</th>\n",
" <th>Keywords</th>\n",
" <th>cast</th>\n",
" <th>crew</th>\n",
" <th>revenue</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>[{'id': 313576, 'name': 'Hot Tub Time Machine ...</td>\n",
" <td>14000000</td>\n",
" <td>[{'id': 35, 'name': 'Comedy'}]</td>\n",
" <td>NaN</td>\n",
" <td>tt2637294</td>\n",
" <td>en</td>\n",
" <td>Hot Tub Time Machine 2</td>\n",
" <td>When Lou, who has become the \"father of the In...</td>\n",
" <td>6.575393</td>\n",
" <td>...</td>\n",
" <td>2/20/15</td>\n",
" <td>93.0</td>\n",
" <td>[{'iso_639_1': 'en', 'name': 'English'}]</td>\n",
" <td>Released</td>\n",
" <td>The Laws of Space and Time are About to be Vio...</td>\n",
" <td>Hot Tub Time Machine 2</td>\n",
" <td>[{'id': 4379, 'name': 'time travel'}, {'id': 9...</td>\n",
" <td>[{'cast_id': 4, 'character': 'Lou', 'credit_id...</td>\n",
" <td>[{'credit_id': '59ac067c92514107af02c8c8', 'de...</td>\n",
" <td>12314651</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>[{'id': 107674, 'name': 'The Princess Diaries ...</td>\n",
" <td>40000000</td>\n",
" <td>[{'id': 35, 'name': 'Comedy'}, {'id': 18, 'nam...</td>\n",
" <td>NaN</td>\n",
" <td>tt0368933</td>\n",
" <td>en</td>\n",
" <td>The Princess Diaries 2: Royal Engagement</td>\n",
" <td>Mia Thermopolis is now a college graduate and ...</td>\n",
" <td>8.248895</td>\n",
" <td>...</td>\n",
" <td>8/6/04</td>\n",
" <td>113.0</td>\n",
" <td>[{'iso_639_1': 'en', 'name': 'English'}]</td>\n",
" <td>Released</td>\n",
" <td>It can take a lifetime to find true love; she'...</td>\n",
" <td>The Princess Diaries 2: Royal Engagement</td>\n",
" <td>[{'id': 2505, 'name': 'coronation'}, {'id': 42...</td>\n",
" <td>[{'cast_id': 1, 'character': 'Mia Thermopolis'...</td>\n",
" <td>[{'credit_id': '52fe43fe9251416c7502563d', 'de...</td>\n",
" <td>95149435</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>NaN</td>\n",
" <td>3300000</td>\n",
" <td>[{'id': 18, 'name': 'Drama'}]</td>\n",
" <td>http://sonyclassics.com/whiplash/</td>\n",
" <td>tt2582802</td>\n",
" <td>en</td>\n",
" <td>Whiplash</td>\n",
" <td>Under the direction of a ruthless instructor, ...</td>\n",
" <td>64.299990</td>\n",
" <td>...</td>\n",
" <td>10/10/14</td>\n",
" <td>105.0</td>\n",
" <td>[{'iso_639_1': 'en', 'name': 'English'}]</td>\n",
" <td>Released</td>\n",
" <td>The road to greatness can take you to the edge.</td>\n",
" <td>Whiplash</td>\n",
" <td>[{'id': 1416, 'name': 'jazz'}, {'id': 1523, 'n...</td>\n",
" <td>[{'cast_id': 5, 'character': 'Andrew Neimann',...</td>\n",
" <td>[{'credit_id': '54d5356ec3a3683ba0000039', 'de...</td>\n",
" <td>13092000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>NaN</td>\n",
" <td>1200000</td>\n",
" <td>[{'id': 53, 'name': 'Thriller'}, {'id': 18, 'n...</td>\n",
" <td>http://kahaanithefilm.com/</td>\n",
" <td>tt1821480</td>\n",
" <td>hi</td>\n",
" <td>Kahaani</td>\n",
" <td>Vidya Bagchi (Vidya Balan) arrives in Kolkata ...</td>\n",
" <td>3.174936</td>\n",
" <td>...</td>\n",
" <td>3/9/12</td>\n",
" <td>122.0</td>\n",
" <td>[{'iso_639_1': 'en', 'name': 'English'}, {'iso...</td>\n",
" <td>Released</td>\n",
" <td>NaN</td>\n",
" <td>Kahaani</td>\n",
" <td>[{'id': 10092, 'name': 'mystery'}, {'id': 1054...</td>\n",
" <td>[{'cast_id': 1, 'character': 'Vidya Bagchi', '...</td>\n",
" <td>[{'credit_id': '52fe48779251416c9108d6eb', 'de...</td>\n",
" <td>16000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>[{'id': 28, 'name': 'Action'}, {'id': 53, 'nam...</td>\n",
" <td>NaN</td>\n",
" <td>tt1380152</td>\n",
" <td>ko</td>\n",
" <td>마린보이</td>\n",
" <td>Marine Boy is the story of a former national s...</td>\n",
" <td>1.148070</td>\n",
" <td>...</td>\n",
" <td>2/5/09</td>\n",
" <td>118.0</td>\n",
" <td>[{'iso_639_1': 'ko', 'name': '한국어/조선말'}]</td>\n",
" <td>Released</td>\n",
" <td>NaN</td>\n",
" <td>Marine Boy</td>\n",
" <td>NaN</td>\n",
" <td>[{'cast_id': 3, 'character': 'Chun-soo', 'cred...</td>\n",
" <td>[{'credit_id': '52fe464b9251416c75073b43', 'de...</td>\n",
" <td>3923970</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 23 columns</p>\n",
"</div>"
],
"text/plain": [
" id belongs_to_collection budget \\\n",
"0 1 [{'id': 313576, 'name': 'Hot Tub Time Machine ... 14000000 \n",
"1 2 [{'id': 107674, 'name': 'The Princess Diaries ... 40000000 \n",
"2 3 NaN 3300000 \n",
"3 4 NaN 1200000 \n",
"4 5 NaN 0 \n",
"\n",
" genres \\\n",
"0 [{'id': 35, 'name': 'Comedy'}] \n",
"1 [{'id': 35, 'name': 'Comedy'}, {'id': 18, 'nam... \n",
"2 [{'id': 18, 'name': 'Drama'}] \n",
"3 [{'id': 53, 'name': 'Thriller'}, {'id': 18, 'n... \n",
"4 [{'id': 28, 'name': 'Action'}, {'id': 53, 'nam... \n",
"\n",
" homepage imdb_id original_language \\\n",
"0 NaN tt2637294 en \n",
"1 NaN tt0368933 en \n",
"2 http://sonyclassics.com/whiplash/ tt2582802 en \n",
"3 http://kahaanithefilm.com/ tt1821480 hi \n",
"4 NaN tt1380152 ko \n",
"\n",
" original_title \\\n",
"0 Hot Tub Time Machine 2 \n",
"1 The Princess Diaries 2: Royal Engagement \n",
"2 Whiplash \n",
"3 Kahaani \n",
"4 마린보이 \n",
"\n",
" overview popularity ... \\\n",
"0 When Lou, who has become the \"father of the In... 6.575393 ... \n",
"1 Mia Thermopolis is now a college graduate and ... 8.248895 ... \n",
"2 Under the direction of a ruthless instructor, ... 64.299990 ... \n",
"3 Vidya Bagchi (Vidya Balan) arrives in Kolkata ... 3.174936 ... \n",
"4 Marine Boy is the story of a former national s... 1.148070 ... \n",
"\n",
" release_date runtime spoken_languages \\\n",
"0 2/20/15 93.0 [{'iso_639_1': 'en', 'name': 'English'}] \n",
"1 8/6/04 113.0 [{'iso_639_1': 'en', 'name': 'English'}] \n",
"2 10/10/14 105.0 [{'iso_639_1': 'en', 'name': 'English'}] \n",
"3 3/9/12 122.0 [{'iso_639_1': 'en', 'name': 'English'}, {'iso... \n",
"4 2/5/09 118.0 [{'iso_639_1': 'ko', 'name': '한국어/조선말'}] \n",
"\n",
" status tagline \\\n",
"0 Released The Laws of Space and Time are About to be Vio... \n",
"1 Released It can take a lifetime to find true love; she'... \n",
"2 Released The road to greatness can take you to the edge. \n",
"3 Released NaN \n",
"4 Released NaN \n",
"\n",
" title \\\n",
"0 Hot Tub Time Machine 2 \n",
"1 The Princess Diaries 2: Royal Engagement \n",
"2 Whiplash \n",
"3 Kahaani \n",
"4 Marine Boy \n",
"\n",
" Keywords \\\n",
"0 [{'id': 4379, 'name': 'time travel'}, {'id': 9... \n",
"1 [{'id': 2505, 'name': 'coronation'}, {'id': 42... \n",
"2 [{'id': 1416, 'name': 'jazz'}, {'id': 1523, 'n... \n",
"3 [{'id': 10092, 'name': 'mystery'}, {'id': 1054... \n",
"4 NaN \n",
"\n",
" cast \\\n",
"0 [{'cast_id': 4, 'character': 'Lou', 'credit_id... \n",
"1 [{'cast_id': 1, 'character': 'Mia Thermopolis'... \n",
"2 [{'cast_id': 5, 'character': 'Andrew Neimann',... \n",
"3 [{'cast_id': 1, 'character': 'Vidya Bagchi', '... \n",
"4 [{'cast_id': 3, 'character': 'Chun-soo', 'cred... \n",
"\n",
" crew revenue \n",
"0 [{'credit_id': '59ac067c92514107af02c8c8', 'de... 12314651 \n",
"1 [{'credit_id': '52fe43fe9251416c7502563d', 'de... 95149435 \n",
"2 [{'credit_id': '54d5356ec3a3683ba0000039', 'de... 13092000 \n",
"3 [{'credit_id': '52fe48779251416c9108d6eb', 'de... 16000000 \n",
"4 [{'credit_id': '52fe464b9251416c75073b43', 'de... 3923970 \n",
"\n",
"[5 rows x 23 columns]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_data = pd.read_csv(os.path.join(base_input, 'test.csv'))\n",
"train_data = pd.read_csv(os.path.join(base_input, 'train.csv'))\n",
"sample_submission = pd.read_csv(os.path.join(base_input, 'sample_submission.csv'))\n",
"train_data.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['id', 'belongs_to_collection', 'budget', 'genres', 'homepage',\n",
" 'imdb_id', 'original_language', 'original_title', 'overview',\n",
" 'popularity', 'poster_path', 'production_companies',\n",
" 'production_countries', 'release_date', 'runtime', 'spoken_languages',\n",
" 'status', 'tagline', 'title', 'Keywords', 'cast', 'crew', 'revenue'],\n",
" dtype='object')"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data.columns"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>imdb_id</th>\n",
" <th>title</th>\n",
" <th>poster_path</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>tt2637294</td>\n",
" <td>Hot Tub Time Machine 2</td>\n",
" <td>/tQtWuwvMf0hCc2QR2tkolwl7c3c.jpg</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>tt0368933</td>\n",
" <td>The Princess Diaries 2: Royal Engagement</td>\n",
" <td>/w9Z7A0GHEhIp7etpj0vyKOeU1Wx.jpg</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>tt2582802</td>\n",
" <td>Whiplash</td>\n",
" <td>/lIv1QinFqz4dlp5U4lQ6HaiskOZ.jpg</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>tt1821480</td>\n",
" <td>Kahaani</td>\n",
" <td>/aTXRaPrWSinhcmCrcfJK17urp3F.jpg</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>tt1380152</td>\n",
" <td>Marine Boy</td>\n",
" <td>/m22s7zvkVFDU9ir56PiiqIEWFdT.jpg</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id imdb_id title \\\n",
"0 1 tt2637294 Hot Tub Time Machine 2 \n",
"1 2 tt0368933 The Princess Diaries 2: Royal Engagement \n",
"2 3 tt2582802 Whiplash \n",
"3 4 tt1821480 Kahaani \n",
"4 5 tt1380152 Marine Boy \n",
"\n",
" poster_path \n",
"0 /tQtWuwvMf0hCc2QR2tkolwl7c3c.jpg \n",
"1 /w9Z7A0GHEhIp7etpj0vyKOeU1Wx.jpg \n",
"2 /lIv1QinFqz4dlp5U4lQ6HaiskOZ.jpg \n",
"3 /aTXRaPrWSinhcmCrcfJK17urp3F.jpg \n",
"4 /m22s7zvkVFDU9ir56PiiqIEWFdT.jpg "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data[['id', 'imdb_id', 'title', 'poster_path']].head()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading train data posters...\n",
"Could now download: http://image.tmdb.org/t/p/w185nan\n",
"Downloading test data posters...\n",
"Could now download: http://image.tmdb.org/t/p/w185nan\n"
]
}
],
"source": [
"def download_posters(urls, local_paths, base_dir='posters', sleep_time=0.1):\n",
" downloaded_files = []\n",
" for url, local_path in zip(urls, local_paths):\n",
" if os.path.isfile(local_path):\n",
" downloaded_files.append(local_path)\n",
" continue\n",
" \n",
" try:\n",
" urllib.request.urlretrieve(url, local_path)\n",
" downloaded_files.append(local_path)\n",
" except urllib.request.HTTPError:\n",
" print(\"Could now download: \", url)\n",
" downloaded_files.append(\"\")\n",
" sleep(sleep_time)\n",
" return downloaded_files\n",
"\n",
"os.makedirs('posters', exist_ok=True)\n",
"\n",
"print(\"Downloading train data posters...\")\n",
"poster_urls = [\"{}{}\".format(base_img_url, x) for x in train_data['poster_path']]\n",
"poster_local_paths = [\"{}{}\".format('posters', x) for x in train_data['poster_path']]\n",
"train_posters = download_posters(poster_urls, poster_local_paths)\n",
"train_data['poster_url'] = train_posters\n",
"\n",
"print(\"Downloading test data posters...\")\n",
"poster_urls = [\"{}{}\".format(base_img_url, x) for x in test_data['poster_path']]\n",
"poster_local_paths = [\"{}{}\".format('posters', x) for x in test_data['poster_path']]\n",
"test_posters = download_posters(poster_urls, poster_local_paths)\n",
"test_data['poster_url'] = test_posters"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/jpeg": "\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import Image\n",
"Image(train_data['poster_url'][0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Filters"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"train_data = train_data[~(train_data['poster_url'] == '')]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The Math"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"# Region Adjacency Graph (RAG) and normalized cut\n",
"def get_normal_cut(img):\n",
" labels1 = segmentation.slic(img, compactness=30, n_segments=100)\n",
" g = graph.rag_mean_color(img, labels1, mode='similarity')\n",
" labels2 = graph.cut_normalized(labels1, g)\n",
" return labels1, labels2\n",
"\n",
"def get_unique_colors(filename):\n",
" if filename == '':\n",
" # in the test data, if we don't have a poster\n",
" return 0\n",
" print(\"Reading the image...\")\n",
" img = imread(filename)\n",
" print(\"Doing normal cut...\")\n",
" labels = get_normal_cut(img)[1]\n",
" print(\"Doing unique...\")\n",
" return np.unique(labels).size"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x576 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from matplotlib import pyplot as plt\n",
"\n",
"img = imread(train_data['poster_url'][0])\n",
"labels1, labels2 = get_normal_cut(img)\n",
"out1 = color.label2rgb(labels1, img, kind='avg')\n",
"out2 = color.label2rgb(labels2, img, kind='avg')\n",
"\n",
"fig, ax = plt.subplots(nrows=2, sharex=True, sharey=True, figsize=(6, 8))\n",
"ax[0].imshow(out1)\n",
"ax[1].imshow(out2)\n",
"\n",
"for a in ax:\n",
" a.axis('off')\n",
"\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# %%time\n",
"# results = []\n",
"# for idx, fn in enumerate(train_data['poster_url']):\n",
"# print(idx)\n",
"# results.append(get_unique_colors(fn))\n",
"# len(results)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table style=\"border: 2px solid white;\">\n",
"<tr>\n",
"<td style=\"vertical-align: top; border: 0px solid white\">\n",
"<h3 style=\"text-align: left;\">Client</h3>\n",
"<ul style=\"text-align: left; list-style: none; margin: 0; padding: 0;\">\n",
" <li><b>Scheduler: </b>tcp://127.0.0.1:44629</li>\n",
" <li><b>Dashboard: </b><a href='http://127.0.0.1:8787/status' target='_blank'>http://127.0.0.1:8787/status</a>\n",
"</ul>\n",
"</td>\n",
"<td style=\"vertical-align: top; border: 0px solid white\">\n",
"<h3 style=\"text-align: left;\">Cluster</h3>\n",
"<ul style=\"text-align: left; list-style:none; margin: 0; padding: 0;\">\n",
" <li><b>Workers: </b>4</li>\n",
" <li><b>Cores: </b>12</li>\n",
" <li><b>Memory: </b>67.15 GB</li>\n",
"</ul>\n",
"</td>\n",
"</tr>\n",
"</table>"
],
"text/plain": [
"<Client: 'tcp://127.0.0.1:44629' processes=4 threads=12, memory=67.15 GB>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from dask.distributed import Client, LocalCluster\n",
"os.environ['OMP_NUM_THREADS'] = \"3\"\n",
"client = Client()\n",
"client"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"# Training data\n",
"# futures = client.map(get_unique_colors, train_data['poster_url'].tolist())\n",
"# train_results = client.gather(futures)\n",
"# np.save('train_num_colors.npy', np.array(train_results))\n",
"\n",
"# Test data\n",
"# futures = client.map(get_unique_colors, test_data['poster_url'].tolist())\n",
"# test_results = client.gather(futures)\n",
"# np.save('test_num_colors.npy', np.array(test_results))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load from cache\n",
"train_results = np.load('train_num_colors.npy')\n",
"test_results = np.load('test_num_colors.npy')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# client.close()\n",
"# cluster.close()\n",
"# del client, cluster"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.567209996038492\n"
]
}
],
"source": [
"train_data['poster_colors'] = train_results\n",
"test_data['poster_colors'] = test_results\n",
"X = train_data[['budget', 'poster_colors']]\n",
"Y = train_data['revenue']\n",
"lm = LinearRegression()\n",
"reg = lm.fit(X, Y)\n",
"print(reg.score(X, Y))\n",
"predicted_revenue = reg.predict(test_data[['budget', 'poster_colors']])"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 2409678.02719629, -478081.45831937, 1364883.70220721, ...,\n",
" 49078702.54167123, 14022324.70160154, 7813012.1899247 ])"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predicted_revenue"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.6139207260503408"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = train_data[['budget', 'popularity']]\n",
"Y = train_data['revenue']\n",
"lm = LinearRegression()\n",
"reg = lm.fit(X, Y)\n",
"reg.score(X, Y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
name: box_office_kaggle
channels:
- conda-forge
- defaults
dependencies:
- python=3.7
- dask-image
- pandas
- scikit-image
- scikit-learn
- scipy
- dask
- matplotlib
- dask_labextension
- jupyterlab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment