Skip to content

Instantly share code, notes, and snippets.

@mstimberg
Created December 5, 2013 13:52
Show Gist options
  • Save mstimberg/7805469 to your computer and use it in GitHub Desktop.
Save mstimberg/7805469 to your computer and use it in GitHub Desktop.
Theano example: vector vs. scalar
import time
import numpy as np
import theano
import theano.tensor as tt
# constants
N = 1000000
freq = 1000.
tau = 20*0.001
dt = 0.1*0.001
b = 1.2
def get_theano_func_constant():
# freq is a constant
a = tt.dvector('a')
v = tt.dvector('v')
t = tt.dscalar('t')
return theano.function([a, v, t],
a*tt.sin(2.0*freq*np.pi*t) + b + v*tt.exp(-dt/tau) +
(-a*tt.sin(2.0*freq*np.pi*t) - b)*tt.exp(-dt/tau))
def get_theano_func_vector():
# freq is a vector
a = tt.dvector('a')
v = tt.dvector('v')
freq = tt.dvector('freq')
t = tt.dscalar('t')
return theano.function([a, v, freq, t],
a*tt.sin(2.0*freq*np.pi*t) + b + v*tt.exp(-dt/tau) +
(-a*tt.sin(2.0*freq*np.pi*t) - b)*tt.exp(-dt/tau))
# Test with freq as a scalar constant
v = np.random.randn(N)
a = np.linspace(.05, 0.75, N)
theano_func = get_theano_func_constant()
start = time.time()
for t in np.arange(0, 0.01, dt):
v[:] = theano_func(a, v, t)
stop = time.time()
print 'Theano (freq is a constant scalar): %.2fs' % (stop - start)
start = time.time()
for t in np.arange(0, 0.01, dt):
v[:] = (a*np.sin(2.0*freq*np.pi*t) + b + v*np.exp(-dt/tau) +
(-a*np.sin(2.0*freq*np.pi*t) - b)*np.exp(-dt/tau))
stop = time.time()
print 'numpy (freq is a constant scalar): %.2fs' % (stop - start)
# Test with freq as a vector
v = np.random.randn(N)
a = np.linspace(.05, 0.75, N)
freq = np.linspace(1000, 8000, N)
theano_func = get_theano_func_vector()
start = time.time()
for t in np.arange(0, 0.01, dt):
v[:] = theano_func(a, v, freq, t)
stop = time.time()
print 'Theano (freq is a vector): %.2fs' % (stop - start)
start = time.time()
for t in np.arange(0, 0.01, dt):
v[:] = (a*np.sin(2.0*freq*np.pi*t) + b + v*np.exp(-dt/tau) +
(-a*np.sin(2.0*freq*np.pi*t) - b)*np.exp(-dt/tau))
stop = time.time()
print 'numpy (freq is a vector): %.2fs' % (stop - start)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment