Last active
October 25, 2019 00:54
-
-
Save Elijas/322836170de2110385a213ce586f56cc to your computer and use it in GitHub Desktop.
Proof of Concept: ML model saving to Redis
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import time\n", | |
"from io import BytesIO\n", | |
"\n", | |
"import joblib\n", | |
"import redis as redis\n", | |
"from sklearn.datasets import load_digits\n", | |
"from sklearn.linear_model import LogisticRegression\n", | |
"from sklearn.model_selection import train_test_split\n", | |
"import numpy as np \n", | |
"import matplotlib.pyplot as plt\n", | |
"from sklearn.utils.testing import ignore_warnings\n", | |
"from sklearn.exceptions import ConvergenceWarning\n", | |
"\n", | |
"\n", | |
"class Timer:\n", | |
" def __enter__(self):\n", | |
" self.start = time.time()\n", | |
" return self\n", | |
"\n", | |
" def __exit__(self, *args):\n", | |
" self.end = time.time()\n", | |
" self.interval_ms = (self.end - self.start) * 1e3\n", | |
" print(f'(time elapsed: {int(self.interval_ms):d} ms)')\n", | |
"\n", | |
"\n", | |
"def serialize(obj: object) -> bytes:\n", | |
" bytes_container = BytesIO()\n", | |
" joblib.dump(obj, bytes_container)\n", | |
" bytes_container.seek(0) # update to enable reading\n", | |
" bytes_data = bytes_container.read()\n", | |
" print(f'Model was packed to {len(bytes_data)} bytes')\n", | |
" return bytes_data\n", | |
"\n", | |
"\n", | |
"def deserialize(bytes_data: bytes) -> object:\n", | |
" bytes_container = BytesIO(bytes_data)\n", | |
" obj = joblib.load(bytes_container)\n", | |
" print(f'Bytes were unpacked to an object of type {obj.__class__.__name__}')\n", | |
" return obj\n", | |
"\n", | |
"@ignore_warnings(category=ConvergenceWarning)\n", | |
"def train_logistic_regression_model(x_train, y_train):\n", | |
" model = LogisticRegression(solver='lbfgs', multi_class='auto', max_iter=1e3)\n", | |
" model.fit(x_train, y_train)\n", | |
" return model\n", | |
"\n", | |
"def predict_number(model, x):\n", | |
" plt.imshow(np.reshape(x, (8,8)), cmap=plt.cm.gray)\n", | |
" print(f'Predicted number: {model.predict(x.reshape(1,-1))[0]}')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model weight matrix shape: (10, 64)\n" | |
] | |
} | |
], | |
"source": [ | |
"digits_dataset = load_digits()\n", | |
"x_train, x_test, y_train, y_test = train_test_split(digits_dataset.data, digits_dataset.target, test_size=0.2, random_state=0)\n", | |
"\n", | |
"model: LogisticRegression = train_logistic_regression_model(x_train, y_train)\n", | |
"redis_conn = redis.StrictRedis(host=\"localhost\", port=26379, db=0)\n", | |
"print(f'Model weight matrix shape: {model.coef_.shape}')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model was packed to 6058 bytes\n", | |
"(time elapsed: 19 ms)\n" | |
] | |
} | |
], | |
"source": [ | |
"with Timer():\n", | |
" bytes_data: bytes = serialize(model)\n", | |
" redis_conn.set('model-001', bytes_data)\n", | |
"\n", | |
"del model, bytes_data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Bytes were unpacked to an object of type LogisticRegression\n", | |
"(time elapsed: 4 ms)\n", | |
"Predicted number: 9\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAK3klEQVR4nO3d34tc9RnH8c+nq6GNvwKtLZoNXQUJSKEbCQEJaBrbEqtoL3qRgEKlkCvF0IJob0z/AdleFCFEXcFUaaMGEasVdLFCa82PbWvcWNKQkG20Ucrij0pD9OnFTkq0q3tm5pzvOfv0/YLg7uyw32eI75zZszPn64gQgDy+0PYAAOpF1EAyRA0kQ9RAMkQNJHNOE9/UdspT6suWLSu63qWXXlpsreXLlxdb6/Tp08XWOn78eLG1JOnDDz8stlZEeKHbG4k6q5KRSdL27duLrTU+Pl5srbm5uWJrbdu2rdhakjQ9PV10vYXw9BtIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSKZS1LY32X7D9mHbdzc9FIDBLRq17RFJv5B0vaQrJW2xfWXTgwEYTJUj9TpJhyPiSESckvSYpJubHQvAoKpEvVLS2W91me3d9gm2t9rea3tvXcMB6F+Vd2kt9Pau/3lrZUTskLRDyvvWS2ApqHKknpW06qzPRyWdaGYcAMOqEvWrkq6wfZntZZI2S3qq2bEADGrRp98Rcdr27ZKekzQi6cGIONj4ZAAGUunKJxHxjKRnGp4FQA14RRmQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDDt09GFiYqLoemNjY8XW2rNnT7G17r333mJrbdiwodhaEjt0AGgAUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyVTZoeNB2ydtv1ZiIADDqXKknpS0qeE5ANRk0agj4iVJ/ywwC4Aa1PYuLdtbJW2t6/sBGExtUbPtDtANnP0GkiFqIJkqv9J6VNLvJa22PWv7R82PBWBQVfbS2lJiEAD14Ok3kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAzb7vRhxYoVRdebmpoqtlbpx1ZK1sf1eThSA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQTJVrlK2y/aLtGdsHbd9ZYjAAg6ny2u/Tkn4SEfttXyBpn+3nI+L1hmcDMIAq2+68GRH7ex+/J2lG0sqmBwMwmL7epWV7TNIaSa8s8DW23QE6oHLUts+X9LikbRHx7qe/zrY7QDdUOvtt+1zNB70rIp5odiQAw6hy9tuSHpA0ExH3NT8SgGFUOVKvl3SrpI22p3t/vtfwXAAGVGXbnZclucAsAGrAK8qAZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIa9tPowOTlZdL0NGzYUW2t8fLzYWseOHSu21oEDB4qt1RUcqYFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZKpcePCLtv9o+0+9bXd+VmIwAIOp8jLRf0vaGBHv9y4V/LLt30TEHxqeDcAAqlx4MCS93/v03N4fLtYPdFTVi/mP2J6WdFLS8xGx4LY7tvfa3lv3kACqqxR1RHwUEeOSRiWts/2NBe6zIyLWRsTauocEUF1fZ78jYk7SlKRNjUwDYGhVzn5fbHtF7+MvSfq2pENNDwZgMFXOfl8i6WHbI5r/R+BXEfF0s2MBGFSVs99/1vye1ACWAF5RBiRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAynn9nZc3f1OatmfhMU1NTKdeSpO3btxdbKyK80O0cqYFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSKZy1L0L+h+wzUUHgQ7r50h9p6SZpgYBUI+q2+6MSrpB0s5mxwEwrKpH6glJd0n6+LPuwF5aQDdU2aHjRkknI2Lf592PvbSAbqhypF4v6SbbRyU9Jmmj7UcanQrAwBaNOiLuiYjRiBiTtFnSCxFxS+OTARgIv6cGkqmyQd5/RcSU5reyBdBRHKmBZIgaSIaogWSIGkiGqIFkiBpIhqiBZJb8tjsltzmZmJgotpYkzc3NFV2vlJKPa3x8vNhaknT06NFia7HtDvB/gqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWQqXc6odyXR9yR9JOk0lwEGuqufa5R9KyLeaWwSALXg6TeQTNWoQ9Jvbe+zvXWhO7DtDtANVZ9+r4+IE7a/Kul524ci4qWz7xAROyTtkMq+9RLAJ1U6UkfEid5/T0p6UtK6JocCMLgqG+SdZ/uCMx9L+q6k15oeDMBgqjz9/pqkJ22fuf8vI+LZRqcCMLBFo46II5K+WWAWADXgV1pAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMv289bKTpqamiq01PT1dbC1J2rNnT8q1LrroomJrldwGpys4UgPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kEylqG2vsL3b9iHbM7avbnowAIOp+trvn0t6NiJ+YHuZpOUNzgRgCItGbftCSddI+qEkRcQpSaeaHQvAoKo8/b5c0tuSHrJ9wPbO3vW/P4Ftd4BuqBL1OZKuknR/RKyR9IGkuz99p4jYERFr2eYWaFeVqGclzUbEK73Pd2s+cgAdtGjUEfGWpOO2V/duuk7S641OBWBgVc9+3yFpV+/M9xFJtzU3EoBhVIo6IqYl8bMysATwijIgGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGknFE1P9N7fq/aQeMjY0VXW9ycrLYWtdee22xtY4dO1ZsrdJ/ZyVFhBe6nSM1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZDMolHbXm17+qw/79reVmI4AP1b9BplEfGGpHFJsj0i6e+Snmx4LgAD6vfp93WS/hYR5V68C6AvVS8RfMZmSY8u9AXbWyVtHXoiAEOpfKTuXfP7Jkm/XujrbLsDdEM/T7+vl7Q/Iv7R1DAAhtdP1Fv0GU+9AXRHpahtL5f0HUlPNDsOgGFV3XbnX5K+3PAsAGrAK8qAZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSKapbXfeltTv2zO/Iumd2ofphqyPjcfVnq9HxMULfaGRqAdhe2/Wd3hlfWw8rm7i6TeQDFEDyXQp6h1tD9CgrI+Nx9VBnfmZGkA9unSkBlADogaS6UTUtjfZfsP2Ydt3tz1PHWyvsv2i7RnbB23f2fZMdbI9YvuA7afbnqVOtlfY3m37UO/v7uq2Z+pX6z9T9zYI+KvmL5c0K+lVSVsi4vVWBxuS7UskXRIR+21fIGmfpO8v9cd1hu0fS1or6cKIuLHteepi+2FJv4uInb0r6C6PiLm25+pHF47U6yQdjogjEXFK0mOSbm55pqFFxJsRsb/38XuSZiStbHeqetgelXSDpJ1tz1In2xdKukbSA5IUEaeWWtBSN6JeKen4WZ/PKsn//GfYHpO0RtIr7U5SmwlJd0n6uO1Bana5pLclPdT70WKn7fPaHqpfXYjaC9yW5vdsts+X9LikbRHxbtvzDMv2jZJORsS+tmdpwDmSrpJ0f0SskfSBpCV3jqcLUc9KWnXW56OSTrQ0S61sn6v5oHdFRJbLK6+XdJPto5r/UWmj7UfaHak2s5JmI+LMM6rdmo98SelC1K9KusL2Zb0TE5slPdXyTEOzbc3/bDYTEfe1PU9dIuKeiBiNiDHN/129EBG3tDxWLSLiLUnHba/u3XSdpCV3YrPfDfJqFxGnbd8u6TlJI5IejIiDLY9Vh/WSbpX0F9vTvdt+GhHPtDgTFneHpF29A8wRSbe1PE/fWv+VFoB6deHpN4AaETWQDFEDyRA1kAxRA8kQNZAMUQPJ/AemPZurwbW7YQAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"with Timer():\n", | |
" bytes_data: bytes = redis_conn.get('model-001')\n", | |
" model: LogisticRegression = deserialize(bytes_data)\n", | |
"\n", | |
"predict_number(model, x_test[7])" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment