Last active
December 19, 2021 19:19
-
-
Save teemukinnunen/d519813deccbd15e6e526cb1ecd5afc0 to your computer and use it in GitHub Desktop.
AutoML autokeras Jigsaw toxic comments challenge example notebook
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "eee00391", | |
"metadata": {}, | |
"source": [ | |
"## Train a neural network using AutoKeras" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "5bc589e8", | |
"metadata": {}, | |
"source": [ | |
"## Set paths and other variables" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5157a094", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_input_file = \"data/train.csv.zip\"\n", | |
"BATCH_SIZE = 8 # It runs out-of-memmory quite easily :/" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "132cd403", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%env TF_GPU_ALLOCATOR=cuda_malloc_async" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "38005729", | |
"metadata": {}, | |
"source": [ | |
"## Import libs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "1f57d721", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import tensorflow as tf\n", | |
"\n", | |
"from sklearn.model_selection import train_test_split\n", | |
"from sklearn.metrics import roc_auc_score\n", | |
"\n", | |
"import autokeras as ak\n", | |
"import keras_tuner as kt" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b8d5f4e1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"tf.__version__" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a5334e0d", | |
"metadata": {}, | |
"source": [ | |
"## Load ground truth dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4caf0df6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_df = pd.read_csv(train_input_file, compression=\"zip\")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d8a82974", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_df.columns\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c7208185", | |
"metadata": {}, | |
"source": [ | |
"### Split ground truth dataset into training, validation and test" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "3a85f6b4", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_df, test_df = train_test_split(train_df, test_size=0.1)\n", | |
"train_df, val_df = train_test_split(train_df, test_size=0.1)\n", | |
"\n", | |
"train_df.shape, val_df.shape, test_df.shape\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "64f50a37", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_df[\n", | |
" [\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]\n", | |
"].values\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bc8ca14b", | |
"metadata": {}, | |
"source": [ | |
"### Convert pandas dataframes into tensorflow datasets" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "fa394055", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_set = tf.data.Dataset.from_tensor_slices(\n", | |
" (\n", | |
" (train_df.comment_text.values,),\n", | |
" (\n", | |
" train_df[\n", | |
" [\n", | |
" \"toxic\",\n", | |
" \"severe_toxic\",\n", | |
" \"obscene\",\n", | |
" \"threat\",\n", | |
" \"insult\",\n", | |
" \"identity_hate\",\n", | |
" ]\n", | |
" ].values\n", | |
" ),\n", | |
" )\n", | |
").batch(BATCH_SIZE)\n", | |
"val_set = tf.data.Dataset.from_tensor_slices(\n", | |
" (\n", | |
" (val_df.comment_text.values,),\n", | |
" (\n", | |
" val_df[\n", | |
" [\n", | |
" \"toxic\",\n", | |
" \"severe_toxic\",\n", | |
" \"obscene\",\n", | |
" \"threat\",\n", | |
" \"insult\",\n", | |
" \"identity_hate\",\n", | |
" ]\n", | |
" ].values\n", | |
" ),\n", | |
" )\n", | |
").batch(BATCH_SIZE)\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "35619283", | |
"metadata": {}, | |
"source": [ | |
"## Train AutoKeras AutoML model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ac699a01", | |
"metadata": {}, | |
"source": [ | |
"### Init AutoKeras text classifier model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "87c3e9a4", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"clf = ak.TextClassifier(\n", | |
" overwrite=False, # True,\n", | |
" multi_label=True,\n", | |
" max_trials=10,\n", | |
" metrics=[tf.keras.metrics.AUC()],\n", | |
")\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4ac14825", | |
"metadata": {}, | |
"source": [ | |
"### Define earlystop to stop training if it does not improve anymore" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "e52bf2bd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"earlystop = tf.keras.callbacks.EarlyStopping(\n", | |
" monitor=\"val_loss\",\n", | |
" min_delta=0,\n", | |
" patience=0,\n", | |
" verbose=0,\n", | |
" mode=\"auto\",\n", | |
" restore_best_weights=True,\n", | |
")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "0e9f5128", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%env TF_GPU_ALLOCATOR=cuda_malloc_async" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6d72f301", | |
"metadata": {}, | |
"source": [ | |
"### Start training a text classifier using AutoKeras AutoML" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "51615dfe", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"clf.fit(\n", | |
" train_set,\n", | |
" validation_data=val_set,\n", | |
" epochs=10,\n", | |
" batch_size=BATCH_SIZE,\n", | |
" callbacks=[earlystop],\n", | |
" verbose=1,\n", | |
")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5624516a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Display the best model architecture\n", | |
"clf.export_model().summary()\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "5dfa881a", | |
"metadata": {}, | |
"source": [ | |
"## Model evaluation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "2c3ca28f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = clf.export_model()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "11b0fb57", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"y_test = test_df[\n", | |
" [\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]\n", | |
"].values\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "fa046090", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_set = tf.data.Dataset.from_tensor_slices(\n", | |
" (\n", | |
" (test_df.comment_text.values,),\n", | |
" (\n", | |
" test_df[\n", | |
" [\n", | |
" \"toxic\",\n", | |
" \"severe_toxic\",\n", | |
" \"obscene\",\n", | |
" \"threat\",\n", | |
" \"insult\",\n", | |
" \"identity_hate\",\n", | |
" ]\n", | |
" ].values,\n", | |
" ),\n", | |
" )\n", | |
").batch(BATCH_SIZE)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4cc17a56", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"predicted_y = model.predict(test_df.comment_text.values)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b2caabf6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"roc_auc_score(\n", | |
" test_df[\n", | |
" [\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]\n", | |
" ].values,\n", | |
" predicted_y,\n", | |
")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "759664d6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model.evaluate(test_set)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b39fc06b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model.evaluate(val_set)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "04533504", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model.summary()\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b2a946c2", | |
"metadata": {}, | |
"source": [ | |
"## Predict unseen labels (for the Kaggle competition)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bb956bcb", | |
"metadata": {}, | |
"source": [ | |
"### Load the actual test data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "0ed886b0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"real_test_df = pd.read_csv(\"data/test.csv.zip\", compression=\"zip\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "219d2cc0", | |
"metadata": {}, | |
"source": [ | |
"### Predict unseen samples" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "bede9fe7", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"real_test_pred = model.predict(real_test_df.comment_text)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9a1b6aba", | |
"metadata": {}, | |
"source": [ | |
"### Combine predictions with sample ids to store result file in a csv" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "20992eff", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"predictions_df = pd.DataFrame(\n", | |
" real_test_pred,\n", | |
" columns=[\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"],\n", | |
")\n", | |
"predictions_df[\"id\"] = real_test_df[\"id\"]\n", | |
"predictions_df = predictions_df[\n", | |
" [\"id\", \"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]\n", | |
"]\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4e4d4f93", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Predictions output looks like:\n", | |
"predictions_df.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "bfb062a3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Store prediction to be submitted to Kaggle\n", | |
"predictions_df.to_csv(\"data/autokeras_predictions.csv\", index=False)" | |
] | |
} | |
], | |
"metadata": { | |
"environment": { | |
"name": "tf2-gpu.2-5.m75", | |
"type": "gcloud", | |
"uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-5:m75" | |
}, | |
"kernelspec": { | |
"display_name": "autokeras", | |
"language": "python", | |
"name": "autokeras" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment