Skip to content

Instantly share code, notes, and snippets.

@Orbifold
Last active March 26, 2024 06:30
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Orbifold/5e267774dcafa58e7e3cafce7e9c73b6 to your computer and use it in GitHub Desktop.
Save Orbifold/5e267774dcafa58e7e3cafce7e9c73b6 to your computer and use it in GitHub Desktop.
Using GraphSage for link predictions. So much fun.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Graph Link Prediction using GraphSAGE"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"This article is based on the paper [\"Inductive Representation Learning on Large Graphs\" by Hamilton, Ying and Leskovec](https://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf).\n",
"The [StellarGraph](https://www.stellargraph.io) implementation of the [GraphSAGE](http://snap.stanford.edu/graphsage/) algorithm is used to build a model that predicts citation links of [the Cora dataset](https://graphsandnetworks.com/the-cora-dataset/). \n",
"\n",
"The way link prediction is turned into a supervised learning task is actually very savvy. Pairs of nodes are embedded and a binary prediction model is trained where '1' means the nodes are connected and '0' means they are not connected. It's like embedding the adjacency matrix and finding a decision boundary between two types of elements. The entire model is trained end-to-end by minimizing the loss function of choice (e.g., binary cross-entropy between predicted link probabilities and true link labels, with true/false citation links having labels 1/0) using stochastic gradient descent (SGD) updates of the model parameters, with minibatches of 'training' links fed into the model."
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"import networkx as nx\n",
"import pandas as pd\n",
"import os\n",
"\n",
"import stellargraph as sg\n",
"from stellargraph.data import EdgeSplitter\n",
"from stellargraph.mapper import GraphSAGELinkGenerator\n",
"from stellargraph.layer import GraphSAGE, link_classification\n",
"\n",
"import tensorflow.keras as keras # DO NOT USE KERAS DIRECTLY\n",
"from sklearn import preprocessing, feature_extraction, model_selection\n",
"\n",
"from stellargraph import globalvar"
],
"outputs": [],
"execution_count": 31,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Dataset\n",
"\n",
"[The Cora dataset](https://graphsandnetworks.com/the-cora-dataset/) is the hello-world dataset when looking at graph learning. We have described in details in [this article](https://graphsandnetworks.com/the-cora-dataset/) and will not repeat it here. You can also find in the article a direct link to download the data.\n",
"\n",
"The construction below recreates the steps outlined in the article.\n",
"\n"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"data_dir = os.path.expanduser(\"/Users/swa/Desktop/LargeFiles/Graphs/cora\")\n",
"cora_location = os.path.expanduser(os.path.join(data_dir, \"cora.cites\"))\n",
"g_nx = nx.read_edgelist(path=cora_location)\n",
"\n",
"cora_data_location = os.path.expanduser(os.path.join(data_dir, \"cora.content\"))\n",
"node_attr = pd.read_csv(cora_data_location, sep='\\t', header=None)\n",
"values = { str(row.tolist()[0]): row.tolist()[-1] for _, row in node_attr.iterrows()}\n",
"nx.set_node_attributes(g_nx, values, 'subject')\n",
"\n",
"g_nx_ccs = (g_nx.subgraph(c).copy() for c in nx.connected_components(g_nx))\n",
"g_nx = max(g_nx_ccs, key=len)\n",
"print(\"Largest connected component: {} nodes, {} edges\".format(\n",
" g_nx.number_of_nodes(), g_nx.number_of_edges()))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Largest connected component: 2485 nodes, 5069 edges\n"
]
}
],
"execution_count": 3,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"The features of the nodes are taken into account in the model:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"feature_names = [\"w_{}\".format(ii) for ii in range(1433)]\n",
"column_names = feature_names + [\"subject\"]\n",
"node_data = pd.read_csv(os.path.join(data_dir, \"cora.content\"), \n",
" sep=\"\\t\", \n",
" header=None, \n",
" names=column_names)\n",
"node_data.drop(['subject'], axis=1, inplace=True)\n"
],
"outputs": [],
"execution_count": 4,
"metadata": {}
},
{
"cell_type": "code",
"source": [
"node_data.index = node_data.index.map(str)\n",
"node_data = node_data[node_data.index.isin(list(g_nx.nodes()))]"
],
"outputs": [],
"execution_count": 5,
"metadata": {}
},
{
"cell_type": "code",
"source": [
"node_data.head(2)"
],
"outputs": [
{
"output_type": "execute_result",
"execution_count": 7,
"data": {
"text/plain": [
" w_0 w_1 w_2 w_3 w_4 w_5 w_6 w_7 w_8 w_9 ... w_1423 \\\n",
"31336 0 0 0 0 0 0 0 0 0 0 ... 0 \n",
"1061127 0 0 0 0 0 0 0 0 0 0 ... 0 \n",
"\n",
" w_1424 w_1425 w_1426 w_1427 w_1428 w_1429 w_1430 w_1431 \\\n",
"31336 0 0 1 0 0 0 0 0 \n",
"1061127 0 1 0 0 0 0 0 0 \n",
"\n",
" w_1432 \n",
"31336 0 \n",
"1061127 0 \n",
"\n",
"[2 rows x 1433 columns]"
],
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>w_0</th>\n",
" <th>w_1</th>\n",
" <th>w_2</th>\n",
" <th>w_3</th>\n",
" <th>w_4</th>\n",
" <th>w_5</th>\n",
" <th>w_6</th>\n",
" <th>w_7</th>\n",
" <th>w_8</th>\n",
" <th>w_9</th>\n",
" <th>...</th>\n",
" <th>w_1423</th>\n",
" <th>w_1424</th>\n",
" <th>w_1425</th>\n",
" <th>w_1426</th>\n",
" <th>w_1427</th>\n",
" <th>w_1428</th>\n",
" <th>w_1429</th>\n",
" <th>w_1430</th>\n",
" <th>w_1431</th>\n",
" <th>w_1432</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>31336</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1061127</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2 rows × 1433 columns</p>\n",
"</div>"
]
},
"metadata": {}
}
],
"execution_count": 7,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "markdown",
"source": [
"Define a set of node features that will be used by the model as the difference between the set of all node features and a list of user-defined node attributes to ignore:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"feature_names = sorted(set(node_data.columns))"
],
"outputs": [],
"execution_count": 8,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"We need to convert node features that will be used by the model to numeric values that are required for GraphSAGE input. Note that all node features in the Cora dataset, except the categorical \"subject\" feature, are already numeric, and don't require the conversion."
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"node_features = node_data[feature_names].values"
],
"outputs": [],
"execution_count": 12,
"metadata": {}
},
{
"cell_type": "code",
"source": [
"node_features.shape"
],
"outputs": [
{
"output_type": "execute_result",
"execution_count": 13,
"data": {
"text/plain": [
"(2485, 1433)"
]
},
"metadata": {}
}
],
"execution_count": 13,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Add node data to g_nx:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"for nid, f in zip(node_data.index, node_features):\n",
" g_nx.node[nid]['label'] = \"paper\"\n",
" g_nx.node[nid][\"feature\"] = f"
],
"outputs": [],
"execution_count": 17,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Splitting a graph\n",
"\n",
"Splitting graph-like data into train and test sets is not as straightforward as in classic (tabular) machine learning. If you take a subset of nodes you also need to ensure that the edges do not have endpoints across the other set. That is, edges should connect only to train or test nodes but not having endpoints in each set. So, this is in general a little tricky but the StellarGraph framework makes it easy by giving us a method to do this in one line of code. Actually the splitting happens in a slightly different fashion. Instead of taking a subset of nodes all the nodes are kept in both training and test but the edges are randomly sampled. Each of these graphs will have the same number of nodes as the input graph, but the number of links will differ (be reduced) as some of the links will be removed during each split and used as the positive samples for training/testing the link prediction classifier."
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"From the original graph G, extract a randomly sampled subset of test edges (true and false citation links) and the reduced graph G_test with the positive test edges removed. Define an edge splitter on the original graph `g_nx`:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"edge_splitter_test = EdgeSplitter(g_nx)"
],
"outputs": [],
"execution_count": 18,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G, and obtain the reduced graph G_test with the sampled links removed:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"G_test, edge_ids_test, edge_labels_test = edge_splitter_test.train_test_split(\n",
" p=0.1, method=\"global\", keep_connected=True\n",
")"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"** Sampled 506 positive and 506 negative edges. **\n"
]
}
],
"execution_count": 19,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "markdown",
"source": [
"The reduced graph G_test, together with the test ground truth set of links (edge_ids_test, edge_labels_test), will be used for testing the model.\n",
"\n",
"Now repeat this procedure to obtain the training data for the model. From the reduced graph G_test, extract a randomly sampled subset of train edges (true and false citation links) and the reduced graph G_train with the positive train edges removed:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"edge_splitter_train = EdgeSplitter(G_test)\n",
"G_train, edge_ids_train, edge_labels_train = edge_splitter_train.train_test_split(\n",
" p=0.1, method=\"global\", keep_connected=True\n",
")"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"** Sampled 456 positive and 456 negative edges. **\n"
]
}
],
"execution_count": 20,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Defining the GraphSage model"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Convert G_train and G_test to StellarGraph objects (undirected, as required by GraphSAGE) for ML:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"G_train = sg.StellarGraph(G_train, node_features=\"feature\")\n",
"G_test = sg.StellarGraph(G_test, node_features=\"feature\")"
],
"outputs": [],
"execution_count": 21,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Summary of G_train and G_test - note that they have the same set of nodes, only differing in their edge sets:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"print(G_train.info())"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"StellarGraph: Undirected multigraph\n",
" Nodes: 2485, Edges: 4107\n",
"\n",
" Node types:\n",
" paper: [2485]\n",
" Attributes: {'feature', 'subject'}\n",
" Edge types: paper-default->paper\n",
"\n",
" Edge types:\n",
" paper-default->paper: [4107]\n",
"\n"
]
}
],
"execution_count": 22,
"metadata": {}
},
{
"cell_type": "code",
"source": [
"print(G_test.info())"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"StellarGraph: Undirected multigraph\n",
" Nodes: 2485, Edges: 4563\n",
"\n",
" Node types:\n",
" paper: [2485]\n",
" Attributes: {'feature', 'subject'}\n",
" Edge types: paper-default->paper\n",
"\n",
" Edge types:\n",
" paper-default->paper: [4563]\n",
"\n"
]
}
],
"execution_count": 23,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Next, we create the link mappers for sampling and streaming training and testing data to the model. The link mappers essentially \"map\" pairs of nodes `(paper1, paper2)` to the input of GraphSAGE: they take minibatches of node pairs, sample 2-hop subgraphs with `(paper1, paper2)` head nodes extracted from those pairs, and feed them, together with the corresponding binary labels indicating whether those pairs represent true or false citation links, to the input layer of the GraphSAGE model, for SGD updates of the model parameters.\n",
"\n",
"Specify the minibatch size (number of node pairs per minibatch) and the number of epochs for training the model:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"batch_size = 20\n",
"epochs = 20"
],
"outputs": [],
"execution_count": 24,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Specify the sizes of 1- and 2-hop neighbour samples for GraphSAGE:\n",
"\n",
"Note that the length of `num_samples` list defines the number of layers/iterations in the GraphSAGE model. In this example, we are defining a 2-layer GraphSAGE model."
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"num_samples = [20, 10]"
],
"outputs": [],
"execution_count": 25,
"metadata": {}
},
{
"cell_type": "code",
"source": [
"train_gen = GraphSAGELinkGenerator(G_train, batch_size, num_samples).flow(edge_ids_train,edge_labels_train)\n",
"test_gen = GraphSAGELinkGenerator(G_test, batch_size, num_samples).flow(edge_ids_test, edge_labels_test)"
],
"outputs": [],
"execution_count": 26,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Build the model: a 2-layer GraphSAGE model acting as node representation learner, with a link classification layer on concatenated `(paper1, paper2)` node embeddings.\n",
"\n",
"GraphSAGE part of the model, with hidden layer sizes of 50 for both GraphSAGE layers, a bias term, and no dropout. (Dropout can be switched on by specifying a positive dropout rate, 0 < dropout < 1)\n",
"Note that the length of layer_sizes list must be equal to the length of num_samples, as len(num_samples) defines the number of hops (layers) in the GraphSAGE model."
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"layer_sizes = [20, 20]\n",
"assert len(layer_sizes) == len(num_samples)\n",
"\n",
"graphsage = GraphSAGE(\n",
" layer_sizes=layer_sizes, generator=train_gen, bias=True, dropout=0.3\n",
" )"
],
"outputs": [],
"execution_count": 32,
"metadata": {}
},
{
"cell_type": "code",
"source": [
"x_inp, x_out = graphsage.build()"
],
"outputs": [],
"execution_count": 33,
"metadata": {
"scrolled": true
}
},
{
"cell_type": "markdown",
"source": [
"Final link classification layer that takes a pair of node embeddings produced by graphsage, applies a binary operator to them to produce the corresponding link embedding ('ip' for inner product; other options for the binary operator can be seen by running a cell with `?link_classification` in it), and passes it through a dense layer:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"prediction = link_classification(\n",
" output_dim=1, output_act=\"relu\", edge_embedding_method='ip'\n",
" )(x_out)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"link_classification: using 'ip' method to combine node embeddings into edge embeddings\n"
]
}
],
"execution_count": 34,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Stack the GraphSAGE and prediction layers into a Keras model, and specify the loss"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"model = keras.Model(inputs=x_inp, outputs=prediction)\n",
"\n",
"model.compile(\n",
" optimizer=keras.optimizers.Adam(lr=1e-3),\n",
" loss=keras.losses.binary_crossentropy,\n",
" metrics=[\"acc\"],\n",
" )"
],
"outputs": [],
"execution_count": 35,
"metadata": {
"scrolled": true
}
},
{
"cell_type": "markdown",
"source": [
"Evaluate the initial (untrained) model on the train and test set:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"init_train_metrics = model.evaluate_generator(train_gen)\n",
"init_test_metrics = model.evaluate_generator(test_gen)\n",
"\n",
"print(\"\\nTrain Set Metrics of the initial (untrained) model:\")\n",
"for name, val in zip(model.metrics_names, init_train_metrics):\n",
" print(\"\\t{}: {:0.4f}\".format(name, val))\n",
"\n",
"print(\"\\nTest Set Metrics of the initial (untrained) model:\")\n",
"for name, val in zip(model.metrics_names, init_test_metrics):\n",
" print(\"\\t{}: {:0.4f}\".format(name, val))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"Train Set Metrics of the initial (untrained) model:\n",
"\tloss: 0.6847\n",
"\tacc: 0.6316\n",
"\n",
"Test Set Metrics of the initial (untrained) model:\n",
"\tloss: 0.6795\n",
"\tacc: 0.6364\n"
]
}
],
"execution_count": 36,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"Let's go for it:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"history = model.fit_generator(\n",
" train_gen,\n",
" epochs=epochs,\n",
" validation_data=test_gen,\n",
" verbose=2\n",
" )"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"WARNING:tensorflow:From /Users/swa/conda/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.cast instead.\n",
"Epoch 1/20\n",
"51/51 [==============================] - 2s 47ms/step - loss: 0.6117 - acc: 0.6324\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.7215 - acc: 0.6064 - val_loss: 0.6117 - val_acc: 0.6324\n",
"Epoch 2/20\n",
"51/51 [==============================] - 3s 53ms/step - loss: 0.5301 - acc: 0.7263\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.5407 - acc: 0.7171 - val_loss: 0.5301 - val_acc: 0.7263\n",
"Epoch 3/20\n",
"51/51 [==============================] - 3s 53ms/step - loss: 0.4952 - acc: 0.7441\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.4658 - acc: 0.7928 - val_loss: 0.4952 - val_acc: 0.7441\n",
"Epoch 4/20\n",
"51/51 [==============================] - 3s 56ms/step - loss: 0.4618 - acc: 0.7836\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.3992 - acc: 0.8344 - val_loss: 0.4618 - val_acc: 0.7836\n",
"Epoch 5/20\n",
"51/51 [==============================] - 3s 53ms/step - loss: 0.5118 - acc: 0.7856\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.3644 - acc: 0.8980 - val_loss: 0.5118 - val_acc: 0.7856\n",
"Epoch 6/20\n",
"51/51 [==============================] - 3s 57ms/step - loss: 0.4758 - acc: 0.7964\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 8s - loss: 0.3553 - acc: 0.8936 - val_loss: 0.4758 - val_acc: 0.7964\n",
"Epoch 7/20\n",
"51/51 [==============================] - 3s 57ms/step - loss: 0.4806 - acc: 0.8024\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 8s - loss: 0.2726 - acc: 0.9243 - val_loss: 0.4806 - val_acc: 0.8024\n",
"Epoch 8/20\n",
"51/51 [==============================] - 3s 53ms/step - loss: 0.5067 - acc: 0.8093\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.2734 - acc: 0.9309 - val_loss: 0.5067 - val_acc: 0.8093\n",
"Epoch 9/20\n",
"51/51 [==============================] - 3s 55ms/step - loss: 0.5595 - acc: 0.8014\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.2268 - acc: 0.9627 - val_loss: 0.5595 - val_acc: 0.8014\n",
"Epoch 10/20\n",
"51/51 [==============================] - 3s 52ms/step - loss: 0.5354 - acc: 0.7984\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.2051 - acc: 0.9616 - val_loss: 0.5354 - val_acc: 0.7984\n",
"Epoch 11/20\n",
"51/51 [==============================] - 3s 50ms/step - loss: 0.5476 - acc: 0.8123\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.1896 - acc: 0.9627 - val_loss: 0.5476 - val_acc: 0.8123\n",
"Epoch 12/20\n",
"51/51 [==============================] - 3s 54ms/step - loss: 0.5359 - acc: 0.8113\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.1985 - acc: 0.9594 - val_loss: 0.5359 - val_acc: 0.8113\n",
"Epoch 13/20\n",
"51/51 [==============================] - 3s 54ms/step - loss: 0.5721 - acc: 0.8103\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.1794 - acc: 0.9737 - val_loss: 0.5721 - val_acc: 0.8103\n",
"Epoch 14/20\n",
"51/51 [==============================] - 3s 53ms/step - loss: 0.5526 - acc: 0.8152\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.1600 - acc: 0.9825 - val_loss: 0.5526 - val_acc: 0.8152\n",
"Epoch 15/20\n",
"51/51 [==============================] - 3s 53ms/step - loss: 0.6594 - acc: 0.8034\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.1576 - acc: 0.9781 - val_loss: 0.6594 - val_acc: 0.8034\n",
"Epoch 16/20\n",
"51/51 [==============================] - 3s 54ms/step - loss: 0.5591 - acc: 0.8132\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.1426 - acc: 0.9825 - val_loss: 0.5591 - val_acc: 0.8132\n",
"Epoch 17/20\n",
"51/51 [==============================] - 3s 52ms/step - loss: 0.5637 - acc: 0.7935\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.1359 - acc: 0.9792 - val_loss: 0.5637 - val_acc: 0.7935\n",
"Epoch 18/20\n",
"51/51 [==============================] - 3s 53ms/step - loss: 0.6060 - acc: 0.8083\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.1306 - acc: 0.9912 - val_loss: 0.6060 - val_acc: 0.8083\n",
"Epoch 19/20\n",
"51/51 [==============================] - 3s 53ms/step - loss: 0.5586 - acc: 0.7955\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.1258 - acc: 0.9857 - val_loss: 0.5586 - val_acc: 0.7955\n",
"Epoch 20/20\n",
"51/51 [==============================] - 3s 51ms/step - loss: 0.6495 - acc: 0.7964\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\n",
" - 7s - loss: 0.1193 - acc: 0.9923 - val_loss: 0.6495 - val_acc: 0.7964\n"
]
}
],
"execution_count": 37,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"You can use tensorboard to see pretty dataviz or you can use a normal Python plot:"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"def plot_history(history):\n",
" metrics = sorted(history.history.keys())\n",
" metrics = metrics[:len(metrics)//2]\n",
" \n",
" f,axs = plt.subplots(1, len(metrics), figsize=(12,4))\n",
"\n",
" for m,ax in zip(metrics,axs):\n",
" # summarize history for metric m\n",
" ax.plot(history.history[m])\n",
" ax.plot(history.history['val_' + m])\n",
" ax.set_title(m)\n",
" ax.set_ylabel(m)\n",
" ax.set_xlabel('epoch')\n",
" ax.legend(['train', 'test'], loc='upper right')\n",
"\n",
"plot_history(history) "
],
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 864x288 with 2 Axes>"
],
"image/png": [
"\n"
]
},
"metadata": {
"needs_background": "light"
}
}
],
"execution_count": 39,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"So, how well does our model perform?"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"train_metrics = model.evaluate_generator(train_gen)\n",
"test_metrics = model.evaluate_generator(test_gen)\n",
"\n",
"print(\"\\nTrain Set Metrics of the trained model:\")\n",
"for name, val in zip(model.metrics_names, train_metrics):\n",
" print(\"\\t{}: {:0.4f}\".format(name, val))\n",
"\n",
"print(\"\\nTest Set Metrics of the trained model:\")\n",
"for name, val in zip(model.metrics_names, test_metrics):\n",
" print(\"\\t{}: {:0.4f}\".format(name, val))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"Train Set Metrics of the trained model:\n",
"\tloss: 0.0549\n",
"\tacc: 0.9978\n",
"\n",
"Test Set Metrics of the trained model:\n",
"\tloss: 0.6798\n",
"\tacc: 0.7925\n"
]
}
],
"execution_count": 40,
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"There is space for improvements but this article is in the first place a conceptual invitation not a way to accuracy paradise."
],
"metadata": {}
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"language": "python",
"display_name": "Python 3"
},
"language_info": {
"name": "python",
"version": "3.7.2",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kernel_info": {
"name": "python3"
},
"nteract": {
"version": "0.15.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment