Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Untitled9.ipynb
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Untitled9.ipynb",
"version": "0.3.2",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "TPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/parulnith/7f8c174e6ac099e86f0495d3d9a4c01e/untitled9.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"metadata": {
"id": "cNnM2w-HCeb1",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# Music genre classification notebook"
]
},
{
"metadata": {
"id": "2l3sppZMCydR",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Importing Libraries"
]
},
{
"metadata": {
"id": "Gt3fyg6dCNvX",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"# feature extractoring and preprocessing data\n",
"import librosa\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import os\n",
"from PIL import Image\n",
"import pathlib\n",
"import csv\n",
"\n",
"# Preprocessing\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder, StandardScaler\n",
"\n",
"#Keras\n",
"import keras\n",
"\n",
"import warnings\n",
"warnings.filterwarnings('ignore')"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "DPe_ebYuDqr5",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Extracting music and features\n",
"\n",
"### Dataset\n",
"\n",
"We use [GTZAN genre collection](http://marsyasweb.appspot.com/download/data_sets/) dataset for classification. \n",
"<br>\n",
"<br>\n",
"The dataset consists of 10 genres i.e\n",
" * Blues\n",
" * Classical\n",
" * Country\n",
" * Disco\n",
" * Hiphop\n",
" * Jazz\n",
" * Metal\n",
" * Pop\n",
" * Reggae\n",
" * Rock\n",
" \n",
"Each genre contains 100 songs. Total dataset: 1000 songs"
]
},
{
"metadata": {
"id": "neqMS0VoDpN5",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
""
]
},
{
"metadata": {
"id": "AfBSVfRCD3PE",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Extracting the Spectrogram for every Audio"
]
},
{
"metadata": {
"id": "BHh3pTEVDdrT",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"cmap = plt.get_cmap('inferno')\n",
"\n",
"plt.figure(figsize=(10,10))\n",
"genres = 'blues classical country disco hiphop jazz metal pop reggae rock'.split()\n",
"for g in genres:\n",
" pathlib.Path(f'img_data/{g}').mkdir(parents=True, exist_ok=True) \n",
" for filename in os.listdir(f'./MIR/genres/{g}'):\n",
" songname = f'./MIR/genres/{g}/{filename}'\n",
" y, sr = librosa.load(songname, mono=True, duration=5)\n",
" plt.specgram(y, NFFT=2048, Fs=2, Fc=0, noverlap=128, cmap=cmap, sides='default', mode='default', scale='dB');\n",
" plt.axis('off');\n",
" plt.savefig(f'img_data/{g}/{filename[:-3].replace(\".\", \"\")}.png')\n",
" plt.clf()\n",
" "
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "SszVgjYnFNX9",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"All the audio files get converted into their respective spectrograms .WE can noe easily extract features from them."
]
},
{
"metadata": {
"id": "3Nw9HpSdFRsW",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
""
]
},
{
"metadata": {
"id": "piwUwgP5Eef9",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Extracting features from Spectrogram\n",
"\n",
"\n",
"We will extract\n",
"\n",
"* Mel-frequency cepstral coefficients (MFCC)(20 in number)\n",
"* Spectral Centroid,\n",
"* Zero Crossing Rate\n",
"* Chroma Frequencies\n",
"* Spectral Roll-off."
]
},
{
"metadata": {
"id": "__g8tX8pDeIL",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"header = 'filename chroma_stft rmse spectral_centroid spectral_bandwidth rolloff zero_crossing_rate'\n",
"for i in range(1, 21):\n",
" header += f' mfcc{i}'\n",
"header += ' label'\n",
"header = header.split()"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "TBlT448pEqR9",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Writing data to csv file\n",
"\n",
"We write the data to a csv file "
]
},
{
"metadata": {
"id": "ZsSQmB0PE3Iu",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"file = open('data.csv', 'w', newline='')\n",
"with file:\n",
" writer = csv.writer(file)\n",
" writer.writerow(header)\n",
"genres = 'blues classical country disco hiphop jazz metal pop reggae rock'.split()\n",
"for g in genres:\n",
" for filename in os.listdir(f'./MIR/genres/{g}'):\n",
" songname = f'./MIR/genres/{g}/{filename}'\n",
" y, sr = librosa.load(songname, mono=True, duration=30)\n",
" chroma_stft = librosa.feature.chroma_stft(y=y, sr=sr)\n",
" spec_cent = librosa.feature.spectral_centroid(y=y, sr=sr)\n",
" spec_bw = librosa.feature.spectral_bandwidth(y=y, sr=sr)\n",
" rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)\n",
" zcr = librosa.feature.zero_crossing_rate(y)\n",
" mfcc = librosa.feature.mfcc(y=y, sr=sr)\n",
" to_append = f'{filename} {np.mean(chroma_stft)} {np.mean(rmse)} {np.mean(spec_cent)} {np.mean(spec_bw)} {np.mean(rolloff)} {np.mean(zcr)}' \n",
" for e in mfcc:\n",
" to_append += f' {np.mean(e)}'\n",
" to_append += f' {g}'\n",
" file = open('data.csv', 'a', newline='')\n",
" with file:\n",
" writer = csv.writer(file)\n",
" writer.writerow(to_append.split())"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "0yfdo1cj6V7d",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"The data has been extracted into a [data.csv](https://github.com/parulnith/Music-Genre-Classification-with-Python/blob/master/data.csv) file."
]
},
{
"metadata": {
"id": "fgeCZSKQEp1A",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# Analysing the Data in Pandas"
]
},
{
"metadata": {
"id": "Kr5_EdpD9dyh",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 253
},
"outputId": "81fd4a29-93fa-44f8-bf90-2f99981f761a"
},
"cell_type": "code",
"source": [
"data = pd.read_csv('data.csv')\n",
"data.head()"
],
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>filename</th>\n",
" <th>chroma_stft</th>\n",
" <th>rmse</th>\n",
" <th>spectral_centroid</th>\n",
" <th>spectral_bandwidth</th>\n",
" <th>rolloff</th>\n",
" <th>zero_crossing_rate</th>\n",
" <th>mfcc1</th>\n",
" <th>mfcc2</th>\n",
" <th>mfcc3</th>\n",
" <th>...</th>\n",
" <th>mfcc12</th>\n",
" <th>mfcc13</th>\n",
" <th>mfcc14</th>\n",
" <th>mfcc15</th>\n",
" <th>mfcc16</th>\n",
" <th>mfcc17</th>\n",
" <th>mfcc18</th>\n",
" <th>mfcc19</th>\n",
" <th>mfcc20</th>\n",
" <th>label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>blues.00081.au</td>\n",
" <td>0.380260</td>\n",
" <td>0.248262</td>\n",
" <td>2116.942959</td>\n",
" <td>1956.611056</td>\n",
" <td>4196.107960</td>\n",
" <td>0.127272</td>\n",
" <td>-26.929785</td>\n",
" <td>107.334008</td>\n",
" <td>-46.809993</td>\n",
" <td>...</td>\n",
" <td>14.336612</td>\n",
" <td>-13.821769</td>\n",
" <td>7.562789</td>\n",
" <td>-6.181372</td>\n",
" <td>0.330165</td>\n",
" <td>-6.829571</td>\n",
" <td>0.965922</td>\n",
" <td>-7.570825</td>\n",
" <td>2.918987</td>\n",
" <td>blues</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>blues.00022.au</td>\n",
" <td>0.306451</td>\n",
" <td>0.113475</td>\n",
" <td>1156.070496</td>\n",
" <td>1497.668176</td>\n",
" <td>2170.053545</td>\n",
" <td>0.058613</td>\n",
" <td>-233.860772</td>\n",
" <td>136.170239</td>\n",
" <td>3.289490</td>\n",
" <td>...</td>\n",
" <td>-2.250578</td>\n",
" <td>3.959198</td>\n",
" <td>5.322555</td>\n",
" <td>0.812028</td>\n",
" <td>-1.107202</td>\n",
" <td>-4.556555</td>\n",
" <td>-2.436490</td>\n",
" <td>3.316913</td>\n",
" <td>-0.608485</td>\n",
" <td>blues</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>blues.00031.au</td>\n",
" <td>0.253487</td>\n",
" <td>0.151571</td>\n",
" <td>1331.073970</td>\n",
" <td>1973.643437</td>\n",
" <td>2900.174130</td>\n",
" <td>0.042967</td>\n",
" <td>-221.802549</td>\n",
" <td>110.843071</td>\n",
" <td>18.620984</td>\n",
" <td>...</td>\n",
" <td>-13.037723</td>\n",
" <td>-12.652228</td>\n",
" <td>-1.821905</td>\n",
" <td>-7.260097</td>\n",
" <td>-6.660252</td>\n",
" <td>-14.682694</td>\n",
" <td>-11.719264</td>\n",
" <td>-11.025216</td>\n",
" <td>-13.387260</td>\n",
" <td>blues</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>blues.00012.au</td>\n",
" <td>0.269320</td>\n",
" <td>0.119072</td>\n",
" <td>1361.045467</td>\n",
" <td>1567.804596</td>\n",
" <td>2739.625101</td>\n",
" <td>0.069124</td>\n",
" <td>-207.208080</td>\n",
" <td>132.799175</td>\n",
" <td>-15.438986</td>\n",
" <td>...</td>\n",
" <td>-0.613248</td>\n",
" <td>0.384877</td>\n",
" <td>2.605128</td>\n",
" <td>-5.188924</td>\n",
" <td>-9.527455</td>\n",
" <td>-9.244394</td>\n",
" <td>-2.848274</td>\n",
" <td>-1.418707</td>\n",
" <td>-5.932607</td>\n",
" <td>blues</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>blues.00056.au</td>\n",
" <td>0.391059</td>\n",
" <td>0.137728</td>\n",
" <td>1811.076084</td>\n",
" <td>2052.332563</td>\n",
" <td>3927.809582</td>\n",
" <td>0.075480</td>\n",
" <td>-145.434568</td>\n",
" <td>102.829023</td>\n",
" <td>-12.517677</td>\n",
" <td>...</td>\n",
" <td>7.457218</td>\n",
" <td>-10.470444</td>\n",
" <td>-2.360483</td>\n",
" <td>-6.783624</td>\n",
" <td>2.671134</td>\n",
" <td>-4.760879</td>\n",
" <td>-0.949005</td>\n",
" <td>0.024832</td>\n",
" <td>-2.005315</td>\n",
" <td>blues</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 28 columns</p>\n",
"</div>"
],
"text/plain": [
" filename chroma_stft rmse spectral_centroid \\\n",
"0 blues.00081.au 0.380260 0.248262 2116.942959 \n",
"1 blues.00022.au 0.306451 0.113475 1156.070496 \n",
"2 blues.00031.au 0.253487 0.151571 1331.073970 \n",
"3 blues.00012.au 0.269320 0.119072 1361.045467 \n",
"4 blues.00056.au 0.391059 0.137728 1811.076084 \n",
"\n",
" spectral_bandwidth rolloff zero_crossing_rate mfcc1 \\\n",
"0 1956.611056 4196.107960 0.127272 -26.929785 \n",
"1 1497.668176 2170.053545 0.058613 -233.860772 \n",
"2 1973.643437 2900.174130 0.042967 -221.802549 \n",
"3 1567.804596 2739.625101 0.069124 -207.208080 \n",
"4 2052.332563 3927.809582 0.075480 -145.434568 \n",
"\n",
" mfcc2 mfcc3 ... mfcc12 mfcc13 mfcc14 mfcc15 \\\n",
"0 107.334008 -46.809993 ... 14.336612 -13.821769 7.562789 -6.181372 \n",
"1 136.170239 3.289490 ... -2.250578 3.959198 5.322555 0.812028 \n",
"2 110.843071 18.620984 ... -13.037723 -12.652228 -1.821905 -7.260097 \n",
"3 132.799175 -15.438986 ... -0.613248 0.384877 2.605128 -5.188924 \n",
"4 102.829023 -12.517677 ... 7.457218 -10.470444 -2.360483 -6.783624 \n",
"\n",
" mfcc16 mfcc17 mfcc18 mfcc19 mfcc20 label \n",
"0 0.330165 -6.829571 0.965922 -7.570825 2.918987 blues \n",
"1 -1.107202 -4.556555 -2.436490 3.316913 -0.608485 blues \n",
"2 -6.660252 -14.682694 -11.719264 -11.025216 -13.387260 blues \n",
"3 -9.527455 -9.244394 -2.848274 -1.418707 -5.932607 blues \n",
"4 2.671134 -4.760879 -0.949005 0.024832 -2.005315 blues \n",
"\n",
"[5 rows x 28 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
}
]
},
{
"metadata": {
"id": "iHrDHCaR9gKR",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "7d32943a-1ad5-4a59-c13a-beebeb36e4c2"
},
"cell_type": "code",
"source": [
"data.shape"
],
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(1000, 28)"
]
},
"metadata": {
"tags": []
},
"execution_count": 7
}
]
},
{
"metadata": {
"id": "veD5BgX49hZa",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"# Dropping unneccesary columns\n",
"data = data.drop(['filename'],axis=1)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Nyr0aAAsGXjZ",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Encoding the Labels"
]
},
{
"metadata": {
"id": "frI5HH4q-1HS",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"genre_list = data.iloc[:, -1]\n",
"encoder = LabelEncoder()\n",
"y = encoder.fit_transform(genre_list)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Slm8W0-iGVhI",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
""
]
},
{
"metadata": {
"id": "_2n8a02zGfvP",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Scaling the Feature columns"
]
},
{
"metadata": {
"id": "uqcqn-nyAofk",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"scaler = StandardScaler()\n",
"X = scaler.fit_transform(np.array(data.iloc[:, :-1], dtype = float))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "e3VZvbwpGo9R",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Dividing data into training and Testing set"
]
},
{
"metadata": {
"id": "F1GW3VvQA7Rj",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "upuczQ-KBHJ5",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "1431a28b-e8b6-4db2-e505-7e149e37c0d7"
},
"cell_type": "code",
"source": [
"len(y_train)"
],
"execution_count": 12,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"800"
]
},
"metadata": {
"tags": []
},
"execution_count": 12
}
]
},
{
"metadata": {
"id": "LtoE_FqqBzM8",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "76555a2b-2030-48e1-b52d-d71b4ebae38e"
},
"cell_type": "code",
"source": [
"len(y_test)"
],
"execution_count": 13,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"200"
]
},
"metadata": {
"tags": []
},
"execution_count": 13
}
]
},
{
"metadata": {
"id": "ir9XaWgQB0lq",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 119
},
"outputId": "2ec90814-19d8-4f27-934a-1ce54406d4ea"
},
"cell_type": "code",
"source": [
"X_train[10]"
],
"execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([-0.9149113 , 0.18294103, -1.10587131, -1.3875197 , -1.14640873,\n",
" -0.97232926, -0.29174214, 1.20078936, -0.68458101, -0.55849017,\n",
" -1.27056582, -0.88176926, -0.74844069, -0.40970382, 0.49685952,\n",
" -1.12666045, 0.59501437, -0.39783853, 0.29327275, -0.72916871,\n",
" 0.63015786, -0.91149976, 0.7743942 , -0.64790051, 0.42229852,\n",
" -1.01449461])"
]
},
"metadata": {
"tags": []
},
"execution_count": 14
}
]
},
{
"metadata": {
"id": "Vp2yc5FWG04e",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# Classification with Keras\n",
"\n",
"## Building our Network"
]
},
{
"metadata": {
"id": "Qj3sc2uFEUMt",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"from keras import models\n",
"from keras import layers\n",
"\n",
"model = models.Sequential()\n",
"model.add(layers.Dense(256, activation='relu', input_shape=(X_train.shape[1],)))\n",
"\n",
"model.add(layers.Dense(128, activation='relu'))\n",
"\n",
"model.add(layers.Dense(64, activation='relu'))\n",
"\n",
"model.add(layers.Dense(10, activation='softmax'))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "7yrsmpI6EjJ2",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"model.compile(optimizer='adam',\n",
" loss='sparse_categorical_crossentropy',\n",
" metrics=['accuracy'])"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "bP0hVm4aElS7",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 697
},
"outputId": "aacf234d-d0a9-4de4-91be-5fd45a33b279"
},
"cell_type": "code",
"source": [
"history = model.fit(X_train,\n",
" y_train,\n",
" epochs=20,\n",
" batch_size=128)\n",
" "
],
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"800/800 [==============================] - 1s 811us/step - loss: 2.1289 - acc: 0.2400\n",
"Epoch 2/20\n",
"800/800 [==============================] - 0s 39us/step - loss: 1.7940 - acc: 0.4088\n",
"Epoch 3/20\n",
"800/800 [==============================] - 0s 37us/step - loss: 1.5437 - acc: 0.4450\n",
"Epoch 4/20\n",
"800/800 [==============================] - 0s 38us/step - loss: 1.3584 - acc: 0.5413\n",
"Epoch 5/20\n",
"800/800 [==============================] - 0s 38us/step - loss: 1.2220 - acc: 0.5750\n",
"Epoch 6/20\n",
"800/800 [==============================] - 0s 41us/step - loss: 1.1187 - acc: 0.6288\n",
"Epoch 7/20\n",
"800/800 [==============================] - 0s 37us/step - loss: 1.0326 - acc: 0.6550\n",
"Epoch 8/20\n",
"800/800 [==============================] - 0s 44us/step - loss: 0.9631 - acc: 0.6713\n",
"Epoch 9/20\n",
"800/800 [==============================] - 0s 47us/step - loss: 0.9143 - acc: 0.6913\n",
"Epoch 10/20\n",
"800/800 [==============================] - 0s 37us/step - loss: 0.8630 - acc: 0.7125\n",
"Epoch 11/20\n",
"800/800 [==============================] - 0s 36us/step - loss: 0.8095 - acc: 0.7263\n",
"Epoch 12/20\n",
"800/800 [==============================] - 0s 37us/step - loss: 0.7728 - acc: 0.7700\n",
"Epoch 13/20\n",
"800/800 [==============================] - 0s 36us/step - loss: 0.7433 - acc: 0.7563\n",
"Epoch 14/20\n",
"800/800 [==============================] - 0s 45us/step - loss: 0.7066 - acc: 0.7825\n",
"Epoch 15/20\n",
"800/800 [==============================] - 0s 43us/step - loss: 0.6718 - acc: 0.7787\n",
"Epoch 16/20\n",
"800/800 [==============================] - 0s 36us/step - loss: 0.6601 - acc: 0.7913\n",
"Epoch 17/20\n",
"800/800 [==============================] - 0s 36us/step - loss: 0.6242 - acc: 0.7963\n",
"Epoch 18/20\n",
"800/800 [==============================] - 0s 44us/step - loss: 0.5994 - acc: 0.8038\n",
"Epoch 19/20\n",
"800/800 [==============================] - 0s 42us/step - loss: 0.5715 - acc: 0.8125\n",
"Epoch 20/20\n",
"800/800 [==============================] - 0s 39us/step - loss: 0.5437 - acc: 0.8250\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "0m1J0_wUFK4C",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "ffd3bf36-29ea-437a-987c-9aa600b9dae6"
},
"cell_type": "code",
"source": [
"test_loss, test_acc = model.evaluate(X_test,y_test)"
],
"execution_count": 20,
"outputs": [
{
"output_type": "stream",
"text": [
"200/200 [==============================] - 0s 244us/step\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "f6HrjXeUF0Ko",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "ea282dbd-6f9e-48c7-de2d-dc9afde8949e"
},
"cell_type": "code",
"source": [
"print('test_acc: ',test_acc)"
],
"execution_count": 21,
"outputs": [
{
"output_type": "stream",
"text": [
"test_acc: 0.68\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "3yQmP_f5Kq0w",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Tes accuracy is less than training dataa accuracy. This hints at Overfitting"
]
},
{
"metadata": {
"id": "-U2qzRJoHV9O",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Validating our approach\n",
"Let's set apart 200 samples in our training data to use as a validation set:"
]
},
{
"metadata": {
"id": "xJNbvYZoF7ZT",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"x_val = X_train[:200]\n",
"partial_x_train = X_train[200:]\n",
"\n",
"y_val = y_train[:200]\n",
"partial_y_train = y_train[200:]"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "L1EkG59EHeEV",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Now let's train our network for 20 epochs:"
]
},
{
"metadata": {
"id": "Dp3G4P3aP4k2",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1071
},
"outputId": "25e1a389-1ac2-425b-bd5f-05736b6e9b96"
},
"cell_type": "code",
"source": [
"model = models.Sequential()\n",
"model.add(layers.Dense(512, activation='relu', input_shape=(X_train.shape[1],)))\n",
"model.add(layers.Dense(256, activation='relu'))\n",
"model.add(layers.Dense(128, activation='relu'))\n",
"model.add(layers.Dense(64, activation='relu'))\n",
"model.add(layers.Dense(10, activation='softmax'))\n",
"\n",
"model.compile(optimizer='adam',\n",
" loss='sparse_categorical_crossentropy',\n",
" metrics=['accuracy'])\n",
"\n",
"model.fit(partial_x_train,\n",
" partial_y_train,\n",
" epochs=30,\n",
" batch_size=512,\n",
" validation_data=(x_val, y_val))\n",
"results = model.evaluate(X_test, y_test)"
],
"execution_count": 37,
"outputs": [
{
"output_type": "stream",
"text": [
"Train on 600 samples, validate on 200 samples\n",
"Epoch 1/30\n",
"600/600 [==============================] - 1s 1ms/step - loss: 2.3074 - acc: 0.0950 - val_loss: 2.1857 - val_acc: 0.2850\n",
"Epoch 2/30\n",
"600/600 [==============================] - 0s 65us/step - loss: 2.1126 - acc: 0.3783 - val_loss: 2.0936 - val_acc: 0.2400\n",
"Epoch 3/30\n",
"600/600 [==============================] - 0s 59us/step - loss: 1.9535 - acc: 0.3633 - val_loss: 1.9966 - val_acc: 0.2600\n",
"Epoch 4/30\n",
"600/600 [==============================] - 0s 58us/step - loss: 1.8082 - acc: 0.3833 - val_loss: 1.8713 - val_acc: 0.3250\n",
"Epoch 5/30\n",
"600/600 [==============================] - 0s 59us/step - loss: 1.6663 - acc: 0.4083 - val_loss: 1.7302 - val_acc: 0.3450\n",
"Epoch 6/30\n",
"600/600 [==============================] - 0s 52us/step - loss: 1.5329 - acc: 0.4550 - val_loss: 1.6233 - val_acc: 0.3700\n",
"Epoch 7/30\n",
"600/600 [==============================] - 0s 62us/step - loss: 1.4236 - acc: 0.4850 - val_loss: 1.5402 - val_acc: 0.3950\n",
"Epoch 8/30\n",
"600/600 [==============================] - 0s 57us/step - loss: 1.3250 - acc: 0.5117 - val_loss: 1.4655 - val_acc: 0.3800\n",
"Epoch 9/30\n",
"600/600 [==============================] - 0s 52us/step - loss: 1.2338 - acc: 0.5633 - val_loss: 1.3927 - val_acc: 0.4650\n",
"Epoch 10/30\n",
"600/600 [==============================] - 0s 61us/step - loss: 1.1577 - acc: 0.5983 - val_loss: 1.3338 - val_acc: 0.5500\n",
"Epoch 11/30\n",
"600/600 [==============================] - 0s 64us/step - loss: 1.0981 - acc: 0.6317 - val_loss: 1.3111 - val_acc: 0.5550\n",
"Epoch 12/30\n",
"600/600 [==============================] - 0s 52us/step - loss: 1.0529 - acc: 0.6517 - val_loss: 1.2696 - val_acc: 0.5400\n",
"Epoch 13/30\n",
"600/600 [==============================] - 0s 52us/step - loss: 0.9994 - acc: 0.6567 - val_loss: 1.2480 - val_acc: 0.5400\n",
"Epoch 14/30\n",
"600/600 [==============================] - 0s 65us/step - loss: 0.9673 - acc: 0.6633 - val_loss: 1.2384 - val_acc: 0.5700\n",
"Epoch 15/30\n",
"600/600 [==============================] - 0s 58us/step - loss: 0.9286 - acc: 0.6633 - val_loss: 1.1953 - val_acc: 0.5800\n",
"Epoch 16/30\n",
"600/600 [==============================] - 0s 59us/step - loss: 0.8849 - acc: 0.6783 - val_loss: 1.2000 - val_acc: 0.5550\n",
"Epoch 17/30\n",
"600/600 [==============================] - 0s 61us/step - loss: 0.8621 - acc: 0.6850 - val_loss: 1.1743 - val_acc: 0.5850\n",
"Epoch 18/30\n",
"600/600 [==============================] - 0s 61us/step - loss: 0.8195 - acc: 0.7150 - val_loss: 1.1609 - val_acc: 0.5750\n",
"Epoch 19/30\n",
"600/600 [==============================] - 0s 62us/step - loss: 0.7976 - acc: 0.7283 - val_loss: 1.1238 - val_acc: 0.6150\n",
"Epoch 20/30\n",
"600/600 [==============================] - 0s 63us/step - loss: 0.7660 - acc: 0.7650 - val_loss: 1.1604 - val_acc: 0.5850\n",
"Epoch 21/30\n",
"600/600 [==============================] - 0s 65us/step - loss: 0.7465 - acc: 0.7650 - val_loss: 1.1888 - val_acc: 0.5700\n",
"Epoch 22/30\n",
"600/600 [==============================] - 0s 65us/step - loss: 0.7099 - acc: 0.7517 - val_loss: 1.1563 - val_acc: 0.6050\n",
"Epoch 23/30\n",
"600/600 [==============================] - 0s 68us/step - loss: 0.6857 - acc: 0.7683 - val_loss: 1.0900 - val_acc: 0.6200\n",
"Epoch 24/30\n",
"600/600 [==============================] - 0s 67us/step - loss: 0.6597 - acc: 0.7850 - val_loss: 1.0872 - val_acc: 0.6300\n",
"Epoch 25/30\n",
"600/600 [==============================] - 0s 67us/step - loss: 0.6377 - acc: 0.7967 - val_loss: 1.1148 - val_acc: 0.6200\n",
"Epoch 26/30\n",
"600/600 [==============================] - 0s 64us/step - loss: 0.6070 - acc: 0.8200 - val_loss: 1.1397 - val_acc: 0.6150\n",
"Epoch 27/30\n",
"600/600 [==============================] - 0s 66us/step - loss: 0.5991 - acc: 0.8167 - val_loss: 1.1255 - val_acc: 0.6300\n",
"Epoch 28/30\n",
"600/600 [==============================] - 0s 62us/step - loss: 0.5656 - acc: 0.8333 - val_loss: 1.0955 - val_acc: 0.6350\n",
"Epoch 29/30\n",
"600/600 [==============================] - 0s 66us/step - loss: 0.5513 - acc: 0.8300 - val_loss: 1.1030 - val_acc: 0.6050\n",
"Epoch 30/30\n",
"600/600 [==============================] - 0s 56us/step - loss: 0.5498 - acc: 0.8233 - val_loss: 1.0869 - val_acc: 0.6250\n",
"200/200 [==============================] - 0s 65us/step\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "dljqHfDPI6lH",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
""
]
},
{
"metadata": {
"id": "Mvi9it1SI4aR",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "98b01ef2-3935-442b-82d6-45f56e036d39"
},
"cell_type": "code",
"source": [
"results"
],
"execution_count": 38,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[1.2261371064186095, 0.65]"
]
},
"metadata": {
"tags": []
},
"execution_count": 38
}
]
},
{
"metadata": {
"id": "r3hb8s1l4rBA",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Predictions on Test Data"
]
},
{
"metadata": {
"id": "gudBAhIXJIi2",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"predictions = model.predict(X_test)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Xb7bVPSwJQF0",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "aca09c75-1d21-4847-bdd9-a0521dc8d948"
},
"cell_type": "code",
"source": [
"predictions[0].shape"
],
"execution_count": 26,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(10,)"
]
},
"metadata": {
"tags": []
},
"execution_count": 26
}
]
},
{
"metadata": {
"id": "llusRQV0JRy9",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "a856289d-883a-47cb-c0fb-ec148330a60a"
},
"cell_type": "code",
"source": [
"np.sum(predictions[0])"
],
"execution_count": 27,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"1.0"
]
},
"metadata": {
"tags": []
},
"execution_count": 27
}
]
},
{
"metadata": {
"id": "0eoEuSZqJTdU",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "94c17d00-dd7f-40a1-84d2-78d1ebde6103"
},
"cell_type": "code",
"source": [
"np.argmax(predictions[0])"
],
"execution_count": 28,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"8"
]
},
"metadata": {
"tags": []
},
"execution_count": 28
}
]
},
{
"metadata": {
"id": "Utgt1bXfJVRN",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment