Skip to content

Instantly share code, notes, and snippets.

@hanneshapke
Last active April 16, 2018 15:40
Show Gist options
  • Save hanneshapke/0a49aa27eb0c7ed11662f1fe37e18c9c to your computer and use it in GitHub Desktop.
Save hanneshapke/0a49aa27eb0c7ed11662f1fe37e18c9c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualization of Convolutional Neural Network Layers\n",
"\n",
"This model to classify newsgroup articles was adopted by an example from https://github.com/tuhinsharma/newsgroup-prediction/"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# install required packages\n",
"# !pip install pydot graphviz"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"from matplotlib import pyplot as plt\n",
"\n",
"import numpy as np\n",
"from IPython.display import HTML as html_print"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n",
"Using TensorFlow backend.\n"
]
}
],
"source": [
"from sklearn.datasets import fetch_20newsgroups\n",
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
"from keras import backend as K\n",
"from keras.preprocessing.text import Tokenizer, text_to_word_sequence\n",
"from keras.preprocessing.sequence import pad_sequences\n",
"from keras.models import Model\n",
"from keras.layers import Embedding, Dense, Conv1D, Dropout, MaxPooling1D, Flatten, Input\n",
"from keras.utils import np_utils"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Preprocessing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize Dataset"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"train_text = fetch_20newsgroups(subset='train').data\n",
"train_label = fetch_20newsgroups(subset='train').target\n",
"test_text = fetch_20newsgroups(subset='test').data\n",
"test_label = fetch_20newsgroups(subset='test').target\n",
"\n",
"labels = fetch_20newsgroups(subset='test').target_names"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize Constatns"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"max_words_to_keep = 1000\n",
"maxlen_text = 300\n",
"token_vec_size = 128\n",
"output_dim = np.unique(train_label).__len__()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tokenize and pad each newsgroups posts "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = Tokenizer(num_words=max_words_to_keep,\n",
" filters='!\"#$%&()*+,\\'-./:;<=>?@[\\\\]^_`{|}~\\t\\n\\\"',\n",
" lower=True,\n",
" split=\" \",\n",
" char_level=False)\n",
"tokenizer.fit_on_texts(train_text)\n",
"\n",
"sequences = tokenizer.texts_to_sequences(train_text)\n",
"train_X = pad_sequences(sequences=sequences, maxlen=maxlen_text, padding='post', truncating='post')\n",
"\n",
"sequences = tokenizer.texts_to_sequences(test_text)\n",
"test_X = pad_sequences(sequences=sequences, maxlen=maxlen_text, padding='post', truncating='post')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Handle the Label using LabelEncoder"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"encoder = LabelEncoder()\n",
"\n",
"encoder.fit(train_label)\n",
"\n",
"encoded_train_Y = encoder.transform(train_label)\n",
"train_Y = np_utils.to_categorical(encoded_train_Y, num_classes=output_dim)\n",
"\n",
"encoded_train_Y = encoder.transform(test_label)\n",
"test_Y = np_utils.to_categorical(encoded_train_Y, num_classes=output_dim)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Lets print the shapes of data and labels"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(11314, 300)\n",
"(11314, 20)\n",
"(7532, 300)\n",
"(7532, 20)\n"
]
}
],
"source": [
"print(train_X.shape)\n",
"print(train_Y.shape)\n",
"print(test_X.shape)\n",
"print(test_Y.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Convolutional Net Model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"sequence_input = Input(shape=(maxlen_text,))\n",
"\n",
"x = Embedding(name='embedding_layer', input_dim=max_words_to_keep, output_dim=token_vec_size, \n",
" input_length=maxlen_text)(sequence_input)\n",
"x = Dropout(.20)(x)\n",
"\n",
"x = Conv1D(64, 5, activation='relu', name='1-conv1d', padding='same')(x)\n",
"x = MaxPooling1D(pool_size=4)(x)\n",
"x = Dropout(.20)(x)\n",
"\n",
"x = Conv1D(64, 5, activation='relu', name='2-conv1d', padding='same')(x)\n",
"x = MaxPooling1D(pool_size=2)(x)\n",
"x = Dropout(.20)(x)\n",
"\n",
"x = Flatten()(x)\n",
"\n",
"output = Dense(units=output_dim,activation='softmax')(x)\n",
"cnn_model = Model(inputs=sequence_input, outputs=output)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compile the CNN model"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"cnn_model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Investigate the CNN model"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input_1 (InputLayer) (None, 300) 0 \n",
"_________________________________________________________________\n",
"embedding_layer (Embedding) (None, 300, 128) 128000 \n",
"_________________________________________________________________\n",
"dropout_1 (Dropout) (None, 300, 128) 0 \n",
"_________________________________________________________________\n",
"1-conv1d (Conv1D) (None, 300, 64) 41024 \n",
"_________________________________________________________________\n",
"max_pooling1d_1 (MaxPooling1 (None, 75, 64) 0 \n",
"_________________________________________________________________\n",
"dropout_2 (Dropout) (None, 75, 64) 0 \n",
"_________________________________________________________________\n",
"2-conv1d (Conv1D) (None, 75, 64) 20544 \n",
"_________________________________________________________________\n",
"max_pooling1d_2 (MaxPooling1 (None, 37, 64) 0 \n",
"_________________________________________________________________\n",
"dropout_3 (Dropout) (None, 37, 64) 0 \n",
"_________________________________________________________________\n",
"flatten_1 (Flatten) (None, 2368) 0 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 20) 47380 \n",
"=================================================================\n",
"Total params: 236,948\n",
"Trainable params: 236,948\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"cnn_model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<svg height=\"863pt\" viewBox=\"0.00 0.00 379.93 863.00\" width=\"380pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 859)\">\n",
"<title>G</title>\n",
"<polygon fill=\"#ffffff\" points=\"-4,4 -4,-859 375.9277,-859 375.9277,4 -4,4\" stroke=\"transparent\"/>\n",
"<!-- 4700649848 -->\n",
"<g class=\"node\" id=\"node1\">\n",
"<title>4700649848</title>\n",
"<polygon fill=\"none\" points=\"52.124,-810.5 52.124,-854.5 319.8037,-854.5 319.8037,-810.5 52.124,-810.5\" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"116.3052\" y=\"-828.3\">input_1: InputLayer</text>\n",
"<polyline fill=\"none\" points=\"180.4863,-810.5 180.4863,-854.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"208.3208\" y=\"-839.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"180.4863,-832.5 236.1553,-832.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"208.3208\" y=\"-817.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"236.1553,-810.5 236.1553,-854.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"277.9795\" y=\"-839.3\">(None, 300)</text>\n",
"<polyline fill=\"none\" points=\"236.1553,-832.5 319.8037,-832.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"277.9795\" y=\"-817.3\">(None, 300)</text>\n",
"</g>\n",
"<!-- 4700649960 -->\n",
"<g class=\"node\" id=\"node2\">\n",
"<title>4700649960</title>\n",
"<polygon fill=\"none\" points=\"9.7446,-729.5 9.7446,-773.5 362.1831,-773.5 362.1831,-729.5 9.7446,-729.5\" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"102.3052\" y=\"-747.3\">embedding_layer: Embedding</text>\n",
"<polyline fill=\"none\" points=\"194.8657,-729.5 194.8657,-773.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"222.7002\" y=\"-758.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"194.8657,-751.5 250.5347,-751.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"222.7002\" y=\"-736.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"250.5347,-729.5 250.5347,-773.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"306.3589\" y=\"-758.3\">(None, 300)</text>\n",
"<polyline fill=\"none\" points=\"250.5347,-751.5 362.1831,-751.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"306.3589\" y=\"-736.3\">(None, 300, 128)</text>\n",
"</g>\n",
"<!-- 4700649848&#45;&gt;4700649960 -->\n",
"<g class=\"edge\" id=\"edge1\">\n",
"<title>4700649848-&gt;4700649960</title>\n",
"<path d=\"M185.9639,-810.3664C185.9639,-802.1516 185.9639,-792.6579 185.9639,-783.7252\" fill=\"none\" stroke=\"#000000\"/>\n",
"<polygon fill=\"#000000\" points=\"189.464,-783.6068 185.9639,-773.6068 182.464,-783.6069 189.464,-783.6068\" stroke=\"#000000\"/>\n",
"</g>\n",
"<!-- 4699992416 -->\n",
"<g class=\"node\" id=\"node3\">\n",
"<title>4699992416</title>\n",
"<polygon fill=\"none\" points=\"38.5034,-648.5 38.5034,-692.5 333.4243,-692.5 333.4243,-648.5 38.5034,-648.5\" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"102.3052\" y=\"-666.3\">dropout_1: Dropout</text>\n",
"<polyline fill=\"none\" points=\"166.1069,-648.5 166.1069,-692.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"193.9414\" y=\"-677.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"166.1069,-670.5 221.7759,-670.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"193.9414\" y=\"-655.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"221.7759,-648.5 221.7759,-692.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"277.6001\" y=\"-677.3\">(None, 300, 128)</text>\n",
"<polyline fill=\"none\" points=\"221.7759,-670.5 333.4243,-670.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"277.6001\" y=\"-655.3\">(None, 300, 128)</text>\n",
"</g>\n",
"<!-- 4700649960&#45;&gt;4699992416 -->\n",
"<g class=\"edge\" id=\"edge2\">\n",
"<title>4700649960-&gt;4699992416</title>\n",
"<path d=\"M185.9639,-729.3664C185.9639,-721.1516 185.9639,-711.6579 185.9639,-702.7252\" fill=\"none\" stroke=\"#000000\"/>\n",
"<polygon fill=\"#000000\" points=\"189.464,-702.6068 185.9639,-692.6068 182.464,-702.6069 189.464,-702.6068\" stroke=\"#000000\"/>\n",
"</g>\n",
"<!-- 4699993592 -->\n",
"<g class=\"node\" id=\"node4\">\n",
"<title>4699993592</title>\n",
"<polygon fill=\"none\" points=\"40.4482,-567.5 40.4482,-611.5 331.4795,-611.5 331.4795,-567.5 40.4482,-567.5\" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"102.3052\" y=\"-585.3\">1-conv1d: Conv1D</text>\n",
"<polyline fill=\"none\" points=\"164.1621,-567.5 164.1621,-611.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"191.9966\" y=\"-596.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"164.1621,-589.5 219.8311,-589.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"191.9966\" y=\"-574.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"219.8311,-567.5 219.8311,-611.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"275.6553\" y=\"-596.3\">(None, 300, 128)</text>\n",
"<polyline fill=\"none\" points=\"219.8311,-589.5 331.4795,-589.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"275.6553\" y=\"-574.3\">(None, 300, 64)</text>\n",
"</g>\n",
"<!-- 4699992416&#45;&gt;4699993592 -->\n",
"<g class=\"edge\" id=\"edge3\">\n",
"<title>4699992416-&gt;4699993592</title>\n",
"<path d=\"M185.9639,-648.3664C185.9639,-640.1516 185.9639,-630.6579 185.9639,-621.7252\" fill=\"none\" stroke=\"#000000\"/>\n",
"<polygon fill=\"#000000\" points=\"189.464,-621.6068 185.9639,-611.6068 182.464,-621.6069 189.464,-621.6068\" stroke=\"#000000\"/>\n",
"</g>\n",
"<!-- 4660095352 -->\n",
"<g class=\"node\" id=\"node5\">\n",
"<title>4660095352</title>\n",
"<polygon fill=\"none\" points=\"0,-486.5 0,-530.5 371.9277,-530.5 371.9277,-486.5 0,-486.5\" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"105.8052\" y=\"-504.3\">max_pooling1d_1: MaxPooling1D</text>\n",
"<polyline fill=\"none\" points=\"211.6104,-486.5 211.6104,-530.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"239.4448\" y=\"-515.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"211.6104,-508.5 267.2793,-508.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"239.4448\" y=\"-493.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"267.2793,-486.5 267.2793,-530.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"319.6035\" y=\"-515.3\">(None, 300, 64)</text>\n",
"<polyline fill=\"none\" points=\"267.2793,-508.5 371.9277,-508.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"319.6035\" y=\"-493.3\">(None, 75, 64)</text>\n",
"</g>\n",
"<!-- 4699993592&#45;&gt;4660095352 -->\n",
"<g class=\"edge\" id=\"edge4\">\n",
"<title>4699993592-&gt;4660095352</title>\n",
"<path d=\"M185.9639,-567.3664C185.9639,-559.1516 185.9639,-549.6579 185.9639,-540.7252\" fill=\"none\" stroke=\"#000000\"/>\n",
"<polygon fill=\"#000000\" points=\"189.464,-540.6068 185.9639,-530.6068 182.464,-540.6069 189.464,-540.6068\" stroke=\"#000000\"/>\n",
"</g>\n",
"<!-- 4700647720 -->\n",
"<g class=\"node\" id=\"node6\">\n",
"<title>4700647720</title>\n",
"<polygon fill=\"none\" points=\"45.5034,-405.5 45.5034,-449.5 326.4243,-449.5 326.4243,-405.5 45.5034,-405.5\" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"109.3052\" y=\"-423.3\">dropout_2: Dropout</text>\n",
"<polyline fill=\"none\" points=\"173.1069,-405.5 173.1069,-449.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"200.9414\" y=\"-434.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"173.1069,-427.5 228.7759,-427.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"200.9414\" y=\"-412.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"228.7759,-405.5 228.7759,-449.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"277.6001\" y=\"-434.3\">(None, 75, 64)</text>\n",
"<polyline fill=\"none\" points=\"228.7759,-427.5 326.4243,-427.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"277.6001\" y=\"-412.3\">(None, 75, 64)</text>\n",
"</g>\n",
"<!-- 4660095352&#45;&gt;4700647720 -->\n",
"<g class=\"edge\" id=\"edge5\">\n",
"<title>4660095352-&gt;4700647720</title>\n",
"<path d=\"M185.9639,-486.3664C185.9639,-478.1516 185.9639,-468.6579 185.9639,-459.7252\" fill=\"none\" stroke=\"#000000\"/>\n",
"<polygon fill=\"#000000\" points=\"189.464,-459.6068 185.9639,-449.6068 182.464,-459.6069 189.464,-459.6068\" stroke=\"#000000\"/>\n",
"</g>\n",
"<!-- 4466755344 -->\n",
"<g class=\"node\" id=\"node7\">\n",
"<title>4466755344</title>\n",
"<polygon fill=\"none\" points=\"47.4482,-324.5 47.4482,-368.5 324.4795,-368.5 324.4795,-324.5 47.4482,-324.5\" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"109.3052\" y=\"-342.3\">2-conv1d: Conv1D</text>\n",
"<polyline fill=\"none\" points=\"171.1621,-324.5 171.1621,-368.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"198.9966\" y=\"-353.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"171.1621,-346.5 226.8311,-346.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"198.9966\" y=\"-331.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"226.8311,-324.5 226.8311,-368.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"275.6553\" y=\"-353.3\">(None, 75, 64)</text>\n",
"<polyline fill=\"none\" points=\"226.8311,-346.5 324.4795,-346.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"275.6553\" y=\"-331.3\">(None, 75, 64)</text>\n",
"</g>\n",
"<!-- 4700647720&#45;&gt;4466755344 -->\n",
"<g class=\"edge\" id=\"edge6\">\n",
"<title>4700647720-&gt;4466755344</title>\n",
"<path d=\"M185.9639,-405.3664C185.9639,-397.1516 185.9639,-387.6579 185.9639,-378.7252\" fill=\"none\" stroke=\"#000000\"/>\n",
"<polygon fill=\"#000000\" points=\"189.464,-378.6068 185.9639,-368.6068 182.464,-378.6069 189.464,-378.6068\" stroke=\"#000000\"/>\n",
"</g>\n",
"<!-- 4701008448 -->\n",
"<g class=\"node\" id=\"node8\">\n",
"<title>4701008448</title>\n",
"<polygon fill=\"none\" points=\"3.5,-243.5 3.5,-287.5 368.4277,-287.5 368.4277,-243.5 3.5,-243.5\" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"109.3052\" y=\"-261.3\">max_pooling1d_2: MaxPooling1D</text>\n",
"<polyline fill=\"none\" points=\"215.1104,-243.5 215.1104,-287.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"242.9448\" y=\"-272.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"215.1104,-265.5 270.7793,-265.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"242.9448\" y=\"-250.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"270.7793,-243.5 270.7793,-287.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"319.6035\" y=\"-272.3\">(None, 75, 64)</text>\n",
"<polyline fill=\"none\" points=\"270.7793,-265.5 368.4277,-265.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"319.6035\" y=\"-250.3\">(None, 37, 64)</text>\n",
"</g>\n",
"<!-- 4466755344&#45;&gt;4701008448 -->\n",
"<g class=\"edge\" id=\"edge7\">\n",
"<title>4466755344-&gt;4701008448</title>\n",
"<path d=\"M185.9639,-324.3664C185.9639,-316.1516 185.9639,-306.6579 185.9639,-297.7252\" fill=\"none\" stroke=\"#000000\"/>\n",
"<polygon fill=\"#000000\" points=\"189.464,-297.6068 185.9639,-287.6068 182.464,-297.6069 189.464,-297.6068\" stroke=\"#000000\"/>\n",
"</g>\n",
"<!-- 4394721744 -->\n",
"<g class=\"node\" id=\"node9\">\n",
"<title>4394721744</title>\n",
"<polygon fill=\"none\" points=\"45.5034,-162.5 45.5034,-206.5 326.4243,-206.5 326.4243,-162.5 45.5034,-162.5\" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"109.3052\" y=\"-180.3\">dropout_3: Dropout</text>\n",
"<polyline fill=\"none\" points=\"173.1069,-162.5 173.1069,-206.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"200.9414\" y=\"-191.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"173.1069,-184.5 228.7759,-184.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"200.9414\" y=\"-169.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"228.7759,-162.5 228.7759,-206.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"277.6001\" y=\"-191.3\">(None, 37, 64)</text>\n",
"<polyline fill=\"none\" points=\"228.7759,-184.5 326.4243,-184.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"277.6001\" y=\"-169.3\">(None, 37, 64)</text>\n",
"</g>\n",
"<!-- 4701008448&#45;&gt;4394721744 -->\n",
"<g class=\"edge\" id=\"edge8\">\n",
"<title>4701008448-&gt;4394721744</title>\n",
"<path d=\"M185.9639,-243.3664C185.9639,-235.1516 185.9639,-225.6579 185.9639,-216.7252\" fill=\"none\" stroke=\"#000000\"/>\n",
"<polygon fill=\"#000000\" points=\"189.464,-216.6068 185.9639,-206.6068 182.464,-216.6069 189.464,-216.6068\" stroke=\"#000000\"/>\n",
"</g>\n",
"<!-- 4700616016 -->\n",
"<g class=\"node\" id=\"node10\">\n",
"<title>4700616016</title>\n",
"<polygon fill=\"none\" points=\"53.6724,-81.5 53.6724,-125.5 318.2554,-125.5 318.2554,-81.5 53.6724,-81.5\" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"109.3052\" y=\"-99.3\">flatten_1: Flatten</text>\n",
"<polyline fill=\"none\" points=\"164.938,-81.5 164.938,-125.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"192.7725\" y=\"-110.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"164.938,-103.5 220.6069,-103.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"192.7725\" y=\"-88.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"220.6069,-81.5 220.6069,-125.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"269.4312\" y=\"-110.3\">(None, 37, 64)</text>\n",
"<polyline fill=\"none\" points=\"220.6069,-103.5 318.2554,-103.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"269.4312\" y=\"-88.3\">(None, 2368)</text>\n",
"</g>\n",
"<!-- 4394721744&#45;&gt;4700616016 -->\n",
"<g class=\"edge\" id=\"edge9\">\n",
"<title>4394721744-&gt;4700616016</title>\n",
"<path d=\"M185.9639,-162.3664C185.9639,-154.1516 185.9639,-144.6579 185.9639,-135.7252\" fill=\"none\" stroke=\"#000000\"/>\n",
"<polygon fill=\"#000000\" points=\"189.464,-135.6068 185.9639,-125.6068 182.464,-135.6069 189.464,-135.6068\" stroke=\"#000000\"/>\n",
"</g>\n",
"<!-- 4699823016 -->\n",
"<g class=\"node\" id=\"node11\">\n",
"<title>4699823016</title>\n",
"<polygon fill=\"none\" points=\"60.6792,-.5 60.6792,-44.5 311.2485,-44.5 311.2485,-.5 60.6792,-.5\" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"112.8052\" y=\"-18.3\">dense_1: Dense</text>\n",
"<polyline fill=\"none\" points=\"164.9312,-.5 164.9312,-44.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"192.7656\" y=\"-29.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"164.9312,-22.5 220.6001,-22.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"192.7656\" y=\"-7.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"220.6001,-.5 220.6001,-44.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"265.9243\" y=\"-29.3\">(None, 2368)</text>\n",
"<polyline fill=\"none\" points=\"220.6001,-22.5 311.2485,-22.5 \" stroke=\"#000000\"/>\n",
"<text fill=\"#000000\" font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"265.9243\" y=\"-7.3\">(None, 20)</text>\n",
"</g>\n",
"<!-- 4700616016&#45;&gt;4699823016 -->\n",
"<g class=\"edge\" id=\"edge10\">\n",
"<title>4700616016-&gt;4699823016</title>\n",
"<path d=\"M185.9639,-81.3664C185.9639,-73.1516 185.9639,-63.6579 185.9639,-54.7252\" fill=\"none\" stroke=\"#000000\"/>\n",
"<polygon fill=\"#000000\" points=\"189.464,-54.6068 185.9639,-44.6068 182.464,-54.6069 189.464,-54.6068\" stroke=\"#000000\"/>\n",
"</g>\n",
"</g>\n",
"</svg>"
],
"text/plain": [
"<IPython.core.display.SVG object>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import SVG\n",
"from keras.utils.vis_utils import model_to_dot\n",
"\n",
"SVG(model_to_dot(cnn_model, show_shapes=True).create(prog='dot', format='svg'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train the CNN model"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 10748 samples, validate on 566 samples\n",
"Epoch 1/20\n",
"10748/10748 [==============================] - 34s 3ms/step - loss: 2.2057 - acc: 0.2430 - val_loss: 2.1786 - val_acc: 0.2597\n",
"Epoch 2/20\n",
"10748/10748 [==============================] - 34s 3ms/step - loss: 1.7938 - acc: 0.3968 - val_loss: 1.7430 - val_acc: 0.4117\n",
"Epoch 3/20\n",
"10748/10748 [==============================] - 33s 3ms/step - loss: 1.3563 - acc: 0.5538 - val_loss: 1.4601 - val_acc: 0.5035\n",
"Epoch 4/20\n",
"10748/10748 [==============================] - 33s 3ms/step - loss: 1.0526 - acc: 0.6555 - val_loss: 1.2644 - val_acc: 0.5866\n",
"Epoch 5/20\n",
"10748/10748 [==============================] - 36s 3ms/step - loss: 0.8561 - acc: 0.7262 - val_loss: 1.1694 - val_acc: 0.6290\n",
"Epoch 6/20\n",
"10748/10748 [==============================] - 33s 3ms/step - loss: 0.7174 - acc: 0.7747 - val_loss: 1.1188 - val_acc: 0.6396\n",
"Epoch 7/20\n",
"10748/10748 [==============================] - 33s 3ms/step - loss: 0.6119 - acc: 0.8128 - val_loss: 1.1169 - val_acc: 0.6625\n",
"Epoch 8/20\n",
"10748/10748 [==============================] - 35s 3ms/step - loss: 0.5235 - acc: 0.8421 - val_loss: 1.1096 - val_acc: 0.6555\n",
"Epoch 9/20\n",
"10748/10748 [==============================] - 37s 3ms/step - loss: 0.4371 - acc: 0.8732 - val_loss: 1.1009 - val_acc: 0.6714\n",
"Epoch 10/20\n",
"10748/10748 [==============================] - 35s 3ms/step - loss: 0.3778 - acc: 0.8897 - val_loss: 1.1384 - val_acc: 0.6714\n",
"Epoch 11/20\n",
"10748/10748 [==============================] - 39s 4ms/step - loss: 0.3155 - acc: 0.9139 - val_loss: 1.0981 - val_acc: 0.6926\n",
"Epoch 12/20\n",
"10748/10748 [==============================] - 41s 4ms/step - loss: 0.2559 - acc: 0.9339 - val_loss: 1.2108 - val_acc: 0.6855\n",
"Epoch 13/20\n",
"10748/10748 [==============================] - 37s 3ms/step - loss: 0.2164 - acc: 0.9467 - val_loss: 1.2011 - val_acc: 0.6943\n",
"Epoch 14/20\n",
"10748/10748 [==============================] - 36s 3ms/step - loss: 0.1719 - acc: 0.9613 - val_loss: 1.2036 - val_acc: 0.6943\n",
"Epoch 15/20\n",
"10748/10748 [==============================] - 40s 4ms/step - loss: 0.1387 - acc: 0.9706 - val_loss: 1.2619 - val_acc: 0.6943\n",
"Epoch 16/20\n",
"10748/10748 [==============================] - 35s 3ms/step - loss: 0.1092 - acc: 0.9810 - val_loss: 1.2944 - val_acc: 0.7049\n",
"Epoch 17/20\n",
"10748/10748 [==============================] - 36s 3ms/step - loss: 0.0886 - acc: 0.9857 - val_loss: 1.3341 - val_acc: 0.6926\n",
"Epoch 18/20\n",
"10748/10748 [==============================] - 43s 4ms/step - loss: 0.0726 - acc: 0.9899 - val_loss: 1.4092 - val_acc: 0.7085\n",
"Epoch 19/20\n",
"10748/10748 [==============================] - 36s 3ms/step - loss: 0.0569 - acc: 0.9928 - val_loss: 1.4334 - val_acc: 0.7032\n",
"Epoch 20/20\n",
"10748/10748 [==============================] - 38s 4ms/step - loss: 0.0449 - acc: 0.9960 - val_loss: 1.4275 - val_acc: 0.7049\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x11c4767b8>"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cnn_model.fit(train_X, train_Y, batch_size=256,validation_split=0.05, epochs=20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualizing the Trained Network Layers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Set up Keras\n",
"The backend needs to be explicitly told that it is in the inferance stage. From the Keras docs: The learning phase flag is a bool tensor (0 = test, 1 = train) to be passed as input to any Keras function that uses a different behavior at train time and test time."
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"K.set_learning_phase(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define helper function "
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"def cstr(s, color='black'):\n",
" return \"<text style=\\\"color:{}\\\">{}</text>\".format(color, s)\n",
"\n",
"def color(hvalue, threshold, max=1, cdefault='black', colors=['red', 'yellow', 'green', 'cyan', 'blue']):\n",
" num_colors = len(colors)\n",
" if hvalue < threshold:\n",
" return cdefault\n",
" for i, color in enumerate(colors):\n",
" if hvalue > (max - (max - threshold) / num_colors * (i + 1)):\n",
" return color\n",
" \n",
"def get_conv_layer(model, layer_name):\n",
" conv_layer = model.get_layer(layer_name)\n",
" output_dim = conv_layer.output_shape[1]\n",
" return conv_layer, output_dim"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define heatmap functions"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"def get_heatmap(model, layer_name, matrix, y_labels):\n",
"\n",
" # obtain probability of the label with the highest certainty\n",
" network_output = model.get_output_at(0)[:, np.argmax(y_labels)]\n",
" # obtain the output vector and its dimension of the convolutional layer we want to visualize\n",
" conv_layer, layer_output_dim = get_conv_layer(model, layer_name)\n",
" # Setting up the calculation of the gradients between the output and the conv layer. Will be executed in the iteration step\n",
" grads = K.gradients(network_output, conv_layer.output)[0]\n",
" \n",
" # average the gradients across our samples (one sample) and all filters\n",
" pooled_grads = K.mean(grads, axis=(0, 2))\n",
" # set up the computation graph\n",
" iterate = K.function([model.get_input_at(0)], [pooled_grads, conv_layer.output[0]])\n",
" # execute the computation graph with our converted document (matrix) as an input\n",
" pooled_grad_value, conv_layer_output_value = iterate([matrix]) \n",
" # loop over every layer output vector element and multiply it by the gradient of the element\n",
" for i in range(layer_output_dim):\n",
" conv_layer_output_value[i] *= pooled_grad_value[i]\n",
" # calculating the average output value for each output dimension across all filters\n",
" heatmap = np.mean(conv_layer_output_value, axis=-1)\n",
" return norm_heatmap(heatmap)\n",
"\n",
"def norm_heatmap(heatmap): \n",
" # element-wise maximum calculation, basically setting all negative values to zero\n",
" heatmap = np.maximum(heatmap, 0)\n",
" # normalizing the heatmap to values between 0 and 1\n",
" norm_heatmap = heatmap / np.max(heatmap)\n",
" return norm_heatmap\n",
"\n",
"def plot_heatmap(heatmap, height_ratio=0.05):\n",
" # calculating how often the vector should be repeated to display a height relative to the vector length\n",
" repeat_vector_n_times = int(heatmap.shape[0] * height_ratio)\n",
" plt.matshow([heatmap] * repeat_vector_n_times)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Determine tokens of interest"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"def get_token_indices(model, layer_name, threshold, matrix, y_labels):\n",
" heatmap = get_heatmap(model=model, layer_name=layer_name, matrix=matrix, y_labels=y_labels)\n",
" _, output_dim = get_conv_layer(model, layer_name)\n",
"\n",
" # depending on the ration between the input and layer output shape, we need to calculate \n",
" # how many original tokens have contributed to the layer output\n",
" dim_ratio = matrix.shape[1] / output_dim\n",
" if dim_ratio < 1.5:\n",
" window_size = 1\n",
" else:\n",
" window_size = 2\n",
" \n",
" indices = {}\n",
" indices_above_threshold = np.where(heatmap > threshold)[0].tolist()\n",
" for i in indices_above_threshold:\n",
" scaled_index = i * int(dim_ratio)\n",
" for ind in range(scaled_index - window_size, scaled_index + window_size + 1):\n",
" if ind not in indices or indices[ind] < heatmap[i]:\n",
" indices.update({ind: heatmap[i]})\n",
" return indices"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Put everything together"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def get_highlighted_tokens(tokens, matrix, model, layer_name, threshold, y_labels):\n",
" indices = get_token_indices(model, layer_name, threshold, matrix, y_labels)\n",
" \n",
" ctokens = []\n",
" for i, t in enumerate(tokens):\n",
" if i in indices.keys():\n",
" _color = color(indices[i], threshold=threshold)\n",
" ctokens.append(cstr(t, color=_color))\n",
" else:\n",
" ctokens.append(t)\n",
" return html_print(cstr(' '.join(ctokens), color='black') )"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(75, 64)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA5wAAABGCAYAAACg/5sFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAACfdJREFUeJzt3WuMHXUZx/Hvr9vdlouKXBShKIhNtTFatEFUYhSiFiXgC2Mgaowh6RtMIMEY9I3RxBe+QX2hJgQQXyioKEgMUQmSeIkp1yqXilYiUgSKF+7S6+OLM+paWnum7Jzdmf1+ks2emfPPzJP9nZmzz5nLSVUhSZIkSdJcWzLfBUiSJEmShsmGU5IkSZLUCRtOSZIkSVInbDglSZIkSZ2w4ZQkSZIkdcKGU5IkSZLUiYk0nEnWJbkvyeYkF09inepGkiuSbE1y96x5hye5Mckfmt8vnc8adWCSHJfk5iT3JrknyQXNfPPtuSTLk9yS5DdNtp9r5p+QZEOzb/5Okpn5rlUHJslUkjuT/KiZNtsBSPKnJHcl2Zjktmae++QBSHJYkmuS/C7JpiRvNdv+S7Kq2V7//fNkkgsXe7adN5xJpoCvAmcAq4Fzk6zuer3qzJXAuj3mXQzcVFUrgZuaafXPTuCiqloNnAKc32yr5tt/24DTquqNwBpgXZJTgC8CX6qq1wD/AM6bxxr1wlwAbJo1bbbD8a6qWlNVa5tp98nD8BXgx1X1WuCNjLZfs+25qrqv2V7XAG8GngWuZZFnO4kjnCcDm6vq/qraDlwNnD2B9aoDVfVz4O97zD4b+Gbz+JvAByZalOZEVT1cVXc0j59i9OZ3LObbezXydDM53fwUcBpwTTPfbHsqyQrg/cBlzXQw2yFzn9xzSV4CvAO4HKCqtlfV45jt0JwO/LGqHmCRZzuJhvNY4MFZ01uaeRqOl1fVw83jR4CXz2cxeuGSHA+cBGzAfAehOeVyI7AVuBH4I/B4Ve1shrhv7q8vA58CdjfTR2C2Q1HAT5PcnmR9M899cv+dADwGfKM5Ff6yJIdgtkNzDnBV83hRZ+tNgzSnqqoYvUGqp5IcCnwfuLCqnpz9nPn2V1Xtak7xWcHozJPXznNJmgNJzgS2VtXt812LOnFqVb2J0WVJ5yd5x+wn3Sf31lLgTcDXq+ok4Bn2OMXSbPutuW7+LOB7ez63GLOdRMP5EHDcrOkVzTwNx6NJXgHQ/N46z/XoACWZZtRsfquqftDMNt8BaU7buhl4K3BYkqXNU+6b++ntwFlJ/sTokpXTGF0bZrYDUFUPNb+3MroO7GTcJw/BFmBLVW1opq9h1ICa7XCcAdxRVY8204s620k0nLcCK5s75s0wOrx8/QTWq8m5HvhY8/hjwA/nsRYdoOa6r8uBTVV1yaynzLfnkhyV5LDm8UHAuxldo3sz8MFmmNn2UFV9uqpWVNXxjN5ff1ZVH8Zsey/JIUle9O/HwHuAu3Gf3HtV9QjwYJJVzazTgXsx2yE5l/+eTguLPNuMjup2vJLkfYyuMZkCrqiqL3S+UnUiyVXAO4EjgUeBzwLXAd8FXgk8AHyoqva8sZAWuCSnAr8A7uK/14J9htF1nObbY0newOgmBVOMPmj8blV9PsmrGR0VOxy4E/hIVW2bv0r1QiR5J/DJqjrTbPuvyfDaZnIp8O2q+kKSI3Cf3HtJ1jC60dcMcD/wcZr9M2bba80HRH8GXl1VTzTzFvV2O5GGU5IkSZK0+HjTIEmSJElSJ2w4JUmSJEmdsOGUJEmSJHXChlOSJEmS1AkbTkmSJElSJybacCZZP8n1aXLMdrjMdrjMdrjMdrjMdrjMdrgWe7aTPsK5qP/YA2e2w2W2w2W2w2W2w2W2w2W2w7Wos/WUWkmSJElSJ1JVc77QmSyr5RzyvPk72MY0y+Z8fXNtx9HPr31fKu2W/fqjHms1/ve/PbjV+MxMtxpf23e0Gr/jxOV7nb/ryWeZevHza535S6vFs2v5VLvxLV9OM3/b1mp8LRv/77l7absXw5LHn201fufLxn9dAix9eler8Tz73F5nz9t2e+hB7cY//c9u6mhkyfifz9Xu3e2WfdDet6t92ra93file9+utu/6JzNTz/87t90vaOHpy/ut/o/s/T1lRz3HdPayz2j7/9w+lr9PHfy/qP/Vm+3W105rc5btAvrbP8czbK9tYxW0tIsClnMIb8npXSx6Ih76+NvGHru7XX/ELed/rdX49x6zptX4pccc12r8zgcebDX+kUte12r8MZ9tNZynVr641fh/rGwXwPFX3t9q/PYTjx577LYjZ1ot+6Drbmk1/tFzxn9dAhz9yydaja8772k1vmu1pt1rP7/a2FElI0sOHr/h3/3MM+2WvarddsXmP7db/pGHtxrfdr/QuSUtd7S7W37YovmzgP55aq3j2rOs3T+nta3dB6qZbveeVTtaftDV9+22Tb4L6XUJvX9tat8W0t9+Q9009tixPrJPsi7JfUk2J7n4gCuTJEmSJC0a+204k0wBXwXOAFYD5yZZ3XVhkiRJkqR+G+cI58nA5qq6v6q2A1cDZ3dbliRJkiSp78ZpOI8FZl/Qs6WZJ0mSJEnSPs3ZTYOaLzRdD7CcdndWlSRJkiQNzzhHOB8CZt/6dEUz739U1aVVtbaq1vbils6SJEmSpE6N03DeCqxMckKSGeAc4Ppuy5IkSZIk9d1+T6mtqp1JPgH8BJgCrqiqhfWFfZIkSZKkBWesazir6gbgho5rkSRJkiQNyDin1EqSJEmS1JoNpyRJkiSpEzackiRJkqRO7LfhTHJFkq1J7p5EQZIkSZKkYRjnCOeVwLqO65AkSZIkDcx+G86q+jnw9wnUIkmSJEkakLG+FmUcSdYD6wGWc/BcLVaSJEmS1FNzdtOgqrq0qtZW1dppls3VYiVJkiRJPeVdaiVJkiRJnbDhlCRJkiR1YpyvRbkK+DWwKsmWJOd1X5YkSZIkqe/2e9Ogqjp3EoVIkiRJkobFU2olSZIkSZ2w4ZQkSZIkdcKGU5IkSZLUiXFuGnRckpuT3JvkniQXTKIwSZIkSVK/7femQcBO4KKquiPJi4Dbk9xYVfd2XJskSZIkqcf2e4Szqh6uqjuax08Bm4Bjuy5MkiRJktRv4xzh/I8kxwMnARv28tx6YD3Acg6eg9IkSZIkSX029k2DkhwKfB+4sKqe3PP5qrq0qtZW1dppls1ljZIkSZKkHhqr4UwyzajZ/FZV/aDbkiRJkiRJQzDOXWoDXA5sqqpLui9JkiRJkjQE4xzhfDvwUeC0JBubn/d1XJckSZIkqef2e9OgqvolkAnUIkmSJEkakFTV3C80eQx4YC9PHQn8dc5XqIXAbIfLbIfLbIfLbIfLbIfLbIdriNm+qqqOGmdgJw3nPleW3FZVaye2Qk2M2Q6X2Q6X2Q6X2Q6X2Q6X2Q7XYs927K9FkSRJkiSpDRtOSZIkSVInJt1wXjrh9WlyzHa4zHa4zHa4zHa4zHa4zHa4FnW2E72GU5IkSZK0eHhKrSRJkiSpEzackiRJkqRO2HBKkiRJkjphwylJkiRJ6oQNpyRJkiSpE/8CswIW+q9C4Y0AAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x11a8cd898>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"text_id = 6050\n",
"layer_name = '2-conv1d'\n",
"\n",
"sequence = train_X[text_id]\n",
"index_to_tokens = {i: l for l, i in tokenizer.word_index.items()}\n",
"tokens = [index_to_tokens[s] if s > 0 else 'pad' for s in sequence]\n",
"\n",
"heatmap = get_heatmap(cnn_model, layer_name, sequence.reshape(1, 300), y_labels=train_Y[text_id])\n",
"heatmap = norm_heatmap(heatmap)\n",
"plot_heatmap(heatmap)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classification for sci.space\n"
]
},
{
"data": {
"text/html": [
"<text style=\"color:black\">from head edu steve subject re drive in how organization ma usa lines 19 <text style=\"color:cyan\">in</text> <text style=\"color:cyan\">article</text> <text style=\"color:cyan\">toronto</text> <text style=\"color:cyan\">edu</text> <text style=\"color:cyan\">toronto</text> <text style=\"color:cyan\">edu</text> <text style=\"color:cyan\">writes</text> <text style=\"color:cyan\">the</text> <text style=\"color:red\">national</text> <text style=\"color:red\">air</text> <text style=\"color:red\">space</text> <text style=\"color:red\">has</text> <text style=\"color:red\">both</text> the and the however quite it s no longer on display <text style=\"color:yellow\">like</text> <text style=\"color:yellow\">most</text> <text style=\"color:yellow\">has</text> <text style=\"color:yellow\">much</text> <text style=\"color:yellow\">more</text> <text style=\"color:blue\">stuff</text> <text style=\"color:blue\">than</text> <text style=\"color:blue\">it</text> <text style=\"color:blue\">can</text> <text style=\"color:blue\">display</text> <text style=\"color:blue\">at</text> <text style=\"color:blue\">once</text> <text style=\"color:cyan\">and</text> <text style=\"color:cyan\">does</text> <text style=\"color:cyan\">the</text> <text style=\"color:cyan\">the</text> <text style=\"color:cyan\">are</text> open to the <text style=\"color:blue\">public</text> <text style=\"color:blue\">all</text> <text style=\"color:blue\">or</text> <text style=\"color:blue\">almost</text> <text style=\"color:blue\">all</text> <text style=\"color:blue\">still</text> <text style=\"color:blue\">in</text> <text style=\"color:blue\">the</text> <text style=\"color:blue\">are</text> available for but <text style=\"color:cyan\">i</text> <text style=\"color:cyan\">don</text> <text style=\"color:cyan\">t</text> <text style=\"color:cyan\">know</text> <text style=\"color:cyan\">about</text> <text style=\"color:blue\">at</text> <text style=\"color:blue\">least</text> <text style=\"color:blue\">it</text> <text style=\"color:blue\">might</text> be worth a try i m not sure if are necessary <text style=\"color:blue\">but</text> <text style=\"color:blue\">i</text> <text style=\"color:blue\">think</text> <text style=\"color:blue\">not</text> <text style=\"color:blue\">good</text> and let us know what you find steve phone bitnet ma usa internet edu league <text style=\"color:blue\">for</text> <text style=\"color:blue\">contact</text> <text style=\"color:blue\">uunet</text> <text style=\"color:blue\">net</text> <text style=\"color:blue\">pad</text> pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad pad</text>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": 70,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(\"Classification for {}\".format(labels[np.argmax(train_Y[text_id])]))\n",
"\n",
"get_highlighted_tokens(tokens=tokens, \n",
" matrix=sequence.reshape(1, 300), \n",
" model=cnn_model, \n",
" layer_name=layer_name, \n",
" threshold=0.1,\n",
" y_labels=train_Y[text_id])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment