Skip to content

Instantly share code, notes, and snippets.

@hyoban
Created October 19, 2021 07:52
Show Gist options
  • Save hyoban/fdbd75a521d27e09ed8e83870b522317 to your computer and use it in GitHub Desktop.
Save hyoban/fdbd75a521d27e09ed8e83870b522317 to your computer and use it in GitHub Desktop.
学生成绩分类
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "学生成绩分类",
"provenance": [],
"toc_visible": true,
"authorship_tag": "ABX9TyNDT9c9JZD+l0wV/bhTLPUK",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/hyoban/fdbd75a521d27e09ed8e83870b522317/.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TXQegZc8gBFr"
},
"source": [
"# 利用学生数据集预测学生成绩等级\n",
"\n",
"## 数据\n",
"\n",
"student.csv,关于部分字段的描述见决策树课件P41。数据的最后一列G3是标签。\n",
"\n",
"## 要求\n",
"\n",
"1. 使用决策树完成学生成绩等级预测,可选取部分或全部特征,分析参数对结果的影响,并进行调参优化,决策树可视化\n",
"2. 使用两种以上常用的集成学习方法完成学生成绩等级预测,分析参数对结果的影响,并进行调参优化\n",
"3. 对上述几种方法的结果进行比较分析,形成结论\n",
"\n",
"## 软件环境\n",
"\n",
"Jupyter Notebook\n",
"\n",
"## 提交形式\n",
"\n",
"实验报告电子版,包括文字说明和代码及结果截图,提交到多模式教学网\n",
"\n",
"文字描述上述过程并配过程及结果的截图,包括特征选取、数据处理、调优过程、训练得到的模型、评价结果、以及结论分析等等\n",
"\n",
"## 时间要求\n",
"\n",
"11月10日前"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KhfTob_jg_YF"
},
"source": [
"## 导入使用的依赖"
]
},
{
"cell_type": "code",
"metadata": {
"id": "nOTCYRmbMS31"
},
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"from sklearn.preprocessing import OneHotEncoder, LabelEncoder\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"\n",
"from IPython.display import Image\n",
"from sklearn import tree\n",
"import pydotplus\n",
"\n",
"from sklearn.metrics import accuracy_score"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gqg6286AhH7Y"
},
"source": [
"## 定义数据处理的函数\n",
"\n",
"将学生成绩按照范围分级"
]
},
{
"cell_type": "code",
"metadata": {
"id": "3DRQOu0HMp47"
},
"source": [
"def grade_to_level(x):\n",
"\tx = int(x)\n",
"\tif x < 5:\n",
"\t\treturn 'bad'\n",
"\telif x >= 5 and x < 10:\n",
"\t\treturn 'medium'\n",
"\telif x >= 10 and x < 15:\n",
"\t\treturn 'good'\n",
"\telse:\n",
"\t\treturn 'excellent'\n",
"\n",
"def pedu_to_level(x):\n",
"\tx = int(x)\n",
"\tif x > 3:\n",
"\t\treturn 'high'\n",
"\telif x > 1.5:\n",
"\t\treturn 'medium'\n",
"\telse:\n",
"\t\treturn 'low'"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "swm_HXWZhdHx"
},
"source": [
"## 从 csv 文件中读取数据"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 439
},
"id": "jY2Er6vRacTb",
"outputId": "6088dafc-3adf-47a7-bcde-6afd697a5f34"
},
"source": [
"stu_grade = pd.read_csv('student_1.csv')\n",
"stu_grade"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>school</th>\n",
" <th>sex</th>\n",
" <th>address</th>\n",
" <th>Pstatus</th>\n",
" <th>Pedu</th>\n",
" <th>reason</th>\n",
" <th>guardian</th>\n",
" <th>traveltime</th>\n",
" <th>studytime</th>\n",
" <th>schoolsup</th>\n",
" <th>famsup</th>\n",
" <th>paid</th>\n",
" <th>activities</th>\n",
" <th>nursery</th>\n",
" <th>higher</th>\n",
" <th>internet</th>\n",
" <th>romantic</th>\n",
" <th>famrel</th>\n",
" <th>freetime</th>\n",
" <th>goout</th>\n",
" <th>Dalc</th>\n",
" <th>Walc</th>\n",
" <th>health</th>\n",
" <th>absences</th>\n",
" <th>G1</th>\n",
" <th>G2</th>\n",
" <th>G3</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>GP</td>\n",
" <td>F</td>\n",
" <td>U</td>\n",
" <td>A</td>\n",
" <td>4.0</td>\n",
" <td>course</td>\n",
" <td>mother</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>6</td>\n",
" <td>5</td>\n",
" <td>6</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>GP</td>\n",
" <td>F</td>\n",
" <td>U</td>\n",
" <td>T</td>\n",
" <td>1.0</td>\n",
" <td>course</td>\n",
" <td>father</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>5</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>5</td>\n",
" <td>5</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>GP</td>\n",
" <td>F</td>\n",
" <td>U</td>\n",
" <td>T</td>\n",
" <td>1.0</td>\n",
" <td>other</td>\n",
" <td>mother</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>10</td>\n",
" <td>7</td>\n",
" <td>8</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>GP</td>\n",
" <td>F</td>\n",
" <td>U</td>\n",
" <td>T</td>\n",
" <td>3.0</td>\n",
" <td>home</td>\n",
" <td>mother</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" <td>2</td>\n",
" <td>15</td>\n",
" <td>14</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>GP</td>\n",
" <td>F</td>\n",
" <td>U</td>\n",
" <td>T</td>\n",
" <td>3.0</td>\n",
" <td>home</td>\n",
" <td>father</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>4</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>5</td>\n",
" <td>4</td>\n",
" <td>6</td>\n",
" <td>10</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>390</th>\n",
" <td>MS</td>\n",
" <td>M</td>\n",
" <td>U</td>\n",
" <td>A</td>\n",
" <td>2.0</td>\n",
" <td>course</td>\n",
" <td>other</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>5</td>\n",
" <td>5</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>5</td>\n",
" <td>4</td>\n",
" <td>11</td>\n",
" <td>9</td>\n",
" <td>9</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>391</th>\n",
" <td>MS</td>\n",
" <td>M</td>\n",
" <td>U</td>\n",
" <td>T</td>\n",
" <td>2.0</td>\n",
" <td>course</td>\n",
" <td>mother</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" <td>5</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>14</td>\n",
" <td>16</td>\n",
" <td>16</td>\n",
" </tr>\n",
" <tr>\n",
" <th>392</th>\n",
" <td>MS</td>\n",
" <td>M</td>\n",
" <td>R</td>\n",
" <td>T</td>\n",
" <td>1.0</td>\n",
" <td>course</td>\n",
" <td>other</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>5</td>\n",
" <td>5</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>10</td>\n",
" <td>8</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <th>393</th>\n",
" <td>MS</td>\n",
" <td>M</td>\n",
" <td>R</td>\n",
" <td>T</td>\n",
" <td>2.5</td>\n",
" <td>course</td>\n",
" <td>mother</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>11</td>\n",
" <td>12</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>394</th>\n",
" <td>MS</td>\n",
" <td>M</td>\n",
" <td>U</td>\n",
" <td>T</td>\n",
" <td>1.0</td>\n",
" <td>course</td>\n",
" <td>father</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" <td>5</td>\n",
" <td>8</td>\n",
" <td>9</td>\n",
" <td>9</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>395 rows × 27 columns</p>\n",
"</div>"
],
"text/plain": [
" school sex address Pstatus Pedu ... health absences G1 G2 G3\n",
"0 GP F U A 4.0 ... 3 6 5 6 6\n",
"1 GP F U T 1.0 ... 3 4 5 5 6\n",
"2 GP F U T 1.0 ... 3 10 7 8 10\n",
"3 GP F U T 3.0 ... 5 2 15 14 15\n",
"4 GP F U T 3.0 ... 5 4 6 10 10\n",
".. ... .. ... ... ... ... ... ... .. .. ..\n",
"390 MS M U A 2.0 ... 4 11 9 9 9\n",
"391 MS M U T 2.0 ... 2 3 14 16 16\n",
"392 MS M R T 1.0 ... 3 3 10 8 7\n",
"393 MS M R T 2.5 ... 5 0 11 12 10\n",
"394 MS M U T 1.0 ... 5 5 8 9 9\n",
"\n",
"[395 rows x 27 columns]"
]
},
"metadata": {},
"execution_count": 73
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FjXV1lMohi9e"
},
"source": [
"## 对非文本数据进行编码"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"id": "SakyYzSMMsor",
"outputId": "3dd3918a-469d-469b-b802-695f5015bce6"
},
"source": [
"# 确定选取的特征\n",
"new_data = stu_grade.iloc[:, [2, 3, 4, 8, 9, 10, 11, 14, 15, 24, 25, 26]]\n",
"\n",
"# 对成绩划分范围\n",
"stu_data = new_data.copy()\n",
"stu_data['G1'] = pd.Series(map(lambda x: grade_to_level(x), stu_data['G1']))\n",
"stu_data['G2'] = pd.Series(map(lambda x: grade_to_level(x), stu_data['G2']))\n",
"stu_data['G3'] = pd.Series(map(lambda x: grade_to_level(x), stu_data['G3']))\n",
"\n",
"# 选出需要进行数值编码的特征\n",
"str_columns = stu_data.dtypes[stu_data.dtypes == 'object'].index\n",
"\n",
"# 数值编码\n",
"for col in str_columns:\n",
" stu_data[col] = LabelEncoder().fit_transform(stu_data[col])\n",
"\n",
"# one hot 编码\n",
"stu_data = pd.get_dummies(stu_data, columns=str_columns.drop(['G1', 'G2', 'G3']))\n",
"\n",
"stu_data.head()"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pedu</th>\n",
" <th>studytime</th>\n",
" <th>G1</th>\n",
" <th>G2</th>\n",
" <th>G3</th>\n",
" <th>address_0</th>\n",
" <th>address_1</th>\n",
" <th>Pstatus_0</th>\n",
" <th>Pstatus_1</th>\n",
" <th>schoolsup_0</th>\n",
" <th>schoolsup_1</th>\n",
" <th>famsup_0</th>\n",
" <th>famsup_1</th>\n",
" <th>paid_0</th>\n",
" <th>paid_1</th>\n",
" <th>higher_0</th>\n",
" <th>higher_1</th>\n",
" <th>internet_0</th>\n",
" <th>internet_1</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>4.0</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.0</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.0</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3.0</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3.0</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pedu studytime G1 G2 ... higher_0 higher_1 internet_0 internet_1\n",
"0 4.0 2 3 3 ... 0 1 1 0\n",
"1 1.0 2 3 3 ... 0 1 0 1\n",
"2 1.0 2 3 3 ... 0 1 0 1\n",
"3 3.0 3 1 2 ... 0 1 0 1\n",
"4 3.0 2 3 2 ... 0 1 1 0\n",
"\n",
"[5 rows x 19 columns]"
]
},
"metadata": {},
"execution_count": 74
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NrqpDIFuh2bp"
},
"source": [
"## 训练模型\n",
"\n",
"限制决策树的最大深度来获取更好的结果\n",
"\n",
"## 决策树可视化"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 436
},
"id": "PGlg7OzxM1ZU",
"outputId": "3e16f772-2a23-4362-8ad0-cda82fbd6caa"
},
"source": [
"X_train, X_test, y_train, y_test = train_test_split(stu_data.iloc[:, :-1], stu_data['G3'], test_size=0.3, random_state=5)\n",
"\n",
"dt_model = DecisionTreeClassifier(random_state=5, max_depth=2)\n",
"dt_model.fit(X_train, y_train)\n",
"\n",
"dot_data = tree.export_graphviz(dt_model, \n",
" out_file=None, \n",
" feature_names=X_train.columns.values,\n",
" class_names=['0', '1', '2', '3'],\n",
" filled=True,\n",
" rounded=True,\n",
" special_characters=True\n",
" )\n",
"graph = pydotplus.graph_from_dot_data(dot_data)\n",
"Image(graph.create_png())"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"image/png": "\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {},
"execution_count": 75
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JEZGmQ5ziFBO"
},
"source": [
"## 判断模型得分"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c3LTmTItN3DH",
"outputId": "1c8363a2-9c23-42aa-dc7c-7b3d1e2b9e0e"
},
"source": [
"y_pred = dt_model.predict(X_test)\n",
"accuracy_score(y_test, y_pred)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.8823529411764706"
]
},
"metadata": {},
"execution_count": 76
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DwrDSDeUiJtj"
},
"source": [
"## 选取最佳参数"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "EQYGwHELOeNG",
"outputId": "f8768380-1e85-4bf3-d571-840bfdcd5608"
},
"source": [
"entropy_thresholds = np.linspace(0, 1, 50)\n",
"gini_thresholds = np.linspace(0, 0.5, 50)\n",
"\n",
"param_grid = [{'criterion': ['entropy'],\n",
" 'min_impurity_decrease': entropy_thresholds},\n",
" {'criterion': ['gini'],\n",
" 'min_impurity_decrease': gini_thresholds},\n",
" {'max_depth': range(2, 10)},\n",
" {'min_samples_split': range(2, 30, 3)}]\n",
"\n",
"clf = GridSearchCV(DecisionTreeClassifier(), param_grid, cv=5, return_train_score=True)\n",
"clf.fit(X_train, y_train)\n",
"clf.best_params_, clf.best_score_"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"({'criterion': 'entropy', 'min_impurity_decrease': 0.04081632653061224},\n",
" 0.854935064935065)"
]
},
"metadata": {},
"execution_count": 30
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment