Skip to content

Instantly share code, notes, and snippets.

@mnarayan
Last active May 4, 2019 16:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mnarayan/733f386bab55653d3f9ec515ddfc280d to your computer and use it in GitHub Desktop.
Save mnarayan/733f386bab55653d3f9ec515ddfc280d to your computer and use it in GitHub Desktop.
Investigating Canonical Correlations

Summary

Canonical Correlation Analysis is a dimensionality reduction technique to find the subspace that maximizes the correlation between two sets of multivariate features X and Y that share the same number of rows or observations.

Since CCA is a supervised technique it is easy to obtain extremely high canonical correlations that might not generalize due to overfitting.

The script sample_canonical_correlations.m is designed to investigate out-of-sample canonical correlations. If one partitions the number of rows/observations into training and test sets, then one can

    1. do ordinary CCA on the training set
    1. use the canonical variates from the training set to obtain out-of-sample canonical correlations on the test set
    1. Compare in-sample vs. out-of-sample canonical correlations

In the Galaxy Combo-17 dataset consisting of 6 features capturing redshift (Y) and 22 features of luminosity and brightness (X) in 3438 galaxies. The maximum number of canonical variates is 6. Here is a comparison of in-sample and out-of-sample canonical correlations using 10-fold CV. As you can see the first canonical correlation is strong and has very low cv-error, while correlations 3-6 tend to be stronger in the training set than the test set. There is overoptimism in observed correlations even in this relatively low dimensional example with over 3000 observations and 6-22 features.

CV Error in Canonical Correlations

Usage

Use help sample_canonical_correlations to see an example of how to use the function.

function [cca_rho cca_v cca_cv] = sample_canonical_correlation(X,Y,varargin)
% SAMPLE_CANONICAL_CORRELATION
%
% Usage: [rho] = sample_canonical_correlation(X,Y, R_X, R_Y)
%
% Inputs:
% - X is the test set data matrix of n_samples x p features
% - Y is the test set data matrix of n_samples x r features
% - options.W_X is the linear projection matrix for X
% - options.W_Y is a linear projection matrix for Y
% - options.mu_x is the training mean of X
% - options.mu_y is the training mean of Y
%
% where,
% (X * W_X, Y * W_Y) produces upto t canonical variates,
% and t <= min(p,r)
%
% Description:
%
% W_X should be the eigenvectors of R_X
% R_X = Shat_XX^{-1/2} * Shat_XY * Shat_YY^{-1} * Shat_YX * Shat_XX^{-1/2}
%
% W_Y should be eigenvectors of R_Y
% R_Y = Shat_YY^{-1/2} * Shat_YX * Shat_XX^{-1} * Shat_XY * Shat_YY^{-1/2}
%
% The sample covariances (Shat_XX, Shat_YY, Shat_XY) are expected to have been obtained on the training data.
%
%
% References
% Modern Multivariate Statistics by Izenmen
%
% Copyright 2017, Manjari Narayan
%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Combo17-Galaxy Example
%%%%%%%%%%%%%%%%%%%%%%%%%%%
% if(exist('Combo17Galaxy.csv'))
% Combo17 = readtable('Combo17Galaxy.csv','ReadVariableNames',1);
% %Combo17 = table2array(Combo17);
% else
% websave('Combo17Galaxy.csv','https://astrostatistics.psu.edu/datasets/COMBO17.csv')
% end
% Yidx = [1 2 4 5 6 8 9]; Xidx = [10:2:16 30:2:65];
% exclude_idx = find(sum(isnan(table2array(Combo17)),2)>0)';
% include_idx = setdiff(1:height(Combo17),exclude_idx);
% n_samples = length(include_idx);
% X = table2array(Combo17(include_idx,Xidx));
% Y = table2array(Combo17(include_idx,Yidx));
% cvobj = cvpartition(n_samples,'Kfold',10);
%
% for foldNo=1:cvobj.NumTestSets
% cca_opts = {};
% cca_opts.mu_X = mean(X(cvobj.training(foldNo),:));
% cca_opts.mu_Y = mean(Y(cvobj.training(foldNo),:));
% [cca_opts.W_X cca_opts.W_Y rho_train{foldNo}] = ...
% canoncorr(X(cvobj.training(foldNo),:),Y(cvobj.training(foldNo),:));
% [cca_rho{foldNo} cca_v{foldNo} cca_cv{foldNo}] = ...
% sample_canonical_correlation(...
% X(cvobj.test(foldNo),:), ...
% Y(cvobj.test(foldNo),:), ...
% cca_opts);
% end
%% Compare in-sample correlations rho_train with out-of-sample correlations cca_rho
% %%%%%%%%%%%%%%%%%%%%%%%%%%%
options = {};
if nargin==2
options.mu_x = mean(X);
options.mu_y = mean(Y);
Shat_XX = cov(X);
Shatinv_XX = pinv(Shat_XX);
Shat_YY = cov(Y);
Shatinv_YY = pinv(Shat_YY);
Shat_XY = bsxfun(@minus, X, options.mu_X)' * ...
bsxfun(@minus,Y, options.mu_Y);
R_X = sqrtm(Shatinv_XX)*Shat_XY*Shatinv_YY*Shat_XY'*sqrtm(Shatinv_XX);
R_Y = sqrtm(Shatinv_YY)*Shat_XY'*Shatinv_XX*Shat_XY*sqrtm(Shatinv_YY);
t = min(size(X,2),size(Y,2));
[V D] = eig(R_X);
options.W_X = V(:,1:t);
[V D] = eig(R_Y);
options.W_Y = V(:,1:t);
else
options = varargin{1};
end
Xcenter = bsxfun(@minus, X, options.mu_X);
Ycenter = bsxfun(@minus, Y, options.mu_Y);
% Canonical Variates (U, V)
U_CCA = Xcenter * options.W_X;
V_CCA = Ycenter * options.W_Y;
cca_v = cat(3,U_CCA, V_CCA);
cca_sdX = sqrt(diag(U_CCA' * U_CCA));
cca_sdY = sqrt(diag(V_CCA' * V_CCA));
cca_cov = U_CCA' * V_CCA;
cca_rho = diag(inv(diag(cca_sdX)) * cca_cov * inv(diag(cca_sdY)));
disp('Top 3 canonical correlations')
disp(cca_rho(1:3))
cca_cv = [];
t = min([size(X,2) size(Y,2) size(options.W_X,2)]);
for tt=1:t
cca_cv(tt) = sum(cca_rho(1:tt).^2);
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment