Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save teemukinnunen/d519813deccbd15e6e526cb1ecd5afc0 to your computer and use it in GitHub Desktop.
Save teemukinnunen/d519813deccbd15e6e526cb1ecd5afc0 to your computer and use it in GitHub Desktop.
AutoML autokeras Jigsaw toxic comments challenge example notebook
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "eee00391",
"metadata": {},
"source": [
"## Train a neural network using AutoKeras"
]
},
{
"cell_type": "markdown",
"id": "5bc589e8",
"metadata": {},
"source": [
"## Set paths and other variables"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5157a094",
"metadata": {},
"outputs": [],
"source": [
"train_input_file = \"data/train.csv.zip\"\n",
"BATCH_SIZE = 8 # It runs out-of-memmory quite easily :/"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "132cd403",
"metadata": {},
"outputs": [],
"source": [
"%env TF_GPU_ALLOCATOR=cuda_malloc_async"
]
},
{
"cell_type": "markdown",
"id": "38005729",
"metadata": {},
"source": [
"## Import libs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f57d721",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import roc_auc_score\n",
"\n",
"import autokeras as ak\n",
"import keras_tuner as kt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8d5f4e1",
"metadata": {},
"outputs": [],
"source": [
"tf.__version__"
]
},
{
"cell_type": "markdown",
"id": "a5334e0d",
"metadata": {},
"source": [
"## Load ground truth dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4caf0df6",
"metadata": {},
"outputs": [],
"source": [
"train_df = pd.read_csv(train_input_file, compression=\"zip\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8a82974",
"metadata": {},
"outputs": [],
"source": [
"train_df.columns\n"
]
},
{
"cell_type": "markdown",
"id": "c7208185",
"metadata": {},
"source": [
"### Split ground truth dataset into training, validation and test"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a85f6b4",
"metadata": {},
"outputs": [],
"source": [
"train_df, test_df = train_test_split(train_df, test_size=0.1)\n",
"train_df, val_df = train_test_split(train_df, test_size=0.1)\n",
"\n",
"train_df.shape, val_df.shape, test_df.shape\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "64f50a37",
"metadata": {},
"outputs": [],
"source": [
"train_df[\n",
" [\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]\n",
"].values\n"
]
},
{
"cell_type": "markdown",
"id": "bc8ca14b",
"metadata": {},
"source": [
"### Convert pandas dataframes into tensorflow datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa394055",
"metadata": {},
"outputs": [],
"source": [
"train_set = tf.data.Dataset.from_tensor_slices(\n",
" (\n",
" (train_df.comment_text.values,),\n",
" (\n",
" train_df[\n",
" [\n",
" \"toxic\",\n",
" \"severe_toxic\",\n",
" \"obscene\",\n",
" \"threat\",\n",
" \"insult\",\n",
" \"identity_hate\",\n",
" ]\n",
" ].values\n",
" ),\n",
" )\n",
").batch(BATCH_SIZE)\n",
"val_set = tf.data.Dataset.from_tensor_slices(\n",
" (\n",
" (val_df.comment_text.values,),\n",
" (\n",
" val_df[\n",
" [\n",
" \"toxic\",\n",
" \"severe_toxic\",\n",
" \"obscene\",\n",
" \"threat\",\n",
" \"insult\",\n",
" \"identity_hate\",\n",
" ]\n",
" ].values\n",
" ),\n",
" )\n",
").batch(BATCH_SIZE)\n"
]
},
{
"cell_type": "markdown",
"id": "35619283",
"metadata": {},
"source": [
"## Train AutoKeras AutoML model"
]
},
{
"cell_type": "markdown",
"id": "ac699a01",
"metadata": {},
"source": [
"### Init AutoKeras text classifier model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "87c3e9a4",
"metadata": {},
"outputs": [],
"source": [
"clf = ak.TextClassifier(\n",
" overwrite=False, # True,\n",
" multi_label=True,\n",
" max_trials=10,\n",
" metrics=[tf.keras.metrics.AUC()],\n",
")\n"
]
},
{
"cell_type": "markdown",
"id": "4ac14825",
"metadata": {},
"source": [
"### Define earlystop to stop training if it does not improve anymore"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e52bf2bd",
"metadata": {},
"outputs": [],
"source": [
"earlystop = tf.keras.callbacks.EarlyStopping(\n",
" monitor=\"val_loss\",\n",
" min_delta=0,\n",
" patience=0,\n",
" verbose=0,\n",
" mode=\"auto\",\n",
" restore_best_weights=True,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e9f5128",
"metadata": {},
"outputs": [],
"source": [
"%env TF_GPU_ALLOCATOR=cuda_malloc_async"
]
},
{
"cell_type": "markdown",
"id": "6d72f301",
"metadata": {},
"source": [
"### Start training a text classifier using AutoKeras AutoML"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51615dfe",
"metadata": {},
"outputs": [],
"source": [
"clf.fit(\n",
" train_set,\n",
" validation_data=val_set,\n",
" epochs=10,\n",
" batch_size=BATCH_SIZE,\n",
" callbacks=[earlystop],\n",
" verbose=1,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5624516a",
"metadata": {},
"outputs": [],
"source": [
"# Display the best model architecture\n",
"clf.export_model().summary()\n"
]
},
{
"cell_type": "markdown",
"id": "5dfa881a",
"metadata": {},
"source": [
"## Model evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c3ca28f",
"metadata": {},
"outputs": [],
"source": [
"model = clf.export_model()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11b0fb57",
"metadata": {},
"outputs": [],
"source": [
"y_test = test_df[\n",
" [\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]\n",
"].values\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa046090",
"metadata": {},
"outputs": [],
"source": [
"test_set = tf.data.Dataset.from_tensor_slices(\n",
" (\n",
" (test_df.comment_text.values,),\n",
" (\n",
" test_df[\n",
" [\n",
" \"toxic\",\n",
" \"severe_toxic\",\n",
" \"obscene\",\n",
" \"threat\",\n",
" \"insult\",\n",
" \"identity_hate\",\n",
" ]\n",
" ].values,\n",
" ),\n",
" )\n",
").batch(BATCH_SIZE)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4cc17a56",
"metadata": {},
"outputs": [],
"source": [
"predicted_y = model.predict(test_df.comment_text.values)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2caabf6",
"metadata": {},
"outputs": [],
"source": [
"roc_auc_score(\n",
" test_df[\n",
" [\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]\n",
" ].values,\n",
" predicted_y,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "759664d6",
"metadata": {},
"outputs": [],
"source": [
"model.evaluate(test_set)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b39fc06b",
"metadata": {},
"outputs": [],
"source": [
"model.evaluate(val_set)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04533504",
"metadata": {},
"outputs": [],
"source": [
"model.summary()\n"
]
},
{
"cell_type": "markdown",
"id": "b2a946c2",
"metadata": {},
"source": [
"## Predict unseen labels (for the Kaggle competition)"
]
},
{
"cell_type": "markdown",
"id": "bb956bcb",
"metadata": {},
"source": [
"### Load the actual test data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ed886b0",
"metadata": {},
"outputs": [],
"source": [
"real_test_df = pd.read_csv(\"data/test.csv.zip\", compression=\"zip\")"
]
},
{
"cell_type": "markdown",
"id": "219d2cc0",
"metadata": {},
"source": [
"### Predict unseen samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bede9fe7",
"metadata": {},
"outputs": [],
"source": [
"real_test_pred = model.predict(real_test_df.comment_text)"
]
},
{
"cell_type": "markdown",
"id": "9a1b6aba",
"metadata": {},
"source": [
"### Combine predictions with sample ids to store result file in a csv"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "20992eff",
"metadata": {},
"outputs": [],
"source": [
"predictions_df = pd.DataFrame(\n",
" real_test_pred,\n",
" columns=[\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"],\n",
")\n",
"predictions_df[\"id\"] = real_test_df[\"id\"]\n",
"predictions_df = predictions_df[\n",
" [\"id\", \"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]\n",
"]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e4d4f93",
"metadata": {},
"outputs": [],
"source": [
"# Predictions output looks like:\n",
"predictions_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bfb062a3",
"metadata": {},
"outputs": [],
"source": [
"# Store prediction to be submitted to Kaggle\n",
"predictions_df.to_csv(\"data/autokeras_predictions.csv\", index=False)"
]
}
],
"metadata": {
"environment": {
"name": "tf2-gpu.2-5.m75",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-5:m75"
},
"kernelspec": {
"display_name": "autokeras",
"language": "python",
"name": "autokeras"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment