Skip to content

Instantly share code, notes, and snippets.

@mfigurnov
Last active August 29, 2015 14:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mfigurnov/2863e30ce3f73a473d32 to your computer and use it in GitHub Desktop.
Save mfigurnov/2863e30ce3f73a473d32 to your computer and use it in GitHub Desktop.
function [ ret ] = my_nnpool( input, pool, varargin )
opts.stride = 1 ;
opts.pad = 0 ;
opts.method = 'max' ;
backMode = numel(varargin) > 0 && ~isstr(varargin{1}) ;
if backMode
dzdy = varargin{1} ;
if isstr(varargin{2}) && strcmpi(varargin{2}, 'verbose')
opts = vl_argparse(opts, varargin(3:end));
else
opts = vl_argparse(opts, varargin(2:end));
end
else
if numel(varargin) > 0 && isstr(varargin{1}) && strcmpi(varargin{1}, 'verbose')
opts = vl_argparse(opts, varargin(2:end));
else
opts = vl_argparse(opts, varargin);
end
end
if length(pool) == 1
windowHeight = pool;
windowWidth = pool;
elseif length(pool) == 2
windowHeight = pool(1);
windowWidth = pool(2);
else
error('SIZE has neither one nor two elements.');
end
height = size(input, 1);
width = size(input, 2);
D = size(input, 3);
N = size(input, 4);
if length(opts.stride) == 1
strideY = opts.stride;
strideX = opts.stride;
elseif length(opts.stride) == 2
strideY = opts.stride(1);
strideX = opts.stride(2);
else
error('STRIDE has neither one nor two elements.');
end
if strideX < 1 || strideY < 1
error('At least one element of STRIDE is smaller than one.');
end
if length(opts.pad) == 1
padTop = opts.pad;
padBottom = opts.pad;
padLeft = opts.pad;
padRight = opts.pad;
elseif length(opts.pad) == 4
padTop = opts.pad(1);
padBottom = opts.pad(2);
padLeft = opts.pad(3);
padRight = opts.pad(4);
else
error('PAD has neither one nor four elements.');
end
if height < windowHeight || width < windowWidth
error('Pooling SIZE is larger than the DATA.');
end
if windowHeight == 0 || windowWidth == 0
error('A dimension of the pooling SIZE is void.');
end
if strideX == 0 || strideY == 0
error('An element of STRIDE is zero.');
end
if padLeft < 0 || padRight < 0 || padTop < 0 || padBottom < 0
error('An element of PAD is negative.');
end
if padLeft >= windowWidth || padRight >= windowWidth || padTop >= windowHeight || padBottom >= windowHeight
error('A padding value is larger or equal than the size of the pooling window.');
end
if ~backMode
pooledWidth = floor((width + (padLeft + padRight) - windowWidth)/strideX) + 1 ;
pooledHeight = floor((height + (padTop + padBottom) - windowHeight)/strideY) + 1 ;
ret = zeros(pooledHeight, pooledWidth, D, N, 'single');
if strcmpi(opts.method, 'max')
for n = 1:N
for d = 1:D
for y = 1:pooledHeight
for x = 1:pooledWidth
x1 = (x-1) * strideX - padLeft + 1;
y1 = (y-1) * strideY - padTop + 1;
x2 = min(x1 + windowWidth - 1, width);
y2 = min(y1 + windowHeight - 1, height);
x1 = max(x1, 1);
y1 = max(y1, 1);
values = input(y1:y2, x1:x2, d, n);
bestValue = max(values(:));
ret(y, x, d, n) = bestValue;
end
end
end
end
elseif strcmpi(opts.method, 'avg')
for n = 1:N
for d = 1:D
for y = 1:pooledHeight
for x = 1:pooledWidth
x1 = (x-1) * strideX - padLeft + 1;
y1 = (y-1) * strideY - padTop + 1;
x2 = min(x1 + windowWidth - 1, width);
y2 = min(y1 + windowHeight - 1, height);
x1 = max(x1, 1);
y1 = max(y1, 1);
values = input(y1:y2, x1:x2, d, n);
avgValue = sum(values(:)) / ((y2 - y1 + 1) * (x2 - x1 + 1));
ret(y, x, d, n) = avgValue;
end
end
end
end
else
error('METHOD is not a supported method.');
end
else % backward mode
pooledHeight = size(dzdy, 1);
pooledWidth = size(dzdy, 2);
ret = zeros(height, width, D, N, 'single');
if strcmpi(opts.method, 'max')
for n = 1:N
for d = 1:D
for py = 1:pooledHeight
for px = 1:pooledWidth
x1 = (px-1) * strideX - padLeft + 1;
y1 = (py-1) * strideY - padTop + 1;
x2 = min(x1 + windowWidth - 1, width);
y2 = min(y1 + windowHeight - 1, height);
x1 = max(x1, 1);
y1 = max(y1, 1);
bestValue = input(y1, x1, d, n);
bestIndex = [y1 x1];
for y = y1:y2
for x = x1:x2
value = input(y, x, d, n);
if value > bestValue
bestValue = value;
bestIndex = [y x];
end
end
end
ret(bestIndex(1), bestIndex(2), d, n) = ...
ret(bestIndex(1), bestIndex(2), d, n) + dzdy(py, px, d, n);
end
end
end
end
elseif strcmpi(opts.method, 'avg')
for n = 1:N
for d = 1:D
for py = 1:pooledHeight
for px = 1:pooledWidth
x1 = (px-1) * strideX - padLeft + 1;
y1 = (py-1) * strideY - padTop + 1;
x2 = min(x1 + windowWidth - 1, width);
y2 = min(y1 + windowHeight - 1, height);
x1 = max(x1, 1);
y1 = max(y1, 1);
avgValue = dzdy(py, px, d, n) / ((y2 - y1 + 1) * (x2 - x1 + 1));
ret(y1:y2, x1:x2, d, n) = ret(y1:y2, x1:x2, d, n) + avgValue;
end
end
end
end
else
error('METHOD is not a supported method.');
end
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment