Skip to content

Instantly share code, notes, and snippets.

@Elijas
Last active October 25, 2019 00:54
Show Gist options
  • Save Elijas/322836170de2110385a213ce586f56cc to your computer and use it in GitHub Desktop.
Save Elijas/322836170de2110385a213ce586f56cc to your computer and use it in GitHub Desktop.
Proof of Concept: ML model saving to Redis
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"from io import BytesIO\n",
"\n",
"import joblib\n",
"import redis as redis\n",
"from sklearn.datasets import load_digits\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.model_selection import train_test_split\n",
"import numpy as np \n",
"import matplotlib.pyplot as plt\n",
"from sklearn.utils.testing import ignore_warnings\n",
"from sklearn.exceptions import ConvergenceWarning\n",
"\n",
"\n",
"class Timer:\n",
" def __enter__(self):\n",
" self.start = time.time()\n",
" return self\n",
"\n",
" def __exit__(self, *args):\n",
" self.end = time.time()\n",
" self.interval_ms = (self.end - self.start) * 1e3\n",
" print(f'(time elapsed: {int(self.interval_ms):d} ms)')\n",
"\n",
"\n",
"def serialize(obj: object) -> bytes:\n",
" bytes_container = BytesIO()\n",
" joblib.dump(obj, bytes_container)\n",
" bytes_container.seek(0) # update to enable reading\n",
" bytes_data = bytes_container.read()\n",
" print(f'Model was packed to {len(bytes_data)} bytes')\n",
" return bytes_data\n",
"\n",
"\n",
"def deserialize(bytes_data: bytes) -> object:\n",
" bytes_container = BytesIO(bytes_data)\n",
" obj = joblib.load(bytes_container)\n",
" print(f'Bytes were unpacked to an object of type {obj.__class__.__name__}')\n",
" return obj\n",
"\n",
"@ignore_warnings(category=ConvergenceWarning)\n",
"def train_logistic_regression_model(x_train, y_train):\n",
" model = LogisticRegression(solver='lbfgs', multi_class='auto', max_iter=1e3)\n",
" model.fit(x_train, y_train)\n",
" return model\n",
"\n",
"def predict_number(model, x):\n",
" plt.imshow(np.reshape(x, (8,8)), cmap=plt.cm.gray)\n",
" print(f'Predicted number: {model.predict(x.reshape(1,-1))[0]}')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model weight matrix shape: (10, 64)\n"
]
}
],
"source": [
"digits_dataset = load_digits()\n",
"x_train, x_test, y_train, y_test = train_test_split(digits_dataset.data, digits_dataset.target, test_size=0.2, random_state=0)\n",
"\n",
"model: LogisticRegression = train_logistic_regression_model(x_train, y_train)\n",
"redis_conn = redis.StrictRedis(host=\"localhost\", port=26379, db=0)\n",
"print(f'Model weight matrix shape: {model.coef_.shape}')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model was packed to 6058 bytes\n",
"(time elapsed: 19 ms)\n"
]
}
],
"source": [
"with Timer():\n",
" bytes_data: bytes = serialize(model)\n",
" redis_conn.set('model-001', bytes_data)\n",
"\n",
"del model, bytes_data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Bytes were unpacked to an object of type LogisticRegression\n",
"(time elapsed: 4 ms)\n",
"Predicted number: 9\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAK3klEQVR4nO3d34tc9RnH8c+nq6GNvwKtLZoNXQUJSKEbCQEJaBrbEqtoL3qRgEKlkCvF0IJob0z/AdleFCFEXcFUaaMGEasVdLFCa82PbWvcWNKQkG20Ucrij0pD9OnFTkq0q3tm5pzvOfv0/YLg7uyw32eI75zZszPn64gQgDy+0PYAAOpF1EAyRA0kQ9RAMkQNJHNOE9/UdspT6suWLSu63qWXXlpsreXLlxdb6/Tp08XWOn78eLG1JOnDDz8stlZEeKHbG4k6q5KRSdL27duLrTU+Pl5srbm5uWJrbdu2rdhakjQ9PV10vYXw9BtIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSKZS1LY32X7D9mHbdzc9FIDBLRq17RFJv5B0vaQrJW2xfWXTgwEYTJUj9TpJhyPiSESckvSYpJubHQvAoKpEvVLS2W91me3d9gm2t9rea3tvXcMB6F+Vd2kt9Pau/3lrZUTskLRDyvvWS2ApqHKknpW06qzPRyWdaGYcAMOqEvWrkq6wfZntZZI2S3qq2bEADGrRp98Rcdr27ZKekzQi6cGIONj4ZAAGUunKJxHxjKRnGp4FQA14RRmQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDDt09GFiYqLoemNjY8XW2rNnT7G17r333mJrbdiwodhaEjt0AGgAUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyVTZoeNB2ydtv1ZiIADDqXKknpS0qeE5ANRk0agj4iVJ/ywwC4Aa1PYuLdtbJW2t6/sBGExtUbPtDtANnP0GkiFqIJkqv9J6VNLvJa22PWv7R82PBWBQVfbS2lJiEAD14Ok3kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAzb7vRhxYoVRdebmpoqtlbpx1ZK1sf1eThSA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQTJVrlK2y/aLtGdsHbd9ZYjAAg6ny2u/Tkn4SEfttXyBpn+3nI+L1hmcDMIAq2+68GRH7ex+/J2lG0sqmBwMwmL7epWV7TNIaSa8s8DW23QE6oHLUts+X9LikbRHx7qe/zrY7QDdUOvtt+1zNB70rIp5odiQAw6hy9tuSHpA0ExH3NT8SgGFUOVKvl3SrpI22p3t/vtfwXAAGVGXbnZclucAsAGrAK8qAZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIa9tPowOTlZdL0NGzYUW2t8fLzYWseOHSu21oEDB4qt1RUcqYFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZKpcePCLtv9o+0+9bXd+VmIwAIOp8jLRf0vaGBHv9y4V/LLt30TEHxqeDcAAqlx4MCS93/v03N4fLtYPdFTVi/mP2J6WdFLS8xGx4LY7tvfa3lv3kACqqxR1RHwUEeOSRiWts/2NBe6zIyLWRsTauocEUF1fZ78jYk7SlKRNjUwDYGhVzn5fbHtF7+MvSfq2pENNDwZgMFXOfl8i6WHbI5r/R+BXEfF0s2MBGFSVs99/1vye1ACWAF5RBiRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAynn9nZc3f1OatmfhMU1NTKdeSpO3btxdbKyK80O0cqYFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSKZy1L0L+h+wzUUHgQ7r50h9p6SZpgYBUI+q2+6MSrpB0s5mxwEwrKpH6glJd0n6+LPuwF5aQDdU2aHjRkknI2Lf592PvbSAbqhypF4v6SbbRyU9Jmmj7UcanQrAwBaNOiLuiYjRiBiTtFnSCxFxS+OTARgIv6cGkqmyQd5/RcSU5reyBdBRHKmBZIgaSIaogWSIGkiGqIFkiBpIhqiBZJb8tjsltzmZmJgotpYkzc3NFV2vlJKPa3x8vNhaknT06NFia7HtDvB/gqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWQqXc6odyXR9yR9JOk0lwEGuqufa5R9KyLeaWwSALXg6TeQTNWoQ9Jvbe+zvXWhO7DtDtANVZ9+r4+IE7a/Kul524ci4qWz7xAROyTtkMq+9RLAJ1U6UkfEid5/T0p6UtK6JocCMLgqG+SdZ/uCMx9L+q6k15oeDMBgqjz9/pqkJ22fuf8vI+LZRqcCMLBFo46II5K+WWAWADXgV1pAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMv289bKTpqamiq01PT1dbC1J2rNnT8q1LrroomJrldwGpys4UgPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kEylqG2vsL3b9iHbM7avbnowAIOp+trvn0t6NiJ+YHuZpOUNzgRgCItGbftCSddI+qEkRcQpSaeaHQvAoKo8/b5c0tuSHrJ9wPbO3vW/P4Ftd4BuqBL1OZKuknR/RKyR9IGkuz99p4jYERFr2eYWaFeVqGclzUbEK73Pd2s+cgAdtGjUEfGWpOO2V/duuk7S641OBWBgVc9+3yFpV+/M9xFJtzU3EoBhVIo6IqYl8bMysATwijIgGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGknFE1P9N7fq/aQeMjY0VXW9ycrLYWtdee22xtY4dO1ZsrdJ/ZyVFhBe6nSM1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZDMolHbXm17+qw/79reVmI4AP1b9BplEfGGpHFJsj0i6e+Snmx4LgAD6vfp93WS/hYR5V68C6AvVS8RfMZmSY8u9AXbWyVtHXoiAEOpfKTuXfP7Jkm/XujrbLsDdEM/T7+vl7Q/Iv7R1DAAhtdP1Fv0GU+9AXRHpahtL5f0HUlPNDsOgGFV3XbnX5K+3PAsAGrAK8qAZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSKapbXfeltTv2zO/Iumd2ofphqyPjcfVnq9HxMULfaGRqAdhe2/Wd3hlfWw8rm7i6TeQDFEDyXQp6h1tD9CgrI+Nx9VBnfmZGkA9unSkBlADogaS6UTUtjfZfsP2Ydt3tz1PHWyvsv2i7RnbB23f2fZMdbI9YvuA7afbnqVOtlfY3m37UO/v7uq2Z+pX6z9T9zYI+KvmL5c0K+lVSVsi4vVWBxuS7UskXRIR+21fIGmfpO8v9cd1hu0fS1or6cKIuLHteepi+2FJv4uInb0r6C6PiLm25+pHF47U6yQdjogjEXFK0mOSbm55pqFFxJsRsb/38XuSZiStbHeqetgelXSDpJ1tz1In2xdKukbSA5IUEaeWWtBSN6JeKen4WZ/PKsn//GfYHpO0RtIr7U5SmwlJd0n6uO1Bana5pLclPdT70WKn7fPaHqpfXYjaC9yW5vdsts+X9LikbRHxbtvzDMv2jZJORsS+tmdpwDmSrpJ0f0SskfSBpCV3jqcLUc9KWnXW56OSTrQ0S61sn6v5oHdFRJbLK6+XdJPto5r/UWmj7UfaHak2s5JmI+LMM6rdmo98SelC1K9KusL2Zb0TE5slPdXyTEOzbc3/bDYTEfe1PU9dIuKeiBiNiDHN/129EBG3tDxWLSLiLUnHba/u3XSdpCV3YrPfDfJqFxGnbd8u6TlJI5IejIiDLY9Vh/WSbpX0F9vTvdt+GhHPtDgTFneHpF29A8wRSbe1PE/fWv+VFoB6deHpN4AaETWQDFEDyRA1kAxRA8kQNZAMUQPJ/AemPZurwbW7YQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"with Timer():\n",
" bytes_data: bytes = redis_conn.get('model-001')\n",
" model: LogisticRegression = deserialize(bytes_data)\n",
"\n",
"predict_number(model, x_test[7])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment