Skip to content

Instantly share code, notes, and snippets.

@Piyush3dB
Created June 30, 2015 14:03
Show Gist options
  • Save Piyush3dB/01df75af9889414de1b6 to your computer and use it in GitHub Desktop.
Save Piyush3dB/01df75af9889414de1b6 to your computer and use it in GitHub Desktop.
Blahut-Arimoto algorithm implementation in Matlab
function [C r] = BlahutArimoto(p)
disp('BlahutArimoto')
% Capacity of discrete memoryless channel
% Blahut-Arimoto algorithm
% Input
% p: m x n matrix
% p is the transition matrix for a channel with m inputs and n outputs
%
% The input matrix p should contain no zero row and no zero column.
%
% p(i,j) is the condition probability that the channel output
% is j given that the input is i
% (i=1,2,...,m and j = 1,2,...,n)
%
%
% Output
% capacity : capacity in bits
% r: channel input distribution which achieves capacity
%
% For example, the transition matrix for the erasure channel is
% can be calculated as
% e = 0.5;
% p = [1-e e 0; 0 e 1-e]; % conditional prob. for erasure channel
% The capacity can be calculated by BlahutArimoto(p), and is equal to 1-e
%
% Check that the entries of input matrix p are non-negative
if ~isempty(find(p < 0))
disp('Error: some entry in the input matrix is negative')
C = 0; return;
end
% Check that the input matrix p does not have zero column
column_sum = sum(p);
if ~isempty(find(column_sum == 0))
disp('Error: there is a zero column in the input matrix');
C = 0; return;
end
% Check that the input matrix p does not have zero row
row_sum = sum(p,2);
if ~isempty(find(row_sum == 0))
disp('Error: there is a zero row in the input matrix');
C = 0; return;
else
p = diag(sum(p,2))^(-1) * p; % Make sure that the row sums are 1
end
[m n] = size(p);
r = ones(1,m)/m; % initial distribution for channel input
q = zeros(m,n);
error_tolerance = 1e-5/m;
r1 = [];
for i = 1:m
p(i,:) = p(i,:)/sum(p(i,:));
end
for iter = 1:10000
for j = 1:n
q(:,j) = r'.*p(:,j);
q(:,j) = q(:,j)/sum(q(:,j));
end
for i = 1:m
r1(i) = prod(q(i,:).^p(i,:));
end
r1 = r1/sum(r1);
if norm(r1 - r) < error_tolerance
break
else
r = r1;
end
end
C = 0;
for i = 1:m
for j = 1:n
if r(i) > 0 && q(i,j) > 0
C = C+ r(i)*p(i,j)* log(q(i,j)/r(i));
end
end
end
C = C/log(2); % Capacity in bits
@kumar-archit
Copy link

Isn't this 2 user case (where input alphabet is of size m and output alphabet is of size n)?

@HypeMachine
Copy link

in line 84 "C = C+ r(i)p(i,j) log(q(i,j)/r(i));" u must use log2() function instead log(), log() function returns value of ln

@suaybarslan
Copy link

suaybarslan commented Jul 11, 2022

isnt the line 69 supposed to be r1(i) = r(i)*prod((p(i,:)./sum(r'.*p(:,i))).^p(i,:));

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment