-
-
Save alexminnaar/746188692902fac3c36ed249760ee22e to your computer and use it in GitHub Desktop.
character_rnn_for_ner.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "character_rnn_for_ner.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"machine_shape": "hm", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/alexminnaar/746188692902fac3c36ed249760ee22e/character_rnn_for_ner.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "9XNJRcmoIJQU", | |
"colab_type": "code", | |
"outputId": "067cdcb1-7b92-4e8b-b28d-2f3020710fa7", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 51 | |
} | |
}, | |
"source": [ | |
"\n", | |
"try:\n", | |
" # %tensorflow_version only exists in Colab.\n", | |
" %tensorflow_version 2.x\n", | |
"except Exception:\n", | |
" pass\n", | |
"import tensorflow as tf\n", | |
"\n", | |
"import numpy as np\n", | |
"import os\n", | |
"import time\n", | |
"\n", | |
"print(tf.__version__)" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"TensorFlow 2.x selected.\n", | |
"2.0.0\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "chLTt1VrIa2s", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"labels = set()\n", | |
"\n", | |
"def file2Examples(file_name):\n", | |
" '''\n", | |
" Read data files and return input/output pairs\n", | |
" '''\n", | |
" \n", | |
" examples=[]\n", | |
"\n", | |
" with open(file_name,\"r\") as f:\n", | |
"\n", | |
" next(f)\n", | |
" next(f)\n", | |
"\n", | |
" example = [[],[]]\n", | |
"\n", | |
" for line in f:\n", | |
"\n", | |
" input_output_split= line.split()\n", | |
"\n", | |
" if len(input_output_split)==4:\n", | |
" example[0].append(input_output_split[0])\n", | |
" example[1].append(input_output_split[-1])\n", | |
" labels.add(input_output_split[-1])\n", | |
"\n", | |
" elif len(input_output_split)==0:\n", | |
" examples.append(example)\n", | |
" example=[[],[]]\n", | |
" else:\n", | |
" example=[[],[]]\n", | |
"\n", | |
" f.close()\n", | |
" \n", | |
" return examples\n", | |
" \n", | |
"# Extract examples from train, validation, and test files which can be found at \n", | |
"# https://github.com/davidsbatista/NER-datasets/tree/master/CONLL2003\n", | |
"train_examples = file2Examples(\"train.txt\")\n", | |
"test_examples = file2Examples(\"test.txt\")\n", | |
"valid_examples = file2Examples(\"valid.txt\")" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Of8tTTetJFAO", | |
"colab_type": "code", | |
"outputId": "b48a9a2f-b529-4873-ee10-f566eb5eb94d", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 71 | |
} | |
}, | |
"source": [ | |
" # create character vocab\n", | |
" all_text = \" \".join([\" \".join(x[0]) for x in train_examples+valid_examples+test_examples])\n", | |
" vocab = sorted(set(all_text))\n", | |
" \n", | |
" # create character/id and label/id mapping\n", | |
" char2idx = {u:i+1 for i, u in enumerate(vocab)}\n", | |
" idx2char = np.array(vocab)\n", | |
" label2idx = {u:i+1 for i, u in enumerate(labels)}\n", | |
" idx2label = np.array(labels)\n", | |
" \n", | |
" print(idx2label)\n", | |
" print(char2idx)" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"{'B-LOC', 'I-PER', 'I-MISC', 'I-ORG', 'O', 'B-MISC', 'B-PER', 'I-LOC', 'B-ORG'}\n", | |
"{' ': 1, '!': 2, '\"': 3, '#': 4, '$': 5, '%': 6, '&': 7, \"'\": 8, '(': 9, ')': 10, '*': 11, '+': 12, ',': 13, '-': 14, '.': 15, '/': 16, '0': 17, '1': 18, '2': 19, '3': 20, '4': 21, '5': 22, '6': 23, '7': 24, '8': 25, '9': 26, ':': 27, ';': 28, '=': 29, '?': 30, '@': 31, 'A': 32, 'B': 33, 'C': 34, 'D': 35, 'E': 36, 'F': 37, 'G': 38, 'H': 39, 'I': 40, 'J': 41, 'K': 42, 'L': 43, 'M': 44, 'N': 45, 'O': 46, 'P': 47, 'Q': 48, 'R': 49, 'S': 50, 'T': 51, 'U': 52, 'V': 53, 'W': 54, 'X': 55, 'Y': 56, 'Z': 57, '[': 58, ']': 59, '`': 60, 'a': 61, 'b': 62, 'c': 63, 'd': 64, 'e': 65, 'f': 66, 'g': 67, 'h': 68, 'i': 69, 'j': 70, 'k': 71, 'l': 72, 'm': 73, 'n': 74, 'o': 75, 'p': 76, 'q': 77, 'r': 78, 's': 79, 't': 80, 'u': 81, 'v': 82, 'w': 83, 'x': 84, 'y': 85, 'z': 86}\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "RHtLQrq4JhJe", | |
"colab_type": "code", | |
"outputId": "dfd18e10-a57f-4b4f-e2ff-db39f89eec80", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 68 | |
} | |
}, | |
"source": [ | |
" def split_char_labels(eg):\n", | |
" '''\n", | |
" For a given input/output example, break tokens into characters while keeping \n", | |
" the same label.\n", | |
" '''\n", | |
"\n", | |
" tokens = eg[0]\n", | |
" labels=eg[1]\n", | |
"\n", | |
" input_chars = []\n", | |
" output_char_labels = []\n", | |
"\n", | |
" for token,label in zip(tokens,labels):\n", | |
"\n", | |
" input_chars.extend([char for char in token])\n", | |
" input_chars.extend(' ')\n", | |
" output_char_labels.extend([label]*len(token))\n", | |
" output_char_labels.extend('O')\n", | |
"\n", | |
" return [[char2idx[x] for x in input_chars[:-1]],np.array([label2idx[x] for x in output_char_labels[:-1]])]\n", | |
" \n", | |
" train_formatted = [split_char_labels(eg) for eg in train_examples]\n", | |
" test_formatted = [split_char_labels(eg) for eg in test_examples]\n", | |
" valid_formatted = [split_char_labels(eg) for eg in valid_examples]\n", | |
" \n", | |
" print(len(train_formatted))\n", | |
" print(len(test_formatted))\n", | |
" print(len(valid_formatted))" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"14985\n", | |
"3682\n", | |
"3464\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LDRSAIObKBL8", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 289 | |
}, | |
"outputId": "eae715eb-e60b-4e56-9998-860e2f2c3ea2" | |
}, | |
"source": [ | |
" # training generator\n", | |
" def gen_train_series():\n", | |
"\n", | |
" for eg in train_formatted:\n", | |
" yield eg[0],eg[1]\n", | |
" \n", | |
" # validation generator\n", | |
" def gen_valid_series():\n", | |
" \n", | |
" for eg in valid_formatted:\n", | |
" yield eg[0],eg[1]\n", | |
" \n", | |
" # test generator\n", | |
" def gen_test_series():\n", | |
"\n", | |
" for eg in test_formatted:\n", | |
" yield eg[0],eg[1]\n", | |
" \n", | |
" # create Dataset objects for train, test and validation sets \n", | |
" series = tf.data.Dataset.from_generator(gen_train_series,output_types=(tf.int32, tf.int32),output_shapes = ((None, None)))\n", | |
" series_valid = tf.data.Dataset.from_generator(gen_valid_series,output_types=(tf.int32, tf.int32),output_shapes = ((None, None)))\n", | |
" series_test = tf.data.Dataset.from_generator(gen_test_series,output_types=(tf.int32, tf.int32),output_shapes = ((None, None)))\n", | |
"\n", | |
" BATCH_SIZE = 128\n", | |
" BUFFER_SIZE=1000\n", | |
" \n", | |
" # create padded batch series objects for train, test and validation sets\n", | |
" ds_series_batch = series.shuffle(BUFFER_SIZE).padded_batch(BATCH_SIZE, padded_shapes=([None], [None]), drop_remainder=True)\n", | |
" ds_series_batch_valid = series_valid.padded_batch(BATCH_SIZE, padded_shapes=([None], [None]), drop_remainder=True)\n", | |
" ds_series_batch_test = series_test.padded_batch(BATCH_SIZE, padded_shapes=([None], [None]), drop_remainder=True)\n", | |
" \n", | |
" # print example batches\n", | |
" for input_example_batch, target_example_batch in ds_series_batch_valid.take(1):\n", | |
" print(input_example_batch)\n", | |
" print(target_example_batch)" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"tf.Tensor(\n", | |
"[[34 49 40 ... 0 0 0]\n", | |
" [43 46 45 ... 0 0 0]\n", | |
" [54 65 79 ... 0 0 0]\n", | |
" ...\n", | |
" [ 3 1 36 ... 0 0 0]\n", | |
" [40 66 1 ... 0 0 0]\n", | |
" [35 81 78 ... 0 0 0]], shape=(128, 228), dtype=int32)\n", | |
"tf.Tensor(\n", | |
"[[5 5 5 ... 0 0 0]\n", | |
" [1 1 1 ... 0 0 0]\n", | |
" [6 6 6 ... 0 0 0]\n", | |
" ...\n", | |
" [5 5 5 ... 0 0 0]\n", | |
" [5 5 5 ... 0 0 0]\n", | |
" [7 7 7 ... 0 0 0]], shape=(128, 228), dtype=int32)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4msLXztSJtqo", | |
"colab_type": "code", | |
"outputId": "7711a8fa-6633-42f4-aa91-e2fa972f940f", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 255 | |
} | |
}, | |
"source": [ | |
" vocab_size = len(vocab)+1\n", | |
"\n", | |
" # The embedding dimension\n", | |
" embedding_dim = 256\n", | |
"\n", | |
" # Number of RNN units\n", | |
" rnn_units = 1024\n", | |
"\n", | |
" label_size = len(labels) \n", | |
" \n", | |
" # build LSTM model\n", | |
" def build_model(vocab_size,label_size, embedding_dim, rnn_units, batch_size):\n", | |
" model = tf.keras.Sequential([\n", | |
" tf.keras.layers.Embedding(vocab_size, embedding_dim,\n", | |
" batch_input_shape=[batch_size, None],mask_zero=True),\n", | |
" tf.keras.layers.LSTM(rnn_units,\n", | |
" return_sequences=True,\n", | |
" stateful=True,\n", | |
" recurrent_initializer='glorot_uniform'),\n", | |
" tf.keras.layers.Dense(label_size)\n", | |
" ])\n", | |
" return model\n", | |
"\n", | |
" model = build_model(\n", | |
" vocab_size = len(vocab)+1,\n", | |
" label_size=len(labels)+1,\n", | |
" embedding_dim=embedding_dim,\n", | |
" rnn_units=rnn_units,\n", | |
" batch_size=BATCH_SIZE)\n", | |
"\n", | |
" model.summary()" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Model: \"sequential\"\n", | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"embedding (Embedding) (128, None, 256) 22272 \n", | |
"_________________________________________________________________\n", | |
"lstm (LSTM) (128, None, 1024) 5246976 \n", | |
"_________________________________________________________________\n", | |
"dense (Dense) (128, None, 10) 10250 \n", | |
"=================================================================\n", | |
"Total params: 5,279,498\n", | |
"Trainable params: 5,279,498\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "S1vnxVFcK1Hk", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
" import os\n", | |
"\n", | |
" # define loss function\n", | |
" def loss(labels, logits):\n", | |
" return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)\n", | |
"\n", | |
" model.compile(optimizer='adam', loss=loss,metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\n", | |
"\n", | |
" # Directory where the checkpoints will be saved\n", | |
" checkpoint_dir = './training_checkpoints'\n", | |
" # Name of the checkpoint files\n", | |
" checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")\n", | |
"\n", | |
" checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(\n", | |
" filepath=checkpoint_prefix,\n", | |
" save_weights_only=True)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2CQ2I9UDK9ng", | |
"colab_type": "code", | |
"outputId": "4d319e29-9e83-49ec-d66f-78efcd8a7812", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 717 | |
} | |
}, | |
"source": [ | |
" EPOCHS=20\n", | |
" \n", | |
" history = model.fit(ds_series_batch, epochs=EPOCHS, validation_data=ds_series_batch_valid,callbacks=[checkpoint_callback])" | |
], | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1/20\n", | |
"117/117 [==============================] - 67s 575ms/step - loss: 0.2180 - sparse_categorical_accuracy: 0.7980 - val_loss: 0.0000e+00 - val_sparse_categorical_accuracy: 0.0000e+00\n", | |
"Epoch 2/20\n", | |
"117/117 [==============================] - 57s 489ms/step - loss: 0.1282 - sparse_categorical_accuracy: 0.8415 - val_loss: 0.1121 - val_sparse_categorical_accuracy: 0.8583\n", | |
"Epoch 3/20\n", | |
"117/117 [==============================] - 57s 491ms/step - loss: 0.1007 - sparse_categorical_accuracy: 0.8672 - val_loss: 0.0985 - val_sparse_categorical_accuracy: 0.8778\n", | |
"Epoch 4/20\n", | |
"117/117 [==============================] - 57s 488ms/step - loss: 0.0894 - sparse_categorical_accuracy: 0.8822 - val_loss: 0.0919 - val_sparse_categorical_accuracy: 0.8868\n", | |
"Epoch 5/20\n", | |
"117/117 [==============================] - 57s 486ms/step - loss: 0.0841 - sparse_categorical_accuracy: 0.8904 - val_loss: 0.0857 - val_sparse_categorical_accuracy: 0.8921\n", | |
"Epoch 6/20\n", | |
"117/117 [==============================] - 57s 485ms/step - loss: 0.0781 - sparse_categorical_accuracy: 0.8967 - val_loss: 0.0833 - val_sparse_categorical_accuracy: 0.8966\n", | |
"Epoch 7/20\n", | |
"117/117 [==============================] - 57s 486ms/step - loss: 0.0743 - sparse_categorical_accuracy: 0.9022 - val_loss: 0.0807 - val_sparse_categorical_accuracy: 0.9003\n", | |
"Epoch 8/20\n", | |
"117/117 [==============================] - 57s 486ms/step - loss: 0.0718 - sparse_categorical_accuracy: 0.9061 - val_loss: 0.0766 - val_sparse_categorical_accuracy: 0.9047\n", | |
"Epoch 9/20\n", | |
"117/117 [==============================] - 57s 484ms/step - loss: 0.0683 - sparse_categorical_accuracy: 0.9103 - val_loss: 0.0753 - val_sparse_categorical_accuracy: 0.9079\n", | |
"Epoch 10/20\n", | |
"117/117 [==============================] - 57s 487ms/step - loss: 0.0653 - sparse_categorical_accuracy: 0.9142 - val_loss: 0.0729 - val_sparse_categorical_accuracy: 0.9105\n", | |
"Epoch 11/20\n", | |
"117/117 [==============================] - 57s 483ms/step - loss: 0.0627 - sparse_categorical_accuracy: 0.9177 - val_loss: 0.0708 - val_sparse_categorical_accuracy: 0.9131\n", | |
"Epoch 12/20\n", | |
"117/117 [==============================] - 57s 484ms/step - loss: 0.0582 - sparse_categorical_accuracy: 0.9235 - val_loss: 0.0696 - val_sparse_categorical_accuracy: 0.9158\n", | |
"Epoch 13/20\n", | |
"117/117 [==============================] - 57s 485ms/step - loss: 0.0547 - sparse_categorical_accuracy: 0.9283 - val_loss: 0.0681 - val_sparse_categorical_accuracy: 0.9169\n", | |
"Epoch 14/20\n", | |
"117/117 [==============================] - 57s 485ms/step - loss: 0.0517 - sparse_categorical_accuracy: 0.9328 - val_loss: 0.0672 - val_sparse_categorical_accuracy: 0.9191\n", | |
"Epoch 15/20\n", | |
"117/117 [==============================] - 57s 485ms/step - loss: 0.0483 - sparse_categorical_accuracy: 0.9371 - val_loss: 0.0653 - val_sparse_categorical_accuracy: 0.9220\n", | |
"Epoch 16/20\n", | |
"117/117 [==============================] - 57s 486ms/step - loss: 0.0440 - sparse_categorical_accuracy: 0.9428 - val_loss: 0.0657 - val_sparse_categorical_accuracy: 0.9237\n", | |
"Epoch 17/20\n", | |
"117/117 [==============================] - 57s 486ms/step - loss: 0.0405 - sparse_categorical_accuracy: 0.9471 - val_loss: 0.0670 - val_sparse_categorical_accuracy: 0.9226\n", | |
"Epoch 18/20\n", | |
"117/117 [==============================] - 57s 486ms/step - loss: 0.0375 - sparse_categorical_accuracy: 0.9515 - val_loss: 0.0650 - val_sparse_categorical_accuracy: 0.9252\n", | |
"Epoch 19/20\n", | |
"117/117 [==============================] - 57s 487ms/step - loss: 0.0342 - sparse_categorical_accuracy: 0.9550 - val_loss: 0.0672 - val_sparse_categorical_accuracy: 0.9257\n", | |
"Epoch 20/20\n", | |
"117/117 [==============================] - 57s 486ms/step - loss: 0.0308 - sparse_categorical_accuracy: 0.9600 - val_loss: 0.0658 - val_sparse_categorical_accuracy: 0.9298\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "gvTLV_8SgTlm", | |
"colab_type": "code", | |
"outputId": "5fea9a95-3e53-4919-d083-bfeda8387452", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 853 | |
} | |
}, | |
"source": [ | |
"from sklearn.metrics import classification_report, confusion_matrix\n", | |
"\n", | |
"preds = np.array([])\n", | |
"y_trues= np.array([])\n", | |
"\n", | |
"# iterate through test set, make predictions based on trained model\n", | |
"for input_example_batch, target_example_batch in ds_series_batch_test:\n", | |
"\n", | |
" pred=model.predict(input_example_batch)\n", | |
" pred_max=tf.argmax(tf.nn.softmax(pred),2).numpy().flatten()\n", | |
" y_true=target_example_batch.numpy().flatten()\n", | |
"\n", | |
" preds=np.concatenate([preds,pred_max])\n", | |
" y_trues=np.concatenate([y_trues,y_true])\n", | |
"\n", | |
"# remove padding from evaluation\n", | |
"remove_padding = [(p,y) for p,y in zip(preds,y_trues) if y!=0]\n", | |
"\n", | |
"r_p = [x[0] for x in remove_padding]\n", | |
"r_t = [x[1] for x in remove_padding]\n", | |
"\n", | |
"# print confusion matrix and classification report\n", | |
"print(confusion_matrix(r_p,r_t))\n", | |
"print(classification_report(r_p,r_t))\n" | |
], | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"WARNING:tensorflow:5 out of the last 5 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:6 out of the last 6 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:7 out of the last 7 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:8 out of the last 8 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:9 out of the last 9 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:10 out of the last 10 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:11 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:11 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:10 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:10 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:10 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:10 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:10 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:10 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:10 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:10 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:10 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:11 out of the last 12 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:11 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:11 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:11 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:11 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"WARNING:tensorflow:11 out of the last 11 calls to <function _make_execution_function.<locals>.distributed_function at 0x7f528ed0a730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", | |
"[[ 7445 7 2 74 684 1032 833 0 2188]\n", | |
" [ 19 6265 62 431 145 25 36 78 69]\n", | |
" [ 19 85 434 217 276 25 9 53 24]\n", | |
" [ 58 563 174 3170 647 33 99 211 74]\n", | |
" [ 892 140 208 247 186948 742 950 82 2026]\n", | |
" [ 562 14 23 30 412 2061 167 21 660]\n", | |
" [ 699 22 16 52 680 353 6297 17 1163]\n", | |
" [ 6 170 75 406 196 8 18 906 8]\n", | |
" [ 749 11 16 106 532 558 654 4 3950]]\n", | |
" precision recall f1-score support\n", | |
"\n", | |
" 1.0 0.71 0.61 0.66 12265\n", | |
" 2.0 0.86 0.88 0.87 7130\n", | |
" 3.0 0.43 0.38 0.40 1142\n", | |
" 4.0 0.67 0.63 0.65 5029\n", | |
" 5.0 0.98 0.97 0.98 192235\n", | |
" 6.0 0.43 0.52 0.47 3950\n", | |
" 7.0 0.69 0.68 0.69 9299\n", | |
" 8.0 0.66 0.51 0.57 1793\n", | |
" 9.0 0.39 0.60 0.47 6580\n", | |
"\n", | |
" accuracy 0.91 239423\n", | |
" macro avg 0.65 0.64 0.64 239423\n", | |
"weighted avg 0.92 0.91 0.91 239423\n", | |
"\n" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment