Skip to content

Instantly share code, notes, and snippets.

@JesseLivezey
Last active June 13, 2016 20:22
Show Gist options
  • Save JesseLivezey/12e0e320960a58c278138d402f724fca to your computer and use it in GitHub Desktop.
Save JesseLivezey/12e0e320960a58c278138d402f724fca to your computer and use it in GitHub Desktop.
CorrMM OMP and BLAS timings script
#!/usr/bin/env python
import argparse, time, os, subprocess
import numpy as np
import theano
from theano.tensor.nnet.corr import (CorrMM, CorrMM_gradInputs,
CorrMM_gradWeights)
from theano.compat.python2x import OrderedDict
n_avg = 10
n_images = 128
image_shapes = [(n_images, 3, 128, 128), (n_images, 85, 2, 258)]
filter_shapes = [(96, 3, 5, 5), (64, 85, 2, 20)]
map_shapes = [(n_images, 96, 124, 124), (n_images, 64, 1, 239)]
shape_names = ['Imagenet-like', 'Spectrogram-like']
op_names = ['Legacy', 'CorrMM']
def time_op(data_type, direction):
if data_type == 'image':
imgshp = image_shapes[0]
filshp = filter_shapes[0]
mapshp = map_shapes[0]
else:
imgshp = image_shapes[1]
filshp = filter_shapes[1]
mapshp = map_shapes[1]
images = np.random.randn(*imgshp).astype('float32')
filters = np.random.randn(*filshp).astype('float32')
maps = np.random.randn(*mapshp).astype('float32')
images_sym = theano.shared(images)
filters_sym = theano.shared(filters)
maps_sym = theano.shared(maps)
updates = OrderedDict()
if direction == 'gradInputs':
op = CorrMM_gradInputs()
out = op(filters_sym, maps_sym)
updates[images_sym] = out
if direction == 'gradInputs':
op = CorrMM_gradWeights()
out = op(images_sym, maps_sym)
updates[filters_sym] = out
else:
op = CorrMM()
out = op(images_sym, filters_sym)
updates[maps_sym] = out
omp_threads = os.environ.get('OMP_NUM_THREADS', None)
mkl_threads = os.environ.get('MKL_NUM_THREADS', None)
f = theano.function([], [], updates=updates)
f()
start = time.time()
for ii in range(n_avg):
f()
avg_time = (time.time()-start)/float(n_avg)
return avg_time, mkl_threads, omp_threads
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Time conv w.r.t. threads.')
parser.add_argument('max_threads', type=int,
help='max number of threads')
args = parser.parse_args()
max_threads = args.max_threads
log_threads = int(np.log2(max_threads))
if pow(2, log_threads) == max_threads:
log_threads -= 1
threads = np.power(2, np.arange(1, log_threads+1)).tolist() + [max_threads]
omp_flags = """export OMP_NUM_THREADS={0}; """
blas_flags = """export MKL_NUM_THREADS={0}; """
theano_flags = """THEANO_FLAGS='floatX=float32,device=cpu,openmp={0}' """
base_cmd = """python -c 'from conv_timing import time_op; print(time_op({0}, {1}))'"""
def parse(rval):
rval = rval[1:-2]
rval = rval.split(',')
rval = [r.split("'") for r in rval]
results = [float(rval[0][0])]
for r in rval[1:]:
if r[0] == ' None':
results.append(None)
else:
results.append(int(r[1]))
return results
def print_results(blas, omp, base_t, t, notes):
print('{0: <10}{1: <10}{2: <10}{3: <10.2f}{4: <10}'.format(blas, omp,
int(t*1000), float(base_t)/t, notes))
data_types = ['"image"', '"spect"']
directions = ['"forward"', '"gradInputs"', '"gradWeights"']
for d in directions:
for ii, dt in enumerate(data_types):
print('Op: {}'.format(d.split('"')[1]))
print('Image shape, filter shape')
print('{} {}'.format(image_shapes[ii], filter_shapes[ii]))
print(shape_names[ii])
print('{0: <10}{1: <10}{2: <10}{3: <10}{4: <10}'.format('MKL', 'OpenMP', 'Time(ms)', 'Speedup', 'Notes'))
cmd = (blas_flags.format(max_threads) +
theano_flags.format(False)+base_cmd.format(dt, d))
base_notes = 'BLAS only, baseline'
rval = subprocess.check_output(cmd, shell=True)
base_t, base_blas, base_omp = parse(rval)
cmd = (blas_flags.format(1) +
theano_flags.format(False)+base_cmd.format(dt, d))
notes = 'single thread'
rval = subprocess.check_output(cmd, shell=True)
t, blas, omp = parse(rval)
print_results(blas, omp, base_t, t, notes)
print_results(base_blas, base_omp, base_t, base_t, base_notes)
cmd = (omp_flags.format(max_threads) +
theano_flags.format(True)+base_cmd.format(dt, d))
notes = 'only OMP set'
rval = subprocess.check_output(cmd, shell=True)
t, blas, omp = parse(rval)
print_results(blas, omp, base_t, t, notes)
first = True
for n in threads[:-1]:
cmd = (omp_flags.format(1) + blas_flags.format(n) +
theano_flags.format(False)+base_cmd.format(dt, d))
if first:
notes = 'BLAS only'
else:
notes = '"'
rval = subprocess.check_output(cmd, shell=True)
t, blas, omp = parse(rval)
print_results(blas, omp, base_t, t, notes)
for n in threads:
cmd = (omp_flags.format(n) + blas_flags.format(1) +
theano_flags.format(True)+base_cmd.format(dt, d))
if first:
notes = 'OMP only'
else:
notes = '"'
rval = subprocess.check_output(cmd, shell=True)
t, blas, omp = parse(rval)
print_results(blas, omp, base_t, t, notes)
print('\n')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment