Skip to content

Instantly share code, notes, and snippets.

@teamtom
Last active November 1, 2022 21:58
Show Gist options
  • Save teamtom/3a317c300f484b0f5573f15ca4ad3a0f to your computer and use it in GitHub Desktop.
Save teamtom/3a317c300f484b0f5573f15ca4ad3a0f to your computer and use it in GitHub Desktop.
using a simple Sequential model in Tabular Learner
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'fastai'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-1-7dbd506d2f31>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[1;32mfrom\u001b[0m \u001b[0mfastai\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtabular\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mall\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[1;33m*\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mnn\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0mdf\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'iris.csv'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0msplit\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mTrainTestSplitter\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrandom_state\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m42\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdf\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'fastai'"
]
}
],
"source": [
"from fastai.tabular.all import *\n",
"import torch.nn as nn\n",
"\n",
"df = pd.read_csv('iris.csv')\n",
"split = TrainTestSplitter(random_state=42)(df)\n",
"df.species = pd.Categorical(df.species)\n",
"\n",
"dls = TabularPandas(df, splits=split, procs=[Normalize], cat_names=[], cont_names=list(df.columns[:-1]), y_names='species', y_block=CategoryBlock()).dataloaders(bs=8)\n",
"\n",
"class NNet(nn.Module):\n",
" def __init__(self):\n",
" super(NNet, self).__init__()\n",
" self.nnet = nn.Sequential(\n",
" nn.Linear(4,10),\n",
" nn.ReLU(),\n",
" nn.Linear(10,3),\n",
" nn.Softmax()\n",
" )\n",
" def forward(self, x, _):\n",
" return self.nnet(x.view(-1,4))\n",
"\n",
"model = NNet()\n",
"learn = Learner(dls, model=model, metrics=accuracy, loss_func=CrossEntropyLossFlat)\n",
"learn.fit(10, lr=0.1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.6 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "570feb405e2e27c949193ac68f46852414290d515b0ba6e5d90d076ed2284471"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment