Skip to content

Instantly share code, notes, and snippets.

@nelvadas
Last active March 15, 2024 13:46
Show Gist options
  • Save nelvadas/92ae8452cd3857e50fc8cef47635ac24 to your computer and use it in GitHub Desktop.
Save nelvadas/92ae8452cd3857e50fc8cef47635ac24 to your computer and use it in GitHub Desktop.
public class InsuranceCostPredictorApp
{
public static void main(String[] args ) throws IOException {
var regressionFactory = new RegressionFactory();
var csvLoader = new CSVLoader<>(regressionFactory);
//Load data
String[] insuranceHeaders = new String[]{"Age", "Gender", "BodyMassIndex", "Children", "Smoke", "Department", "PremiumInsurance"};
DataSource<Regressor> insuranceDataSource = csvLoader.loadDataSource(Paths.get("src/main/resources/insurance.csv"), "PremiumInsurance", insuranceHeaders);
//split train and test data
var splitter = new TrainTestSplitter<>(insuranceDataSource, 0.95, 0L);
Dataset<Regressor> trainingDataset = new MutableDataset<>(splitter.getTrain());
Dataset<Regressor> testingDataset = new MutableDataset<>(splitter.getTest());
//Create a linear Model
LinearSGDTrainer trainer = new LinearSGDTrainer(new SquaredLoss(), new AdaGrad(0.5), 5, 77);
Model<Regressor> model = trainer.train(trainingDataset);
//Evaluate Model on test DataSet
RegressionEvaluator evaluator = new RegressionEvaluator();
Regressor dimension0 = new Regressor("DIM-0", Double.NaN);
RegressionEvaluation score = evaluator.evaluate(model, testingDataset);
// Metrics
System.out.println("Testing Data Set");
System.out.println("Root Mean Squared Error :" + score.rmse(dimension0));
System.out.println("R-squared:" + Math.abs(score.r2(dimension0)));
//predict Test values
Regressor outputPlaceHolder = RegressionFactory.UNKNOWN_REGRESSOR;
String[] featureNames = {"Age","Gender","BodyMassIndex","Children","Smoke","Department"};
double[] featureValues = new double[]{25, 1.0,33,0,0,3};
Example<Regressor> sample = new ArrayExample<>(outputPlaceHolder,featureNames,featureValues);
Prediction<Regressor> prediction = model.predict(sample);
double result = prediction.getOutput().getValues()[0];
System.out.println("Predicted price =>"+result);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment