Skip to content

Instantly share code, notes, and snippets.

@khaotik
Last active November 29, 2016 17:08
Show Gist options
  • Save khaotik/9c7c72e562a919e241c9a09522fc4df5 to your computer and use it in GitHub Desktop.
Save khaotik/9c7c72e562a919e241c9a09522fc4df5 to your computer and use it in GitHub Desktop.
for investigating theano scan overhead
from __future__ import print_function
from time import time
import theano as th
import theano.tensor as T
import numpy as np
USE_PROFILER = True
if not USE_PROFILER:
USE_PROFILER = None
def timed_test(fn, n=10):
xval = np.random.randn(1000,10000).astype(np.float32)
t=time()
for _ in range(n):
fn(xval)
return time()-t
x = T.matrix()
y1, _ = th.scan(fn=lambda x : T.dot(x,x), sequences=x, profile=USE_PROFILER)
y2, _ = th.scan(fn=lambda x : T.sum(x*x), sequences=x, profile=USE_PROFILER)
fn_dot_scan = th.function([x], y1, profile=USE_PROFILER)
fn_sum_scan = th.function([x], y2, profile=USE_PROFILER)
y3 = T.sum(x*x, axis=-1)
fn_plain = th.function([x], y3, profile=USE_PROFILER)
print('Using device %s'%th.config.device)
print(' scan with dot | scan with sum | full unrolled')
for _ in range(10):
time1 = timed_test(fn_dot_scan)
time2 = timed_test(fn_sum_scan)
time3 = timed_test(fn_plain)
print('% 1.12f | %1.12f | %1.12f '%(time1, time2, time3))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment