Skip to content

Instantly share code, notes, and snippets.

@ogrisel
Last active March 18, 2021 21:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save ogrisel/bd2cf3350fff5bbd5a0899fa6baf3267 to your computer and use it in GitHub Desktop.
Save ogrisel/bd2cf3350fff5bbd5a0899fa6baf3267 to your computer and use it in GitHub Desktop.
import numpy as np
from sklearn.utils.extmath import _incremental_mean_and_var
from sklearn.utils.sparsefuncs import _incr_mean_var_axis0
from sklearn.utils.sparsefuncs import _csc_mean_var_axis0
from sklearn.utils.sparsefuncs import _csr_mean_var_axis0
from scipy.sparse import csr_matrix
from scipy.sparse import csc_matrix
for dtype in [np.float64, np.float32]:
print(f"## dtype={dtype.__name__}")
rng = np.random.RandomState(0)
X = np.full(
shape=(1000, 1),
fill_value=100.,
dtype=dtype,
)
# Add some missing records which should be ignored:
missing_indices = rng.choice(np.arange(X.shape[0]), 10, replace=False)
X[missing_indices, 0] = np.nan
# Random positive weights:
sample_weight = rng.rand(X.shape[0]).astype(dtype)
mean, var, _ = _incremental_mean_and_var(
X,
last_mean=np.zeros(X.shape[1], dtype=dtype),
last_variance=np.zeros(X.shape[1], dtype=dtype),
last_sample_count=0,
sample_weight=sample_weight,
)
print(_incremental_mean_and_var.__name__, mean, var)
mean, var = _csr_mean_var_axis0(
csr_matrix(X),
weights=sample_weight,
)
print(_csr_mean_var_axis0.__name__, mean, var)
mean, var, _ = _incr_mean_var_axis0(
csr_matrix(X),
last_mean=np.zeros(X.shape[1], dtype=dtype),
last_var=np.zeros(X.shape[1], dtype=dtype),
last_n=np.zeros(shape=(1,), dtype=dtype),
weights=sample_weight,
)
print(_incr_mean_var_axis0.__name__, "csr", mean, var)
mean, var = _csc_mean_var_axis0(
csc_matrix(X),
weights=sample_weight,
)
print(_csc_mean_var_axis0.__name__, mean, var)
mean, var, _ = _incr_mean_var_axis0(
csc_matrix(X),
last_mean=np.zeros(X.shape[1], dtype=dtype),
last_var=np.zeros(X.shape[1], dtype=dtype),
last_n=np.zeros(shape=(1,), dtype=dtype),
weights=sample_weight,
)
print(_incr_mean_var_axis0.__name__, "csc", mean, var)
@ogrisel
Copy link
Author

ogrisel commented Feb 24, 2021

I have edited the gist to add some nan values and to print the value of the mean.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment