Skip to content

Instantly share code, notes, and snippets.

@annaleighsmith
Last active October 22, 2021 17:08
Show Gist options
  • Save annaleighsmith/a4e5f8cfeccf8812536f795eb444a0a6 to your computer and use it in GitHub Desktop.
Save annaleighsmith/a4e5f8cfeccf8812536f795eb444a0a6 to your computer and use it in GitHub Desktop.
linear_regression_helper
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "5faf7d98",
"metadata": {},
"source": [
"## linear regression helper\n",
"\n",
"\n",
"A simple function that helps format the aesthetics and pick the serialized columns from __pandas DataFrames__, convert those __numpy ndarray__ while fitting a linear model with __sklearn__\n",
"\n",
"Source and the income.data.csv file are here:\n",
"> https://www.scribbr.com/statistics/simple-linear-regression/"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e7d4f4a8",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"from sklearn.linear_model import LinearRegression"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "de9020f7",
"metadata": {},
"outputs": [],
"source": [
"size_dict = {\n",
" 'square_xs': (4, 4),\n",
" 'square_sm': (6, 6),\n",
" 'square_md': (8, 8),\n",
" 'square_lg': (12, 12),\n",
" 'square_xl': (16, 16),\n",
" 'rect_xs': (4, 2.5),\n",
" 'rect_sm': (6, 3.75),\n",
" 'rect_md': (8, 5),\n",
" 'rect_lg': (12, 7.5),\n",
" 'rect_xl': (16, 10),\n",
"}\n",
"\n",
"\n",
"def calcLinearReg(df_X: pd.DataFrame, df_Y: pd.DataFrame, col_num_X: int, col_num_Y: int):\n",
" X = df_X.iloc[:, col_num_X].values.reshape(-1, 1)\n",
" Y = df_Y.iloc[:, col_num_Y].values.reshape(-1, 1)\n",
" linear_regressor = LinearRegression() # create object for the class\n",
" linear_regressor.fit(X, Y) # perform linear regression\n",
" Y_pred = linear_regressor.predict(X) # make predictions\n",
" return [X, Y, Y_pred]\n",
"\n",
"\n",
"def displayLinearReg(X: np.ndarray, Y: np.ndarray, Y_pred: np.ndarray,\n",
" colorA: str = 'blue', colorB: str = 'black', title: str = '', xlabel: str = '', ylabel: str = '',\n",
" size='square_md', font_sizes: tuple = (14, 12, 12)):\n",
" \n",
" \"\"\"displays matplotlib plot with inputs of numpy arrays\"\"\"\n",
"\n",
" plt.rcParams[\"figure.figsize\"] = size_dict[size]\n",
" fig = plt.figure()\n",
" plt.scatter(X, Y, color=colorA)\n",
" plt.plot(X, Y_pred, color=colorB)\n",
" plt.title(title, fontsize=font_sizes[0])\n",
" plt.xlabel(xlabel, fontsize=font_sizes[1])\n",
" plt.ylabel(ylabel, fontsize=font_sizes[2])\n",
" return plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "110f2162",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['income', 'happiness'], dtype='object')\n"
]
}
],
"source": [
"df = pd.read_csv('income.data.csv')\n",
"df = df.iloc[:, 1:]\n",
"\n",
"# just to double check we want it indexed at 0 and 1 for income and happiness respectivly\n",
"print(df.columns)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ca1a2be8",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"results = calcLinearReg(\n",
" df_X=df, # independent var\n",
" df_Y=df, # dependent var\n",
" col_num_X=0,\n",
" col_num_Y=1\n",
")\n",
"\n",
"displayLinearReg(\n",
" X=results[0],\n",
" Y=results[1],\n",
" Y_pred=results[2],\n",
" colorA='pink',\n",
" colorB='green',\n",
" title='Income and Happiness Linear Regression',\n",
" xlabel='Income',\n",
" ylabel='Happiness',\n",
" size='square_sm',\n",
" font_sizes=(18, 14, 14)\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@annaleighsmith
Copy link
Author

a function that uses sklearn.linear_model and pandas DataFrame to standardize and simplify the production of linear regressions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment