Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ElisonSherton/cd21b00e19806054485a656549bf9265 to your computer and use it in GitHub Desktop.
Save ElisonSherton/cd21b00e19806054485a656549bf9265 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "20eb6a2a-df6b-496c-b5ca-61922e1b1717",
"metadata": {},
"outputs": [],
"source": [
"from fastai.tabular.all import *\n",
"from sklearn.datasets import make_multilabel_classification"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6501e60f-db0e-4d61-bc88-7f7b9e03e7ff",
"metadata": {},
"outputs": [],
"source": [
"# Make a 20 input x 5 output sample -> target dataset\n",
"X, y = make_multilabel_classification(n_samples = 100, n_features = 4)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "69f79e1a-e840-4d94-8fa0-5af1ca225710",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>feat_1</th>\n",
" <th>feat_2</th>\n",
" <th>feat_3</th>\n",
" <th>feat_4</th>\n",
" <th>feat_5</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>8.0</td>\n",
" <td>10.0</td>\n",
" <td>9.0</td>\n",
" <td>6.0</td>\n",
" <td>a</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>9.0</td>\n",
" <td>15.0</td>\n",
" <td>11.0</td>\n",
" <td>8.0</td>\n",
" <td>b</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" feat_1 feat_2 feat_3 feat_4 feat_5\n",
"0 8.0 10.0 9.0 6.0 a\n",
"1 9.0 15.0 11.0 8.0 b"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train = pd.DataFrame(X, columns = [f\"feat_{idx+1}\" for idx in range(X.shape[1])])\n",
"num_cols = train.columns.tolist()\n",
"categories = [\"a\"] * 50 + [\"b\"] * 25 + [\"c\"] * 25; random.shuffle(categories)\n",
"train[\"feat_5\"] = categories\n",
"cat_cols = [\"feat_5\"]\n",
"train.head(2)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "dc9d29c5-15e1-46b1-882a-15e22b563103",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>target_1</th>\n",
" <th>target_2</th>\n",
" <th>target_3</th>\n",
" <th>target_4</th>\n",
" <th>target_5</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" target_1 target_2 target_3 target_4 target_5\n",
"0 0 0 0 0 0\n",
"1 0 1 0 0 0"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Define the target\n",
"target = pd.DataFrame(y, columns = [f\"target_{idx+1}\" for idx in range(y.shape[1])])\n",
"y_names = target.columns.tolist()\n",
"target.head(2)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8eb9e37e-15bb-4052-ac06-846215d5fe9e",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>feat_1</th>\n",
" <th>feat_2</th>\n",
" <th>feat_3</th>\n",
" <th>feat_4</th>\n",
" <th>feat_5</th>\n",
" <th>target_1</th>\n",
" <th>target_2</th>\n",
" <th>target_3</th>\n",
" <th>target_4</th>\n",
" <th>target_5</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>8.0</td>\n",
" <td>10.0</td>\n",
" <td>9.0</td>\n",
" <td>6.0</td>\n",
" <td>a</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>9.0</td>\n",
" <td>15.0</td>\n",
" <td>11.0</td>\n",
" <td>8.0</td>\n",
" <td>b</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" feat_1 feat_2 feat_3 feat_4 feat_5 target_1 target_2 target_3 \\\n",
"0 8.0 10.0 9.0 6.0 a 0 0 0 \n",
"1 9.0 15.0 11.0 8.0 b 0 1 0 \n",
"\n",
" target_4 target_5 \n",
"0 0 0 \n",
"1 0 0 "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train = pd.concat([train, target], axis = 1)\n",
"train.head(2)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "23d06c08-c66f-42f6-acbd-17a466b0d091",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Define a random splitter\n",
"splits = RandomSplitter(valid_pct=0.2)(range_of(train))\n",
"\n",
"to = TabularPandas(train, procs=[Categorify,Normalize],\n",
" cat_names = cat_cols,\n",
" cont_names = num_cols,\n",
" y_names = y_names,\n",
" splits=splits)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "16ffdc44-9983-4455-9077-6bf429d0e172",
"metadata": {},
"outputs": [],
"source": [
"dls = to.dataloaders(bs=16).to(\"cuda\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c31466d2-ae19-4811-ae99-daf24ae9f73e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([16, 1]), torch.Size([16, 4]), torch.Size([16, 5]))"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xb_cat, xb_cont, yb = dls.one_batch()\n",
"xb_cat.shape, xb_cont.shape, yb.shape"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "de8521e5-18bd-426b-ad97-a61ea8f0c257",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"SuggestedLRs(valley=0.0010000000474974513)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Specify the loss function and appropriate metric\n",
"learn = tabular_learner(dls, metrics=[accuracy_multi], loss_func = BCEWithLogitsLossFlat())\n",
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "8f4db69b-0cb6-4742-989f-33b33c18e940",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy_multi</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.705741</td>\n",
" <td>0.687174</td>\n",
" <td>0.610000</td>\n",
" <td>00:00</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.672940</td>\n",
" <td>0.658388</td>\n",
" <td>0.660000</td>\n",
" <td>00:00</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.635036</td>\n",
" <td>0.623204</td>\n",
" <td>0.650000</td>\n",
" <td>00:00</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.610399</td>\n",
" <td>0.598147</td>\n",
" <td>0.660000</td>\n",
" <td>00:00</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.596583</td>\n",
" <td>0.580337</td>\n",
" <td>0.680000</td>\n",
" <td>00:00</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Train the model\n",
"learn.fit_one_cycle(5, 1e-3)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment