Last active
December 6, 2020 12:58
-
-
Save RobotOptimist/1adeb410287b0dde2a8bbfd77d38b228 to your computer and use it in GitHub Desktop.
ML.NET Multiple Linear Regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
R&D Spend | Administration | Marketing Spend | State | Profit | |
---|---|---|---|---|---|
165349.2 | 136897.8 | 471784.1 | New York | 192261.83 | |
162597.7 | 151377.59 | 443898.53 | California | 191792.06 | |
153441.51 | 101145.55 | 407934.54 | Florida | 191050.39 | |
144372.41 | 118671.85 | 383199.62 | New York | 182901.99 | |
142107.34 | 91391.77 | 366168.42 | Florida | 166187.94 | |
131876.9 | 99814.71 | 362861.36 | New York | 156991.12 | |
134615.46 | 147198.87 | 127716.82 | California | 156122.51 | |
130298.13 | 145530.06 | 323876.68 | Florida | 155752.6 | |
120542.52 | 148718.95 | 311613.29 | New York | 152211.77 | |
123334.88 | 108679.17 | 304981.62 | California | 149759.96 | |
101913.08 | 110594.11 | 229160.95 | Florida | 146121.95 | |
100671.96 | 91790.61 | 249744.55 | California | 144259.4 | |
93863.75 | 127320.38 | 249839.44 | Florida | 141585.52 | |
91992.39 | 135495.07 | 252664.93 | California | 134307.35 | |
119943.24 | 156547.42 | 256512.92 | Florida | 132602.65 | |
114523.61 | 122616.84 | 261776.23 | New York | 129917.04 | |
78013.11 | 121597.55 | 264346.06 | California | 126992.93 | |
94657.16 | 145077.58 | 282574.31 | New York | 125370.37 | |
91749.16 | 114175.79 | 294919.57 | Florida | 124266.9 | |
86419.7 | 153514.11 | 0 | New York | 122776.86 | |
76253.86 | 113867.3 | 298664.47 | California | 118474.03 | |
78389.47 | 153773.43 | 299737.29 | New York | 111313.02 | |
73994.56 | 122782.75 | 303319.26 | Florida | 110352.25 | |
67532.53 | 105751.03 | 304768.73 | Florida | 108733.99 | |
77044.01 | 99281.34 | 140574.81 | New York | 108552.04 | |
64664.71 | 139553.16 | 137962.62 | California | 107404.34 | |
75328.87 | 144135.98 | 134050.07 | Florida | 105733.54 | |
72107.6 | 127864.55 | 353183.81 | New York | 105008.31 | |
66051.52 | 182645.56 | 118148.2 | Florida | 103282.38 | |
65605.48 | 153032.06 | 107138.38 | New York | 101004.64 | |
61994.48 | 115641.28 | 91131.24 | Florida | 99937.59 | |
61136.38 | 152701.92 | 88218.23 | New York | 97483.56 | |
63408.86 | 129219.61 | 46085.25 | California | 97427.84 | |
55493.95 | 103057.49 | 214634.81 | Florida | 96778.92 | |
46426.07 | 157693.92 | 210797.67 | California | 96712.8 | |
46014.02 | 85047.44 | 205517.64 | New York | 96479.51 | |
28663.76 | 127056.21 | 201126.82 | Florida | 90708.19 | |
44069.95 | 51283.14 | 197029.42 | California | 89949.14 | |
20229.59 | 65947.93 | 185265.1 | New York | 81229.06 | |
38558.51 | 82982.09 | 174999.3 | California | 81005.76 | |
28754.33 | 118546.05 | 172795.67 | California | 78239.91 | |
27892.92 | 84710.77 | 164470.71 | Florida | 77798.83 | |
23640.93 | 96189.63 | 148001.11 | California | 71498.49 | |
15505.73 | 127382.3 | 35534.17 | New York | 69758.98 | |
22177.74 | 154806.14 | 28334.72 | California | 65200.33 | |
1000.23 | 124153.04 | 1903.93 | New York | 64926.08 | |
1315.46 | 115816.21 | 297114.46 | Florida | 49490.75 | |
0 | 135426.92 | 0 | California | 42559.73 | |
542.05 | 51743.15 | 0 | New York | 35673.41 | |
0 | 116983.8 | 45173.06 | California | 14681.4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
FROM jmacivor/dotnet-binder:0.1.1 | |
ARG NB_USER=jovyan | |
ARG NB_UID=1000 | |
USER $NB_USER | |
ENV HOME=/home/$NB_USER | |
WORKDIR $HOME | |
COPY multiple_linear_regression.ipynb $HOME/multiple_linear_regression.ipynb | |
COPY 50_Startups.csv $HOME/50_Startups.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# ML.NET Multiple Linear Regression\n", | |
"\n", | |
"This notebook will swiftly take you to the best result I could manage with this data in ML.NET. \n", | |
"However, I encourage you to experiment with the code and see if you can do better! Try different regression trainers or try passing options to the regression trainers to improve their accuracy. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"// ML.NET Nuget packages installation\n", | |
"#r \"nuget:Microsoft.ML,1.5.0\"\n", | |
"#r \"nuget:Microsoft.ML.Mkl.Components,1.5.0\"\n", | |
"#r \"nuget:Microsoft.ML.FastTree,1.5.0\"\n", | |
"//Install XPlot package\n", | |
"#r \"nuget:XPlot.Plotly,2.0.0\" \n", | |
"using Microsoft.ML; \n", | |
"using Microsoft.ML.Data;\n", | |
"using Microsoft.ML.Trainers.FastTree;\n", | |
"using XPlot.Plotly;\n", | |
"using System;\n", | |
"using System.Linq;" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Installed package Microsoft.ML.Mkl.Components version 1.5.0" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Installed package Microsoft.ML version 1.5.0" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Installed package XPlot.Plotly version 2.0.0" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Installed package Microsoft.ML.FastTree version 1.5.0" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Load up the data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"public class ModelInput\n", | |
"{\n", | |
" [ColumnName(\"R&D Spend\"), LoadColumn(0)]\n", | |
" public float R_D_Spend { get; set; }\n", | |
"\n", | |
"\n", | |
" [ColumnName(\"Administration\"), LoadColumn(1)]\n", | |
" public float Administration { get; set; }\n", | |
"\n", | |
"\n", | |
" [ColumnName(\"Marketing Spend\"), LoadColumn(2)]\n", | |
" public float Marketing_Spend { get; set; }\n", | |
"\n", | |
"\n", | |
" [ColumnName(\"State\"), LoadColumn(3)]\n", | |
" public string State { get; set; }\n", | |
"\n", | |
"\n", | |
" [ColumnName(\"Profit\"), LoadColumn(4)]\n", | |
" public float Profit { get; set; }\n", | |
"\n", | |
"\n", | |
"}\n", | |
"\n", | |
"public class ModelOutput\n", | |
"{\n", | |
" public float Score { get; set; }\n", | |
"}" | |
], | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"var mlContext = new MLContext(seed: 1);\n", | |
"\n", | |
"IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(\n", | |
" path: @\"./50_Startups.csv\",\n", | |
" hasHeader: true,\n", | |
" separatorChar: ',',\n", | |
" allowQuoting: true,\n", | |
" allowSparse: false);\n", | |
"\n", | |
"var split = mlContext.Data.TrainTestSplit(trainingDataView, testFraction: 0.2);" | |
], | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Create the Pipeline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"var dataProcessPipeline = mlContext.Transforms.Categorical.OneHotEncoding(new[] { new InputOutputColumnPair(\"State\", \"State\") })\n", | |
" .Append(mlContext.Transforms.Concatenate(\"Features\", new[] { \"State\", \"R&D Spend\", \"Administration\", \"Marketing Spend\" }));\n", | |
"\n", | |
"var trainer = mlContext.Regression.Trainers.Ols(labelColumnName: \"Profit\", featureColumnName: \"Features\");\n", | |
"var trainingPipeline = dataProcessPipeline.Append(trainer);" | |
], | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Train the model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"Console.WriteLine(\"=============== Training model ===============\");\n", | |
"\n", | |
"ITransformer model = trainingPipeline.Fit(split.TrainSet); \n", | |
"\n", | |
"Console.WriteLine(\"=============== End of training process ===============\");" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "=============== Training model ===============\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "=============== End of training process ===============\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Evaluate" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"Console.WriteLine(\"=============== Cross-validating to get model's accuracy metrics ===============\");\n", | |
"var crossValidationResults = mlContext.Regression.CrossValidate(split.TrainSet, trainingPipeline, numberOfFolds: 5, labelColumnName: \"Profit\");\n", | |
"var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);\n", | |
"var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError);\n", | |
"var RMS = crossValidationResults.Select(r => r.Metrics.RootMeanSquaredError);\n", | |
"var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction);\n", | |
"var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);\n", | |
"\n", | |
"Console.WriteLine($\"*************************************************************************************************************\");\n", | |
"Console.WriteLine($\"* Metrics for Regression model \");\n", | |
"Console.WriteLine($\"*------------------------------------------------------------------------------------------------------------\");\n", | |
"Console.WriteLine($\"* Average L1 Loss: {L1.Average():0.###} \");\n", | |
"Console.WriteLine($\"* Average L2 Loss: {L2.Average():0.###} \");\n", | |
"Console.WriteLine($\"* Average RMS: {RMS.Average():0.###} \");\n", | |
"Console.WriteLine($\"* Average Loss Function: {lossFunction.Average():0.###} \");\n", | |
"Console.WriteLine($\"* Average R-squared: {R2.Average():0.###} \");\n", | |
"Console.WriteLine($\"*************************************************************************************************************\");" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "=============== Cross-validating to get model's accuracy metrics ===============\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "*************************************************************************************************************\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "* Metrics for Regression model \r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "*------------------------------------------------------------------------------------------------------------\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "* Average L1 Loss: 8034.512 \r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "* Average L2 Loss: 121276536.187 \r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "* Average RMS: 10470.8 \r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "* Average Loss Function: 121276539.908 \r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "* Average R-squared: 0.909 \r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "*************************************************************************************************************\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Create predictions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"source": [ | |
"var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(model);\n", | |
"var testSet = mlContext.Data.CreateEnumerable<ModelInput>(split.TestSet, reuseRowObject:false);\n", | |
"foreach (var ts in testSet) \n", | |
"{\n", | |
" var testInput = new ModelInput() \n", | |
" {\n", | |
" State = ts.State,\n", | |
" Marketing_Spend = ts.Marketing_Spend,\n", | |
" R_D_Spend = ts.R_D_Spend,\n", | |
" Administration = ts.Administration,\n", | |
" Profit = 0.0F\n", | |
" };\n", | |
" var prediction = predEngine.Predict(testInput).Score;\n", | |
" var actual = ts.Profit;\n", | |
" Console.WriteLine($\"Prediction: {prediction}, Actual: {actual}\");\n", | |
"}\n", | |
"" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Prediction: 179669.16, Actual: 191050.39\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Prediction: 127416.75, Actual: 141585.52\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Prediction: 126266.45, Actual: 134307.34\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Prediction: 100436.875, Actual: 103282.38\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Prediction: 96505.125, Actual: 96778.92\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Prediction: 88131.81, Actual: 96712.8\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Prediction: 90226.22, Actual: 96479.51\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Prediction: 88884.82, Actual: 89949.14\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": "Prediction: 46510.44, Actual: 42559.73\r\n" | |
}, | |
"execution_count": 1, | |
"metadata": {} | |
} | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".NET (C#)", | |
"language": "C#", | |
"name": ".net-csharp" | |
}, | |
"language_info": { | |
"file_extension": ".cs", | |
"mimetype": "text/x-csharp", | |
"name": "C#", | |
"pygments_lexer": "csharp", | |
"version": "8.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment