Skip to content

Instantly share code, notes, and snippets.

@RobotOptimist
Last active November 22, 2020 19:56
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save RobotOptimist/1bfd719dc621af45a0e633ffa7ecb9ec to your computer and use it in GitHub Desktop.
MLNET Simple Regression
FROM jmacivor/dotnet-binder:0.1.1
ARG NB_USER=jovyan
ARG NB_UID=1000
USER $NB_USER
ENV HOME=/home/$NB_USER
WORKDIR $HOME
COPY ml_net_simple_regression.ipynb $HOME/ml_net_simple_regression.ipynb
COPY Salary_Data.csv $HOME/Salary_Data.csv
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MLNET Simple Regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import the Libraries"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\r\n",
"<div>\r\n",
" <div id='dotnet-interactive-this-cell-65284.Microsoft.DotNet.Interactive.Http.HttpPort' style='display: none'>\r\n",
" The below script needs to be able to find the current output cell; this is an easy method to get it.\r\n",
" </div>\r\n",
" <script type='text/javascript'>\r\n",
"async function probeAddresses(probingAddresses) {\r\n",
" function timeout(ms, promise) {\r\n",
" return new Promise(function (resolve, reject) {\r\n",
" setTimeout(function () {\r\n",
" reject(new Error('timeout'))\r\n",
" }, ms)\r\n",
" promise.then(resolve, reject)\r\n",
" })\r\n",
" }\r\n",
"\r\n",
" if (Array.isArray(probingAddresses)) {\r\n",
" for (let i = 0; i < probingAddresses.length; i++) {\r\n",
"\r\n",
" let rootUrl = probingAddresses[i];\r\n",
"\r\n",
" if (!rootUrl.endsWith('/')) {\r\n",
" rootUrl = `${rootUrl}/`;\r\n",
" }\r\n",
"\r\n",
" try {\r\n",
" let response = await timeout(1000, fetch(`${rootUrl}discovery`, {\r\n",
" method: 'POST',\r\n",
" cache: 'no-cache',\r\n",
" mode: 'cors',\r\n",
" timeout: 1000,\r\n",
" headers: {\r\n",
" 'Content-Type': 'text/plain'\r\n",
" },\r\n",
" body: probingAddresses[i]\r\n",
" }));\r\n",
"\r\n",
" if (response.status == 200) {\r\n",
" return rootUrl;\r\n",
" }\r\n",
" }\r\n",
" catch (e) { }\r\n",
" }\r\n",
" }\r\n",
"}\r\n",
"\r\n",
"function loadDotnetInteractiveApi() {\r\n",
" probeAddresses([\"http://10.20.201.1:1000/\", \"http://172.30.176.1:1000/\", \"http://192.168.0.194:1000/\", \"http://127.0.0.1:1000/\"])\r\n",
" .then((root) => {\r\n",
" // use probing to find host url and api resources\r\n",
" // load interactive helpers and language services\r\n",
" let dotnetInteractiveRequire = require.config({\r\n",
" context: '65284.Microsoft.DotNet.Interactive.Http.HttpPort',\r\n",
" paths:\r\n",
" {\r\n",
" 'dotnet-interactive': `${root}resources`\r\n",
" }\r\n",
" }) || require;\r\n",
"\r\n",
" window.dotnetInteractiveRequire = dotnetInteractiveRequire;\r\n",
"\r\n",
" window.configureRequireFromExtension = function(extensionName, extensionCacheBuster) {\r\n",
" let paths = {};\r\n",
" paths[extensionName] = `${root}extensions/${extensionName}/resources/`;\r\n",
" \r\n",
" let internalRequire = require.config({\r\n",
" context: extensionCacheBuster,\r\n",
" paths: paths,\r\n",
" urlArgs: `cacheBuster=${extensionCacheBuster}`\r\n",
" }) || require;\r\n",
"\r\n",
" return internalRequire\r\n",
" };\r\n",
" \r\n",
" dotnetInteractiveRequire([\r\n",
" 'dotnet-interactive/dotnet-interactive'\r\n",
" ],\r\n",
" function (dotnet) {\r\n",
" dotnet.init(window);\r\n",
" },\r\n",
" function (error) {\r\n",
" console.log(error);\r\n",
" }\r\n",
" );\r\n",
" })\r\n",
" .catch(error => {console.log(error);});\r\n",
" }\r\n",
"\r\n",
"// ensure `require` is available globally\r\n",
"if ((typeof(require) !== typeof(Function)) || (typeof(require.config) !== typeof(Function))) {\r\n",
" let require_script = document.createElement('script');\r\n",
" require_script.setAttribute('src', 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js');\r\n",
" require_script.setAttribute('type', 'text/javascript');\r\n",
" \r\n",
" \r\n",
" require_script.onload = function() {\r\n",
" loadDotnetInteractiveApi();\r\n",
" };\r\n",
"\r\n",
" document.getElementsByTagName('head')[0].appendChild(require_script);\r\n",
"}\r\n",
"else {\r\n",
" loadDotnetInteractiveApi();\r\n",
"}\r\n",
"\r\n",
" </script>\r\n",
"</div>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"Installed package Microsoft.ML.Mkl.Components version 1.5.0"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"Installed package Microsoft.ML version 1.5.0"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"Installed package XPlot.Plotly version 2.0.0"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"// ML.NET Nuget packages installation\n",
"#r \"nuget:Microsoft.ML,1.5.0\"\n",
"#r \"nuget:Microsoft.ML.Mkl.Components,1.5.0\"\n",
"//Install XPlot package\n",
"#r \"nuget:XPlot.Plotly,2.0.0\" \n",
"using Microsoft.ML; \n",
"using Microsoft.ML.Data;\n",
"using XPlot.Plotly;\n",
"using System;\n",
"using System.Linq;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Declare the Data classes"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"public class ModelInput\n",
"{\n",
" [ColumnName(\"YearsExperience\"), LoadColumn(0)]\n",
" public float YearsExperience { get; set; }\n",
"\n",
" [ColumnName(\"Salary\"), LoadColumn(1)]\n",
" public float Salary { get; set; }\n",
"}\n",
"\n",
"public class ModelOutput\n",
"{ \n",
" public float Score { get; set; }\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the Data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"var mlContext = new MLContext(seed: 1);\n",
"\n",
"IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(\n",
" path: \"./Salary_Data.csv\",\n",
" hasHeader: true,\n",
" separatorChar: ',',\n",
" allowQuoting: true,\n",
" allowSparse: false);\n",
"\n",
"var split = mlContext.Data.TrainTestSplit(trainingDataView, testFraction: 0.2);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare Data For XPlot"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div id=\"76074f94-c37e-4b23-8eda-5cec54c226ca\" style=\"width: 900px; height: 500px;\"></div>\r\n",
"<script type=\"text/javascript\">\r\n",
"\r\n",
"var renderPlotly = function() {\r\n",
" var xplotRequire = require.config({context:'xplot-3.0.1',paths:{plotly:'https://cdn.plot.ly/plotly-1.49.2.min'}}) || require;\r\n",
" xplotRequire(['plotly'], function(Plotly) {\r\n",
"\n",
" var data = [{\"type\":\"scatter\",\"x\":[1.1,1.3,2.0,2.2,2.9,3.0,3.2,3.2,3.7,3.9,4.0,4.5,4.9,5.1,5.3,5.9,6.0,6.8,7.1,7.9,8.2,8.7,9.0,9.5,9.6,10.5],\"y\":[39343.0,46205.0,43525.0,39891.0,56642.0,60150.0,54445.0,64445.0,57189.0,63218.0,55794.0,61111.0,67938.0,66029.0,83088.0,81363.0,93940.0,91738.0,98273.0,101302.0,113812.0,109431.0,105582.0,116969.0,112635.0,121872.0],\"mode\":\"markers\"}];\n",
" var layout = {\"title\":\"Years Vs Salary\"};\n",
" Plotly.newPlot('76074f94-c37e-4b23-8eda-5cec54c226ca', data, layout);\n",
" \r\n",
"});\r\n",
"};\r\n",
"// ensure `require` is available globally\r\n",
"if ((typeof(require) !== typeof(Function)) || (typeof(require.config) !== typeof(Function))) {\r\n",
" let require_script = document.createElement('script');\r\n",
" require_script.setAttribute('src', 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js');\r\n",
" require_script.setAttribute('type', 'text/javascript');\r\n",
" \r\n",
" \r\n",
" require_script.onload = function() {\r\n",
" renderPlotly();\r\n",
" };\r\n",
"\r\n",
" document.getElementsByTagName('head')[0].appendChild(require_script);\r\n",
"}\r\n",
"else {\r\n",
" renderPlotly();\r\n",
"}\r\n",
"\r\n",
"</script>\r\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"var years = split.TrainSet.GetColumn<float>(\"YearsExperience\").ToArray();\n",
"var salary = split.TrainSet.GetColumn<float>(\"Salary\").ToArray();\n",
"\n",
"var yearsChart = Chart.Plot(new Graph.Scatter\n",
"{ \n",
" x = years,\n",
" y = salary,\n",
" mode = \"markers\"\n",
"});\n",
"\n",
"yearsChart.WithTitle(\"Years Vs Salary\");\n",
"display(yearsChart);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create The Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"(1,27): warning CS1701: Assuming assembly reference 'Microsoft.AspNetCore.Html.Abstractions, Version=2.2.0.0, Culture=neutral, PublicKeyToken=adb9793829ddae60' used by 'Microsoft.DotNet.Interactive.Formatting' matches identity 'Microsoft.AspNetCore.Html.Abstractions, Version=3.1.9.0, Culture=neutral, PublicKeyToken=adb9793829ddae60' of 'Microsoft.AspNetCore.Html.Abstractions', you may need to supply runtime policy\n",
"\n"
]
}
],
"source": [
"var dataProcessPipeline = mlContext.Transforms.Concatenate(\"Features\", new[] { \"YearsExperience\" })\n",
" .Append(mlContext.Transforms.NormalizeMinMax(\"Features\", \"Features\"));\n",
"// Set the training algorithm \n",
"var trainer = mlContext.Regression.Trainers.Ols(labelColumnName: \"Salary\", featureColumnName: \"Features\");\n",
"\n",
"// Build training pipeline\n",
"IEstimator<ITransformer> trainingPipeline = dataProcessPipeline.Append(trainer);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the Model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"ITransformer mlModel = trainingPipeline.Fit(split.TrainSet);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate the pipline"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"*************************************************************************************************************\n",
"* Metrics for Regression model \n",
"*------------------------------------------------------------------------------------------------------------\n",
"* Average L1 Loss: 4812.021 \n",
"* Average L2 Loss: 34218924.822 \n",
"* Average RMS: 5821.137 \n",
"* Average Loss Function: 34218925.104 \n",
"* Average R-squared: 0.942 \n",
"*************************************************************************************************************\n"
]
}
],
"source": [
"var crossValidationResults = mlContext.Regression.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: \"Salary\");\n",
"var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);\n",
"var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError);\n",
"var RMS = crossValidationResults.Select(r => r.Metrics.RootMeanSquaredError);\n",
"var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction);\n",
"var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);\n",
"\n",
"Console.WriteLine($\"*************************************************************************************************************\");\n",
"Console.WriteLine($\"* Metrics for Regression model \");\n",
"Console.WriteLine($\"*------------------------------------------------------------------------------------------------------------\");\n",
"Console.WriteLine($\"* Average L1 Loss: {L1.Average():0.###} \");\n",
"Console.WriteLine($\"* Average L2 Loss: {L2.Average():0.###} \");\n",
"Console.WriteLine($\"* Average RMS: {RMS.Average():0.###} \");\n",
"Console.WriteLine($\"* Average Loss Function: {lossFunction.Average():0.###} \");\n",
"Console.WriteLine($\"* Average R-squared: {R2.Average():0.###} \");\n",
"Console.WriteLine($\"*************************************************************************************************************\");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test the Model"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Predicted Salary: 41099.438\n",
"\n",
"\n",
"\n",
"\n",
"Predicted Salary: 64416.47\n",
"\n",
"\n",
"\n",
"\n",
"Predicted Salary: 65349.15\n",
"\n",
"\n",
"\n",
"\n",
"Predicted Salary: 123175.375\n",
"\n",
"\n"
]
},
{
"data": {
"text/html": [
"<div id=\"ca381dff-4ee3-49bb-a523-625191663e2c\" style=\"width: 700px; height: 400px;\"></div>\r\n",
"<script type=\"text/javascript\">\r\n",
"\r\n",
"var renderPlotly = function() {\r\n",
" var xplotRequire = require.config({context:'xplot-3.0.1',paths:{plotly:'https://cdn.plot.ly/plotly-1.49.2.min'}}) || require;\r\n",
" xplotRequire(['plotly'], function(Plotly) {\r\n",
"\n",
" var data = [{\"type\":\"scatter\",\"x\":[1.5,4.0,4.1,10.3],\"y\":[37731.0,56957.0,57081.0,122391.0],\"mode\":\"markers\",\"name\":\"Actual\"},{\"type\":\"scatter\",\"x\":[1.5,4.0,4.1,10.3],\"y\":[41099.438,64416.47,65349.15,123175.375],\"mode\":\"line\",\"name\":\"Predicted\"}];\n",
" var layout = {\"title\":\"\",\"showlegend\":true,\"xaxis\":{\"title\":\"Years\",\"_isSubplotObj\":true},\"yaxis\":{\"title\":\"Salary\",\"_isSubplotObj\":true},\"barmode\":\"group\"};\n",
" Plotly.newPlot('ca381dff-4ee3-49bb-a523-625191663e2c', data, layout);\n",
" \r\n",
"});\r\n",
"};\r\n",
"// ensure `require` is available globally\r\n",
"if ((typeof(require) !== typeof(Function)) || (typeof(require.config) !== typeof(Function))) {\r\n",
" let require_script = document.createElement('script');\r\n",
" require_script.setAttribute('src', 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js');\r\n",
" require_script.setAttribute('type', 'text/javascript');\r\n",
" \r\n",
" \r\n",
" require_script.onload = function() {\r\n",
" renderPlotly();\r\n",
" };\r\n",
"\r\n",
" document.getElementsByTagName('head')[0].appendChild(require_script);\r\n",
"}\r\n",
"else {\r\n",
" renderPlotly();\r\n",
"}\r\n",
"\r\n",
"</script>\r\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);\n",
"var enumerableTestSet = mlContext.Data.CreateEnumerable<ModelInput>(split.TestSet, reuseRowObject: false)\n",
" .Select(ts => new ModelInput() { YearsExperience = ts.YearsExperience }); \n",
"var preductionResults = enumerableTestSet.Select(ts => predEngine.Predict(ts));\n",
"\n",
"foreach (var predictionResult in preductionResults)\n",
"{\n",
" Console.WriteLine($\"\\n\\nPredicted Salary: {predictionResult.Score}\\n\\n\");\n",
"}\n",
"\n",
"var testYears = enumerableTestSet.Select(ts => ts.YearsExperience).ToArray();\n",
"var actualSalaries = mlContext.Data.CreateEnumerable<ModelInput>(split.TestSet, reuseRowObject: false).Select(ts => ts.Salary).ToArray();\n",
"var predictedSalaries = preductionResults.Select(r => r.Score).ToArray();\n",
"\n",
"var actual = new Graph.Scatter()\n",
"{\n",
" x = testYears,\n",
" y = actualSalaries,\n",
" mode = \"markers\",\n",
" name = \"Actual\"\n",
"};\n",
"\n",
"var predicted = new Graph.Scatter()\n",
"{\n",
" x = testYears,\n",
" y = predictedSalaries,\n",
" mode = \"line\",\n",
" name = \"Predicted\"\n",
"};\n",
"\n",
"var chart = Chart.Plot(new[] {actual, predicted});\n",
"var layout = new Layout.Layout(){barmode = \"group\", title=\"\"};\n",
"chart.WithLayout(layout);\n",
"chart.WithXTitle(\"Years\");\n",
"chart.WithYTitle(\"Salary\");\n",
"chart.WithLegend(true);\n",
"chart.Width = 700;\n",
"chart.Height = 400;\n",
"\n",
"display(chart);\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".NET (C#)",
"language": "C#",
"name": ".net-csharp"
},
"language_info": {
"file_extension": ".cs",
"mimetype": "text/x-csharp",
"name": "C#",
"pygments_lexer": "csharp",
"version": "8.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
YearsExperience Salary
1.1 39343.00
1.3 46205.00
1.5 37731.00
2.0 43525.00
2.2 39891.00
2.9 56642.00
3.0 60150.00
3.2 54445.00
3.2 64445.00
3.7 57189.00
3.9 63218.00
4.0 55794.00
4.0 56957.00
4.1 57081.00
4.5 61111.00
4.9 67938.00
5.1 66029.00
5.3 83088.00
5.9 81363.00
6.0 93940.00
6.8 91738.00
7.1 98273.00
7.9 101302.00
8.2 113812.00
8.7 109431.00
9.0 105582.00
9.5 116969.00
9.6 112635.00
10.3 122391.00
10.5 121872.00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment