Skip to content

Instantly share code, notes, and snippets.

@khaotik
Created March 14, 2017 10:19
Show Gist options
  • Save khaotik/e648b990b9799dd05e002887573f0b02 to your computer and use it in GitHub Desktop.
Save khaotik/e648b990b9799dd05e002887573f0b02 to your computer and use it in GitHub Desktop.
circular convolution with generalized elemwise op
from itertools import product
from random import randint
from time import time
import numpy as np
import theano as th
T = th.tensor
from theano.tensor.padding import idx, at_idx
from theano.tensor.signal import conv
N = 256
KSIZE = 3
x = T.matrix()
ker = th.shared(np.random.rand(KSIZE, KSIZE).astype(th.config.floatX))
iy = idx(x, 0)
ix = idx(x, 1)
ny = T.shape(x)[0]
nx = T.shape(x)[1]
# circular 3x3 convolution on 2D tensor
mid = KSIZE//2
y_new = sum(at_idx(x, (iy+dy-mid)%ny, (ix+dx-mid)%nx)*ker[dy, dx] for dx, dy in product(range(KSIZE), range(KSIZE)))
xp = T.join(0, x[-mid:], x, x[:mid])
xp = T.join(1, xp[:, -mid:], xp, xp[:, :mid])
y_old = T.signal.conv.conv2d(xp, ker, border_mode='valid', image_shape=(N+mid*2, N+mid*2), filter_shape=(KSIZE, KSIZE))
fn_old = th.function([x], y_old)
fn_new = th.function([x], y_new)
xval = np.arange(N**2, dtype='float32').reshape(N, N)
# exlude the first call
fn_old(xval)
fn_new(xval)
beg = time()
for _ in range(100):
yval1 = fn_old(xval)
t_old = time() - beg
beg = time()
for _ in range(100):
yval1 = fn_new(xval)
t_new = time() - beg
print('old: %f new %f' % (t_old, t_new))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment