Skip to content

Instantly share code, notes, and snippets.

@tecno14
Created October 5, 2021 11:03
Show Gist options
  • Save tecno14/7fe0c8f7579e62988fe97b9e0b0cdf4a to your computer and use it in GitHub Desktop.
Save tecno14/7fe0c8f7579e62988fe97b9e0b0cdf4a to your computer and use it in GitHub Desktop.
StandardScaler in C# take list of class objects where have constructor
using System;
using System.Data;
using System.Linq;
using System.Collections.Generic;
using System.Reflection;
using System.ComponentModel;
namespace PricePrediction.MachineLearning
{
/// <summary>
/// Standardize features by removing the mean and scaling to unit variance.
/// more : https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html
/// </summary>
public class StandardScaler<T> where T : new()
{
private List<double> _mean;
private List<double> _standardDeviation;
/// <summary>
/// fit then transform
/// </summary>
/// <param name="dt"></param>
/// <returns></returns>
public List<T> FitTransform(List<T> listOfObjects)
{
return Fit(listOfObjects).Transform(listOfObjects);
}
/// <summary>
/// Reset then
/// </summary>
/// <param name="dt"></param>
/// <returns></returns>
public StandardScaler<T> Fit(List<T> listOfObjects)
{
_mean = new();
_standardDeviation = new();
if (listOfObjects.Count < 1)
throw new Exception("no data");
var dt = listOfObjects.ToArraysOfColumns<T>();
for (int i = 0; i < dt.Length; i++)
{
_mean.Add(dt[i].Average());
_standardDeviation.Add(Calculations.StandardDeviation(dt[i]));
}
return this;
}
/// <summary>
/// Get
/// </summary>
/// <param name="dt"></param>
/// <returns></returns>
public List<T> Transform(List<T> listOfObjects)
{
if (_mean == null)
throw new Exception("This StandardScaler instance is not fitted yet. Call 'Fit' with appropriate arguments before using this estimator.");
//if (dt.Columns.Count != _mean.Count)
// throw new Exception("number of fitted columns not same as current one");
var dt = listOfObjects.ToArraysOfColumns<T>();
for (int c = 0; c < dt.Length; c++)
for (int r = 0; r < dt[c].Length; r++)
dt[c][r] = (dt[c][r] - _mean[c]) / _standardDeviation[c];
return ToListOfObject(dt);
}
private static List<T> ToListOfObject(double[][] arr)
{
var res = new List<T>();
PropertyDescriptorCollection properties = TypeDescriptor.GetProperties(typeof(T));
var ObjectsCount = arr[0].Length;
for (int i = 0; i < ObjectsCount; i++)
{
T o = new();
for (int j = 0; j < properties.Count; j++)
properties[j].SetValue(o, arr[j][i]);
res.Add(o);
}
return res;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment