Skip to content

Instantly share code, notes, and snippets.

@darthgera123
Created March 27, 2020 11:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save darthgera123/e288b7d86e4a2a0e1946f5a1dc2020df to your computer and use it in GitHub Desktop.
Save darthgera123/e288b7d86e4a2a0e1946f5a1dc2020df to your computer and use it in GitHub Desktop.
Baseline for SPCRT
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Baseline Submission for the Challenge SPCRT"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split \n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn import metrics"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"train_data = pd.read_csv('aicrowd_educational_spcrt/data/public/train.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Clean and analyse the data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style>\n",
" .dataframe thead tr:only-child th {\n",
" text-align: right;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: left;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>number_of_elements</th>\n",
" <th>mean_atomic_mass</th>\n",
" <th>wtd_mean_atomic_mass</th>\n",
" <th>gmean_atomic_mass</th>\n",
" <th>wtd_gmean_atomic_mass</th>\n",
" <th>entropy_atomic_mass</th>\n",
" <th>wtd_entropy_atomic_mass</th>\n",
" <th>range_atomic_mass</th>\n",
" <th>wtd_range_atomic_mass</th>\n",
" <th>std_atomic_mass</th>\n",
" <th>...</th>\n",
" <th>wtd_mean_Valence</th>\n",
" <th>gmean_Valence</th>\n",
" <th>wtd_gmean_Valence</th>\n",
" <th>entropy_Valence</th>\n",
" <th>wtd_entropy_Valence</th>\n",
" <th>range_Valence</th>\n",
" <th>wtd_range_Valence</th>\n",
" <th>std_Valence</th>\n",
" <th>wtd_std_Valence</th>\n",
" <th>critical_temp</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>3</td>\n",
" <td>86.299100</td>\n",
" <td>65.789610</td>\n",
" <td>64.984139</td>\n",
" <td>49.765400</td>\n",
" <td>0.836621</td>\n",
" <td>1.013759</td>\n",
" <td>146.88130</td>\n",
" <td>20.950610</td>\n",
" <td>63.713516</td>\n",
" <td>...</td>\n",
" <td>3.500000</td>\n",
" <td>3.301927</td>\n",
" <td>3.464102</td>\n",
" <td>1.088900</td>\n",
" <td>0.971342</td>\n",
" <td>1</td>\n",
" <td>1.400000</td>\n",
" <td>0.471405</td>\n",
" <td>0.500000</td>\n",
" <td>4.50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" <td>72.952854</td>\n",
" <td>56.414763</td>\n",
" <td>59.186241</td>\n",
" <td>35.639703</td>\n",
" <td>1.445795</td>\n",
" <td>1.041520</td>\n",
" <td>122.90607</td>\n",
" <td>35.383159</td>\n",
" <td>40.250192</td>\n",
" <td>...</td>\n",
" <td>2.257143</td>\n",
" <td>2.168944</td>\n",
" <td>2.219783</td>\n",
" <td>1.594167</td>\n",
" <td>1.087480</td>\n",
" <td>1</td>\n",
" <td>1.131429</td>\n",
" <td>0.400000</td>\n",
" <td>0.437059</td>\n",
" <td>7.60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>6</td>\n",
" <td>82.318112</td>\n",
" <td>99.033554</td>\n",
" <td>53.069787</td>\n",
" <td>71.259834</td>\n",
" <td>1.427749</td>\n",
" <td>1.324091</td>\n",
" <td>192.98100</td>\n",
" <td>40.196140</td>\n",
" <td>70.933858</td>\n",
" <td>...</td>\n",
" <td>4.300000</td>\n",
" <td>3.203101</td>\n",
" <td>3.772087</td>\n",
" <td>1.647214</td>\n",
" <td>1.510613</td>\n",
" <td>5</td>\n",
" <td>1.580000</td>\n",
" <td>1.950783</td>\n",
" <td>1.791647</td>\n",
" <td>3.01</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>57.444449</td>\n",
" <td>60.476650</td>\n",
" <td>56.067907</td>\n",
" <td>58.936797</td>\n",
" <td>1.362775</td>\n",
" <td>1.128041</td>\n",
" <td>34.84360</td>\n",
" <td>27.021980</td>\n",
" <td>12.367487</td>\n",
" <td>...</td>\n",
" <td>3.650000</td>\n",
" <td>3.309751</td>\n",
" <td>3.442623</td>\n",
" <td>1.333736</td>\n",
" <td>1.089489</td>\n",
" <td>3</td>\n",
" <td>1.800000</td>\n",
" <td>1.118034</td>\n",
" <td>1.194780</td>\n",
" <td>14.10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>4</td>\n",
" <td>76.517718</td>\n",
" <td>56.808817</td>\n",
" <td>59.310096</td>\n",
" <td>35.773432</td>\n",
" <td>1.197273</td>\n",
" <td>0.981880</td>\n",
" <td>122.90607</td>\n",
" <td>34.833160</td>\n",
" <td>44.289459</td>\n",
" <td>...</td>\n",
" <td>2.264286</td>\n",
" <td>2.213364</td>\n",
" <td>2.226222</td>\n",
" <td>1.368922</td>\n",
" <td>1.048834</td>\n",
" <td>1</td>\n",
" <td>1.100000</td>\n",
" <td>0.433013</td>\n",
" <td>0.440952</td>\n",
" <td>36.80</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 82 columns</p>\n",
"</div>"
],
"text/plain": [
" number_of_elements mean_atomic_mass wtd_mean_atomic_mass \\\n",
"0 3 86.299100 65.789610 \n",
"1 5 72.952854 56.414763 \n",
"2 6 82.318112 99.033554 \n",
"3 4 57.444449 60.476650 \n",
"4 4 76.517718 56.808817 \n",
"\n",
" gmean_atomic_mass wtd_gmean_atomic_mass entropy_atomic_mass \\\n",
"0 64.984139 49.765400 0.836621 \n",
"1 59.186241 35.639703 1.445795 \n",
"2 53.069787 71.259834 1.427749 \n",
"3 56.067907 58.936797 1.362775 \n",
"4 59.310096 35.773432 1.197273 \n",
"\n",
" wtd_entropy_atomic_mass range_atomic_mass wtd_range_atomic_mass \\\n",
"0 1.013759 146.88130 20.950610 \n",
"1 1.041520 122.90607 35.383159 \n",
"2 1.324091 192.98100 40.196140 \n",
"3 1.128041 34.84360 27.021980 \n",
"4 0.981880 122.90607 34.833160 \n",
"\n",
" std_atomic_mass ... wtd_mean_Valence gmean_Valence \\\n",
"0 63.713516 ... 3.500000 3.301927 \n",
"1 40.250192 ... 2.257143 2.168944 \n",
"2 70.933858 ... 4.300000 3.203101 \n",
"3 12.367487 ... 3.650000 3.309751 \n",
"4 44.289459 ... 2.264286 2.213364 \n",
"\n",
" wtd_gmean_Valence entropy_Valence wtd_entropy_Valence range_Valence \\\n",
"0 3.464102 1.088900 0.971342 1 \n",
"1 2.219783 1.594167 1.087480 1 \n",
"2 3.772087 1.647214 1.510613 5 \n",
"3 3.442623 1.333736 1.089489 3 \n",
"4 2.226222 1.368922 1.048834 1 \n",
"\n",
" wtd_range_Valence std_Valence wtd_std_Valence critical_temp \n",
"0 1.400000 0.471405 0.500000 4.50 \n",
"1 1.131429 0.400000 0.437059 7.60 \n",
"2 1.580000 1.950783 1.791647 3.01 \n",
"3 1.800000 1.118034 1.194780 14.10 \n",
"4 1.100000 0.433013 0.440952 36.80 \n",
"\n",
"[5 rows x 82 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data.head()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style>\n",
" .dataframe thead tr:only-child th {\n",
" text-align: right;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: left;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>number_of_elements</th>\n",
" <th>mean_atomic_mass</th>\n",
" <th>wtd_mean_atomic_mass</th>\n",
" <th>gmean_atomic_mass</th>\n",
" <th>wtd_gmean_atomic_mass</th>\n",
" <th>entropy_atomic_mass</th>\n",
" <th>wtd_entropy_atomic_mass</th>\n",
" <th>range_atomic_mass</th>\n",
" <th>wtd_range_atomic_mass</th>\n",
" <th>std_atomic_mass</th>\n",
" <th>...</th>\n",
" <th>wtd_mean_Valence</th>\n",
" <th>gmean_Valence</th>\n",
" <th>wtd_gmean_Valence</th>\n",
" <th>entropy_Valence</th>\n",
" <th>wtd_entropy_Valence</th>\n",
" <th>range_Valence</th>\n",
" <th>wtd_range_Valence</th>\n",
" <th>std_Valence</th>\n",
" <th>wtd_std_Valence</th>\n",
" <th>critical_temp</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>...</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" <td>18073.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>4.116527</td>\n",
" <td>87.495853</td>\n",
" <td>72.915281</td>\n",
" <td>71.193951</td>\n",
" <td>58.444208</td>\n",
" <td>1.165612</td>\n",
" <td>1.064409</td>\n",
" <td>115.732133</td>\n",
" <td>33.213727</td>\n",
" <td>44.442844</td>\n",
" <td>...</td>\n",
" <td>3.152312</td>\n",
" <td>3.056546</td>\n",
" <td>3.054714</td>\n",
" <td>1.296028</td>\n",
" <td>1.054028</td>\n",
" <td>2.044708</td>\n",
" <td>1.481685</td>\n",
" <td>0.841078</td>\n",
" <td>0.676041</td>\n",
" <td>34.492796</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>1.439625</td>\n",
" <td>29.586564</td>\n",
" <td>33.320437</td>\n",
" <td>30.920472</td>\n",
" <td>36.470563</td>\n",
" <td>0.365019</td>\n",
" <td>0.401233</td>\n",
" <td>54.718595</td>\n",
" <td>26.886071</td>\n",
" <td>20.068666</td>\n",
" <td>...</td>\n",
" <td>1.189356</td>\n",
" <td>1.043451</td>\n",
" <td>1.172383</td>\n",
" <td>0.392761</td>\n",
" <td>0.380274</td>\n",
" <td>1.242861</td>\n",
" <td>0.976455</td>\n",
" <td>0.485247</td>\n",
" <td>0.455984</td>\n",
" <td>34.307997</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>1.000000</td>\n",
" <td>6.941000</td>\n",
" <td>6.941000</td>\n",
" <td>5.685033</td>\n",
" <td>3.193745</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>...</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000210</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>3.000000</td>\n",
" <td>72.451240</td>\n",
" <td>52.177725</td>\n",
" <td>58.001648</td>\n",
" <td>35.258590</td>\n",
" <td>0.969858</td>\n",
" <td>0.777619</td>\n",
" <td>78.353150</td>\n",
" <td>16.830450</td>\n",
" <td>32.890369</td>\n",
" <td>...</td>\n",
" <td>2.118056</td>\n",
" <td>2.279705</td>\n",
" <td>2.092115</td>\n",
" <td>1.060857</td>\n",
" <td>0.778998</td>\n",
" <td>1.000000</td>\n",
" <td>0.920286</td>\n",
" <td>0.471405</td>\n",
" <td>0.308515</td>\n",
" <td>5.400000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>4.000000</td>\n",
" <td>84.841880</td>\n",
" <td>60.786693</td>\n",
" <td>66.361592</td>\n",
" <td>39.898482</td>\n",
" <td>1.199541</td>\n",
" <td>1.146366</td>\n",
" <td>122.906070</td>\n",
" <td>26.658401</td>\n",
" <td>45.129500</td>\n",
" <td>...</td>\n",
" <td>2.618182</td>\n",
" <td>2.615321</td>\n",
" <td>2.433589</td>\n",
" <td>1.368922</td>\n",
" <td>1.165410</td>\n",
" <td>2.000000</td>\n",
" <td>1.062667</td>\n",
" <td>0.800000</td>\n",
" <td>0.500000</td>\n",
" <td>20.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>5.000000</td>\n",
" <td>100.351275</td>\n",
" <td>85.994130</td>\n",
" <td>78.019689</td>\n",
" <td>73.097796</td>\n",
" <td>1.444537</td>\n",
" <td>1.360442</td>\n",
" <td>155.006000</td>\n",
" <td>38.360375</td>\n",
" <td>59.663892</td>\n",
" <td>...</td>\n",
" <td>4.030000</td>\n",
" <td>3.741657</td>\n",
" <td>3.920517</td>\n",
" <td>1.589027</td>\n",
" <td>1.331926</td>\n",
" <td>3.000000</td>\n",
" <td>1.920000</td>\n",
" <td>1.200000</td>\n",
" <td>1.021023</td>\n",
" <td>63.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>9.000000</td>\n",
" <td>208.980400</td>\n",
" <td>208.980400</td>\n",
" <td>208.980400</td>\n",
" <td>208.980400</td>\n",
" <td>1.983797</td>\n",
" <td>1.958203</td>\n",
" <td>207.972460</td>\n",
" <td>205.589910</td>\n",
" <td>101.019700</td>\n",
" <td>...</td>\n",
" <td>7.000000</td>\n",
" <td>7.000000</td>\n",
" <td>7.000000</td>\n",
" <td>2.141963</td>\n",
" <td>1.949739</td>\n",
" <td>6.000000</td>\n",
" <td>6.992200</td>\n",
" <td>3.000000</td>\n",
" <td>3.000000</td>\n",
" <td>185.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>8 rows × 82 columns</p>\n",
"</div>"
],
"text/plain": [
" number_of_elements mean_atomic_mass wtd_mean_atomic_mass \\\n",
"count 18073.000000 18073.000000 18073.000000 \n",
"mean 4.116527 87.495853 72.915281 \n",
"std 1.439625 29.586564 33.320437 \n",
"min 1.000000 6.941000 6.941000 \n",
"25% 3.000000 72.451240 52.177725 \n",
"50% 4.000000 84.841880 60.786693 \n",
"75% 5.000000 100.351275 85.994130 \n",
"max 9.000000 208.980400 208.980400 \n",
"\n",
" gmean_atomic_mass wtd_gmean_atomic_mass entropy_atomic_mass \\\n",
"count 18073.000000 18073.000000 18073.000000 \n",
"mean 71.193951 58.444208 1.165612 \n",
"std 30.920472 36.470563 0.365019 \n",
"min 5.685033 3.193745 0.000000 \n",
"25% 58.001648 35.258590 0.969858 \n",
"50% 66.361592 39.898482 1.199541 \n",
"75% 78.019689 73.097796 1.444537 \n",
"max 208.980400 208.980400 1.983797 \n",
"\n",
" wtd_entropy_atomic_mass range_atomic_mass wtd_range_atomic_mass \\\n",
"count 18073.000000 18073.000000 18073.000000 \n",
"mean 1.064409 115.732133 33.213727 \n",
"std 0.401233 54.718595 26.886071 \n",
"min 0.000000 0.000000 0.000000 \n",
"25% 0.777619 78.353150 16.830450 \n",
"50% 1.146366 122.906070 26.658401 \n",
"75% 1.360442 155.006000 38.360375 \n",
"max 1.958203 207.972460 205.589910 \n",
"\n",
" std_atomic_mass ... wtd_mean_Valence gmean_Valence \\\n",
"count 18073.000000 ... 18073.000000 18073.000000 \n",
"mean 44.442844 ... 3.152312 3.056546 \n",
"std 20.068666 ... 1.189356 1.043451 \n",
"min 0.000000 ... 1.000000 1.000000 \n",
"25% 32.890369 ... 2.118056 2.279705 \n",
"50% 45.129500 ... 2.618182 2.615321 \n",
"75% 59.663892 ... 4.030000 3.741657 \n",
"max 101.019700 ... 7.000000 7.000000 \n",
"\n",
" wtd_gmean_Valence entropy_Valence wtd_entropy_Valence range_Valence \\\n",
"count 18073.000000 18073.000000 18073.000000 18073.000000 \n",
"mean 3.054714 1.296028 1.054028 2.044708 \n",
"std 1.172383 0.392761 0.380274 1.242861 \n",
"min 1.000000 0.000000 0.000000 0.000000 \n",
"25% 2.092115 1.060857 0.778998 1.000000 \n",
"50% 2.433589 1.368922 1.165410 2.000000 \n",
"75% 3.920517 1.589027 1.331926 3.000000 \n",
"max 7.000000 2.141963 1.949739 6.000000 \n",
"\n",
" wtd_range_Valence std_Valence wtd_std_Valence critical_temp \n",
"count 18073.000000 18073.000000 18073.000000 18073.000000 \n",
"mean 1.481685 0.841078 0.676041 34.492796 \n",
"std 0.976455 0.485247 0.455984 34.307997 \n",
"min 0.000000 0.000000 0.000000 0.000210 \n",
"25% 0.920286 0.471405 0.308515 5.400000 \n",
"50% 1.062667 0.800000 0.500000 20.000000 \n",
"75% 1.920000 1.200000 1.021023 63.000000 \n",
"max 6.992200 3.000000 3.000000 185.000000 \n",
"\n",
"[8 rows x 82 columns]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data.describe()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Split Data for Train and Validation"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"X = train_data.drop('critical_temp',1)\n",
"y = train_data['critical_temp']\n",
"# Validation testing\n",
"X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define the Classifier and Train"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None, normalize=False)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"regressor = LinearRegression() \n",
"regressor.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Check which variables have the most impact"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style>\n",
" .dataframe thead tr:only-child th {\n",
" text-align: right;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: left;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Coefficient</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>number_of_elements</th>\n",
" <td>-4.202422</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean_atomic_mass</th>\n",
" <td>0.833105</td>\n",
" </tr>\n",
" <tr>\n",
" <th>wtd_mean_atomic_mass</th>\n",
" <td>-0.881193</td>\n",
" </tr>\n",
" <tr>\n",
" <th>gmean_atomic_mass</th>\n",
" <td>-0.510610</td>\n",
" </tr>\n",
" <tr>\n",
" <th>wtd_gmean_atomic_mass</th>\n",
" <td>0.642180</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Coefficient\n",
"number_of_elements -4.202422\n",
"mean_atomic_mass 0.833105\n",
"wtd_mean_atomic_mass -0.881193\n",
"gmean_atomic_mass -0.510610\n",
"wtd_gmean_atomic_mass 0.642180"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"coeff_df = pd.DataFrame(regressor.coef_, X.columns, columns=['Coefficient']) \n",
"coeff_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Predict on validation"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"y_pred = regressor.predict(X_val)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame({'Actual': y_val, 'Predicted': y_pred})\n",
"df1 = df.head(25)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate the Performance"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean Absolute Error: 13.42086725495139\n",
"Mean Squared Error: 323.28465055058496\n",
"Root Mean Squared Error: 17.98011820179681\n"
]
}
],
"source": [
"print('Mean Absolute Error:', metrics.mean_absolute_error(y_val, y_pred)) \n",
"print('Mean Squared Error:', metrics.mean_squared_error(y_val, y_pred)) \n",
"print('Root Mean Squared Error:', np.sqrt(metrics.mean_squared_error(y_val, y_pred)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Test Set"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"test_data = pd.read_csv('aicrowd_educational_spcrt/data/public/test.csv')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style>\n",
" .dataframe thead tr:only-child th {\n",
" text-align: right;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: left;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>number_of_elements</th>\n",
" <th>mean_atomic_mass</th>\n",
" <th>wtd_mean_atomic_mass</th>\n",
" <th>gmean_atomic_mass</th>\n",
" <th>wtd_gmean_atomic_mass</th>\n",
" <th>entropy_atomic_mass</th>\n",
" <th>wtd_entropy_atomic_mass</th>\n",
" <th>range_atomic_mass</th>\n",
" <th>wtd_range_atomic_mass</th>\n",
" <th>std_atomic_mass</th>\n",
" <th>...</th>\n",
" <th>mean_Valence</th>\n",
" <th>wtd_mean_Valence</th>\n",
" <th>gmean_Valence</th>\n",
" <th>wtd_gmean_Valence</th>\n",
" <th>entropy_Valence</th>\n",
" <th>wtd_entropy_Valence</th>\n",
" <th>range_Valence</th>\n",
" <th>wtd_range_Valence</th>\n",
" <th>std_Valence</th>\n",
" <th>wtd_std_Valence</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2</td>\n",
" <td>82.768190</td>\n",
" <td>87.837285</td>\n",
" <td>82.144935</td>\n",
" <td>87.360109</td>\n",
" <td>0.685627</td>\n",
" <td>0.509575</td>\n",
" <td>20.27638</td>\n",
" <td>51.522285</td>\n",
" <td>10.138190</td>\n",
" <td>...</td>\n",
" <td>4.50</td>\n",
" <td>4.750000</td>\n",
" <td>4.472136</td>\n",
" <td>4.728708</td>\n",
" <td>0.686962</td>\n",
" <td>0.514653</td>\n",
" <td>1</td>\n",
" <td>2.750000</td>\n",
" <td>0.500000</td>\n",
" <td>0.433013</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4</td>\n",
" <td>76.444563</td>\n",
" <td>81.456750</td>\n",
" <td>59.356672</td>\n",
" <td>68.229617</td>\n",
" <td>1.199541</td>\n",
" <td>1.108189</td>\n",
" <td>121.32760</td>\n",
" <td>36.950657</td>\n",
" <td>43.823354</td>\n",
" <td>...</td>\n",
" <td>2.25</td>\n",
" <td>2.142857</td>\n",
" <td>2.213364</td>\n",
" <td>2.119268</td>\n",
" <td>1.368922</td>\n",
" <td>1.309526</td>\n",
" <td>1</td>\n",
" <td>0.571429</td>\n",
" <td>0.433013</td>\n",
" <td>0.349927</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>5</td>\n",
" <td>88.936744</td>\n",
" <td>51.090431</td>\n",
" <td>70.358975</td>\n",
" <td>34.783991</td>\n",
" <td>1.445824</td>\n",
" <td>1.525092</td>\n",
" <td>122.90607</td>\n",
" <td>10.438667</td>\n",
" <td>46.482335</td>\n",
" <td>...</td>\n",
" <td>2.40</td>\n",
" <td>2.114679</td>\n",
" <td>2.352158</td>\n",
" <td>2.095193</td>\n",
" <td>1.589027</td>\n",
" <td>1.314189</td>\n",
" <td>1</td>\n",
" <td>0.967890</td>\n",
" <td>0.489898</td>\n",
" <td>0.318634</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>76.517718</td>\n",
" <td>56.149432</td>\n",
" <td>59.310096</td>\n",
" <td>35.562124</td>\n",
" <td>1.197273</td>\n",
" <td>1.042132</td>\n",
" <td>122.90607</td>\n",
" <td>31.920690</td>\n",
" <td>44.289459</td>\n",
" <td>...</td>\n",
" <td>2.25</td>\n",
" <td>2.251429</td>\n",
" <td>2.213364</td>\n",
" <td>2.214646</td>\n",
" <td>1.368922</td>\n",
" <td>1.078855</td>\n",
" <td>1</td>\n",
" <td>1.074286</td>\n",
" <td>0.433013</td>\n",
" <td>0.433834</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3</td>\n",
" <td>104.608490</td>\n",
" <td>89.558979</td>\n",
" <td>101.719818</td>\n",
" <td>88.481210</td>\n",
" <td>1.070258</td>\n",
" <td>0.944284</td>\n",
" <td>59.94547</td>\n",
" <td>33.541423</td>\n",
" <td>25.225148</td>\n",
" <td>...</td>\n",
" <td>5.00</td>\n",
" <td>5.811245</td>\n",
" <td>4.762203</td>\n",
" <td>5.743954</td>\n",
" <td>1.054920</td>\n",
" <td>0.803990</td>\n",
" <td>3</td>\n",
" <td>3.024096</td>\n",
" <td>1.414214</td>\n",
" <td>0.728448</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 81 columns</p>\n",
"</div>"
],
"text/plain": [
" number_of_elements mean_atomic_mass wtd_mean_atomic_mass \\\n",
"0 2 82.768190 87.837285 \n",
"1 4 76.444563 81.456750 \n",
"2 5 88.936744 51.090431 \n",
"3 4 76.517718 56.149432 \n",
"4 3 104.608490 89.558979 \n",
"\n",
" gmean_atomic_mass wtd_gmean_atomic_mass entropy_atomic_mass \\\n",
"0 82.144935 87.360109 0.685627 \n",
"1 59.356672 68.229617 1.199541 \n",
"2 70.358975 34.783991 1.445824 \n",
"3 59.310096 35.562124 1.197273 \n",
"4 101.719818 88.481210 1.070258 \n",
"\n",
" wtd_entropy_atomic_mass range_atomic_mass wtd_range_atomic_mass \\\n",
"0 0.509575 20.27638 51.522285 \n",
"1 1.108189 121.32760 36.950657 \n",
"2 1.525092 122.90607 10.438667 \n",
"3 1.042132 122.90607 31.920690 \n",
"4 0.944284 59.94547 33.541423 \n",
"\n",
" std_atomic_mass ... mean_Valence wtd_mean_Valence \\\n",
"0 10.138190 ... 4.50 4.750000 \n",
"1 43.823354 ... 2.25 2.142857 \n",
"2 46.482335 ... 2.40 2.114679 \n",
"3 44.289459 ... 2.25 2.251429 \n",
"4 25.225148 ... 5.00 5.811245 \n",
"\n",
" gmean_Valence wtd_gmean_Valence entropy_Valence wtd_entropy_Valence \\\n",
"0 4.472136 4.728708 0.686962 0.514653 \n",
"1 2.213364 2.119268 1.368922 1.309526 \n",
"2 2.352158 2.095193 1.589027 1.314189 \n",
"3 2.213364 2.214646 1.368922 1.078855 \n",
"4 4.762203 5.743954 1.054920 0.803990 \n",
"\n",
" range_Valence wtd_range_Valence std_Valence wtd_std_Valence \n",
"0 1 2.750000 0.500000 0.433013 \n",
"1 1 0.571429 0.433013 0.349927 \n",
"2 1 0.967890 0.489898 0.318634 \n",
"3 1 1.074286 0.433013 0.433834 \n",
"4 3 3.024096 1.414214 0.728448 \n",
"\n",
"[5 rows x 81 columns]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Predict on test set"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"y_test = regressor.predict(test_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save it in correct format"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame(y_test,columns=['critical_temp'])\n",
"df.to_csv('aicrowd_educational_spcrt/data/public/submission.csv',index=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To participate in the challenge click [here](https://www.aicrowd.com/challenges/spcrt-superconductor-critical-temperature)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment