Last active
March 15, 2024 13:46
-
-
Save nelvadas/92ae8452cd3857e50fc8cef47635ac24 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class InsuranceCostPredictorApp | |
{ | |
public static void main(String[] args ) throws IOException { | |
var regressionFactory = new RegressionFactory(); | |
var csvLoader = new CSVLoader<>(regressionFactory); | |
//Load data | |
String[] insuranceHeaders = new String[]{"Age", "Gender", "BodyMassIndex", "Children", "Smoke", "Department", "PremiumInsurance"}; | |
DataSource<Regressor> insuranceDataSource = csvLoader.loadDataSource(Paths.get("src/main/resources/insurance.csv"), "PremiumInsurance", insuranceHeaders); | |
//split train and test data | |
var splitter = new TrainTestSplitter<>(insuranceDataSource, 0.95, 0L); | |
Dataset<Regressor> trainingDataset = new MutableDataset<>(splitter.getTrain()); | |
Dataset<Regressor> testingDataset = new MutableDataset<>(splitter.getTest()); | |
//Create a linear Model | |
LinearSGDTrainer trainer = new LinearSGDTrainer(new SquaredLoss(), new AdaGrad(0.5), 5, 77); | |
Model<Regressor> model = trainer.train(trainingDataset); | |
//Evaluate Model on test DataSet | |
RegressionEvaluator evaluator = new RegressionEvaluator(); | |
Regressor dimension0 = new Regressor("DIM-0", Double.NaN); | |
RegressionEvaluation score = evaluator.evaluate(model, testingDataset); | |
// Metrics | |
System.out.println("Testing Data Set"); | |
System.out.println("Root Mean Squared Error :" + score.rmse(dimension0)); | |
System.out.println("R-squared:" + Math.abs(score.r2(dimension0))); | |
//predict Test values | |
Regressor outputPlaceHolder = RegressionFactory.UNKNOWN_REGRESSOR; | |
String[] featureNames = {"Age","Gender","BodyMassIndex","Children","Smoke","Department"}; | |
double[] featureValues = new double[]{25, 1.0,33,0,0,3}; | |
Example<Regressor> sample = new ArrayExample<>(outputPlaceHolder,featureNames,featureValues); | |
Prediction<Regressor> prediction = model.predict(sample); | |
double result = prediction.getOutput().getValues()[0]; | |
System.out.println("Predicted price =>"+result); | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment