Skip to content

Instantly share code, notes, and snippets.

@GaelVaroquaux
Last active July 18, 2021 12:35
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save GaelVaroquaux/047d13d738d89ddcd4bc297edcd53233 to your computer and use it in GitHub Desktop.
Save GaelVaroquaux/047d13d738d89ddcd4bc297edcd53233 to your computer and use it in GitHub Desktop.
Linear deconfounding in a fit-transform API
"""
A scikit-learn like transformer to remove a confounding effect on X.
"""
from sklearn.base import BaseEstimator, TransformerMixin, clone
from sklearn.linear_model import LinearRegression
import numpy as np
class DeConfounder(BaseEstimator, TransformerMixin):
""" A transformer removing the effect of y on X.
"""
def __init__(self, confound_model=LinearRegression()):
self.confound_model = confound_model
def fit(self, X, y):
if y.ndim == 1:
y = y[:, np.newaxis]
confound_model = clone(self.confound_model)
confound_model.fit(y, X)
self.confound_model_ = confound_model
return self
def transform(self, X, y):
if y.ndim == 1:
y = y[:, np.newaxis]
X_confounds = self.confound_model_.predict(y)
return X - X_confounds
def test_deconfounder():
rng = np.random.RandomState(0)
# An in-sample test
X = rng.normal(size=(100, 10))
y = rng.normal(size=100)
deconfounder = DeConfounder()
deconfounder.fit(X, y)
X_clean = deconfounder.transform(X, y)
# Check that X_clean is indeed orthogonal to y
np.testing.assert_almost_equal(X_clean.T.dot(y), 0)
# An out-of-sample test
# Generate data where X is a linear function of y
y = rng.normal(size=100)
coef = rng.normal(size=10)
X = coef * y[:, np.newaxis]
X_train = X[:-10]
y_train = y[:-10]
deconfounder.fit(X_train, y_train)
X_clean = deconfounder.transform(X, y)
# Check that X_clean is indeed orthogonal to y
np.testing.assert_almost_equal(X_clean.T.dot(y), 0)
if __name__ == '__main__':
test_deconfounder()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment