Skip to content

Instantly share code, notes, and snippets.

@hadarl
Last active October 14, 2015 14:35
Show Gist options
  • Save hadarl/38b1fda6d2297c894bc5 to your computer and use it in GitHub Desktop.
Save hadarl/38b1fda6d2297c894bc5 to your computer and use it in GitHub Desktop.
information bottleneck code
% calc_info_curve(pXY, betaVec)
pXY = [9:-1:1; 1:9]'
pXY = pXY/sum(pXY(:))
figure;
imagesc(pXY')
xlabel('x')
ylabel('y')
n = 30;
betaVec = logspace(0,3,n); % logspace ( power of min beta, power of max beta, number of beta values )
Info = zeros(n,2);
xDim =size(pXY,1);
p0Xhat_X = eye(xDim);
for i = n:-1:1
[pXhat_X, pY_Xhat] = IB(pXY,betaVec(i),p0Xhat_X);
[Info(i,1), Info(i,2)] = info_curve_point(pXhat_X, pY_Xhat,pXY);
p0Xhat_X = pXhat_X;
imagesc(pXhat_X)
pause(0.1);
end
figure;
plot(betaVec,Info(:,1))
title('I(X;Xhat)');
figure;
plot(betaVec,Info(:,2))
title('I(Xhat;Y)');
figure;
plot(Info(:,1),Info(:,2))
title('I(Xhat;Y) as a function of I(X;Xhat)');
function [DKL_sum] = DKL2(p,q)
p = p(:);
q = q(:);
DKL_sum = sum(p.*log2(p./q));
function [pXhat_X, pY_Xhat] = IB(pXY, beta, p0Xhat_X)
%{
%% Constructing the joint distribution p(x,y)
pXY = [9:-1:1; 1:9]'
[xDim, yDim] = size(pXY);
pXY = pXY/sum(pXY(:))
figure;
imagesc(pXY')
xlabel('x')
ylabel('y')
beta = 100;
p0Xhat_X = eye(xDim);
% p0Xhat_X = ones(xDim)/xDim;
%}
%%
[pXhat_X, pY_Xhat, L] = IB_iteration(pXY, beta, p0Xhat_X);
LIB = zeros(10,1);
curr_dif =100;
InfoXXhat = zeros(10,1);
InfoXhatY = zeros(10,1);
i=1;
while curr_dif>0.000001
LIB(i) = L;
[InfoXXhat(i), InfoXhatY(i)] = info_curve_point(pXhat_X, pY_Xhat,pXY);
Lprev = L;
pXhat_X_prev = pXhat_X;
[pXhat_X, pY_Xhat, L] = IB_iteration(pXY, beta, pXhat_X_prev);
curr_dif = abs(L-Lprev);
i = i+1;
end
LIB(i)=L;
[InfoXXhat(i), InfoXhatY(i)] = info_curve_point(pXhat_X, pY_Xhat,pXY);
%%
%{
figure;
plot(LIB(1:i))
figure;
plot(InfoXXhat(1:i))
hold on
plot(InfoXhatY(1:i),'k')
plot(InfoXXhat(1:i) - beta*InfoXhatY(1:i),'r')
%}
function [pXhat_X, pY_Xhat, L] = IB_iteration(pXY, beta, p0Xhat_X)
yDim = size(pXY,2);
xDim = size(pXY, 1);
pX = sum(pXY,length(size(pXY)));
p0Xhat = pX'*p0Xhat_X';
%Bayes:
p0X_Xhat = (p0Xhat_X.*repmat(pX',[xDim 1])./repmat(p0Xhat',[1 xDim]))';
pY_X = (pXY./repmat(pX,[1 yDim]))';
p0Y_Xhat = pY_X * p0X_Xhat;
DKL_X_Xhat = zeros(xDim);
for i = 1:xDim
for j = 1:xDim
DKL_X_Xhat(i,j) = DKL2(pY_X(:,i),p0Y_Xhat(:,j));
end
end
unnorm_pXhat_X = repmat(p0Xhat,[length(pX) 1]).*exp(-beta*DKL_X_Xhat);
% different rows correspond to different Xhats:
unnorm_pXhat_X = unnorm_pXhat_X';
ZX_beta = ones(length(pX)) * unnorm_pXhat_X;
pXhat_X = unnorm_pXhat_X./ZX_beta;
pXhat = pX'*pXhat_X';
pX_Xhat = (pXhat_X.*repmat(pX',[xDim 1])./repmat(pXhat',[1 xDim]))';
pY_Xhat = pY_X * pX_Xhat;
[IX_Xhat, IXhat_Y] = info_curve_point(pXhat_X, pY_Xhat, pXY);
L = IX_Xhat - beta*IXhat_Y;
function [IX_Xhat, IXhat_Y] = info_curve_point(pXhat_X, pY_Xhat, pXY)
[xDim, yDim] = size(pXY);
pX = sum(pXY,2);
pXhatX = pXhat_X.*repmat(pX',[xDim 1]);
pXhat = pX'*pXhat_X';
IX_Xhat = DKL2(pXhatX,pXhat'*pX');
pYXhat = pY_Xhat.*repmat(pXhat,[yDim 1]);
pY = pXhat*pY_Xhat';
IXhat_Y = DKL2(pYXhat,pY'*pXhat);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment