Skip to content

Instantly share code, notes, and snippets.

@Orbifold
Created August 23, 2018 12:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Orbifold/b67247a0c9299aca15ffa125341009ae to your computer and use it in GitHub Desktop.
Save Orbifold/b67247a0c9299aca15ffa125341009ae to your computer and use it in GitHub Desktop.
Linear regression with MXNet.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Linear regression with MXNet"
],
"metadata": {}
},
{
"cell_type": "code",
"source": [
"import mxnet as mx\n",
"from mxnet import nd, autograd, gluon\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"data_ctx = mx.cpu()\n",
"model_ctx = mx.cpu()"
],
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"~/conda/lib/python3.6/site-packages/urllib3/contrib/pyopenssl.py:46: DeprecationWarning: OpenSSL.rand is deprecated - you should use os.urandom instead\n",
" import OpenSSL.SSL\n"
]
}
],
"execution_count": 1,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"num_inputs = 2\n",
"num_outputs = 1\n",
"num_examples = 1000\n",
"\n",
"def real_fn(X):\n",
" return 2 * X[:, 0] - 3.4 * X[:, 1] + 4.2\n",
"\n",
"X = nd.random_normal(shape=(num_examples, num_inputs))\n",
"noise = 0.01 * nd.random_normal(shape=(num_examples,))\n",
"y = real_fn(X) + noise"
],
"outputs": [],
"execution_count": 2,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"batch_size = 4\n",
"train_data = gluon.data.DataLoader(gluon.data.ArrayDataset(X, y), batch_size=batch_size, shuffle=True)"
],
"outputs": [],
"execution_count": 3,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"net = gluon.nn.Dense(1, in_units=2)"
],
"outputs": [],
"execution_count": 4,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"print(net.weight)\n",
"print(net.bias)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Parameter dense0_weight (shape=(1, 2), dtype=float32)\n",
"Parameter dense0_bias (shape=(1,), dtype=float32)\n"
]
}
],
"execution_count": 5,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"net.collect_params().initialize(mx.init.Normal(sigma=1.), ctx=model_ctx)"
],
"outputs": [],
"execution_count": 6,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"net(nd.array([[0,1]]))"
],
"outputs": [
{
"output_type": "execute_result",
"execution_count": 7,
"data": {
"text/plain": [
"\n",
"[[-1.3058136]]\n",
"<NDArray 1x1 @cpu(0)>"
]
},
"metadata": {}
}
],
"execution_count": 7,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"square_loss = gluon.loss.L2Loss()"
],
"outputs": [],
"execution_count": 8,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.0001})"
],
"outputs": [],
"execution_count": 9,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"epochs = 150\n",
"loss_sequence = []\n",
"num_batches = num_examples / batch_size\n",
"\n",
"for e in range(epochs):\n",
" cumulative_loss = 0\n",
" # inner loop\n",
" for i, (data, label) in enumerate(train_data):\n",
" data = data.as_in_context(model_ctx)\n",
" label = label.as_in_context(model_ctx)\n",
" with autograd.record():\n",
" output = net(data)\n",
" loss = square_loss(output, label)\n",
" loss.backward()\n",
" trainer.step(batch_size)\n",
" cumulative_loss += nd.mean(loss).asscalar()\n",
" print(\"Epoch %s, loss: %s\" % (e, cumulative_loss / num_examples))\n",
" loss_sequence.append(cumulative_loss)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 0, loss: 2.7113678157925607\n",
"Epoch 1, loss: 2.5795415766239165\n",
"Epoch 2, loss: 2.4541269619464874\n",
"Epoch 3, loss: 2.3348282388448713\n",
"Epoch 4, loss: 2.2213501752614975\n",
"Epoch 5, loss: 2.113367976427078\n",
"Epoch 6, loss: 2.0106545008420946\n",
"Epoch 7, loss: 1.9128991372585296\n",
"Epoch 8, loss: 1.8199285967350005\n",
"Epoch 9, loss: 1.7314712901711464\n",
"Epoch 10, loss: 1.6473155253529548\n",
"Epoch 11, loss: 1.5672483497262002\n",
"Epoch 12, loss: 1.491078361749649\n",
"Epoch 13, loss: 1.4186117897629738\n",
"Epoch 14, loss: 1.3496565409302712\n",
"Epoch 15, loss: 1.2840780556797982\n",
"Epoch 16, loss: 1.2216709453463555\n",
"Epoch 17, loss: 1.162293390750885\n",
"Epoch 18, loss: 1.105811959862709\n",
"Epoch 19, loss: 1.0520711150765418\n",
"Epoch 20, loss: 1.0009514333605767\n",
"Epoch 21, loss: 0.95231286662817\n",
"Epoch 22, loss: 0.9060354414582252\n",
"Epoch 23, loss: 0.862010746061802\n",
"Epoch 24, loss: 0.8201237805783749\n",
"Epoch 25, loss: 0.780276043370366\n",
"Epoch 26, loss: 0.7423611706793308\n",
"Epoch 27, loss: 0.7062875881791115\n",
"Epoch 28, loss: 0.6719751496911048\n",
"Epoch 29, loss: 0.6393293917477131\n",
"Epoch 30, loss: 0.6082637111544609\n",
"Epoch 31, loss: 0.578710649073124\n",
"Epoch 32, loss: 0.5505954122543335\n",
"Epoch 33, loss: 0.5238473320901393\n",
"Epoch 34, loss: 0.4983994086384773\n",
"Epoch 35, loss: 0.4741862567961216\n",
"Epoch 36, loss: 0.45115084782242776\n",
"Epoch 37, loss: 0.42923357504606247\n",
"Epoch 38, loss: 0.4083809304535389\n",
"Epoch 39, loss: 0.38854364366829397\n",
"Epoch 40, loss: 0.36966929471492765\n",
"Epoch 41, loss: 0.35171456199884416\n",
"Epoch 42, loss: 0.33463025364279747\n",
"Epoch 43, loss: 0.3183785715252161\n",
"Epoch 44, loss: 0.30291297268867495\n",
"Epoch 45, loss: 0.28820096438378096\n",
"Epoch 46, loss: 0.27420299936831\n",
"Epoch 47, loss: 0.2608845283389091\n",
"Epoch 48, loss: 0.24821310175955297\n",
"Epoch 49, loss: 0.23615999141335486\n",
"Epoch 50, loss: 0.22468971230648457\n",
"Epoch 51, loss: 0.21377902992069722\n",
"Epoch 52, loss: 0.20339755322039127\n",
"Epoch 53, loss: 0.19351965060085058\n",
"Epoch 54, loss: 0.18412317997962235\n",
"Epoch 55, loss: 0.17518107267469168\n",
"Epoch 56, loss: 0.16667483854293824\n",
"Epoch 57, loss: 0.15858063255622984\n",
"Epoch 58, loss: 0.15088152593746781\n",
"Epoch 59, loss: 0.14355444251745939\n",
"Epoch 60, loss: 0.13658440142497422\n",
"Epoch 61, loss: 0.12995274115353822\n",
"Epoch 62, loss: 0.12364327760413289\n",
"Epoch 63, loss: 0.11763978493213653\n",
"Epoch 64, loss: 0.11192909374274314\n",
"Epoch 65, loss: 0.10649548701196909\n",
"Epoch 66, loss: 0.10132535001263022\n",
"Epoch 67, loss: 0.09640615732222796\n",
"Epoch 68, loss: 0.09172646640986204\n",
"Epoch 69, loss: 0.08727357058227062\n",
"Epoch 70, loss: 0.08303805483132601\n",
"Epoch 71, loss: 0.0790058258101344\n",
"Epoch 72, loss: 0.07517149670049549\n",
"Epoch 73, loss: 0.07152289288863539\n",
"Epoch 74, loss: 0.06805129644647241\n",
"Epoch 75, loss: 0.0647481979355216\n",
"Epoch 76, loss: 0.06160551110934466\n",
"Epoch 77, loss: 0.05861566048115492\n",
"Epoch 78, loss: 0.0557710908902809\n",
"Epoch 79, loss: 0.05306435045599937\n",
"Epoch 80, loss: 0.050488863822072745\n",
"Epoch 81, loss: 0.04803905757144093\n",
"Epoch 82, loss: 0.04570811653975397\n",
"Epoch 83, loss: 0.043489952296018604\n",
"Epoch 84, loss: 0.04137968973442912\n",
"Epoch 85, loss: 0.03937138960044831\n",
"Epoch 86, loss: 0.037460879892110825\n",
"Epoch 87, loss: 0.0356435053832829\n",
"Epoch 88, loss: 0.03391404440253973\n",
"Epoch 89, loss: 0.03226874783448875\n",
"Epoch 90, loss: 0.030703252350911498\n",
"Epoch 91, loss: 0.029213676789775492\n",
"Epoch 92, loss: 0.02779654618538916\n",
"Epoch 93, loss: 0.02644826531596482\n",
"Epoch 94, loss: 0.025165069661103188\n",
"Epoch 95, loss: 0.023944479756057263\n",
"Epoch 96, loss: 0.022782894434407352\n",
"Epoch 97, loss: 0.021677913954481483\n",
"Epoch 98, loss: 0.020626577369868754\n",
"Epoch 99, loss: 0.019626039378345013\n",
"Epoch 100, loss: 0.018674104558303952\n",
"Epoch 101, loss: 0.01776835938729346\n",
"Epoch 102, loss: 0.01690695836022496\n",
"Epoch 103, loss: 0.016087146033532916\n",
"Epoch 104, loss: 0.01530721583776176\n",
"Epoch 105, loss: 0.014565310759469867\n",
"Epoch 106, loss: 0.01385922425193712\n",
"Epoch 107, loss: 0.013187152079306543\n",
"Epoch 108, loss: 0.012547998392954469\n",
"Epoch 109, loss: 0.011939744890201836\n",
"Epoch 110, loss: 0.011361180235631764\n",
"Epoch 111, loss: 0.010810663416981697\n",
"Epoch 112, loss: 0.010286685751751064\n",
"Epoch 113, loss: 0.009788204517215491\n",
"Epoch 114, loss: 0.00931400292366743\n",
"Epoch 115, loss: 0.008862799366936087\n",
"Epoch 116, loss: 0.008433393939165399\n",
"Epoch 117, loss: 0.008024955290369689\n",
"Epoch 118, loss: 0.007636295366333798\n",
"Epoch 119, loss: 0.007266398513689637\n",
"Epoch 120, loss: 0.00691438973066397\n",
"Epoch 121, loss: 0.006579601499019191\n",
"Epoch 122, loss: 0.006261118083726615\n",
"Epoch 123, loss: 0.005958052480593324\n",
"Epoch 124, loss: 0.00566945843398571\n",
"Epoch 125, loss: 0.005394873068202287\n",
"Epoch 126, loss: 0.00513376785768196\n",
"Epoch 127, loss: 0.004885284634307027\n",
"Epoch 128, loss: 0.0046489134593866764\n",
"Epoch 129, loss: 0.004424064236460253\n",
"Epoch 130, loss: 0.004210209384793415\n",
"Epoch 131, loss: 0.0040066773153375835\n",
"Epoch 132, loss: 0.0038129280564608054\n",
"Epoch 133, loss: 0.0036285838466719724\n",
"Epoch 134, loss: 0.0034531862314324826\n",
"Epoch 135, loss: 0.0032863791410345585\n",
"Epoch 136, loss: 0.0031275631298776714\n",
"Epoch 137, loss: 0.002976552771171555\n",
"Epoch 138, loss: 0.0028328257801476867\n",
"Epoch 139, loss: 0.002695904175430769\n",
"Epoch 140, loss: 0.002565709551097825\n",
"Epoch 141, loss: 0.002441938458941877\n",
"Epoch 142, loss: 0.0023241431822534652\n",
"Epoch 143, loss: 0.0022119790550787\n",
"Epoch 144, loss: 0.0021053332465235143\n",
"Epoch 145, loss: 0.002003820699523203\n",
"Epoch 146, loss: 0.001907310605805833\n",
"Epoch 147, loss: 0.0018154382390202955\n",
"Epoch 148, loss: 0.0017279924581525847\n",
"Epoch 149, loss: 0.0016448109191842376\n"
]
}
],
"execution_count": 10,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"plt.figure(num=None,figsize=(8, 6))\n",
"plt.plot(loss_sequence)\n",
"\n",
"# Adding some bells and whistles to the plot\n",
"plt.grid(True, which=\"both\")\n",
"plt.xlabel('epoch',fontsize=14)\n",
"plt.ylabel('average loss',fontsize=14)"
],
"outputs": [
{
"output_type": "execute_result",
"execution_count": 11,
"data": {
"text/plain": [
"Text(0,0.5,'average loss')"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 576x432 with 1 Axes>"
],
"image/png": [
"\n"
]
},
"metadata": {}
}
],
"execution_count": 11,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"params = net.collect_params() # this returns a ParameterDict\n",
"\n",
"print('The type of \"params\" is a ',type(params))\n",
"\n",
"# A ParameterDict is a dictionary of Parameter class objects\n",
"# therefore, here is how we can read off the parameters from it.\n",
"\n",
"for param in params.values():\n",
" print(param.name,param.data())"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The type of \"params\" is a <class 'mxnet.gluon.parameter.ParameterDict'>\n",
"dense0_weight \n",
"[[ 1.9945382 -3.3480418]]\n",
"<NDArray 1x2 @cpu(0)>\n",
"dense0_bias \n",
"[4.0999794]\n",
"<NDArray 1 @cpu(0)>\n"
]
}
],
"execution_count": 12,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [],
"outputs": [],
"execution_count": null,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"language": "python",
"display_name": "Python 3"
},
"kernel_info": {
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.6.3",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"nteract": {
"version": "0.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment