Skip to content

Instantly share code, notes, and snippets.

@jvlmdr
Last active January 13, 2016 09:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jvlmdr/9d0f214aca8f627dacde to your computer and use it in GitHub Desktop.
Save jvlmdr/9d0f214aca8f627dacde to your computer and use it in GitHub Desktop.
function benchmark()
test();
m1 = 32;
m2 = 48;
k = 24;
x = randn(m1, m2, k);
af = covar1(randn(m1, m2, k));
a = ifft2(af, 'symmetric');
t = time_func(@() mul1(af, x));
fprintf('mul1: %.3g\n', t);
t = time_func(@() mul2(af, x));
fprintf('mul2: %.3g\n', t);
t = time_func(@() blkdiagmul1(af, x));
fprintf('blkdiagmul1: %.3g\n', t);
ar = permute(af, [3 4 1 2]);
xr = permute(x, [3 1 2]);
t = time_func(@() blkdiagmul2(ar, xr));
fprintf('blkdiagmul2: %.3g\n', t);
ar = permute(af, [4 3 1 2]);
xr = permute(x, [3 1 2]);
t = time_func(@() blkdiagmul3(ar, xr));
fprintf('blkdiagmul3: %.3g\n', t);
end
function t = time_func(f)
% burn in
for i = 1:10
f();
end
n = 100;
tic();
for i = 1:n
f();
end
t = toc() / n;
end
function test()
m1 = 4;
m2 = 5;
k = 3;
x = randn(m1, m2, k);
af = covar1(randn(m1, m2, k));
a = ifft2(af, 'symmetric');
assert(equal(mul1(af, x), mul0(a, x)));
assert(equal(mul2(af, x), mul0(a, x)));
fprintf('tests passed\n');
end
function sf = covar1(x)
xf = fft2(x);
sf = bsxfun(@times, permute(conj(xf), [1, 2, 4, 3]), xf);
end
function sf = covar2(x)
% size(x) is [k, m1, m2]
xf = fft2(x);
sf = bsxfun(@times, permute(conj(xf), [1, 2, 4, 3]), xf);
end
function b = mul1(af, x)
[m1, m2, k] = size(x);
xf = fft2(x);
af = permute(af, [3, 4, 1, 2]);
xf = permute(xf, [3, 1, 2]);
bf = zeros(k, m1, m2);
for i = 1:m1*m2
bf(:,i) = af(:,:,i) * xf(:,i);
end
bf = permute(bf, [2, 3, 1]);
b = ifft2(bf, 'symmetric');
end
function b = mul2(af, x)
[m1, m2, k] = size(x);
xf = fft2(x);
bf = blkdiagmul1(af, xf);
b = ifft2(bf, 'symmetric');
end
function b = blkdiagmul1(a, x)
[m1, m2, k] = size(x);
b = sum(bsxfun(@times, a, permute(x, [1, 2, 4, 3])), 4);
b = reshape(b, [m1, m2, k]);
end
function b = blkdiagmul2(a, x)
[k, m1, m2] = size(x);
b = sum(bsxfun(@times, a, permute(x, [4, 1, 2, 3])), 2);
b = reshape(b, [k, m1, m2]);
end
function b = blkdiagmul3(a, x)
[k, m1, m2] = size(x);
b = sum(bsxfun(@times, a, permute(x, [1, 4, 2, 3])), 1);
b = reshape(b, [k, m1, m2]);
end
function b = mul0(a, x)
[m1, m2, k] = size(x);
b = zeros(m1, m2, k);
for u1 = 1:m1
for u2 = 1:m2
for p = 1:k
for t1 = 1:m1
for t2 = 1:m2
for q = 1:k
d1 = mod(u1-t1, m1);
d2 = mod(u2-t2, m2);
b(u1,u2,p) = b(u1,u2,p) + a(1+d1,1+d2,p,q) * x(t1,t2,q);
end
end
end
end
end
end
end
function c = equal(x, y)
if ndims(x) ~= ndims(y)
c = false;
return;
end
if not(all(size(x) == size(y)))
c = false;
return;
end
d = norm(x(:) - y(:));
c = (d <= 1e-9) || (d*norm(y(:)) <= 1e-6);
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment