Skip to content

Instantly share code, notes, and snippets.

Last active December 13, 2022 16:51
Show Gist options
  • Save madrugado/63c068b52a135c6fdbbb6fe17acbc0c8 to your computer and use it in GitHub Desktop.
Save madrugado/63c068b52a135c6fdbbb6fe17acbc0c8 to your computer and use it in GitHub Desktop.
Keras usage example, simple text classification
Display the source blob
Display the rendered blob
"cells": [
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook is based on [this example from Francois Chollet]("
"cell_type": "markdown",
"metadata": {},
"source": [
"Train and evaluate a simple MLP on the 20 newsgroups topic classification task."
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
"source": [
"import numpy as np\n",
"import keras\n",
"from keras.models import Sequential, Model\n",
"from keras.layers import Dense, Dropout, Activation, Input\n",
"from keras.preprocessing.text import Tokenizer"
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
"outputs": [],
"source": [
"max_words = 1000\n",
"batch_size = 32\n",
"epochs = 5"
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
"outputs": [],
"source": [
"from sklearn.datasets import fetch_20newsgroups\n",
"newsgroups_train = fetch_20newsgroups(subset='train')\n",
"newsgroups_test = fetch_20newsgroups(subset='test')"
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"From: (where's my thing)\n",
"Subject: WHAT car is this!?\n",
"Organization: University of Maryland, College Park\n",
"Lines: 15\n",
" I was wondering if anyone out there could enlighten me on this car I saw\n",
"the other day. It was a 2-door sports car, looked to be from the late 60s/\n",
"early 70s. It was called a Bricklin. The doors were really small. In addition,\n",
"the front bumper was separate from the rest of the body. This is \n",
"all I know. If anyone can tellme a model name, engine specs, years\n",
"of production, where this car is made, history, or whatever info you\n",
"have on this funky looking car, please e-mail.\n",
"- IL\n",
" ---- brought to you by your neighborhood Lerxst ----\n",
"source": [
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Preparing the Tokenizer...\n"
"source": [
"print(\"Preparing the Tokenizer...\")\n",
"tokenizer = Tokenizer(num_words=max_words)\n",
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Vectorizing sequence data...\n",
"x_train shape: (11314, 1000)\n",
"x_test shape: (7532, 1000)\n"
"source": [
"print('Vectorizing sequence data...')\n",
"x_train = tokenizer.texts_to_matrix(newsgroups_train[\"data\"], mode='binary')\n",
"x_test = tokenizer.texts_to_matrix(newsgroups_test[\"data\"], mode='binary')\n",
"print('x_train shape:', x_train.shape)\n",
"print('x_test shape:', x_test.shape)"
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
"outputs": [
"data": {
"text/plain": [
"array([[ 0., 1., 1., ..., 0., 0., 0.],\n",
" [ 0., 1., 1., ..., 0., 0., 0.],\n",
" [ 0., 1., 1., ..., 1., 0., 0.],\n",
" ..., \n",
" [ 0., 1., 1., ..., 0., 0., 0.],\n",
" [ 0., 1., 1., ..., 0., 0., 0.],\n",
" [ 0., 0., 0., ..., 0., 0., 0.]])"
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"20 classes\n"
"source": [
"num_classes = np.max(newsgroups_train[\"target\"]) + 1\n",
"print(num_classes, 'classes')"
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Convert class vector to binary class matrix (for use with categorical_crossentropy)\n",
"y_train shape: (11314, 20)\n",
"y_test shape: (7532, 20)\n"
"source": [
"print('Convert class vector to binary class matrix '\n",
" '(for use with categorical_crossentropy)')\n",
"y_train = keras.utils.to_categorical(newsgroups_train[\"target\"], num_classes)\n",
"y_test = keras.utils.to_categorical(newsgroups_test[\"target\"], num_classes)\n",
"print('y_train shape:', y_train.shape)\n",
"print('y_test shape:', y_test.shape)"
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
"outputs": [
"data": {
"text/plain": [
"array([[ 0., 0., 0., ..., 0., 0., 0.],\n",
" [ 0., 0., 0., ..., 0., 0., 0.],\n",
" [ 0., 0., 0., ..., 0., 0., 0.],\n",
" ..., \n",
" [ 0., 0., 0., ..., 0., 0., 0.],\n",
" [ 0., 1., 0., ..., 0., 0., 0.],\n",
" [ 0., 0., 0., ..., 0., 0., 0.]])"
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Building model sequentially 1...\n"
"source": [
"print('Building model sequentially 1...')\n",
"model = Sequential()\n",
"model.add(Dense(512, input_shape=(max_words,)))\n",
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Building model sequentially 2...\n"
"source": [
"print('Building model sequentially 2...')\n",
"model = Sequential([\n",
" Dense(512, input_shape=(max_words,)),\n",
" Activation('relu'),\n",
" Dropout(0.5),\n",
" Dense(num_classes),\n",
" Activation('softmax')\n",
" ])"
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
"outputs": [
"data": {
"text/plain": [
"[<keras.layers.core.Dense at 0x1123c4b00>,\n",
" <keras.layers.core.Activation at 0x1122db780>,\n",
" <keras.layers.core.Dropout at 0x1122db940>,\n",
" <keras.layers.core.Dense at 0x1122dbe10>,\n",
" <keras.layers.core.Activation at 0x112325390>]"
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"backend: tensorflow\n",
"class_name: Sequential\n",
"- class_name: Dense\n",
" config:\n",
" activation: linear\n",
" activity_regularizer: null\n",
" batch_input_shape: !!python/tuple [null, 1000]\n",
" bias_constraint: null\n",
" bias_initializer:\n",
" class_name: Zeros\n",
" config: {}\n",
" bias_regularizer: null\n",
" dtype: float32\n",
" kernel_constraint: null\n",
" kernel_initializer:\n",
" class_name: VarianceScaling\n",
" config: {distribution: uniform, mode: fan_avg, scale: 1.0, seed: null}\n",
" kernel_regularizer: null\n",
" name: dense_3\n",
" trainable: true\n",
" units: 512\n",
" use_bias: true\n",
"- class_name: Activation\n",
" config: {activation: relu, name: activation_3, trainable: true}\n",
"- class_name: Dropout\n",
" config: {name: dropout_2, rate: 0.5, trainable: true}\n",
"- class_name: Dense\n",
" config:\n",
" activation: linear\n",
" activity_regularizer: null\n",
" bias_constraint: null\n",
" bias_initializer:\n",
" class_name: Zeros\n",
" config: {}\n",
" bias_regularizer: null\n",
" kernel_constraint: null\n",
" kernel_initializer:\n",
" class_name: VarianceScaling\n",
" config: {distribution: uniform, mode: fan_avg, scale: 1.0, seed: null}\n",
" kernel_regularizer: null\n",
" name: dense_4\n",
" trainable: true\n",
" units: !!python/object/apply:numpy.core.multiarray.scalar\n",
" - !!python/object/apply:numpy.dtype\n",
" args: [i8, 0, 1]\n",
" state: !!python/tuple [3, <, null, null, null, -1, -1, 0]\n",
" - !!binary |\n",
" use_bias: true\n",
"- class_name: Activation\n",
" config: {activation: softmax, name: activation_4, trainable: true}\n",
"keras_version: 2.0.2\n",
"source": [
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Building model functionally...\n"
"source": [
"print('Building model functionally...')\n",
"a = Input(shape=(max_words,))\n",
"b = Dense(512)(a)\n",
"b = Activation('relu')(b)\n",
"b = Dropout(0.5)(b)\n",
"b = Dense(num_classes)(b)\n",
"b = Activation('softmax')(b)\n",
"model = Model(inputs=a, outputs=b)"
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false
"outputs": [],
"source": [
"from keras.models import model_from_yaml\n",
"yaml_string = model.to_yaml()\n",
"model = model_from_yaml(yaml_string)"
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied (use --upgrade to upgrade): pydot-ng in /Users/madrugado/anaconda3/lib/python3.5/site-packages\n",
"Requirement already satisfied (use --upgrade to upgrade): pyparsing>=2.0.1 in /Users/madrugado/anaconda3/lib/python3.5/site-packages (from pydot-ng)\n",
"\u001b[33mYou are using pip version 8.1.2, however version 9.0.1 is available.\n",
"You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n"
"source": [
"! pip install pydot-ng"
"cell_type": "code",
"execution_count": 27,
"metadata": {
"collapsed": false
"outputs": [],
"source": [
"from keras.utils import plot_model\n",
"plot_model(model, to_file='model.png', show_shapes=True)"
"cell_type": "code",
"execution_count": 28,
"metadata": {
"collapsed": false
"outputs": [
"data": {
"image/svg+xml": [
"<svg height=\"458pt\" viewBox=\"0.00 0.00 298.24 458.00\" width=\"298pt\" xmlns=\"\" xmlns:xlink=\"\">\n",
"<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 454)\">\n",
"<polygon fill=\"white\" points=\"-4,4 -4,-454 294.238,-454 294.238,4 -4,4\" stroke=\"none\"/>\n",
"<!-- 4600907424 -->\n",
"<g class=\"node\" id=\"node1\"><title>4600907424</title>\n",
"<polygon fill=\"none\" points=\"7.7793,-405.5 7.7793,-449.5 282.459,-449.5 282.459,-405.5 7.7793,-405.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"71.9604\" y=\"-423.3\">input_1: InputLayer</text>\n",
"<polyline fill=\"none\" points=\"136.142,-405.5 136.142,-449.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"163.976\" y=\"-434.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"136.142,-427.5 191.811,-427.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"163.976\" y=\"-412.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"191.811,-405.5 191.811,-449.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"237.135\" y=\"-434.3\">(None, 1000)</text>\n",
"<polyline fill=\"none\" points=\"191.811,-427.5 282.459,-427.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"237.135\" y=\"-412.3\">(None, 1000)</text>\n",
"<!-- 4600907312 -->\n",
"<g class=\"node\" id=\"node2\"><title>4600907312</title>\n",
"<polygon fill=\"none\" points=\"19.8345,-324.5 19.8345,-368.5 270.404,-368.5 270.404,-324.5 19.8345,-324.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"71.9604\" y=\"-342.3\">dense_5: Dense</text>\n",
"<polyline fill=\"none\" points=\"124.086,-324.5 124.086,-368.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"151.921\" y=\"-353.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"124.086,-346.5 179.755,-346.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"151.921\" y=\"-331.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"179.755,-324.5 179.755,-368.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"225.08\" y=\"-353.3\">(None, 1000)</text>\n",
"<polyline fill=\"none\" points=\"179.755,-346.5 270.404,-346.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"225.08\" y=\"-331.3\">(None, 512)</text>\n",
"<!-- 4600907424&#45;&gt;4600907312 -->\n",
"<g class=\"edge\" id=\"edge1\"><title>4600907424-&gt;4600907312</title>\n",
"<path d=\"M145.119,-405.329C145.119,-397.183 145.119,-387.699 145.119,-378.797\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"148.619,-378.729 145.119,-368.729 141.619,-378.729 148.619,-378.729\" stroke=\"black\"/>\n",
"<!-- 4601527376 -->\n",
"<g class=\"node\" id=\"node3\"><title>4601527376</title>\n",
"<polygon fill=\"none\" points=\"0,-243.5 0,-287.5 290.238,-287.5 290.238,-243.5 0,-243.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"75.4604\" y=\"-261.3\">activation_5: Activation</text>\n",
"<polyline fill=\"none\" points=\"150.921,-243.5 150.921,-287.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"178.755\" y=\"-272.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"150.921,-265.5 206.59,-265.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"178.755\" y=\"-250.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"206.59,-243.5 206.59,-287.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"248.414\" y=\"-272.3\">(None, 512)</text>\n",
"<polyline fill=\"none\" points=\"206.59,-265.5 290.238,-265.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"248.414\" y=\"-250.3\">(None, 512)</text>\n",
"<!-- 4600907312&#45;&gt;4601527376 -->\n",
"<g class=\"edge\" id=\"edge2\"><title>4600907312-&gt;4601527376</title>\n",
"<path d=\"M145.119,-324.329C145.119,-316.183 145.119,-306.699 145.119,-297.797\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"148.619,-297.729 145.119,-287.729 141.619,-297.729 148.619,-297.729\" stroke=\"black\"/>\n",
"<!-- 4601040736 -->\n",
"<g class=\"node\" id=\"node4\"><title>4601040736</title>\n",
"<polygon fill=\"none\" points=\"11.6587,-162.5 11.6587,-206.5 278.58,-206.5 278.58,-162.5 11.6587,-162.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"75.4604\" y=\"-180.3\">dropout_3: Dropout</text>\n",
"<polyline fill=\"none\" points=\"139.262,-162.5 139.262,-206.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"167.097\" y=\"-191.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"139.262,-184.5 194.931,-184.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"167.097\" y=\"-169.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"194.931,-162.5 194.931,-206.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"236.755\" y=\"-191.3\">(None, 512)</text>\n",
"<polyline fill=\"none\" points=\"194.931,-184.5 278.58,-184.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"236.755\" y=\"-169.3\">(None, 512)</text>\n",
"<!-- 4601527376&#45;&gt;4601040736 -->\n",
"<g class=\"edge\" id=\"edge3\"><title>4601527376-&gt;4601040736</title>\n",
"<path d=\"M145.119,-243.329C145.119,-235.183 145.119,-225.699 145.119,-216.797\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"148.619,-216.729 145.119,-206.729 141.619,-216.729 148.619,-216.729\" stroke=\"black\"/>\n",
"<!-- 4600579912 -->\n",
"<g class=\"node\" id=\"node5\"><title>4600579912</title>\n",
"<polygon fill=\"none\" points=\"23.3345,-81.5 23.3345,-125.5 266.904,-125.5 266.904,-81.5 23.3345,-81.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"75.4604\" y=\"-99.3\">dense_6: Dense</text>\n",
"<polyline fill=\"none\" points=\"127.586,-81.5 127.586,-125.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"155.421\" y=\"-110.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"127.586,-103.5 183.255,-103.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"155.421\" y=\"-88.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"183.255,-81.5 183.255,-125.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"225.08\" y=\"-110.3\">(None, 512)</text>\n",
"<polyline fill=\"none\" points=\"183.255,-103.5 266.904,-103.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"225.08\" y=\"-88.3\">(None, 20)</text>\n",
"<!-- 4601040736&#45;&gt;4600579912 -->\n",
"<g class=\"edge\" id=\"edge4\"><title>4601040736-&gt;4600579912</title>\n",
"<path d=\"M145.119,-162.329C145.119,-154.183 145.119,-144.699 145.119,-135.797\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"148.619,-135.729 145.119,-125.729 141.619,-135.729 148.619,-135.729\" stroke=\"black\"/>\n",
"<!-- 4601423128 -->\n",
"<g class=\"node\" id=\"node6\"><title>4601423128</title>\n",
"<polygon fill=\"none\" points=\"3.5,-0.5 3.5,-44.5 286.738,-44.5 286.738,-0.5 3.5,-0.5\" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"78.9604\" y=\"-18.3\">activation_6: Activation</text>\n",
"<polyline fill=\"none\" points=\"154.421,-0.5 154.421,-44.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"182.255\" y=\"-29.3\">input:</text>\n",
"<polyline fill=\"none\" points=\"154.421,-22.5 210.09,-22.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"182.255\" y=\"-7.3\">output:</text>\n",
"<polyline fill=\"none\" points=\"210.09,-0.5 210.09,-44.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"248.414\" y=\"-29.3\">(None, 20)</text>\n",
"<polyline fill=\"none\" points=\"210.09,-22.5 286.738,-22.5 \" stroke=\"black\"/>\n",
"<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"248.414\" y=\"-7.3\">(None, 20)</text>\n",
"<!-- 4600579912&#45;&gt;4601423128 -->\n",
"<g class=\"edge\" id=\"edge5\"><title>4600579912-&gt;4601423128</title>\n",
"<path d=\"M145.119,-81.3294C145.119,-73.1826 145.119,-63.6991 145.119,-54.7971\" fill=\"none\" stroke=\"black\"/>\n",
"<polygon fill=\"black\" points=\"148.619,-54.729 145.119,-44.729 141.619,-54.729 148.619,-54.729\" stroke=\"black\"/>\n",
"text/plain": [
"<IPython.core.display.SVG object>"
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
"source": [
"from IPython.display import SVG\n",
"from keras.utils.vis_utils import model_to_dot\n",
"SVG(model_to_dot(model, show_shapes=True).create(prog='dot', format='svg'))"
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": true
"outputs": [],
"source": [
"from keras.objectives import categorical_crossentropy\n",
"from keras import backend as K\n",
"epsilon = 1.0e-9\n",
"def custom_objective(y_true, y_pred):\n",
" '''Yet another crossentropy'''\n",
" y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)\n",
" y_pred /= K.sum(y_pred, axis=-1, keepdims=True)\n",
" cce = categorical_crossentropy(y_pred, y_true)\n",
" return cce"
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": false
"outputs": [],
"source": [
" optimizer='adam',\n",
" metrics=['accuracy'])"
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": false
"outputs": [],
"source": [
" optimizer='adam',\n",
" metrics=['accuracy'])"
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": false
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 10182 samples, validate on 1132 samples\n",
"Epoch 1/5\n",
"10182/10182 [==============================] - 6s - loss: 11.2602 - acc: 0.3476 - val_loss: 8.1297 - val_acc: 0.5442\n",
"Epoch 2/5\n",
"10182/10182 [==============================] - 5s - loss: 7.1478 - acc: 0.5968 - val_loss: 6.9277 - val_acc: 0.5998\n",
"Epoch 3/5\n",
"10182/10182 [==============================] - 5s - loss: 5.6592 - acc: 0.6782 - val_loss: 5.8904 - val_acc: 0.6564\n",
"Epoch 4/5\n",
"10182/10182 [==============================] - 5s - loss: 4.8580 - acc: 0.7209 - val_loss: 5.7133 - val_acc: 0.6643\n",
"Epoch 5/5\n",
"10182/10182 [==============================] - 5s - loss: 4.5376 - acc: 0.7376 - val_loss: 5.5546 - val_acc: 0.6687\n"
"source": [
"history =, y_train,\n",
" batch_size=batch_size,\n",
" epochs=epochs,\n",
" verbose=1,\n",
" validation_split=0.1)"
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": false
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 10182 samples, validate on 1132 samples\n",
"Epoch 1/5\n",
"10182/10182 [==============================] - 6s - loss: 5.0572 - acc: 0.6996 - val_loss: 5.9746 - val_acc: 0.6396\n",
"Epoch 2/5\n",
"10182/10182 [==============================] - 6s - loss: 4.8309 - acc: 0.7135 - val_loss: 5.9775 - val_acc: 0.6396\n"
"source": [
"from keras.callbacks import EarlyStopping \n",
"early_stopping=EarlyStopping(monitor='val_loss') \n",
"history =, y_train,\n",
" batch_size=batch_size,\n",
" epochs=epochs,\n",
" verbose=1,\n",
" validation_split=0.1,\n",
" callbacks=[early_stopping])"
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": false
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 10182 samples, validate on 1132 samples\n",
"Epoch 1/5\n",
"10182/10182 [==============================] - 5s - loss: 4.6762 - acc: 0.7209 - val_loss: 5.8910 - val_acc: 0.6475\n",
"Epoch 2/5\n",
"10182/10182 [==============================] - 5s - loss: 4.5583 - acc: 0.7272 - val_loss: 5.8414 - val_acc: 0.6493\n",
"Epoch 3/5\n",
"10182/10182 [==============================] - 5s - loss: 4.4485 - acc: 0.7323 - val_loss: 5.9157 - val_acc: 0.6422\n",
"Epoch 4/5\n",
"10182/10182 [==============================] - 5s - loss: 4.3723 - acc: 0.7361 - val_loss: 5.9310 - val_acc: 0.6369\n",
"Epoch 5/5\n",
"10182/10182 [==============================] - 5s - loss: 4.3307 - acc: 0.7380 - val_loss: 5.7791 - val_acc: 0.6511\n"
"source": [
"from keras.callbacks import TensorBoard \n",
"tensorboard=TensorBoard(log_dir='./logs', write_graph=True, write_images=True)\n",
"history =, y_train,\n",
" batch_size=batch_size,\n",
" epochs=epochs,\n",
" verbose=1,\n",
" validation_split=0.1,\n",
" callbacks=[tensorboard])"
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": false
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"7328/7532 [============================>.] - ETA: 0s\n",
"Test score: 7.05705657565\n",
"Test accuracy: 0.572623473149\n"
"source": [
"score = model.evaluate(x_test, y_test,\n",
" batch_size=batch_size, verbose=1)\n",
"print('Test score:', score[0])\n",
"print('Test accuracy:', score[1])"
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
"outputs": [],
"source": []
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [conda root]",
"language": "python",
"name": "conda-root-py"
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
"nbformat": 4,
"nbformat_minor": 1
Copy link

bumsun commented Apr 20, 2017

text = "From: (where's my thing) Subject: WHAT car is this!? Nntp-Posting-Host: Organization: University of Maryland, College Park Lines: 15 I was wondering if anyone out there could enlighten me on this car I saw the other day. It was a 2-door sports car, looked to be from the late 60s/ early 70s. It was called a Bricklin. The doors were really small. In addition, the front bumper was separate from the rest of the body. This is  all I know. If anyone can tellme a model name, engine specs, years of production, where this car is made, history, or whatever info you have on this funky looking car, please e-mail.  Thanks, - IL ---- brought to you by your neighborhood Lerxst ----"
prediction = model.predict(np.array(tokenizer.texts_to_matrix(text, mode='binary')))

Я новичок в машинном обучении. Подскажите пожалуйста, как сделать прогноз для данного сообщения? Потому что у меня что-то странное выдается, т.к. я скорее всего не правильно обрабатываю текст)

Copy link

feeeper commented May 15, 2017

@bumsun, похоже, что проблема в том, что tokenizer.texts_to_matrix принимает массив объектов. Если передаётся строка, то метод считает, что это массив символов и творит ерунду. У меня такой код работает, похоже, правильно:

prediction = model.predict(np.array(tokenizer.texts_to_matrix([text], mode='binary'))) # text заменил на [text]
print(prediction.shape) # (1, 20)
print(prediction) # массив из двадцати значений. i-ый элемент массива указывает вероятность того, что текст относится к i-ой категории.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment