Skip to content

Instantly share code, notes, and snippets.

@jkrukowski
Last active May 31, 2019 13:15
Show Gist options
  • Save jkrukowski/d6510f8a97ec3c3aae42ffd41c468a76 to your computer and use it in GitHub Desktop.
Save jkrukowski/d6510f8a97ec3c3aae42ffd41c468a76 to your computer and use it in GitHub Desktop.
liniear-regression.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "liniear-regression.ipynb",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "swift",
"display_name": "Swift"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/jkrukowski/d6510f8a97ec3c3aae42ffd41c468a76/liniear-regression.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CoEfOhVWQI9T",
"colab_type": "text"
},
"source": [
"# Linear Regression with Swift for TensorFlow"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w9oJuFnoUKuS",
"colab_type": "text"
},
"source": [
"## Imports\n",
"### Import necessary libraries"
]
},
{
"cell_type": "code",
"metadata": {
"id": "glNzWvo49gIl",
"colab_type": "code",
"colab": {}
},
"source": [
"import TensorFlow\n",
"import Python\n",
"%include \"EnableIPythonDisplay.swift\"\n",
"IPythonDisplay.shell.enable_matplotlib(\"inline\")\n",
"let plt = Python.import(\"matplotlib.pyplot\")\n",
"let np = Python.import(\"numpy\")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "S9aaSkrQVw0u",
"colab_type": "text"
},
"source": [
"## Data\n",
"### Data for linear regression"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Qgtq7-6OV2oS",
"colab_type": "code",
"colab": {}
},
"source": [
"let x: [Float] = [1.0, 2.0, 3.0]\n",
"let y: [Float] = [2.0, 4.0, 6.0]"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "tdNVW4siUNX5",
"colab_type": "text"
},
"source": [
"## Helper Functions\n",
"### Functions for plotting and numpy linear regression\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "DS-OLs3aTjvF",
"colab_type": "code",
"colab": {}
},
"source": [
"func plotData(\n",
" x: [Float], \n",
" y: [Float], \n",
" fitLine: PythonObject\n",
") {\n",
" plt.plot(x, y, \"yo\", x, fitLine, \"--k\")\n",
" plt.xlim(0, 8)\n",
" plt.ylim(0, 8)\n",
" plt.show()\n",
"}\n",
"\n",
"func linearRegression(\n",
" x: [Float], \n",
" y: [Float]\n",
") -> PythonObject {\n",
" let fit = np.polyfit(x, y, 1)\n",
" return np.poly1d(fit)(x)\n",
"}"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "NFk9Lyd2UyMb",
"colab_type": "text"
},
"source": [
"## NumPy Linear Regression"
]
},
{
"cell_type": "code",
"metadata": {
"id": "3cHkRZ0jA8JK",
"colab_type": "code",
"outputId": "361466d2-3884-40d0-a35b-3c5e0079f10c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 269
}
},
"source": [
"let regressionFunction = linearRegression(x: x, y: y)\n",
"plotData(x: x, y: y, fitLine: regressionFunction)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFdRJREFUeJzt3XtwVfW5xvHvm8RGbsdrDCgi4NCI\n1AoaKdWDShEHR6tWKZAiHYtOUDiI9nRaRYsG0Wk7jmMvDj2RNIoGtBKdqq0eLUhrOx7bhAYVkVqD\nUFCSWEW5aJTkPX9kaxGT7BXdK2v/kuczkzHZWRufYcjDj3et31rm7oiISDhykg4gIiJdo+IWEQmM\niltEJDAqbhGRwKi4RUQCo+IWEQlMpOI2s2vMbL2ZvWhmK8zswLiDiYhI+9IWt5kdBVwFFLv7l4Bc\nYHrcwUREpH1RRyV5QB8zywP6Aq/HF0lERDqTl+4Ad99mZrcBW4D3gCfd/cn9jzOzUqAUoF+/ficf\nd9xxmc4qItJj1dbWvunuBVGOtXRb3s3sEKAamAbsAB4EVrr7fR29p7i42GtqaqInFhHp5cys1t2L\noxwbZVRyFrDJ3Zvc/UPgIeDUzxNQREQ+uyjFvQUYZ2Z9zcyAicCGeGOJiEhH0ha3uz8HrATWAi+k\n3lMecy4REelA2pOTAO5+I3BjzFlERCQC7ZwUEQmMiltEJDAqbhGRwKi4RUQCo+IWEQmMiltEJDAq\nbhGRwKi4RUQCo+IWEQmMiltEJDAqbhGRwKi4RUQCo+IWEQmMiltEJDAqbhGRwKi4RUQCo+IWEQmM\niltEJDBpi9vMisysbp+Pd83s6u4IJyIin5b2mZPuvhEYDWBmucA24OGYc4mISAe6OiqZCLzq7pvj\nCCMiIul1tbinAyviCCIiItFELm4z+wJwPvBgB98vNbMaM6tpamrKVD4REdlPV1bc5wBr3b2hvW+6\ne7m7F7t7cUFBQWbSiYjIp3SluEvQmEREJHGRitvM+gGTgIfijSMiIumkvRwQwN13A4fFnEVERCLQ\nzkkRkcCouEVEAqPiFhEJjIpbRCQwKm4RkcCouEVEAqPiFhEJjIpbRCQwKm4RkcCouEVEAqPiFhEJ\njIpbRCQwKm4RkcCouEVEAqPiFhEJjIpbRCQwKm4RkcCouEVEAqPiFhEJTNSHBR9sZivN7GUz22Bm\nX407mIShoaGKZ58dypo1OTz77FAaGqqSjiTS40V6WDDwU+AJd59iZl8A+saYSQLR0FDFxo2ltLbu\nAaC5eTMbN5YCUFg4I8loIj1a2hW3mR0EnA5UALj7B+6+I+5gkv3q66+npWUPVVXw1lttr7W27qG+\n/vpkg4n0cFFGJcOAJqDSzP5mZkvNrN/+B5lZqZnVmFlNU1NTxoNK9mlu3sJvfgNLl8Lq1Z98XUTi\nE6W484CTgCXuPgbYDVy7/0HuXu7uxe5eXFBQkOGYko3y84cwdixMmwYXXfTJ10UkPlGKeyuw1d2f\nS329krYil15s9+7dDBu2mMGD+3LFFZCT+pOUk9OX4cNvSTacSA+XtrjdfTvwTzMrSr00EXgp1lSS\n1VpaWvjGN77B9773BEVF5eTnHwMY+fnHUFRUrhOTIjGLelXJPKAqdUVJPfCd+CJJtlu0aBFPPfUU\nd911F4WFM1TUIt0sUnG7ex1QHHMWCcDjjz/OokWLuPTSS7nsssuSjiPSK2nnpES2efNmLrnkEr78\n5S9z5513YmZJRxLplVTcEtm2bds47LDDqK6upm9f7cESSUrUGbcIp556Khs2bCA3NzfpKCK9mlbc\nklZVVRVlZWW0traqtEWygIpbOrV+/XpKS0tZtWoVLS0tSccREVTc0omdO3dy8cUXM2DAAB544AEO\nOOCApCOJCJpxSwfcncsuu4xXXnmFVatWMWjQoKQjiUiKVtzSrnXr1vHwww9z6623cuaZZyYdR0T2\noRW3tGv06NHU1dUxcuTIpKOIyH604pZPaGpq4tFHHwVg1KhR5OToj4hIttFPpXyspaWFkpISpk2b\nxvbt25OOIyId0KhEPnbTTTexatUqKioqGDhwYNJxRKQDWnELAL/97W9ZvHgxs2bNYtasWUnHEZFO\nqLiFpqYmZs6cyejRo/nFL36RdBwRSUOjEqGgoIDbbruNM844gz59+iQdR0TSUHH3cq+//jpHHnmk\nxiMiAdGopBe79957GTFiBLW1tUlHEZEuUHH3Ui+88AKzZ8/mlFNO4cQTT0w6joh0QaRRiZm9BuwE\nWoC97q7HmAXs3Xff5eKLL+aggw7i/vvvJy9PEzORkHTlJ3aCu78ZWxLpFu7OrFmzqK+vZ/Xq1bpe\nWyRAGpX0Mi0tLRx11FH86Ec/4vTTT086joh8Bubu6Q8y2wS8DTjwP+5e3s4xpUApwJAhQ07evHlz\nhqPK5+XuHz/gd9/PRSR5ZlYbdQwddcX9n+5+EnAOMNfMPrVUc/dydy929+KCgoIuxJXu0NjYyPjx\n41m7di2ASlskYJGK2923pf7bCDwMjI0zlGTWRzePqq2t1TMjRXqAtMVtZv3MbMBHnwNnAy/GHUwy\nZ+HChaxevZolS5bo0j+RHiDKVSWFwMOpf1rnAcvd/YlYU0nGPPbYY9x6661cfvnlXHrppUnHEZEM\nSFvc7l4PaJkWqMrKSsaMGcPPf/7zpKOISIZo50UP98ADD/DWW29x4IEHJh1FRDJE13H3UD/72c9o\naGggLy+PI444Iuk4IpJBKu4e6J577mH+/PksXbo06SgiEgMVdw+zbt06rrjiCiZMmMAPfvCDpOOI\nSAxU3D3IO++8w5QpUzjkkENYsWKFbh4l0kPpJ7sHWbBgAZs2beLpp5+msLAw6TgiEhMVdw9y8803\nM2nSJMaPH590FBGJkUYlPcDGjRtpbm7m0EMP5cILL0w6jojETMUduO3btzNhwgQ9M1KkF1FxB2zv\n3r2UlJSwY8cOrr322qTjiEg30Yw7YDfccANr1qxh2bJlnHDCCUnHEZFuohV3oB555BF+/OMfM3v2\nbGbOnJl0HBHpRiruQB177LGUlJRwxx13JB1FRLqZRiWB2bt3L7m5uYwaNYrly5cnHUdEEqAVd2Cu\nuOIKLr/8cqI8K1REeiYVd0AqKyupqKhg0KBBemakSC+m4g5EXV0dc+bMYeLEiZSVlSUdR0QSpOIO\nwI4dO5gyZQqHHXYYy5cv1wN/RXq5yCcnzSwXqAG2uft58UUSgIaGKurrr6e5eQt///sRvP32Hh59\n9Ak9FEFEurTing9siCuI/FtDQxUbN5bS3LwZcL74xQaWL9/LscduSjqaiGSBSMVtZoOBcwE9UqUb\n1NdfT2vrHurq4MEHwR3y89+jvv76pKOJSBaIuuK+A/g+0NrRAWZWamY1ZlbT1NSUkXC9VXPzFv71\nL1i0CB59FN5//9+vi4ikLW4zOw9odPfazo5z93J3L3b34oKCgowF7I1yc49m0SJ47z0oK4M+fdpe\nz88fkmwwEckKUVbcpwHnm9lrwP3A18zsvlhT9XL3338Czz8P3/0uDBvW9lpOTl+GD78l2WAikhXS\nFre7X+fug919KDAdWO3ul8SerJd65ZVX+OUvH+fb357IeecdAxj5+cdQVFROYeGMpOOJSBbQvUqy\nzIgRI1izZg1jx44lPz8/6TgikoW6tAHH3dfoGu54vPfeezzzzDMAjB8/XqUtIh3SzsksMXfuXCZM\nmMCmTbpWW0Q6p+LOAhUVFVRWVrJgwQKGfXQ2UkSkAyruhK1du5a5c+cyadIkbrzxxqTjiEgAVNwJ\n2rlzJ1OmTKGgoICqqirdPEpEItFVJQnq378/8+bNY9y4cWjTkohEpeJOyK5du+jfvz/XXHNN0lFE\nJDAalSRgzZo1DB06lOeeey7pKCISIBV3N3vjjTeYPn06BQUFHH/88UnHEZEAaVTSjT788EOmTZvG\nzp07WbVqFQMGDEg6kogESMXdjRYsWMAzzzxDVVUVo0aNSjqOiARKo5Ju0traSlNTE3PmzOFb3/pW\n0nFEJGBacXeTnJwcKisraW3t8FkUIiKRaMUdsz179vDNb36Tl156CTPTJhsR+dxU3DFyd6688kqq\nq6vZskWPHRORzFBxx+iuu+5i2bJlLFy4kMmTJycdR0R6CBV3TGpra5k3bx5nn302P/zhD5OOIyI9\niIo7Jj/5yU844ogjdPMoEck4XVUSk2XLlrFlyxYOP/zwpKOISA+TdsVtZgea2V/MbJ2ZrTezsu4I\nFqrq6mrefvtt8vPzGTFiRNJxRKQHijIqaQa+5u4nAqOByWY2Lt5YYVq9ejVTp05l0aJFSUcRkR4s\n7ajE3R3YlfrygNSHxxkqRNu2bWP69OkUFRVx8803Jx1HRHqwSCcnzSzXzOqARuApd//U/UjNrNTM\nasyspqmpKdM5s9pHN4/as2cP1dXV9O/fP+lIItKDRSpud29x99HAYGCsmX2pnWPK3b3Y3Yt729Nc\nFi9ezJ///GcqKioYOXJk0nFEpIfr0lUl7r7DzJ4GJgMvxhMpPFdeeSWFhYVMmzYt6Sgi0gtEuaqk\nwMwOTn3eB5gEvBx3sBBs376dlpYWBg4cyJw5c5KOIyK9RJRRySDgaTN7HvgrbTPux+KNlf12797N\npEmTmDFjRtJRRKSXiXJVyfPAmG7IEoyPbh61fv16br/99qTjiEgvo52Tn0F5eTn33nsvZWVlTJo0\nKek4ItLL6F4lXVRTU8NVV13F5MmTueGGG5KOIyK9kIq7i8yMcePGcd9995GTo98+Eel+GpVE5O6Y\nGSeffDJ/+MMfko4jIr2YlowR3XLLLcyfP5+Wlpako4hIL6fijuD3v/89Cxcu5M0339R4REQSpxZK\nY+vWrZSUlDBy5EjKy8sxs6QjiUgvp+LuxAcffMDUqVN5//33qa6upl+/fklHEhHRycnO1NXVsW7d\nOu6++26OO+64pOOIiAAq7k6NHTuWV199lYEDByYdRUTkYxqVtGPjxo1UVlYCqLRFJOtoxb2fXbt2\ncdFFF9HY2MgFF1zAoYcemnQkEZFPUHHvw92ZPXs2GzZs4Mknn1Rpi0hWUnHvY8mSJSxfvpzFixdz\n1llnJR1HRKRdmnGnbN26lWuuuYZzzz2X6667Luk4IiId0oo7ZfDgwaxcuZLTTjtNuyNFJKv1+oZq\nbW3lxRfbHp/59a9/XXNtEcl6vb64Fy9ezJgxYz4ubxGRbJd2VGJmRwPLgELAgXJ3/2ncweLS0FBF\nff31NDdvoa6ugJtuamLGjBmMGjUq6WgiIpFEmXHvBf7b3dea2QCg1syecveXYs6WcQ0NVWzcWEpr\n6x4aG+HGGxsZOtQoKztTN48SkWCkHZW4+xvuvjb1+U5gA3BU3MHiUF9/Pa2te/jwQygrg7174aab\nnIaGm5OOJiISWZeuKjGzobQ98f25dr5XCpQCDBkyJAPRMq+5eQsAOTlwyikwdSoMGfLv10VEQhD5\n5KSZ9Qeqgavd/d39v+/u5e5e7O7FBQUFmcyYMfn5bX+h5ObCpZfCGWd88nURkRBEKm4zO4C20q5y\n94fijRSf4cNvISen7ydey8npy/DhtySUSESk69IWt7WdtasANrj77fFHik9h4QyKisrJzz8GMPLz\nj6GoqJzCwhlJRxMRiSzKjPs0YCbwgpnVpV5b4O6/iy9WfAoLZ6ioRSRoaYvb3f8E6Fo5EZEs0et3\nToqIhEbFLSISGBW3iEhgVNwiIoFRcYuIBEbFLSISGBW3iEhgVNwiIoFRcYuIBEbFLSISGBW3iEhg\nVNwiIoFRcYuIBEbFLSISGBW3iEhgVNwiIoFRcYuIBEbFLSISmCgPC/6VmTWa2YvdEUhERDoXZcV9\nNzA55hwiIhJR2uJ29z8Cb3VDFhERiUAzbhGRwGSsuM2s1MxqzKymqakpU7+siIjsJ2PF7e7l7l7s\n7sUFBQWZ+mVFRGQ/GpWIiAQmyuWAK4BngSIz22pml8UfS0REOpKX7gB3L+mOICIiEo1GJSIigVFx\ni4gERsUtIhIYFbeISGBU3CIigVFxi4gERsUtIhIYFbeISGBU3CIigVFxi4gERsUtIhIYFbeISGBU\n3CIigVFxi4gERsUtIhIYFbeISGBU3CIigVFxi4gERsUtIhKYSMVtZpPNbKOZ/cPMro07lIiIdCzK\nU95zgTuBc4DjgRIzOz7uYCIi0r4oK+6xwD/cvd7dPwDuBy6IN5aIiHQkL8IxRwH/3OfrrcBX9j/I\nzEqB0tSXzWb24uePF6vDgTeTDhGBcmaWcmaWcmZOUdQDoxR3JO5eDpQDmFmNuxdn6teOQwgZQTkz\nTTkzSzkzx8xqoh4bZVSyDTh6n68Hp14TEZEERCnuvwIjzGyYmX0BmA48Em8sERHpSNpRibvvNbP/\nAv4XyAV+5e7r07ytPBPhYhZCRlDOTFPOzFLOzImc0dw9ziAiIpJh2jkpIhIYFbeISGAyWtwhbI03\ns1+ZWWO2X2duZkeb2dNm9pKZrTez+Ulnao+ZHWhmfzGzdamcZUln6oiZ5ZrZ38zssaSzdMTMXjOz\nF8ysriuXh3U3MzvYzFaa2ctmtsHMvpp0pv2ZWVHq9/Gjj3fN7Oqkc7XHzK5J/fy8aGYrzOzATo/P\n1Iw7tTX+78Ak2jbp/BUocfeXMvI/yBAzOx3YBSxz9y8lnacjZjYIGOTua81sAFALXJiFv58G9HP3\nXWZ2APAnYL67/1/C0T7FzL4LFAP/4e7nJZ2nPWb2GlDs7lm9WcTM7gGecfelqavN+rr7jqRzdSTV\nT9uAr7j75qTz7MvMjqLt5+Z4d3/PzH4N/M7d7+7oPZlccQexNd7d/wi8lXSOdNz9DXdfm/p8J7CB\ntl2sWcXb7Ep9eUDqI+vOeJvZYOBcYGnSWUJnZgcBpwMVAO7+QTaXdspE4NVsK+195AF9zCwP6Au8\n3tnBmSzu9rbGZ13RhMjMhgJjgOeSTdK+1AiiDmgEnnL3bMx5B/B9oDXpIGk48KSZ1aZuI5GNhgFN\nQGVq9LTUzPolHSqN6cCKpEO0x923AbcBW4A3gHfc/cnO3qOTk1nOzPoD1cDV7v5u0nna4+4t7j6a\ntl21Y80sq0ZQZnYe0OjutUlnieA/3f0k2u7GOTc12ss2ecBJwBJ3HwPsBrLynBZAapRzPvBg0lna\nY2aH0DadGAYcCfQzs0s6e08mi1tb4zMsNTOuBqrc/aGk86ST+ufy08DkpLPs5zTg/NT8+H7ga2Z2\nX7KR2pdafeHujcDDtI0gs81WYOs+/7JaSVuRZ6tzgLXu3pB0kA6cBWxy9yZ3/xB4CDi1szdksri1\nNT6DUif9KoAN7n570nk6YmYFZnZw6vM+tJ2cfjnZVJ/k7te5+2B3H0rbn8vV7t7piiYJZtYvdSKa\n1OjhbCDrrn5y9+3AP83so7vZTQSy6qT5fkrI0jFJyhZgnJn1Tf3cT6TtnFaHMnl3wM+yNb7bmdkK\n4EzgcDPbCtzo7hXJpmrXacBM4IXU/Bhggbv/LsFM7RkE3JM6a58D/Nrds/ZyuyxXCDzc9rNLHrDc\n3Z9INlKH5gFVqUVaPfCdhPO0K/UX4CRgdtJZOuLuz5nZSmAtsBf4G2m2v2vLu4hIYHRyUkQkMCpu\nEZHAqLhFRAKj4hYRCYyKW0QkMCpuEZHAqLhFRALz/y/yMmPOv3ufAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IcSlAmqpXFQ4",
"colab_type": "text"
},
"source": [
"## Swift for TensorFlow Linear Regression\n",
"### Convert Swift arrays to `Tensor`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "JGndhznPvOAr",
"colab_type": "code",
"colab": {}
},
"source": [
"let dataX = Tensor<Float>(x)\n",
"let dataY = Tensor<Float>(y)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "wLAFXRftvPnP",
"colab_type": "text"
},
"source": [
"### S4TF Linear Model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "pLeqd32UDGJi",
"colab_type": "code",
"colab": {}
},
"source": [
"struct LiniearModel: Layer {\n",
" var w: Tensor<Float>\n",
" var b: Tensor<Float>\n",
" \n",
" init() {\n",
" self.w = Tensor<Float>(randomUniform: [1])\n",
" self.b = Tensor<Float>(randomUniform: [1])\n",
" }\n",
"\n",
" @differentiable\n",
" func call(_ input: Tensor<Float>) -> Tensor<Float> {\n",
" return w * input + b\n",
" }\n",
"}"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "hxfk4CYxv39O",
"colab_type": "text"
},
"source": [
"### Train Function"
]
},
{
"cell_type": "code",
"metadata": {
"id": "gq2SD-H1v21s",
"colab_type": "code",
"colab": {}
},
"source": [
"func train(\n",
" x: Tensor<Float>, \n",
" y: Tensor<Float>, \n",
" model: inout LiniearModel, \n",
" optimizer: SGD<LiniearModel>,\n",
" epoch: Int\n",
") {\n",
" for _ in 0..<epoch {\n",
" let grad = gradient(at: model) { m -> Tensor<Float> in\n",
" let predicted = m(x)\n",
" return meanSquaredError(predicted: predicted, expected: y)\n",
" }\n",
" optimizer.update(&model.allDifferentiableVariables, along: grad)\n",
" } \n",
"}"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "7Ob1Qu_jvZ_P",
"colab_type": "text"
},
"source": [
"### Training Loop"
]
},
{
"cell_type": "code",
"metadata": {
"id": "yuNumQaFGFIf",
"colab_type": "code",
"colab": {}
},
"source": [
"var model = LiniearModel()\n",
"\n",
"let optimizer = SGD(\n",
" for: model, \n",
" learningRate: 0.1\n",
")\n",
"\n",
"// try different epoch values\n",
"\n",
"train(\n",
" x: dataX, \n",
" y: dataY, \n",
" model: &model, \n",
" optimizer: optimizer, \n",
" epoch: 10\n",
")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "c_ap6q_i-Gp0",
"colab_type": "text"
},
"source": [
"### Plot the result"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Qqq1qHCvIjHX",
"colab_type": "code",
"outputId": "ed410152-35bf-45c8-dce0-8d06c1fa5cc6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 269
}
},
"source": [
"func swiftLinearRegression(\n",
" x: Tensor<Float>, \n",
" model: LiniearModel\n",
") -> PythonObject {\n",
" let result = model(x)\n",
" return result.makeNumpyArray()\n",
"}\n",
"\n",
"let swiftRegressionFunction = swiftLinearRegression(x: dataX, model: model)\n",
"plotData(x: x, y: y, fitLine: swiftRegressionFunction)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFctJREFUeJzt3X1w1dWdx/HPNyHGEGmlGlIUFZBt\nqMqI7i1gFRUQSxSi1GU3FNsK7WSlLGJXdLF20Ha0T9ZO7VRtWZ7qGGEUgS5alDriaGdYNCBdEYyD\nqVIUk1i1PoAgyXf/yLWDmHB/0Xtz7kner5mM9+F3r59huB9Ozvmd+zN3FwAgHgWhAwAAOofiBoDI\nUNwAEBmKGwAiQ3EDQGQobgCITKLiNrPvmtlzZrbVzJaZ2ZG5DgYAaF/G4jaz4yVdJSnl7qdJKpRU\nnetgAID2JZ0q6SWpxMx6Seot6dXcRQIAHE6vTAe4+ytm9nNJOyXtlbTO3dcdepyZ1UiqkaTS0tJ/\nHjp0aLazAkC3tWnTptfdvSzJsZZpy7uZ9ZX0gKR/k/SWpPslrXD3ezp6TSqV8rq6uuSJAaCHM7NN\n7p5KcmySqZILJP3F3Zvd/QNJKyV9+dMEBAB8ckmKe6ekUWbW28xM0jhJ23MbCwDQkYzF7e4bJa2Q\ntFnSs+nXLMhxLgBABzIuTkqSu98o6cYcZwEAJMDOSQCIDMUNAJGhuAEgMhQ3AESG4gaAyFDcABAZ\nihsAIkNxA0BkKG4AiAzFDQCRobgBIDIUNwBEhuIGgMhQ3AAQGYobACJDcQNAZChuAIgMxQ0AkclY\n3GZWYWZbDvp528yu7opwAICPy3jNSXevlzRcksysUNIrklblOBcAoAOdnSoZJ+lFd385F2EAAJl1\ntrirJS3LRRAAQDKJi9vMjpBUJen+Dp6vMbM6M6trbm7OVj4AwCE6M+KulLTZ3Rvbe9LdF7h7yt1T\nZWVl2UkHAPiYzhT3VDFNAgDBJSpuMyuVNF7SytzGAQBkkvF0QEly9/ckHZPjLACABNg5CQCRobgB\nIDIUNwBEhuIGgMhQ3AAQGYobACJDcQNAZChuAIgMxQ0AkaG4ASAyFDcARIbiBoDIUNwAEBmKGwAi\nQ3EDQGQobgCIDMUNAJGhuAEgMhQ3AEQm6cWCjzazFWb2vJltN7Ozch0McWhsrNWGDQP1+OMF2rBh\noBoba0NHArq9RBcLlnS7pIfd/V/M7AhJvXOYCZFobKxVfX2NWlv3SJL27XtZ9fU1kqTy8mkhowHd\nWsYRt5l9VtK5khZJkrvvd/e3ch0M+a+h4YZ/lPaHWlv3qKHhhkCJgJ4hyVTJIEnNkpaY2TNmttDM\nSg89yMxqzKzOzOqam5uzHhT5Z9++nZ16HEB2JCnuXpLOlHSXu58h6T1J8w49yN0XuHvK3VNlZWVZ\njol8VFx8YqceB5AdSYp7l6Rd7r4xfX+F2oocPdzgwbeooOCjyx0FBb01ePAtgRIBPUPG4nb31yT9\n1cwq0g+Nk7Qtp6kQhfLyaaqoWKCiohMlmYqLT1JFxQIWJoEcS3pWyWxJtekzShokTc9dJMSkvHya\nPve5f9WBAwdUUlISOg7QIyQqbnffIimV4yyIzJ49e1RYWKji4mIVFRWFjgP0GOycxCfS0tKi6upq\nVVZWqrW1NXQcoEehuNFp7q6rrrpKa9as0WWXXaaCAv4aAV2JTxw67bbbbtOdd96puXPnatasWaHj\nAD0OxY1OWbFiha699lpNmTJFP/3pT0PHAXokihud8sUvflFTp07V3XffzRQJEAifPCTyt7/9Te6u\nU089Vffee6+OPPLI0JGAHoviRkZNTU0aMWKEbriBL48C8gHFjcPas2ePJk2apN27d+uSSy4JHQeA\nku+cRA/U0tKiadOm6emnn9aqVas0cuTI0JEAiOLGYcydO1erV6/W7bffzmgbyCMUNzo0duxYlZSU\n6KqrrgodBcBBzN2z/qapVMrr6uqy/r7oGq+//rqOPfbY0DGAHsXMNrl7ou+EYnESH7FhwwYNGjRI\nq1atCh0FQAcobvzDjh07VFVVpc9//vMaPXp06DgAOkBxQ1Lb9EhlZaUkae3atUyVAHmMxUlo//79\nqqqq0q5du/TYY49pyJAhoSMBOAxG3FBRUZEmTZqke+65R2eddVboOAAyYMTdwzU1Nalfv366/vrr\nQ0cBkFCiEbeZvWRmz5rZFjPjPL9u4le/+pUqKipUX18fOgqATujMiHuMu7+esyToUqtXr9bVV1+t\nSy65hDltIDLMcfdATz31lL72ta/pS1/6kmpra1VYWBg6EoBOSFrcLmmdmW0ys5r2DjCzGjOrM7O6\n5ubm7CVEVr388suaOHGi+vfvrzVr1qh3796hIwHopKTFfY67nympUtIsMzv30APcfYG7p9w9VVZW\nltWQyJ7y8nJNnjxZa9euVb9+/ULHAfAJJJrjdvdX0v9tMrNVkkZIeiKXwZBd77//vvbu3au+ffvq\nt7/9beg4AD6FjCNuMys1sz4f3pZ0oaStuQ6G7GltbdU3vvENnXPOOXr//fdDxwHwKSUZcZdLWmVm\nHx5/r7s/nNNUyKp58+bp/vvv16233sq1IoFuIGNxu3uDpNO7IAty4M4779Stt96qWbNm6Zprrgkd\nB0AWcDpgN7Z27VrNnj1bkyZN0u233670b00AIkdxd2Onn366rrjiCi1btoxztYFuhOLuhhobG3Xg\nwAEdd9xxWrRokUpLS0NHApBFFHc388Ybb+j888/XjBkzQkcBkCMUdzeyb98+TZ48WQ0NDfr2t78d\nOg6AHOFrXbuJ1tZWXXHFFXriiSe0bNkynXvuxza3AugmGHF3EzfeeKOWL1+un/zkJ6qurg4dB0AO\nMeLuJiZNmqSWlhZdd911oaMAyDGKO3K7du3SgAEDNGLECI0YMSJ0HABdgKmSiG3atElDhw7VXXfd\nFToKgC5EcUfqw+/VPuaYY3TppZeGjgOgCzFVEqE333xTlZWV2rt3rx599FH1798/dCQAXYjijkxr\na6suu+wy7dixQ4888ohOPfXU0JEAdDGmSiJTUFCgGTNmaMmSJRozZkzoOAACYMQdkZ07d+rEE0/U\n5ZdfHjoKgIAYcUdi4cKF+sIXvqCNGzeGjgIgMIo7Ag8//LCuvPJKjRkzRmeeeWboOAACo7jz3JYt\nWzRlyhQNGzZM9913n4qKikJHAhBY4uI2s0Ize8bMHsxlILRpbKzVQw+doPHjz1Bp6V4tXfot9enT\nJ3QsAHmgM4uTcyRtl/SZHGVBWmNjrerra1RSskeVldKYMS36+9//S42NfVVePi10PACBJRpxm9kA\nSRdLWpjbOJCk+vrvqbFxjwoKpBkzpEGDpNbWPWpouCF0NAB5IOlUyS8lXSeptaMDzKzGzOrMrK65\nuTkr4Xoid9ctt+zUzJnSu+9+9Ll9+3aGCQUgr2QsbjObKKnJ3Tcd7jh3X+DuKXdPlZWVZS1gT3PT\nTTdp3Tpp0iTpqKM++lxx8YlhQgHIK0lG3GdLqjKzlyQtlzTWzO7JaaoeasmSJfrhD3+o6upzdcUV\nJR95rqCgtwYPviVQMgD5JGNxu/v17j7A3QdKqpb0mLuzdS/LnnzySdXU1Gj8+PG6++5HNXTof6u4\n+CRJpuLik1RRsYCFSQCS2PKeN4YPH66ZM2fq5ptvVlFRkcrLp1HUANpl7p71N02lUl5XV5f19+2O\nXnvtNfXp00elpaWhowAIyMw2uXsqybHsnAzo7bff1le+8hVVVVUpF/+AAuiemCoJ5IMPPtCUKVO0\nbds2PfTQQzKz0JEARILiDsDddeWVV2rdunVavHixLrzwwtCRAESEqZIAbrvtNi1evFjz58/X9OnT\nQ8cBEBlG3AFMnjxZb731lm666abQUQBEiBF3F9qxY4fcXSeffLJuvvlm5rUBfCIUdxfZunWrUqmU\n5s+fHzoKgMhR3F3g1Vdf1UUXXaTS0lLV1NSEjgMgcsxx59g777yjiy++WG+++aaefPJJnXDCCaEj\nAYgcxZ1jl19+uZ599lk9+OCDGj58eOg4ALoBijvH5syZo69+9auaMGFC6CgAugmKO0eef/55DR06\nVGPHjg0dBUA3w+JkDtTW1uqUU07RQw89FDoKgG6I4s6y9evXa/r06TrvvPN0wQUXhI4DoBuiuLNo\n27Ztmjx5soYMGaKVK1equLg4dCQA3RDFnSXvvPOOLrroIpWUlGjt2rXq27dv6EgAuikWJ7OkT58+\nuvbaazVq1CiddNJJoeMA6MYo7k/pwIEDevHFF1VRUaFZs2aFjgOgB8g4VWJmR5rZU2b2ZzN7zsx+\n0BXBYuDumj17tlKplHbt2hU6DoAeIskc9z5JY939dEnDJU0ws1G5jRWHn/3sZ/rNb36jWbNmacCA\nAaHjAOghMk6VeNvFEN9N3y1K//T4CyQuX75c8+bNU3V1tX70ox+FjgOgB0l0VomZFZrZFklNkv7o\n7hvbOabGzOrMrK65uTnbOfPKM888o29+85saPXq0li5dqoICTs4B0HUSNY67t7j7cEkDJI0ws9Pa\nOWaBu6fcPVVWVpbtnHnltNNO07x587R69WrO1QbQ5Tp1Vom7v2Vm6yVNkLQ1N5HyV2Njo8xM/fr1\n0w9+wBotgDCSnFVSZmZHp2+XSBov6flcB8s37733niZOnKhx48appaUldBwAPViSEXd/Sb8zs0K1\nFf197v5gbmPll5aWFk2dOlWbN2/WqlWrVFhYGDoSgB4syVkl/yfpjC7IkpfcXXPmzNGaNWv061//\nWlVVVaEjAejhOB0ig4ULF+qOO+7Q3Llz2RkJIC+w5T2DyZMna/fu3fr+978fOgoASGLE3aFt27Zp\n//79OvbYYzV//nzO1QaQN2ijdrzwwgsaPXq0vvOd74SOAgAfQ3EfoqmpSZWVlSosLNT1118fOg4A\nfAxz3AfZs2ePqqqqtHv3bq1fv14nn3xy6EgA8DEU90Fmzpypp556SitXrtTIkSNDxwGAdlHcB5k7\nd67OO+88XXrppaGjAECHKG5Jmzdv1hlnnKFhw4Zp2LBhoeMAwGH1+MXJlStXKpVKaenSpaGjAEAi\nPbq4N2zYoGnTpmnUqFGqrq4OHQcAEumxxb1jxw5VVVVpwIAB+v3vf6+SkpLQkQAgkR5Z3Pv379fE\niRPl7lq7dq26+4UfAHQvPXJx8ogjjtCPf/xjlZeXa8iQIaHjAECn9Mjiltq+PAoAYtQjp0oAIGYU\nNwBEhuIGgMgkuVjwCWa23sy2mdlzZjanK4LlSmNjrTZsGKjHHy/Qhg0D1dhYGzoSAHRKksXJA5Ku\ncffNZtZH0iYz+6O7b8txtqxrbKxVfX2NWlv3SJL27XtZ9fU1kqTy8mkhowFAYhlH3O6+2903p2+/\nI2m7pONzHSwXGhpu+Edpf6i1dY8aGm4IlAgAOq9Tc9xmNlBtV3zf2M5zNWZWZ2Z1zc3N2UmXZfv2\n7ezU4wCQjxIXt5kdJekBSVe7+9uHPu/uC9w95e6pfN2JWFx8YqceB4B8lKi4zaxIbaVd6+4rcxsp\ndwYPvkUFBb0/8lhBQW8NHnxLoEQA0HlJzioxSYskbXf3X+Q+Uu6Ul09TRcUCFRefJMlUXHySKioW\nsDAJICpJzio5W9LXJT1rZlvSj33P3f+Qu1i5U14+jaIGELWMxe3uf5JkXZAFAJAAOycBIDIUNwBE\nhuIGgMhQ3AAQGYobACJDcQNAZChuAIgMxQ0AkaG4ASAyFDcARIbiBoDIUNwAEBmKGwAiQ3EDQGQo\nbgCIDMUNAJGhuAEgMhQ3AEQmycWCF5tZk5lt7YpAAIDDSzLiXippQo5zAAASyljc7v6EpDe6IAsA\nIAHmuAEgMlkrbjOrMbM6M6trbm7O1tsCAA6RteJ29wXunnL3VFlZWbbeFgBwCKZKACAySU4HXCZp\ng6QKM9tlZt/KfSwAQEd6ZTrA3ad2RRAAQDJMlQBAZChuAIgMxQ0AkaG4ASAyFDcARIbiBoDIUNwA\nEBmKGwAiQ3EDQGQobgCIDMUNAJGhuAEgMhQ3AESG4gaAyFDcABAZihsAIkNxA0BkKG4AiAzFDQCR\nSVTcZjbBzOrNbIeZzct1KABAx5Jc5b1Q0h2SKiWdImmqmZ2S62AAgPYlGXGPkLTD3Rvcfb+k5ZIu\nyW0sAEBHeiU45nhJfz3o/i5JIw89yMxqJNWk7+4zs62fPl5OHSvp9dAhEiBndpEzu8iZPRVJD0xS\n3Im4+wJJCyTJzOrcPZWt986FGDJK5Mw2cmYXObPHzOqSHptkquQVSSccdH9A+jEAQABJivtpSf9k\nZoPM7AhJ1ZL+J7exAAAdyThV4u4HzOw/JD0iqVDSYnd/LsPLFmQjXI7FkFEiZ7aRM7vImT2JM5q7\n5zIIACDL2DkJAJGhuAEgMlkt7hi2xpvZYjNryvfzzM3sBDNbb2bbzOw5M5sTOlN7zOxIM3vKzP6c\nzvmD0Jk6YmaFZvaMmT0YOktHzOwlM3vWzLZ05vSwrmZmR5vZCjN73sy2m9lZoTMdyswq0n+OH/68\nbWZXh87VHjP7bvrzs9XMlpnZkYc9Pltz3Omt8S9IGq+2TTpPS5rq7tuy8j/IEjM7V9K7ku5299NC\n5+mImfWX1N/dN5tZH0mbJF2ah3+eJqnU3d81syJJf5I0x93/N3C0jzGz/5SUkvQZd58YOk97zOwl\nSSl3z+vNImb2O0lPuvvC9Nlmvd39rdC5OpLup1ckjXT3l0PnOZiZHa+2z80p7r7XzO6T9Ad3X9rR\na7I54o5ia7y7PyHpjdA5MnH33e6+OX37HUnb1baLNa94m3fTd4vSP3m34m1mAyRdLGlh6CyxM7PP\nSjpX0iJJcvf9+VzaaeMkvZhvpX2QXpJKzKyXpN6SXj3cwdks7va2xudd0cTIzAZKOkPSxrBJ2pee\ngtgiqUnSH909H3P+UtJ1klpDB8nAJa0zs03pr5HIR4MkNUtakp56WmhmpaFDZVAtaVnoEO1x91ck\n/VzSTkm7Jf3d3dcd7jUsTuY5MztK0gOSrnb3t0PnaY+7t7j7cLXtqh1hZnk1BWVmEyU1ufum0FkS\nOMfdz1Tbt3HOSk/t5Zteks6UdJe7nyHpPUl5uaYlSempnCpJ94fO0h4z66u22YlBko6TVGpmlx/u\nNdksbrbGZ1l6zvgBSbXuvjJ0nkzSvy6vlzQhdJZDnC2pKj1/vFzSWDO7J2yk9qVHX3L3Jkmr1DYF\nmW92Sdp10G9WK9RW5PmqUtJmd28MHaQDF0j6i7s3u/sHklZK+vLhXpDN4mZrfBalF/0WSdru7r8I\nnacjZlZmZkenb5eobXH6+bCpPsrdr3f3Ae4+UG1/Lx9z98OOaEIws9L0QrTSUw8XSsq7s5/c/TVJ\nfzWzD7/NbpykvFo0P8RU5ek0SdpOSaPMrHf6cz9ObWtaHcrmtwN+kq3xXc7Mlkk6X9KxZrZL0o3u\nvihsqnadLenrkp5Nzx9L0vfc/Q8BM7Wnv6TfpVftCyTd5+55e7pdniuXtKrts6teku5194fDRurQ\nbEm16UFag6TpgfO0K/0P4HhJ/x46S0fcfaOZrZC0WdIBSc8ow/Z3trwDQGRYnASAyFDcABAZihsA\nIkNxA0BkKG4AiAzFDQCRobgBIDL/DyzH9iS46ug+AAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "fw9W7yzi-JPQ",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment