Created
June 25, 2015 14:57
-
-
Save bcho/7fc750da00c130713bd6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import sqlite3\n", | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"\n", | |
"conn = sqlite3.connect('./records.db')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 电影票房预测\n", | |
"\n", | |
"\n", | |
"## 绪论\n", | |
"\n", | |
"数据挖掘是一种为了从海量数据中提取有价值信息而产生的数据处理技术。数据挖掘通过结合传统的数据分析方法和处理大数据的复杂算法,从数据中提炼出辅助决策的信息。数据挖掘的定义可以分别从技术层和商业层来看:\n", | |
"\n", | |
"- 从技术层面来看,数据挖掘是一种从大量数据中提取有用信息的过程;\n", | |
"- 从商业层面来看,数据挖掘是一种商业信息处理技术\n", | |
"\n", | |
"\n", | |
"而数据挖掘与传统的数据分析方法也有着本质的区别:数据挖掘是在没有明确假设的前提下挖掘信息和发现知识。因而数据挖掘得到的信息具有先前未知、有效和实用三个特征。\n", | |
"\n", | |
"\n", | |
"下面将介绍如何应用数据挖掘技术从影片数据中提取票房相关的信息。\n", | |
"\n", | |
"\n", | |
"## 问题描述\n", | |
"\n", | |
"票房是电影受大众欢迎程度的一个主要衡量标准。而一部电影是否卖座,除了和电影本身剧情、拍摄技术有关之外,还会和演员导演阵容、影片上映时间等因素相关。本报告尝试通过对历史电影票房及相关信息进行分析,来预测新电影的票房高低。\n", | |
"\n", | |
"\n", | |
"## 问题建模\n", | |
"\n", | |
"在本问题中,数据集是一系列电影,数据对象是一条电影记录;对电影票房的预测转化为一个分类问题。\n", | |
"\n", | |
"为了更加方便量化影片票房属性,我们可以将票房收入划分为高、中、低三档:\n", | |
"\n", | |
"\n", | |
"| 影片名称 | 影片票房(国内、美元) | 票房收入水平 |\n", | |
"|:--------|:-------------:|:----------:|\n", | |
"| Avatar (阿凡达) | 425000000 | 高 |\n", | |
"| The Zero Theorem (零点定理) | 257706 | 中 |\n", | |
"| Video Games: The Movie (电子游戏大电影) | 23043 | 低 |\n", | |
"\n", | |
"\n", | |
"而对于影片对象的描述,可以采用如下属性:\n", | |
"\n", | |
"- 发行季度(`release_season`,枚举属性,可选值有:spring / summer / autumn / winter)\n", | |
"- 影片类型(`genre`, 枚举属性)\n", | |
"- 影片制作人员水平(`staffs`, 枚举属性,可选值有: first-class / professional / amateur)\n", | |
"\n", | |
"\n", | |
"接下来就可以通过上述属性来对影片票房水平进行分类预测。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 数据获取及标记\n", | |
"\n", | |
"\n", | |
"### 数据概况\n", | |
"\n", | |
"在本实验中,数据从以下几个站点获取:\n", | |
"\n", | |
"- the-numbers.com\n", | |
"- imdb.com\n", | |
"\n", | |
"the-numbers.com 是一个提供完善电影数据的网站,在上面可以获取到相关电影的影片、票房、演员信息。\n", | |
"\n", | |
"imdb.com 全称是 Internet movie database,它提供了最大最全的电影及周边信息。\n", | |
"\n", | |
"\n", | |
"本实验使用了一个简单爬虫来从上述网站抓取影片信息,代码见附录。\n", | |
"\n", | |
"抓取后的信息格式如下:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div style=\"max-height:1000px;max-width:1500px;overflow:auto;\">\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>id</th>\n", | |
" <th>title</th>\n", | |
" <th>genre</th>\n", | |
" <th>release_date</th>\n", | |
" <th>budget</th>\n", | |
" <th>box_office_domestic</th>\n", | |
" <th>box_office_foreign</th>\n", | |
" <th>url</th>\n", | |
" <th>imdb_profile_url</th>\n", | |
" <th>source</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1502</td>\n", | |
" <td>Avatar</td>\n", | |
" <td>Action</td>\n", | |
" <td>2009-12-18 00:00:00</td>\n", | |
" <td>425000000</td>\n", | |
" <td>760507648</td>\n", | |
" <td>0</td>\n", | |
" <td>http://www.the-numbers.com/movie/Avatar#tab=su...</td>\n", | |
" <td>http://www.imdb.com/title/tt0499549/</td>\n", | |
" <td>www.the-numbers.com</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" id title genre release_date budget box_office_domestic \\\n", | |
"0 1502 Avatar Action 2009-12-18 00:00:00 425000000 760507648 \n", | |
"\n", | |
" box_office_foreign url \\\n", | |
"0 0 http://www.the-numbers.com/movie/Avatar#tab=su... \n", | |
"\n", | |
" imdb_profile_url source \n", | |
"0 http://www.imdb.com/title/tt0499549/ www.the-numbers.com " | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# 影片信息\n", | |
"\n", | |
"pd.read_sql('select * from movie where title = \"Avatar\"', conn)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div style=\"max-height:1000px;max-width:1500px;overflow:auto;\">\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>id</th>\n", | |
" <th>name</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1</td>\n", | |
" <td>Jonathan Parker</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>2</td>\n", | |
" <td>York Alec Shackleton</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" id name\n", | |
"0 1 Jonathan Parker\n", | |
"1 2 York Alec Shackleton" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# 职员信息\n", | |
"\n", | |
"pd.read_sql('select * from staff limit 2', conn)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div style=\"max-height:1000px;max-width:1500px;overflow:auto;\">\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>影片数量</th>\n", | |
" <th>职员数量</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>22292</td>\n", | |
" <td>10552</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" 影片数量 职员数量\n", | |
"0 22292 10552" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"movies_count = pd.read_sql('select count(*) as 影片数量 from movie', conn)\n", | |
"staffs_count = pd.read_sql('select count(*) as 职员数量 from staff', conn)\n", | |
"\n", | |
"movies_count['职员数量'] = staffs_count['职员数量']\n", | |
"movies_count" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 数据清洗\n", | |
"\n", | |
"为了去除因为经济环境等第三方因素影响,实验中采用以下的数据:\n", | |
"\n", | |
"- 使用 2005 ~ 2013 年的影片记录作为训练集\n", | |
"- 使用 2014 年的影片记录作为测试集\n", | |
"- 对于上述训练集和测试集,去除票房小于 10000 的影片记录" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div style=\"max-height:1000px;max-width:1500px;overflow:auto;\">\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>训练集影片数</th>\n", | |
" <th>测试集影片数</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>4570</td>\n", | |
" <td>518</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" 训练集影片数 测试集影片数\n", | |
"0 4570 518" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sample_movies_count = pd.read_sql('select count(*) as 训练集影片数 from movie where release_date >= \"2005-01-01\" and release_date <= \"2013-12-31\" and box_office_domestic > 10000', conn)\n", | |
"query_movies_count = pd.read_sql('select count(*) as 测试集影片数 from movie where release_date >= \"2014-01-01\" and release_date <= \"2014-12-31\" and box_office_domestic > 10000', conn)\n", | |
"\n", | |
"sample_movies_count['测试集影片数'] = query_movies_count['测试集影片数']\n", | |
"sample_movies_count" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 数据标记\n", | |
"\n", | |
"\n", | |
"上述原始数据中,其中影片类型(genre)的分类已由 the-numbers.com 完成。\n", | |
"\n", | |
"而发行季度则按照以下方式方式划分:\n", | |
"\n", | |
"- 如果 `release_date` 月份为 11, 12, 1,则该影片发行季节为冬季;\n", | |
"- 如果 `release_date` 月份为 2, 3, 4,则该影片发行季节为春季;\n", | |
"- 如果 `release_date` 月份为 5, 6 7,则该影片发行季节为夏季;\n", | |
"- 如果 `release_date` 月份为 8, 9, 10,则该影片发行季节为秋季。\n", | |
"\n", | |
"\n", | |
"影片制作人员水平按照以下规则划分:\n", | |
"\n", | |
"1. 预计算每位制作人员参与过的影片票房的平均值 `avg(s_i)`\n", | |
"1. 对 avg(s_i) 求三中值(low 33%, medium 66.7%, high 100%)\n", | |
"1. 计算每部电影制作人员的影片票房平均值之和的平均值: `avg(m_i) = sum(avg(s_i) for s_i in movie['staff']) / len(movie['staff'])`\n", | |
"1. 如果 `avg(m_i)`:\n", | |
" * 小于 low, 则标记影片制作人员水平为 amateur (业余)\n", | |
" * 大于 low, 小于 medium,则标记影片制作人员水平为 professional (专业)\n", | |
" * 大于 medium, 小于 high,则标记影片制作人员水平为 first-class (一流)\n", | |
" \n", | |
" \n", | |
"因为测量数据集中的电影都已经上映,所以可以预先计算对应的票房水平并在后面用在检验:\n", | |
"\n", | |
"1. 对测量的影片票房求三中值(low 33%, medium 66.7%, high 100%)\n", | |
"1. 对每部影片的票房 (box_office_domestic):\n", | |
" * 小于 low, 则标记票房水平为 low\n", | |
" * 大于 low, 小于 medium, 则标记票房水平为 medium\n", | |
" * 大于 medium, 小于 high, 则标记票房水平为 high\n", | |
" \n", | |
" \n", | |
"通过上述方法即可完成对影片属性的标记:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div style=\"max-height:1000px;max-width:1500px;overflow:auto;\">\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>id</th>\n", | |
" <th>movie_id</th>\n", | |
" <th>release_season</th>\n", | |
" <th>genre</th>\n", | |
" <th>staffs</th>\n", | |
" <th>gross_level</th>\n", | |
" <th>title</th>\n", | |
" <th>box_office_domestic</th>\n", | |
" <th>release_date</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>337</td>\n", | |
" <td>1502</td>\n", | |
" <td>winter</td>\n", | |
" <td>action</td>\n", | |
" <td>first-class</td>\n", | |
" <td>high</td>\n", | |
" <td>Avatar</td>\n", | |
" <td>760507648</td>\n", | |
" <td>2009-12-18 00:00:00</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" id movie_id release_season genre staffs gross_level title \\\n", | |
"0 337 1502 winter action first-class high Avatar \n", | |
"\n", | |
" box_office_domestic release_date \n", | |
"0 760507648 2009-12-18 00:00:00 " | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pd.read_sql('select movie_property.*, movie.title, movie.box_office_domestic, movie.release_date from movie_property join movie on movie.id = movie_id where movie_id = 1502', conn)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 预测算法实现\n", | |
"\n", | |
"\n", | |
"本实验中采用以下 3 种算法来进行分类:\n", | |
"\n", | |
"- id3\n", | |
"- k 近邻\n", | |
"- 朴素贝叶斯\n", | |
"\n", | |
"\n", | |
"对于每个算法的分类结果,采用以下方法判定准确度:\n", | |
"\n", | |
"1. 使用算法对测试集进行求值 (预测值)\n", | |
"2. 将测试集的预测值和真实值进行比较,得出准确率\n", | |
"\n", | |
"\n", | |
"下面详细介绍各个算法的实现。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 获取样本集\n", | |
"def get_sample_movies():\n", | |
" return pd.read_sql('select * from movie_property where id < 4571', conn)\n", | |
" \n", | |
" \n", | |
"# 获取测试集\n", | |
"def get_query_movies():\n", | |
" return pd.read_sql('select * from movie_property where id > 4570', conn)\n", | |
" \n", | |
"\n", | |
"# 用作分类的属性\n", | |
"basic_properties = ['release_season', 'genre', 'staffs']\n", | |
"\n", | |
"gross_level = ['high', 'medium', 'low']\n", | |
"\n", | |
"\n", | |
"# 计算准确率\n", | |
"def describe_algo(df):\n", | |
" wa_ratio = df[df['gross_level'] != df['predicted_gross_level']].count() / df.count()\n", | |
" print('错误率: {0}'.format(wa_ratio['gross_level']))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## id3" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"错误率: 0.3938223938223938\n" | |
] | |
} | |
], | |
"source": [ | |
"#%%timeit -n 1\n", | |
"# 数据集\n", | |
"id3_samples = get_sample_movies()\n", | |
"id3_queries = get_query_movies()\n", | |
"\n", | |
"\n", | |
"\n", | |
"# 计算数据集的熵\n", | |
"def entropy(df, field):\n", | |
" ratios = df.groupby(field).count() / len(df)\n", | |
" ratios = ratios.apply(lambda ratio: - ratio * np.log2(ratio))\n", | |
" key = ratios.keys()[0]\n", | |
" return ratios[key].values.sum()\n", | |
"\n", | |
"\n", | |
"# 计算信息增益\n", | |
"def entropy_with_fields(df, field, target_field):\n", | |
"\n", | |
" def sub_field(g):\n", | |
" return entropy(g, target_field) * len(g) / len(df)\n", | |
"\n", | |
" return df.groupby(field).agg(sub_field).sum()[0]\n", | |
"\n", | |
"\n", | |
"def select_best_field(df, selectable_fields, target_field):\n", | |
" fields = [i for i in df if i in selectable_fields]\n", | |
" best_field, best_gain = None, 2 << 30\n", | |
" for field in fields:\n", | |
" field_gain = entropy_with_fields(df, field, target_field)\n", | |
" if field_gain < best_gain:\n", | |
" best_field, best_gain = field, field_gain\n", | |
" return best_field\n", | |
"\n", | |
"\n", | |
"def split_with_field(df, field):\n", | |
" keys = df.groupby(field).groups.keys()\n", | |
" return {k: df[df[field] == k] for k in keys}\n", | |
"\n", | |
"\n", | |
"# 决策树结点\n", | |
"class ParentNode(object):\n", | |
" \n", | |
" def __init__(self, field):\n", | |
" self.field = field\n", | |
" self.children = {}\n", | |
" \n", | |
" def add_child(self, field_value, child_node):\n", | |
" self.children[field_value] = child_node\n", | |
" \n", | |
" \n", | |
"class LeaveNode(object):\n", | |
" \n", | |
" def __init__(self, value):\n", | |
" self.value = value\n", | |
" \n", | |
" \n", | |
"def all_same(columns):\n", | |
" if len(columns) < 1:\n", | |
" return True\n", | |
" for value in columns:\n", | |
" if value != columns.values[0]:\n", | |
" return False\n", | |
" return True\n", | |
" \n", | |
" \n", | |
"def build_tree(df, selectable_fields, target_field):\n", | |
" if all_same(df[target_field]) or not selectable_fields:\n", | |
" # FIXME why can't filter into a same value?\n", | |
" best_value = df[target_field].value_counts().index[0]\n", | |
" return LeaveNode(best_value)\n", | |
" \n", | |
" field = select_best_field(df, selectable_fields, target_field)\n", | |
" node = ParentNode(field)\n", | |
" \n", | |
" sub_selectable_fields = [i for i in selectable_fields if i != field]\n", | |
" for field_value, sub_df in split_with_field(df, field).items():\n", | |
" child_node = build_tree(sub_df, sub_selectable_fields, target_field)\n", | |
" node.add_child(field_value, child_node)\n", | |
" \n", | |
" return node\n", | |
"\n", | |
"\n", | |
"def decide(root, query):\n", | |
" if isinstance(root, LeaveNode):\n", | |
" return root.value\n", | |
" field_value = query[root.field]\n", | |
" if field_value not in root.children:\n", | |
" return 'N/A'\n", | |
" child = root.children[query[root.field]]\n", | |
" return decide(child, query)\n", | |
"\n", | |
"\n", | |
"# 构造决策树\n", | |
"root = build_tree(id3_samples, basic_properties, 'gross_level')\n", | |
"\n", | |
"\n", | |
"# 计算预测值\n", | |
"id3_queries['predicted_gross_level'] = [decide(root, m) for i, m in id3_queries.iterrows()]\n", | |
"describe_algo(id3_queries)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## k 近邻\n", | |
"\n", | |
"在算法中, K = 5" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"错误率: 0.45366795366795365\n" | |
] | |
} | |
], | |
"source": [ | |
"#%%timeit -n 1\n", | |
"\n", | |
"K = 5\n", | |
"\n", | |
"# 数据集\n", | |
"knn_samples = get_sample_movies()\n", | |
"knn_queries = get_query_movies()\n", | |
"\n", | |
"\n", | |
"# 因为这里的属性都是类别属性,所以使用汉明距离\n", | |
"def distance(ma, mb):\n", | |
" dis = 0\n", | |
" for prop in basic_properties:\n", | |
" if ma[prop] != mb[prop]:\n", | |
" dis += 1\n", | |
" return dis\n", | |
"\n", | |
"\n", | |
"from collections import defaultdict\n", | |
"\n", | |
"\n", | |
"def knn(query, samples, k, target_property):\n", | |
" dis = distance\n", | |
" distances = [(dis(query, sample), sample[target_property])\n", | |
" for sample in samples]\n", | |
" distances = sorted(distances, key=lambda x: x[0])\n", | |
" \n", | |
" k = min(k, len(distances))\n", | |
" k_first = defaultdict(lambda: 0.0)\n", | |
" for _, value in distances[:k]:\n", | |
" k_first[value] += 1\n", | |
" prediction, max_count, chosen = {}, - 1 << 30, None,\n", | |
" for value, count in k_first.items():\n", | |
" if count > max_count:\n", | |
" max_count, chosen = count, value\n", | |
" prediction[value] = count / k\n", | |
" return chosen, prediction\n", | |
"\n", | |
"\n", | |
"knn_movies_samples = [m for _, m in knn_samples.iterrows()]\n", | |
"\n", | |
"\n", | |
"def knn_many(queries, k=None):\n", | |
" k = k or K\n", | |
" return [knn(m, knn_movies_samples, k, 'gross_level')[0] for _, m in queries.iterrows()]\n", | |
"\n", | |
"\n", | |
"knn_queries['predicted_gross_level'] = knn_many(knn_queries)\n", | |
"describe_algo(knn_queries)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 朴素贝叶斯" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"错误率: 0.39575289575289574\n" | |
] | |
} | |
], | |
"source": [ | |
"#%%timeit -n 1\n", | |
"\n", | |
"# 数据集\n", | |
"nb_samples = get_sample_movies()\n", | |
"nb_queries = get_query_movies()\n", | |
"\n", | |
"\n", | |
"def p_key(pairs):\n", | |
" return '_'.join('{0}-{1}'.format(k, v) for k, v in pairs)\n", | |
"\n", | |
"\n", | |
"# 先验概率\n", | |
"P = {}\n", | |
"\n", | |
"\n", | |
"gross_level_groups = nb_samples.groupby('gross_level').groups\n", | |
"\n", | |
"P[p_key([('gross_level', 'high')])] = len(gross_level_groups['high']) / len(nb_samples)\n", | |
"P[p_key([('gross_level', 'medium')])] = len(gross_level_groups['medium']) / len(nb_samples)\n", | |
"P[p_key([('gross_level', 'low')])] = len(gross_level_groups['low']) / len(nb_samples)\n", | |
"\n", | |
"\n", | |
"def calculate_properity(field):\n", | |
" total = len(nb_samples)\n", | |
" for key in nb_samples.groupby(field).groups.keys():\n", | |
" for level in gross_level:\n", | |
" pk = p_key([(field, key), ('gross_level', level)])\n", | |
" c = nb_samples[(nb_samples[field] == key) & (nb_samples['gross_level'] == level)]\n", | |
" P[pk] = len(c) / total\n", | |
" \n", | |
"\n", | |
"# 计算先验概率\n", | |
"for field in basic_properties:\n", | |
" calculate_properity(field)\n", | |
" \n", | |
" \n", | |
"def predict(m, priori_prob=None):\n", | |
" priori_prob = priori_prob or P\n", | |
" \n", | |
" prediction, max_level, max_prob = {}, None, 0.0\n", | |
" for level in gross_level:\n", | |
" p_level = 1\n", | |
" for field in basic_properties:\n", | |
" pk = p_key([(field, m[field]), ('gross_level', level)])\n", | |
" p_level *= priori_prob.get(pk, 0) # TODO use laplace?\n", | |
" prediction[level] = p_level\n", | |
" if p_level > max_prob:\n", | |
" max_level, max_prob = level, p_level\n", | |
" \n", | |
" return max_level, prediction\n", | |
"\n", | |
"\n", | |
"nb_queries['predicted_gross_level'] = [predict(m)[0] for _, m in nb_queries.iterrows()]\n", | |
"describe_algo(nb_queries)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 运行结果\n", | |
"\n", | |
"从上述运行结果可以了解到各个算法准确率、运行效率:\n", | |
"\n", | |
"\n", | |
"| 算法名称 | 问题规模(查询条数) | 准确率 | 运行时间 |\n", | |
"|:-------:|:-----:|:-------:|:-------:|\n", | |
"| id3 |518 | 61% | 1.63 秒 |\n", | |
"| k 近邻 | 518 | 56% | 112 秒 |\n", | |
"| 朴素贝叶斯 | 518 | 61% | 0.261 秒 |\n", | |
"\n", | |
"\n", | |
"不难发现 3 种算法中 k 近邻算法效率和准确率都是最低。而 id3 算法和朴素贝叶斯算法准确率相当,但朴素贝叶斯算法在执行效率上要比 id3 算法高出不少。\n", | |
"\n", | |
"- K 近邻算法执行效率低的原因主要是因为每次查询都需要遍历整个样本空间;而准确率方面,还可以通过修改 K 值来进行调优。\n", | |
"- id3 算法对样本集的计算(相对朴素贝叶斯)比较复杂,需要额外的数据结构来支持,但可以通过离线计算生成决策树来进一步减少查询时间。\n", | |
"- 朴素贝叶斯算法实现最简单,执行效率也高。但如果属性较多并且值分布不合理时将会导致预测结果偏差变大。\n", | |
"- 从 id3 算法的判定树中可以发现演职员水平对判定结果影响较大,这和日常理解一致。\n", | |
"- 数据集选用的属性较少,可以适当添加其他属性来增加维度,提高预测准确率。\n", | |
"\n", | |
"\n", | |
"\n", | |
"\n", | |
"下面是使用算法对最近上映的电影进行票房预测的结果:\n", | |
"\n", | |
"#### 影片信息:\n", | |
"\n", | |
"| 属性 | 属性值 |\n", | |
"|---------|--------------------|\n", | |
"| 影片名称 | 复仇者联盟 2:奥创时代 |\n", | |
"| 影片发行季度 | summer |\n", | |
"| 影片制作人员水平 | first-class |\n", | |
"| 影片类型 | action |\n", | |
"\n", | |
"#### 预测结果:\n", | |
"\n", | |
"| 算法 | 预测结果 |\n", | |
"|:----:|:------:|\n", | |
"| id3 | 高 |\n", | |
"| K 近邻 | 高 |\n", | |
"| 朴素贝叶斯 | 高 |" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 总结\n", | |
"\n", | |
"通过本实验,根据数据挖掘的基本原理,完成了数据收集、数据清洗、数据分析、算法实现和结果分析的完整步骤,加深了我对课程知识的理解。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'high'" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"decide(root, {'release_season': 'summer', 'genre': 'action', 'staffs':'first-class'})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'high'" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"knn({'release_season': 'summer', 'genre': 'action', 'staffs':'first-class'}, knn_movies_samples, K, 'gross_level')[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'high'" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"predict({'release_season': 'summer', 'genre': 'action', 'staffs':'first-class'})[0]" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.4.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment