Skip to content

Instantly share code, notes, and snippets.

@jasmainak
Forked from wmvanvliet/haufe.ipynb
Created July 14, 2017 09:25
Show Gist options
  • Save jasmainak/b1c666773a7bd4a5f5466e3fe2852c7d to your computer and use it in GitHub Desktop.
Save jasmainak/b1c666773a7bd4a5f5466e3fe2852c7d to your computer and use it in GitHub Desktop.
Testing the Haufe trick
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test the Haufe trick"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from __future__ import print_function\n",
"from sklearn.linear_model import LinearRegression\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook tests the computation of patterns from weights, following equation (6) of Haufe et al. 2014 (Neuroimage).\n",
"\n",
"$ A = \\Sigma_X\\, W\\, \\Sigma_{\\,\\hat{Y}}^{-1} $\n",
"\n",
"(The original equation uses $\\hat{s}$ instead of $\\hat{Y}$, but we'll be following scikit-learn's notation here.)\n",
"\n",
"\n",
"We will test by generating test data with a known \"pattern\" (i.e. forward model), that satisfies all the necessary assumptions for the equation to work."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Fix random seed for consistency\n",
"np.random.seed(42)\n",
"\n",
"\n",
"def _gen_data(noise_scale=2):\n",
" \"\"\"Generate some testing data.\n",
"\n",
" Parameters\n",
" ----------\n",
" noise_scale : float\n",
" The amount of noise (in standard deviations) to add to the data.\n",
"\n",
" Returns\n",
" -------\n",
" X : ndarray, shape (n_samples, n_features)\n",
" The measured data.\n",
" Y : ndarray, shape (n_samples, n_targets)\n",
" The latent variables generating the data.\n",
" A : ndarray, shape (n_features, n_targets)\n",
" The forward model, mapping the latent variables (=Y) to the measured\n",
" data (=X).\n",
" \"\"\"\n",
" N = 1000 # Number of samples\n",
" M = 5 # Number of features\n",
"\n",
" # Y has 3 targets and the following covariance:\n",
" cov_Y = np.array([\n",
" [10, 1, 2],\n",
" [1, 5, 1],\n",
" [2, 1, 3],\n",
" ])\n",
" mean_Y = np.array([1, -3, 7])\n",
" Y = np.random.multivariate_normal(mean_Y, cov_Y, size=N)\n",
" Y += [1, 4, 2] # Put an offset\n",
"\n",
" # The pattern (=forward model)\n",
" A = np.array([\n",
" [1, 10, -3],\n",
" [4, 1, 8],\n",
" [3, -2, 4],\n",
" [1, 1, 1],\n",
" [7, 6, 0],\n",
" ]).astype(float)\n",
"\n",
" X = Y.dot(A.T)\n",
" X += noise_scale * np.random.randn(N, M)\n",
" X += [5, 2, 6, 3, 9] # Put an offset\n",
"\n",
" return X, Y, A"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# Generate data without any noise, so we can perfectly reconstruct the pattern A from the measured data\n",
"X, Y, A = _gen_data(noise_scale=0)\n",
"\n",
"# This data is not normalized (i.e. not zero-mean and not unit variance)\n",
"assert (np.abs(X.mean(axis=0)) > 0.1).all()\n",
"assert (X.std(axis=0) > 1.1).all()\n",
"assert (np.abs(Y.mean(axis=0)) > 0.1).all()\n",
"assert (Y.std(axis=0) > 1.1).all()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now proceed by fitting a standard linear regression model to the data and applying the equation."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"model = LinearRegression(normalize=False, fit_intercept=True).fit(X, Y)\n",
"W = model.coef_\n",
"\n",
"def haufe_trick(W, X, Y):\n",
" \"\"\"Perform the Haufe trick.\"\"\"\n",
" # Computing the covariance of X and Y involves removing the mean\n",
" X_ = X - X.mean(axis=0)\n",
" Y_ = Y - Y.mean(axis=0)\n",
" cov_X = X_.T.dot(X_)\n",
" cov_Y = Y_.T.dot(Y_)\n",
"\n",
" # The Haufe trick\n",
" A_hat = cov_X.dot(W.T).dot(np.linalg.pinv(cov_Y))\n",
" return A_hat\n",
"\n",
"A_hat = haufe_trick(W, X, Y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"How did we do?"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Real pattern\n",
"[[ 1. 10. -3.]\n",
" [ 4. 1. 8.]\n",
" [ 3. -2. 4.]\n",
" [ 1. 1. 1.]\n",
" [ 7. 6. 0.]]\n",
"Estimated pattern\n",
"[[ 1.00000000e+00 1.00000000e+01 -3.00000000e+00]\n",
" [ 4.00000000e+00 1.00000000e+00 8.00000000e+00]\n",
" [ 3.00000000e+00 -2.00000000e+00 4.00000000e+00]\n",
" [ 1.00000000e+00 1.00000000e+00 1.00000000e+00]\n",
" [ 7.00000000e+00 6.00000000e+00 -8.88178420e-15]]\n",
"Are they equal? True\n"
]
}
],
"source": [
"print('Real pattern')\n",
"print(A)\n",
"print('Estimated pattern')\n",
"print(A_hat)\n",
"print('Are they equal?', np.allclose(A, A_hat))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The thing about normalization\n",
"\n",
"In the above case, the linear regression model did not normalize the data. What happens if we do?"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Real pattern\n",
"[[ 1. 10. -3.]\n",
" [ 4. 1. 8.]\n",
" [ 3. -2. 4.]\n",
" [ 1. 1. 1.]\n",
" [ 7. 6. 0.]]\n",
"Estimated pattern\n",
"[[ 1.00000000e+00 1.00000000e+01 -3.00000000e+00]\n",
" [ 4.00000000e+00 1.00000000e+00 8.00000000e+00]\n",
" [ 3.00000000e+00 -2.00000000e+00 4.00000000e+00]\n",
" [ 1.00000000e+00 1.00000000e+00 1.00000000e+00]\n",
" [ 7.00000000e+00 6.00000000e+00 1.24344979e-14]]\n",
"Are they equal? True\n"
]
}
],
"source": [
"model = LinearRegression(normalize=True, fit_intercept=True).fit(X, Y)\n",
"W = model.coef_\n",
"A_hat = haufe_trick(W, X, Y)\n",
"\n",
"print('Real pattern')\n",
"print(A)\n",
"print('Estimated pattern')\n",
"print(A_hat)\n",
"print('Are they equal?', np.allclose(A, A_hat))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It still works! That is because scikit-learn reverses the normalization when storing the final filter weights."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment