Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save teemukinnunen/acccf2e6d96559392b0dfdb3fefa329e to your computer and use it in GitHub Desktop.
Save teemukinnunen/acccf2e6d96559392b0dfdb3fefa329e to your computer and use it in GitHub Desktop.
AutoML AutoGluon Toxic Comments challenge notebook example
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "f610c619",
"metadata": {},
"source": [
"# Train a model using AutoGluon"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "broad-election",
"metadata": {},
"outputs": [],
"source": [
"from autogluon.tabular import TabularDataset\n",
"from autogluon.text import TextPredictor\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import roc_auc_score, classification_report, confusion_matrix\n",
"\n",
"import pandas as pd\n",
"\n",
"import os\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"from autogluon.text import TextPredictor\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "structured-bishop",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"from autogluon.text import TextPredictor\n",
"\n",
"# Define a custom MultiLabelPredictor that actually wraps multiple text classifier inside\n",
"class MultiLabelTextPredictor:\n",
" def __init__(\n",
" self,\n",
" labels: list,\n",
" problem_type: str = None,\n",
" eval_metric: str = None,\n",
" path: str = None,\n",
" verbosity: int = 3,\n",
" warn_if_exist: bool = True,\n",
" text_column: str = \"comment_text\",\n",
" ):\n",
"\n",
" self.labels = labels\n",
" self.text_predictors = dict()\n",
" self.path = path\n",
" self.verbosity = verbosity\n",
" self.warn_if_exist = warn_if_exist\n",
" self.text_column = text_column\n",
" self.samples_per_class = 500\n",
"\n",
" for label in self.labels:\n",
" self.text_predictors[label] = TextPredictor(\n",
" label=label,\n",
" problem_type=problem_type,\n",
" eval_metric=eval_metric,\n",
" path=os.path.join(path, label),\n",
" verbosity=verbosity,\n",
" warn_if_exist=warn_if_exist,\n",
" )\n",
"\n",
" def fit(\n",
" self,\n",
" train_data: pd.DataFrame,\n",
" tuning_data: pd.DataFrame = None,\n",
" time_limit: int = None,\n",
" ) -> None:\n",
"\n",
" for i, label in enumerate(self.labels):\n",
" print(\n",
" f\"Training a text classifier for class: {label} ({i}/{len(self.labels)})\"\n",
" )\n",
"\n",
" temp_train_data = train_data # .groupby(label, group_keys=False).apply(lambda x: x.sample(min(len(x), self.samples_per_class)))\n",
"\n",
" self.text_predictors[label].fit(\n",
" train_data=temp_train_data[[self.text_column, label]],\n",
" time_limit=time_limit,\n",
" )\n",
"\n",
" def predict(self, train_data: pd.DataFrame) -> np.array:\n",
"\n",
" y_pred: np.array = np.zeros((train_data.shape[0], len(self.labels)))\n",
"\n",
" for i, label in enumerate(self.labels):\n",
"\n",
" y_pred[:, i] = self.text_predictors[label].predict(\n",
" train_data[[self.text_column]]\n",
" )\n",
"\n",
" return y_pred\n",
"\n",
" def load(self, path: str) -> None:\n",
" \"\"\"\n",
"\n",
" :type path: pathname where text classifiers are being stored\n",
" \"\"\"\n",
" for label in self.labels:\n",
" self.text_predictors[label] = TextPredictor.load(os.path.join(path, label))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "behavioral-abraham",
"metadata": {},
"outputs": [],
"source": [
"train_df = pd.read_csv(\"data/train.csv.zip\", compression=\"zip\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "internal-johnson",
"metadata": {},
"outputs": [],
"source": [
"train_df.head()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "french-halloween",
"metadata": {},
"outputs": [],
"source": [
"class_labels = [\"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]\n",
"data_dir = \"toxic-multilabel\"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "statutory-saturday",
"metadata": {},
"outputs": [],
"source": [
"train_df = train_df.drop(\n",
" columns=[\"id\"]\n",
") \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "italic-boundary",
"metadata": {},
"outputs": [],
"source": [
"train_df, test_df = train_test_split(train_df, test_size=0.2)\n",
"train_df, val_df = train_test_split(train_df, test_size=0.1)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "turned-consensus",
"metadata": {},
"outputs": [],
"source": [
"train_df = TabularDataset(train_df)\n",
"val_df = TabularDataset(val_df)\n",
"test_df = TabularDataset(test_df)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "positive-bicycle",
"metadata": {},
"outputs": [],
"source": [
"# Remove previous runs\n",
"!rm -rf toxic-multilabel"
]
},
{
"cell_type": "markdown",
"id": "28ac2d7b",
"metadata": {},
"source": [
"## Train a MultiLabelTextPredictor"
]
},
{
"cell_type": "markdown",
"id": "0c9fd445",
"metadata": {},
"source": [
"### Init the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cloudy-cookie",
"metadata": {},
"outputs": [],
"source": [
"predictor = MultiLabelTextPredictor(\n",
" labels=class_labels,\n",
" # problem_type='binary',\n",
" eval_metric=\"roc_auc\",\n",
" path=data_dir,\n",
")\n"
]
},
{
"cell_type": "markdown",
"id": "909e1412",
"metadata": {},
"source": [
"### Train the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "intended-massage",
"metadata": {},
"outputs": [],
"source": [
"predictor.fit(train_data=train_df, tuning_data=val_df)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "focal-state",
"metadata": {},
"outputs": [],
"source": [
"predictor.load(path=\"toxic-multilabel\")\n"
]
},
{
"cell_type": "markdown",
"id": "a8fd39f2",
"metadata": {},
"source": [
"## Evaluate the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "detected-plymouth",
"metadata": {},
"outputs": [],
"source": [
"y_test_pred = predictor.predict(test_df)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "antique-intake",
"metadata": {},
"outputs": [],
"source": [
"print(roc_auc_score(test_df[class_labels], y_test_pred))\n",
"print(classification_report(test_df[class_labels], y_test_pred))\n"
]
},
{
"cell_type": "markdown",
"id": "modular-matthew",
"metadata": {},
"source": [
"## Predict real test samples\n",
"(samples which true labels we dont know)"
]
},
{
"cell_type": "markdown",
"id": "df7e30df",
"metadata": {},
"source": [
"### Load data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "arranged-forward",
"metadata": {},
"outputs": [],
"source": [
"real_test_df = pd.read_csv(\"data/test.csv.zip\", compression=\"zip\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "hungarian-movie",
"metadata": {},
"outputs": [],
"source": [
"predicted_toxic = predictor.predict(real_test_df)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "incorporate-duncan",
"metadata": {},
"outputs": [],
"source": [
"predicted_toxic_df = pd.DataFrame(predicted_toxic, columns=class_labels)\n",
"predicted_toxic_df[\"id\"] = real_test_df[\"id\"]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "suited-landscape",
"metadata": {},
"outputs": [],
"source": [
"predicted_toxic_df[\n",
" [\"id\", \"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]\n",
"].head()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "sunset-bandwidth",
"metadata": {},
"outputs": [],
"source": [
"predicted_toxic_df[\n",
" [\"id\", \"toxic\", \"severe_toxic\", \"obscene\", \"threat\", \"insult\", \"identity_hate\"]\n",
"].to_csv(\"toxic-challenge-autogluon.csv\", index=False)\n"
]
}
],
"metadata": {
"environment": {
"name": "rapids-gpu.0-18.m65",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/rapids-gpu.0-18:m65"
},
"kernelspec": {
"display_name": "Python [conda env:root] *",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment