Skip to content

Instantly share code, notes, and snippets.

@vigsterkr
Created February 1, 2020 09:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vigsterkr/48a5f0523528bbab85bec04464ca2b6c to your computer and use it in GitHub Desktop.
Save vigsterkr/48a5f0523528bbab85bec04464ca2b6c to your computer and use it in GitHub Desktop.
ShogunML with SciRuby stack
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"false"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"require 'daru'"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<b> Daru::DataFrame(150x5) </b>\n",
"<table>\n",
" <thead>\n",
" \n",
" <tr>\n",
" <th></th>\n",
" \n",
" <th>sepal_length</th>\n",
" \n",
" <th>sepal_width</th>\n",
" \n",
" <th>petal_length</th>\n",
" \n",
" <th>petal_width</th>\n",
" \n",
" <th>species</th>\n",
" \n",
" </tr>\n",
" \n",
"</thead>\n",
" <tbody>\n",
" \n",
" <tr>\n",
" <td>0</td>\n",
" \n",
" <td>5.1</td>\n",
" \n",
" <td>3.5</td>\n",
" \n",
" <td>1.4</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>1</td>\n",
" \n",
" <td>4.9</td>\n",
" \n",
" <td>3.0</td>\n",
" \n",
" <td>1.4</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>2</td>\n",
" \n",
" <td>4.7</td>\n",
" \n",
" <td>3.2</td>\n",
" \n",
" <td>1.3</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>3</td>\n",
" \n",
" <td>4.6</td>\n",
" \n",
" <td>3.1</td>\n",
" \n",
" <td>1.5</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>4</td>\n",
" \n",
" <td>5.0</td>\n",
" \n",
" <td>3.6</td>\n",
" \n",
" <td>1.4</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>5</td>\n",
" \n",
" <td>5.4</td>\n",
" \n",
" <td>3.9</td>\n",
" \n",
" <td>1.7</td>\n",
" \n",
" <td>0.4</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>6</td>\n",
" \n",
" <td>4.6</td>\n",
" \n",
" <td>3.4</td>\n",
" \n",
" <td>1.4</td>\n",
" \n",
" <td>0.3</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>7</td>\n",
" \n",
" <td>5.0</td>\n",
" \n",
" <td>3.4</td>\n",
" \n",
" <td>1.5</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>8</td>\n",
" \n",
" <td>4.4</td>\n",
" \n",
" <td>2.9</td>\n",
" \n",
" <td>1.4</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>9</td>\n",
" \n",
" <td>4.9</td>\n",
" \n",
" <td>3.1</td>\n",
" \n",
" <td>1.5</td>\n",
" \n",
" <td>0.1</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>10</td>\n",
" \n",
" <td>5.4</td>\n",
" \n",
" <td>3.7</td>\n",
" \n",
" <td>1.5</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>11</td>\n",
" \n",
" <td>4.8</td>\n",
" \n",
" <td>3.4</td>\n",
" \n",
" <td>1.6</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>12</td>\n",
" \n",
" <td>4.8</td>\n",
" \n",
" <td>3.0</td>\n",
" \n",
" <td>1.4</td>\n",
" \n",
" <td>0.1</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>13</td>\n",
" \n",
" <td>4.3</td>\n",
" \n",
" <td>3.0</td>\n",
" \n",
" <td>1.1</td>\n",
" \n",
" <td>0.1</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>14</td>\n",
" \n",
" <td>5.8</td>\n",
" \n",
" <td>4.0</td>\n",
" \n",
" <td>1.2</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>15</td>\n",
" \n",
" <td>5.7</td>\n",
" \n",
" <td>4.4</td>\n",
" \n",
" <td>1.5</td>\n",
" \n",
" <td>0.4</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>16</td>\n",
" \n",
" <td>5.4</td>\n",
" \n",
" <td>3.9</td>\n",
" \n",
" <td>1.3</td>\n",
" \n",
" <td>0.4</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>17</td>\n",
" \n",
" <td>5.1</td>\n",
" \n",
" <td>3.5</td>\n",
" \n",
" <td>1.4</td>\n",
" \n",
" <td>0.3</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>18</td>\n",
" \n",
" <td>5.7</td>\n",
" \n",
" <td>3.8</td>\n",
" \n",
" <td>1.7</td>\n",
" \n",
" <td>0.3</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>19</td>\n",
" \n",
" <td>5.1</td>\n",
" \n",
" <td>3.8</td>\n",
" \n",
" <td>1.5</td>\n",
" \n",
" <td>0.3</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>20</td>\n",
" \n",
" <td>5.4</td>\n",
" \n",
" <td>3.4</td>\n",
" \n",
" <td>1.7</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>21</td>\n",
" \n",
" <td>5.1</td>\n",
" \n",
" <td>3.7</td>\n",
" \n",
" <td>1.5</td>\n",
" \n",
" <td>0.4</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>22</td>\n",
" \n",
" <td>4.6</td>\n",
" \n",
" <td>3.6</td>\n",
" \n",
" <td>1.0</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>23</td>\n",
" \n",
" <td>5.1</td>\n",
" \n",
" <td>3.3</td>\n",
" \n",
" <td>1.7</td>\n",
" \n",
" <td>0.5</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>24</td>\n",
" \n",
" <td>4.8</td>\n",
" \n",
" <td>3.4</td>\n",
" \n",
" <td>1.9</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>25</td>\n",
" \n",
" <td>5.0</td>\n",
" \n",
" <td>3.0</td>\n",
" \n",
" <td>1.6</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>26</td>\n",
" \n",
" <td>5.0</td>\n",
" \n",
" <td>3.4</td>\n",
" \n",
" <td>1.6</td>\n",
" \n",
" <td>0.4</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>27</td>\n",
" \n",
" <td>5.2</td>\n",
" \n",
" <td>3.5</td>\n",
" \n",
" <td>1.5</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>28</td>\n",
" \n",
" <td>5.2</td>\n",
" \n",
" <td>3.4</td>\n",
" \n",
" <td>1.4</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
" <tr>\n",
" <td>29</td>\n",
" \n",
" <td>4.7</td>\n",
" \n",
" <td>3.2</td>\n",
" \n",
" <td>1.6</td>\n",
" \n",
" <td>0.2</td>\n",
" \n",
" <td>setosa</td>\n",
" \n",
" </tr>\n",
" \n",
"\n",
" \n",
" <tr>\n",
" \n",
" <td>...</td>\n",
" \n",
" <td>...</td>\n",
" \n",
" <td>...</td>\n",
" \n",
" <td>...</td>\n",
" \n",
" <td>...</td>\n",
" \n",
" <td>...</td>\n",
" \n",
" </tr>\n",
"\n",
" \n",
"\n",
" <tr>\n",
" <td>149</td>\n",
" \n",
" <td>5.9</td>\n",
" \n",
" <td>3.0</td>\n",
" \n",
" <td>5.1</td>\n",
" \n",
" <td>1.8</td>\n",
" \n",
" <td>virginica</td>\n",
" \n",
" </tr>\n",
" \n",
"</tbody>\n",
"</table>"
],
"text/plain": [
"#<Daru::DataFrame(150x5)>\n",
" sepal_leng sepal_widt petal_leng petal_widt species\n",
" 0 5.1 3.5 1.4 0.2 setosa\n",
" 1 4.9 3.0 1.4 0.2 setosa\n",
" 2 4.7 3.2 1.3 0.2 setosa\n",
" 3 4.6 3.1 1.5 0.2 setosa\n",
" 4 5.0 3.6 1.4 0.2 setosa\n",
" 5 5.4 3.9 1.7 0.4 setosa\n",
" 6 4.6 3.4 1.4 0.3 setosa\n",
" 7 5.0 3.4 1.5 0.2 setosa\n",
" 8 4.4 2.9 1.4 0.2 setosa\n",
" 9 4.9 3.1 1.5 0.1 setosa\n",
" 10 5.4 3.7 1.5 0.2 setosa\n",
" 11 4.8 3.4 1.6 0.2 setosa\n",
" 12 4.8 3.0 1.4 0.1 setosa\n",
" 13 4.3 3.0 1.1 0.1 setosa\n",
" 14 5.8 4.0 1.2 0.2 setosa\n",
" ... ... ... ... ... ..."
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = Daru::DataFrame.from_csv \"https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/639388c2cbc2120a14dcf466e85730eb8be498bb/iris.csv\"\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Shogun ML\n",
"\n",
"You need to compile shogun and the ruby interface first:\n",
"```\n",
"git clone https://github.com/shogun-toolbox/shogun.git\n",
"cd shogun\n",
"mkdir build\n",
"cd build\n",
"cmake -G\"Ninja\" -DINTERFACE_RUBY=ON ..\n",
"ninja\n",
"```\n",
"\n",
"once you've built it either you install the generated binaries with `ninja install` or simply just set `RUBYLIB` runtime environment before you start the jupyter notebook, for example while still in the `build` directory run the following command:\n",
"```\n",
"export RUBYLIB=$PWD/src/interfaces/ruby:$RUBYLIB\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"false"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"require 'shogun'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Prepare the data for the ShogunML model: `X` variables contain the features and `y` contains the labels."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/lib/gems/2.5.0/gems/nmatrix-0.2.4/lib/nmatrix/monkeys.rb:49: warning: constant ::Fixnum is deprecated\n",
"<main>: warning: already initialized constant X\n"
]
},
{
"data": {
"text/plain": [
"#<Shogun::Labels:0x0000564db87f5178 @__swigtype__=\"_p_std__shared_ptrT_shogun__Labels_t\">"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = Shogun::features(df['sepal_length','sepal_width', 'petal_length', 'petal_width'].to_nmatrix.transpose)\n",
"y = Shogun::labels(df.species.to_category.to_ints.to_ary)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create a OneVSOne multiclass classifier that uses LibLinear as a base binary classifier"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"classifier = Shogun::machine(\"MulticlassLibLinear\")\n",
"classifier.put(\"C\", 1.0)\n",
"classifier.put(\"labels\", y)\n",
"classifier.put(\"use_bias\", true)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train the model using the `X` features"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"#<Shogun::MulticlassLabels:0x0000564db8822ab0 @__swigtype__=\"_p_std__shared_ptrT_shogun__MulticlassLabels_t\">"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"classifier.train(X)\n",
"y_pred = classifier.apply_multiclass(X)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Measure the model's performance on the train data (note plz create a train/test split to actually measure the real performance of your model!)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.98"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eval = Shogun::evaluation(\"MulticlassAccuracy\")\n",
"accuracy = eval.evaluate(y, y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"binary_classifier = Shogun::machine(\"LibLinear\")\n",
"strategy = Shogun::multiclass_strategy(\"MulticlassOneVsRestStrategy\")\n",
"mc_classifier = Shogun::machine(\"LinearMulticlassMachine\")\n",
"mc_classifier.put(\"multiclass_strategy\", strategy)\n",
"mc_classifier.put(\"machine\", binary_classifier)\n",
"mc_classifier.put(\"labels\", y)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"true"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mc_classifier.train(X)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.94"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_mc_pred = mc_classifier.apply_multiclass(X)\n",
"accuracy = eval.evaluate(y, y_mc_pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Ruby 2.5.5",
"language": "ruby",
"name": "ruby"
},
"language_info": {
"file_extension": ".rb",
"mimetype": "application/x-ruby",
"name": "ruby",
"version": "2.5.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment