Last active
December 21, 2020 01:59
-
-
Save RobotOptimist/1b7d5fd7bd386e03cb83335a043176e9 to your computer and use it in GitHub Desktop.
ML.NET polynomial regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
FROM jmacivor/dotnet-binder:0.1.1 | |
ARG NB_USER=jovyan | |
ARG NB_UID=1000 | |
USER $NB_USER | |
ENV HOME=/home/$NB_USER | |
WORKDIR $HOME | |
COPY mlnet_polynomial_regression.ipynb $HOME/mlnet_polynomial_regression.ipynb | |
COPY Position_Salaries.csv $HOME/Position_Salaries.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Polynomial Regression" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Using Statements, Classes and Loading the Data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"#r \"nuget:Microsoft.ML,1.5.0\"\n", | |
"#r \"nuget:Microsoft.ML.FastTree,1.5.0\"\n", | |
"#r \"nuget:Microsoft.ML.Mkl.Components,1.5.0\"\n", | |
"//Install XPlot package\n", | |
"#r \"nuget:XPlot.Plotly,2.0.0\" \n", | |
"\n", | |
"using System;\n", | |
"using System.Collections.Generic;\n", | |
"using System.Linq;\n", | |
"using System.Composition;\n", | |
"using Microsoft.ML;\n", | |
"using Microsoft.ML.Data;\n", | |
"using Microsoft.ML.Trainers;\n", | |
"using Microsoft.ML.Transforms;\n", | |
"using XPlot.Plotly;" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Installed package XPlot.Plotly version 2.0.0" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Installed package Microsoft.ML.FastTree version 1.5.0" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Installed package Microsoft.ML version 1.5.0" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Installed package Microsoft.ML.Mkl.Components version 1.5.0" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"public class ModelInput\n", | |
"{\n", | |
"\n", | |
" [ColumnName(\"Position\"), LoadColumn(0)]\n", | |
" public string Position { get; set; }\n", | |
"\n", | |
" [ColumnName(\"Level\"), LoadColumn(1)]\n", | |
" public float Level { get; set; }\n", | |
"\n", | |
" [ColumnName(\"Salary\"), LoadColumn(2)]\n", | |
" public float Salary { get; set; }\n", | |
"}\n", | |
"\n", | |
"public class ModelOutput\n", | |
"{\n", | |
" public float Score { get; set; }\n", | |
"}" | |
], | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"var mlContext = new MLContext();\n", | |
"IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(\n", | |
" path: \"./Position_Salaries.csv\",\n", | |
" hasHeader: true,\n", | |
" separatorChar: ',',\n", | |
" allowQuoting: true,\n", | |
" allowSparse: false);" | |
], | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Create the Pipeline" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Create the transform classes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"public class TransformOutput\n", | |
"{\n", | |
" [VectorType(5)]\n", | |
" public float[] Features {get;set;} \n", | |
"\n", | |
" public float Salary { get; set; }\n", | |
"}" | |
], | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Create the custom transform" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"[CustomMappingFactoryAttribute(\"Features\")]\n", | |
"public class PolynomialFeatures : CustomMappingFactory<ModelInput, TransformOutput>\n", | |
"{\n", | |
"\n", | |
" private readonly int _degree;\n", | |
" public PolynomialFeatures(int degree)\n", | |
" {\n", | |
" _degree = degree;\n", | |
" }\n", | |
"\n", | |
" public void Transform(ModelInput input, TransformOutput output)\n", | |
" {\n", | |
" output.Features = Enumerable.Range(0, _degree + 1).Select(i => (float)(Math.Pow(input.Level, i))).ToArray();\n", | |
" output.Salary = input.Salary;\n", | |
" }\n", | |
"\n", | |
" public override Action<ModelInput, TransformOutput> GetMapping()\n", | |
" {\n", | |
" return Transform;\n", | |
" } \n", | |
"\n", | |
"}\n", | |
"\n", | |
"var polyFeaturesTest = new PolynomialFeatures(2);\n", | |
"var testInputs = Enumerable.Range(1, 10).Select(i => new ModelInput() {Level = i});\n", | |
"var testOutputs = testInputs.Select(ti => \n", | |
"{\n", | |
" var testOutput = new TransformOutput();\n", | |
" polyFeaturesTest.Transform(ti, testOutput);\n", | |
" return testOutput;\n", | |
"});\n", | |
"display(testOutputs)" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": "<table><thead><tr><th><i>index</i></th><th>Features</th><th>Salary</th></tr></thead><tbody><tr><td>0</td><td><div class=\"dni-plaintext\">[ 1, 1, 1 ]</div></td><td><div class=\"dni-plaintext\">0</div></td></tr><tr><td>1</td><td><div class=\"dni-plaintext\">[ 1, 2, 4 ]</div></td><td><div class=\"dni-plaintext\">0</div></td></tr><tr><td>2</td><td><div class=\"dni-plaintext\">[ 1, 3, 9 ]</div></td><td><div class=\"dni-plaintext\">0</div></td></tr><tr><td>3</td><td><div class=\"dni-plaintext\">[ 1, 4, 16 ]</div></td><td><div class=\"dni-plaintext\">0</div></td></tr><tr><td>4</td><td><div class=\"dni-plaintext\">[ 1, 5, 25 ]</div></td><td><div class=\"dni-plaintext\">0</div></td></tr><tr><td>5</td><td><div class=\"dni-plaintext\">[ 1, 6, 36 ]</div></td><td><div class=\"dni-plaintext\">0</div></td></tr><tr><td>6</td><td><div class=\"dni-plaintext\">[ 1, 7, 49 ]</div></td><td><div class=\"dni-plaintext\">0</div></td></tr><tr><td>7</td><td><div class=\"dni-plaintext\">[ 1, 8, 64 ]</div></td><td><div class=\"dni-plaintext\">0</div></td></tr><tr><td>8</td><td><div class=\"dni-plaintext\">[ 1, 9, 81 ]</div></td><td><div class=\"dni-plaintext\">0</div></td></tr><tr><td>9</td><td><div class=\"dni-plaintext\">[ 1, 10, 100 ]</div></td><td><div class=\"dni-plaintext\">0</div></td></tr></tbody></table>" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Create the pipeline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"// Data process configuration with pipeline data transformations \n", | |
"public class CustomMappingOutput { public bool IsUnder5 {get;set;}}\n", | |
"Action<ModelInput, CustomMappingOutput > mapping =\n", | |
" (input, output) => output.IsUnder5 = input.Level < 5;\n", | |
"\n", | |
"var polyFeatures = new PolynomialFeatures(4);\n", | |
"var dataProcessPipeline = mlContext.Transforms.CustomMapping<ModelInput, TransformOutput>(polyFeatures.GetMapping(), contractName: \"PolynomialFeatures\")\n", | |
" .Append(mlContext.Transforms.DropColumns(\"Position\", \"Level\"))\n", | |
" .Append(mlContext.Transforms.Concatenate(\"Features\", new[] {\"Features\"}));\n", | |
"\n", | |
" var trainer = mlContext.Regression.Trainers.Ols(featureColumnName: \"Features\", labelColumnName: \"Salary\");\n", | |
" var trainingPipeline = dataProcessPipeline.Append(trainer);" | |
], | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Train the model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"Console.WriteLine(\"=============== Training model ===============\");\n", | |
"\n", | |
"var transformedDataView = trainingPipeline.Fit(trainingDataView).Transform(trainingDataView);\n", | |
"var transformedData = mlContext.Data.CreateEnumerable<TransformOutput>(transformedDataView, reuseRowObject: false);\n", | |
"display(transformedData);\n", | |
"var model = trainingPipeline.Fit(trainingDataView);\n", | |
"\n", | |
"Console.WriteLine(\"=============== End of training process ===============\");" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "=============== Training model ===============\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": "<table><thead><tr><th><i>index</i></th><th>Features</th><th>Salary</th></tr></thead><tbody><tr><td>0</td><td><div class=\"dni-plaintext\">[ 1, 1, 1, 1, 1 ]</div></td><td><div class=\"dni-plaintext\">45000</div></td></tr><tr><td>1</td><td><div class=\"dni-plaintext\">[ 1, 2, 4, 8, 16 ]</div></td><td><div class=\"dni-plaintext\">50000</div></td></tr><tr><td>2</td><td><div class=\"dni-plaintext\">[ 1, 3, 9, 27, 81 ]</div></td><td><div class=\"dni-plaintext\">60000</div></td></tr><tr><td>3</td><td><div class=\"dni-plaintext\">[ 1, 4, 16, 64, 256 ]</div></td><td><div class=\"dni-plaintext\">80000</div></td></tr><tr><td>4</td><td><div class=\"dni-plaintext\">[ 1, 5, 25, 125, 625 ]</div></td><td><div class=\"dni-plaintext\">110000</div></td></tr><tr><td>5</td><td><div class=\"dni-plaintext\">[ 1, 6, 36, 216, 1296 ]</div></td><td><div class=\"dni-plaintext\">150000</div></td></tr><tr><td>6</td><td><div class=\"dni-plaintext\">[ 1, 7, 49, 343, 2401 ]</div></td><td><div class=\"dni-plaintext\">200000</div></td></tr><tr><td>7</td><td><div class=\"dni-plaintext\">[ 1, 8, 64, 512, 4096 ]</div></td><td><div class=\"dni-plaintext\">300000</div></td></tr><tr><td>8</td><td><div class=\"dni-plaintext\">[ 1, 9, 81, 729, 6561 ]</div></td><td><div class=\"dni-plaintext\">500000</div></td></tr><tr><td>9</td><td><div class=\"dni-plaintext\">[ 1, 10, 100, 1000, 10000 ]</div></td><td><div class=\"dni-plaintext\">1000000</div></td></tr></tbody></table>" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "=============== End of training process ===============\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Evaluate Trainer\n", | |
"This is plain garbage for polynomial linear regression, sadly. But here it is anyway for your entertainment." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"var crossValidationResults = mlContext.Regression.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: \"Salary\");\n", | |
"var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);\n", | |
"var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError);\n", | |
"var RMS = crossValidationResults.Select(r => r.Metrics.RootMeanSquaredError);\n", | |
"var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction);\n", | |
"var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);\n", | |
"\n", | |
"Console.WriteLine($\"*************************************************************************************************************\");\n", | |
"Console.WriteLine($\"* Metrics for Regression model \");\n", | |
"Console.WriteLine($\"*------------------------------------------------------------------------------------------------------------\");\n", | |
"Console.WriteLine($\"* Average L1 Loss: {L1.Average():0.###} \");\n", | |
"Console.WriteLine($\"* Average L2 Loss: {L2.Average():0.###} \");\n", | |
"Console.WriteLine($\"* Average RMS: {RMS.Average():0.###} \");\n", | |
"Console.WriteLine($\"* Average Loss Function: {lossFunction.Average():0.###} \");\n", | |
"Console.WriteLine($\"* Average R-squared: {R2.Average():0.###} \");\n", | |
"Console.WriteLine($\"*************************************************************************************************************\");\n", | |
"" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "*************************************************************************************************************\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "* Metrics for Regression model \r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "*------------------------------------------------------------------------------------------------------------\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "* Average L1 Loss: 28752.58 \r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "* Average L2 Loss: 2536167340.344 \r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "* Average RMS: 33389.23 \r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "* Average Loss Function: 2536167335.8 \r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "* Average R-squared: -∞ \r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "*************************************************************************************************************\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Show some results" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"//predict\n", | |
"var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(model);\n", | |
"var predModel = new ModelInput() { Level=10f};\n", | |
"display(predEngine.Predict(predModel));\n", | |
"\n", | |
"var predictions = Enumerable.Range(1,11).Select(i => predEngine.Predict(new ModelInput() { Level = i})).Select(o => o.Score);\n", | |
"var actualSalaries = transformedData.Select(td => td.Salary);\n", | |
"\n", | |
"var actual = new Graph.Scatter()\n", | |
"{\n", | |
" x = Enumerable.Range(1,11).ToArray(),\n", | |
" y = actualSalaries.ToArray(),\n", | |
" mode = \"markers\",\n", | |
" name = \"Actual\"\n", | |
"};\n", | |
"\n", | |
"var predicted = new Graph.Scatter()\n", | |
"{\n", | |
" x = Enumerable.Range(1,11).ToArray(),\n", | |
" y = predictions.ToArray(),\n", | |
" mode = \"line\",\n", | |
" name = \"Predicted\"\n", | |
"};\n", | |
"\n", | |
"var chart = Chart.Plot(new[] {actual, predicted});\n", | |
"var layout = new Layout.Layout(){barmode = \"group\", title=\"\"};\n", | |
"chart.WithLayout(layout);\n", | |
"chart.WithXTitle(\"Level\");\n", | |
"chart.WithYTitle(\"Salary\");\n", | |
"chart.WithLegend(true);\n", | |
"chart.Width = 700;\n", | |
"chart.Height = 400;\n", | |
"\n", | |
"display(chart);" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": "<table><thead><tr><th>Score</th></tr></thead><tbody><tr><td><div class=\"dni-plaintext\">988917.75</div></td></tr></tbody></table>" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": "<div id=\"e026095d-4c9b-441e-9de1-0ac943dc669d\" style=\"width: 700px; height: 400px;\"></div>\r\n<script type=\"text/javascript\">\r\n\r\nvar renderPlotly = function() {\r\n var xplotRequire = require.config({context:'xplot-3.0.1',paths:{plotly:'https://cdn.plot.ly/plotly-1.49.2.min'}}) || require;\r\n xplotRequire(['plotly'], function(Plotly) {\r\n\n var data = [{\"type\":\"scatter\",\"x\":[1,2,3,4,5,6,7,8,9,10,11],\"y\":[45000.0,50000.0,60000.0,80000.0,110000.0,150000.0,200000.0,300000.0,500000.0,1000000.0],\"mode\":\"markers\",\"name\":\"Actual\"},{\"type\":\"scatter\",\"x\":[1,2,3,4,5,6,7,8,9,10,11],\"y\":[53358.9,31757.156,58640.023,94633.22,121727.22,143277.1,184003.72,289992.22,528692.25,988917.75,1780850.8],\"mode\":\"line\",\"name\":\"Predicted\"}];\n var layout = {\"title\":\"\",\"showlegend\":true,\"xaxis\":{\"title\":\"Level\",\"_isSubplotObj\":true},\"yaxis\":{\"title\":\"Salary\",\"_isSubplotObj\":true},\"barmode\":\"group\"};\n Plotly.newPlot('e026095d-4c9b-441e-9de1-0ac943dc669d', data, layout);\n \r\n});\r\n};\r\n// ensure `require` is available globally\r\nif ((typeof(require) !== typeof(Function)) || (typeof(require.config) !== typeof(Function))) {\r\n let require_script = document.createElement('script');\r\n require_script.setAttribute('src', 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js');\r\n require_script.setAttribute('type', 'text/javascript');\r\n \r\n \r\n require_script.onload = function() {\r\n renderPlotly();\r\n };\r\n\r\n document.getElementsByTagName('head')[0].appendChild(require_script);\r\n}\r\nelse {\r\n renderPlotly();\r\n}\r\n\r\n</script>\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
} | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".NET (C#)", | |
"language": "C#", | |
"name": ".net-csharp" | |
}, | |
"language_info": { | |
"file_extension": ".cs", | |
"mimetype": "text/x-csharp", | |
"name": "C#", | |
"pygments_lexer": "csharp", | |
"version": "8.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Position | Level | Salary | |
---|---|---|---|
Business Analyst | 1 | 45000 | |
Junior Consultant | 2 | 50000 | |
Senior Consultant | 3 | 60000 | |
Manager | 4 | 80000 | |
Country Manager | 5 | 110000 | |
Region Manager | 6 | 150000 | |
Partner | 7 | 200000 | |
Senior Partner | 8 | 300000 | |
C-level | 9 | 500000 | |
CEO | 10 | 1000000 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment